From b6909fc91d5d16ced108700cbd5cd8a13481a6c4 Mon Sep 17 00:00:00 2001 From: Anoop C S Date: Tue, 28 Mar 2017 07:13:47 +0000 Subject: swrap: Implement thread safety using pthread mutexes Added a new mutex variable to socket_info structure along with new macros for locking and unlocking mutex corresponding to each socket_info entry. Apart from individual mutex defined in socket_info structure, 4 new mutexes are added to protect the concurrent access of globally used swrap parameters from different threads. All other individual wrappers and helper routines are also made capable of acquiring relevant mutex locks before operating on such global parameters. Pair-Programmed-With: Michael Adam Signed-off-by: Anoop C S Signed-off-by: Michael Adam Reviewed-by: Andreas Schneider --- src/socket_wrapper.c | 273 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 235 insertions(+), 38 deletions(-) (limited to 'src') diff --git a/src/socket_wrapper.c b/src/socket_wrapper.c index 9c16e1a..6b67224 100644 --- a/src/socket_wrapper.c +++ b/src/socket_wrapper.c @@ -188,7 +188,17 @@ enum swrap_dbglvl_e { #define SOCKET_INFO_CONTAINER(si) \ (struct socket_info_container *)(si) -#define SWRAP_DLIST_ADD(list,item) do { \ +#define SWRAP_LOCK_SI(si) do { \ + struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \ + pthread_mutex_lock(&sic->mutex); \ +} while(0) + +#define SWRAP_UNLOCK_SI(si) do { \ + struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \ + pthread_mutex_unlock(&sic->mutex); \ +} while(0) + +#define DLIST_ADD(list, item) do { \ if (!(list)) { \ (item)->prev = NULL; \ (item)->next = NULL; \ @@ -201,7 +211,13 @@ enum swrap_dbglvl_e { } \ } while (0) -#define SWRAP_DLIST_REMOVE(list,item) do { \ +#define SWRAP_DLIST_ADD(list, item) do { \ + SWRAP_LOCK(list); \ + DLIST_ADD(list, item); \ + SWRAP_UNLOCK(list); \ +} while (0) + +#define DLIST_REMOVE(list, item) do { \ if ((list) == (item)) { \ (list) = (item)->next; \ if (list) { \ @@ -219,10 +235,15 @@ enum swrap_dbglvl_e { (item)->next = NULL; \ } while (0) -#define SWRAP_DLIST_ADD_AFTER(list, item, el) \ -do { \ +#define SWRAP_DLIST_REMOVE(list,item) do { \ + SWRAP_LOCK(list); \ + DLIST_REMOVE(list, item); \ + SWRAP_UNLOCK(list); \ +} while (0) + +#define DLIST_ADD_AFTER(list, item, el) do { \ if ((list) == NULL || (el) == NULL) { \ - SWRAP_DLIST_ADD(list, item); \ + DLIST_ADD(list, item); \ } else { \ (item)->prev = (el); \ (item)->next = (el)->next; \ @@ -233,6 +254,12 @@ do { \ } \ } while (0) +#define SWRAP_DLIST_ADD_AFTER(list, item, el) do { \ + SWRAP_LOCK(list); \ + DLIST_ADD_AFTER(list, item, el); \ + SWRAP_UNLOCK(list); \ +} while (0) + #if defined(HAVE_GETTIMEOFDAY_TZ) || defined(HAVE_GETTIMEOFDAY_TZ_VOID) #define swrapGetTimeOfDay(tval) gettimeofday(tval,NULL) #else @@ -333,6 +360,7 @@ struct socket_info_container struct socket_info info; unsigned int refcount; int next_free; + pthread_mutex_t mutex; }; static struct socket_info_container *sockets; @@ -348,6 +376,26 @@ static struct socket_info_fd *socket_fds; /* The mutex for accessing the global libc.symbols */ static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER; +/* The mutex for syncronizing the port selection during swrap_auto_bind() */ +static pthread_mutex_t autobind_start_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Global mutex to guard the initialization of array of socket_info structures. + */ +static pthread_mutex_t sockets_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Global mutex to protect modification of the socket_fds linked + * list structure by different threads within a process. + */ +static pthread_mutex_t socket_fds_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Global mutex to synchronize the query for first free index in array of + * socket_info structures by different threads within a process. + */ +static pthread_mutex_t first_free_mutex = PTHREAD_MUTEX_INITIALIZER; + /* Function prototypes */ bool socket_wrapper_enabled(void); @@ -1321,7 +1369,10 @@ static void socket_wrapper_init_sockets(void) { size_t i; + SWRAP_LOCK(sockets); + if (sockets != NULL) { + SWRAP_UNLOCK(sockets); return; } @@ -1333,17 +1384,24 @@ static void socket_wrapper_init_sockets(void) if (sockets == NULL) { SWRAP_LOG(SWRAP_LOG_ERROR, "Failed to allocate sockets array.\n"); + SWRAP_UNLOCK(sockets); exit(-1); } + SWRAP_LOCK(first_free); + first_free = 0; for (i = 0; i < max_sockets; i++) { swrap_set_next_free(&sockets[i].info, i+1); + sockets[i].mutex = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER; } /* mark the end of the free list */ swrap_set_next_free(&sockets[max_sockets-1].info, -1); + + SWRAP_UNLOCK(first_free); + SWRAP_UNLOCK(sockets); } bool socket_wrapper_enabled(void) @@ -1377,25 +1435,33 @@ static unsigned int socket_wrapper_default_iface(void) static int swrap_add_socket_info(struct socket_info *si_input) { struct socket_info *si = NULL; - int si_index; + int si_index = -1; if (si_input == NULL) { errno = EINVAL; return -1; } + SWRAP_LOCK(first_free); if (first_free == -1) { errno = ENFILE; - return -1; + goto out; } si_index = first_free; si = swrap_get_socket_info(si_index); + SWRAP_LOCK_SI(si); + first_free = swrap_get_next_free(si); *si = *si_input; swrap_inc_refcount(si); + SWRAP_UNLOCK_SI(si); + +out: + SWRAP_UNLOCK(first_free); + return si_index; } @@ -1795,13 +1861,17 @@ static struct socket_info_fd *find_socket_info_fd(int fd) { struct socket_info_fd *f; + SWRAP_LOCK(socket_fds); + for (f = socket_fds; f; f = f->next) { if (f->fd == fd) { - return f; + break; } } - return NULL; + SWRAP_UNLOCK(socket_fds); + + return f; } static int find_socket_info_index(int fd) @@ -1934,6 +2004,9 @@ static void swrap_remove_stale(int fd) si = swrap_get_socket_info(si_index); + SWRAP_LOCK(first_free); + SWRAP_LOCK_SI(si); + SWRAP_DLIST_REMOVE(socket_fds, fi); swrap_set_next_free(si, first_free); @@ -1943,12 +2016,16 @@ static void swrap_remove_stale(int fd) free(fi); if (swrap_get_refcount(si) > 0) { - return; + goto out; } if (si->un_addr.sun_path[0] != '\0') { unlink(si->un_addr.sun_path); } + +out: + SWRAP_UNLOCK_SI(si); + SWRAP_UNLOCK(first_free); } static int sockaddr_convert_to_un(struct socket_info *si, @@ -3101,16 +3178,26 @@ static int swrap_accept(int s, #endif } + + /* + * prevent parent_si from being altered / closed + * while we read it + */ + SWRAP_LOCK_SI(parent_si); + /* * assume out sockaddr have the same size as the in parent * socket family */ in_addr.sa_socklen = socket_length(parent_si->family); if (in_addr.sa_socklen <= 0) { + SWRAP_UNLOCK_SI(parent_si); errno = EINVAL; return -1; } + SWRAP_UNLOCK_SI(parent_si); + #ifdef HAVE_ACCEPT4 ret = libc_accept4(s, &un_addr.sa.s, &un_addr.sa_socklen, flags); #else @@ -3127,6 +3214,8 @@ static int swrap_accept(int s, fd = ret; + SWRAP_LOCK_SI(parent_si); + ret = sockaddr_convert_from_un(parent_si, &un_addr.sa.un, un_addr.sa_socklen, @@ -3134,6 +3223,7 @@ static int swrap_accept(int s, &in_addr.sa.s, &in_addr.sa_socklen); if (ret == -1) { + SWRAP_UNLOCK_SI(parent_si); close(fd); return ret; } @@ -3147,6 +3237,8 @@ static int swrap_accept(int s, child_si->is_server = 1; child_si->connected = 1; + SWRAP_UNLOCK_SI(parent_si); + child_si->peername = (struct swrap_address) { .sa_socklen = in_addr.sa_socklen, }; @@ -3197,9 +3289,11 @@ static int swrap_accept(int s, if (addr != NULL) { struct socket_info *si = swrap_get_socket_info(idx); + SWRAP_LOCK_SI(si); swrap_pcap_dump_packet(si, addr, SWRAP_ACCEPT_SEND, NULL, 0); swrap_pcap_dump_packet(si, addr, SWRAP_ACCEPT_RECV, NULL, 0); swrap_pcap_dump_packet(si, addr, SWRAP_ACCEPT_ACK, NULL, 0); + SWRAP_UNLOCK_SI(si); } return fd; @@ -3241,6 +3335,8 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) int port; struct stat st; + SWRAP_LOCK(autobind_start); + if (autobind_start_init != 1) { autobind_start_init = 1; autobind_start = getpid(); @@ -3359,6 +3455,7 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) ret = 0; done: + SWRAP_UNLOCK(autobind_start); return ret; } @@ -3380,6 +3477,8 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, return libc_connect(s, serv_addr, addrlen); } + SWRAP_LOCK_SI(si); + if (si->bound == 0) { ret = swrap_auto_bind(s, si, serv_addr->sa_family); if (ret == -1) { @@ -3463,6 +3562,7 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, } done: + SWRAP_UNLOCK_SI(si); return ret; } @@ -3491,6 +3591,8 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) return libc_bind(s, myaddr, addrlen); } + SWRAP_LOCK_SI(si); + switch (si->family) { case AF_INET: { const struct sockaddr_in *sin; @@ -3538,14 +3640,16 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) if (bind_error != 0) { errno = bind_error; - return -1; + ret = -1; + goto out; } #if 0 /* FIXME */ in_use = check_addr_port_in_use(myaddr, addrlen); if (in_use) { errno = EADDRINUSE; - return -1; + ret = -1; + goto out; } #endif @@ -3559,7 +3663,7 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) 1, &si->bcast); if (ret == -1) { - return -1; + goto out; } unlink(un_addr.sa.un.sun_path); @@ -3574,6 +3678,9 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) si->bound = 1; } +out: + SWRAP_UNLOCK_SI(si); + return ret; } @@ -3677,16 +3784,21 @@ static int swrap_listen(int s, int backlog) return libc_listen(s, backlog); } + SWRAP_LOCK_SI(si); + if (si->bound == 0) { ret = swrap_auto_bind(s, si, si->family); if (ret == -1) { errno = EADDRINUSE; - return ret; + goto out; } } ret = libc_listen(s, backlog); +out: + SWRAP_UNLOCK_SI(si); + return ret; } @@ -3853,26 +3965,34 @@ static int swrap_getpeername(int s, struct sockaddr *name, socklen_t *addrlen) { struct socket_info *si = find_socket_info(s); socklen_t len; + int ret = -1; if (!si) { return libc_getpeername(s, name, addrlen); } + SWRAP_LOCK_SI(si); + if (si->peername.sa_socklen == 0) { errno = ENOTCONN; - return -1; + goto out; } len = MIN(*addrlen, si->peername.sa_socklen); if (len == 0) { - return 0; + ret = 0; + goto out; } memcpy(name, &si->peername.sa.ss, len); *addrlen = si->peername.sa_socklen; - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } #ifdef HAVE_ACCEPT_PSOCKLEN_T @@ -3892,20 +4012,28 @@ static int swrap_getsockname(int s, struct sockaddr *name, socklen_t *addrlen) { struct socket_info *si = find_socket_info(s); socklen_t len; + int ret = -1; if (!si) { return libc_getsockname(s, name, addrlen); } + SWRAP_LOCK_SI(si); + len = MIN(*addrlen, si->myname.sa_socklen); if (len == 0) { - return 0; + ret = 0; + goto out; } memcpy(name, &si->myname.sa.ss, len); *addrlen = si->myname.sa_socklen; - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } #ifdef HAVE_ACCEPT_PSOCKLEN_T @@ -3941,6 +4069,8 @@ static int swrap_getsockopt(int s, int level, int optname, optlen); } + SWRAP_LOCK_SI(si); + if (level == SOL_SOCKET) { switch (optname) { #ifdef SO_DOMAIN @@ -4023,6 +4153,7 @@ static int swrap_getsockopt(int s, int level, int optname, ret = -1; done: + SWRAP_UNLOCK_SI(si); return ret; } @@ -4061,6 +4192,8 @@ static int swrap_setsockopt(int s, int level, int optname, optlen); } + SWRAP_LOCK_SI(si); + if (level == IPPROTO_TCP) { switch (optname) { #ifdef TCP_NODELAY @@ -4125,6 +4258,7 @@ static int swrap_setsockopt(int s, int level, int optname, } done: + SWRAP_UNLOCK_SI(si); return ret; } @@ -4149,6 +4283,8 @@ static int swrap_vioctl(int s, unsigned long int r, va_list va) return libc_vioctl(s, r, va); } + SWRAP_LOCK_SI(si); + va_copy(ap, va); rc = libc_vioctl(s, r, va); @@ -4167,6 +4303,7 @@ static int swrap_vioctl(int s, unsigned long int r, va_list va) va_end(ap); + SWRAP_UNLOCK_SI(si); return rc; } @@ -4472,7 +4609,7 @@ static ssize_t swrap_sendmsg_before(int fd, int *bcast) { size_t i, len = 0; - ssize_t ret; + ssize_t ret = -1; if (to_un) { *to_un = NULL; @@ -4484,13 +4621,15 @@ static ssize_t swrap_sendmsg_before(int fd, *bcast = 0; } + SWRAP_LOCK_SI(si); + switch (si->type) { case SOCK_STREAM: { unsigned long mtu; if (!si->connected) { errno = ENOTCONN; - return -1; + goto out; } if (msg->msg_iovlen == 0) { @@ -4533,14 +4672,14 @@ static ssize_t swrap_sendmsg_before(int fd, if (msg_name == NULL) { errno = ENOTCONN; - return -1; + goto out; } ret = sockaddr_convert_to_un(si, msg_name, msg->msg_namelen, tmp_un, 0, bcast); if (ret == -1) { - return -1; + goto out; } if (to_un) { @@ -4558,11 +4697,11 @@ static ssize_t swrap_sendmsg_before(int fd, if (ret == -1) { if (errno == ENOTSOCK) { swrap_remove_stale(fd); - return -ENOTSOCK; + ret = -ENOTSOCK; } else { SWRAP_LOG(SWRAP_LOG_ERROR, "swrap_sendmsg_before failed"); - return -1; } + goto out; } } @@ -4577,7 +4716,7 @@ static ssize_t swrap_sendmsg_before(int fd, 0, NULL); if (ret == -1) { - return -1; + goto out; } ret = libc_connect(fd, @@ -4590,14 +4729,14 @@ static ssize_t swrap_sendmsg_before(int fd, } if (ret == -1) { - return ret; + goto out; } si->defer_connect = 0; break; default: errno = EHOSTUNREACH; - return -1; + goto out; } #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL @@ -4608,7 +4747,7 @@ static ssize_t swrap_sendmsg_before(int fd, ret = swrap_sendmsg_filter_cmsghdr(msg, &cmbuf, &cmlen); if (ret < 0) { free(cmbuf); - return -1; + goto out; } if (cmlen == 0) { @@ -4622,7 +4761,11 @@ static ssize_t swrap_sendmsg_before(int fd, } #endif - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } static void swrap_sendmsg_after(int fd, @@ -4676,6 +4819,8 @@ static void swrap_sendmsg_after(int fd, } len = ofs; + SWRAP_LOCK_SI(si); + switch (si->type) { case SOCK_STREAM: if (ret == -1) { @@ -4699,6 +4844,8 @@ static void swrap_sendmsg_after(int fd, break; } + SWRAP_UNLOCK_SI(si); + free(buf); errno = saved_errno; } @@ -4709,7 +4856,9 @@ static int swrap_recvmsg_before(int fd, struct iovec *tmp_iov) { size_t i, len = 0; - ssize_t ret; + int ret = -1; + + SWRAP_LOCK_SI(si); (void)fd; /* unused */ @@ -4718,7 +4867,7 @@ static int swrap_recvmsg_before(int fd, unsigned int mtu; if (!si->connected) { errno = ENOTCONN; - return -1; + goto out; } if (msg->msg_iovlen == 0) { @@ -4746,7 +4895,7 @@ static int swrap_recvmsg_before(int fd, case SOCK_DGRAM: if (msg->msg_name == NULL) { errno = EINVAL; - return -1; + goto out; } if (msg->msg_iovlen == 0) { @@ -4764,21 +4913,25 @@ static int swrap_recvmsg_before(int fd, */ if (errno == ENOTSOCK) { swrap_remove_stale(fd); - return -ENOTSOCK; + ret = -ENOTSOCK; } else { SWRAP_LOG(SWRAP_LOG_ERROR, "swrap_recvmsg_before failed"); - return -1; } + goto out; } } break; default: errno = EHOSTUNREACH; - return -1; + goto out; } - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } static int swrap_recvmsg_after(int fd, @@ -4810,6 +4963,8 @@ static int swrap_recvmsg_after(int fd, avail += msg->msg_iov[i].iov_len; } + SWRAP_LOCK_SI(si); + /* Convert the socket address before we leave */ if (si->type == SOCK_DGRAM && un_addr != NULL) { rc = sockaddr_convert_from_un(si, @@ -4838,6 +4993,7 @@ static int swrap_recvmsg_after(int fd, buf = (uint8_t *)malloc(remain); if (buf == NULL) { /* we just not capture the packet */ + SWRAP_UNLOCK_SI(si); errno = saved_errno; return -1; } @@ -4895,11 +5051,13 @@ done: msg->msg_control != NULL) { rc = swrap_msghdr_add_socket_info(si, msg); if (rc < 0) { + SWRAP_UNLOCK_SI(si); return -1; } } #endif + SWRAP_UNLOCK_SI(si); return rc; } @@ -5071,11 +5229,16 @@ static ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, un_addr.sa_socklen); } + SWRAP_LOCK_SI(si); + swrap_pcap_dump_packet(si, to, SWRAP_SENDTO, buf, len); + SWRAP_UNLOCK_SI(si); + return len; } + SWRAP_LOCK_SI(si); /* * If it is a dgram socket and we are connected, don't include the * 'to' address. @@ -5096,6 +5259,8 @@ static ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, msg.msg_namelen); } + SWRAP_UNLOCK_SI(si); + swrap_sendmsg_after(s, si, &msg, to, ret); return ret; @@ -5427,6 +5592,8 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags) #endif omsg->msg_iovlen = msg.msg_iovlen; + SWRAP_LOCK_SI(si); + /* * From the manpage: * @@ -5446,6 +5613,8 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags) omsg->msg_namelen = msg.msg_namelen; } + SWRAP_UNLOCK_SI(si); + return ret; } @@ -5481,12 +5650,17 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) ZERO_STRUCT(msg); + SWRAP_LOCK_SI(si); + if (si->connected == 0) { msg.msg_name = omsg->msg_name; /* optional address */ msg.msg_namelen = omsg->msg_namelen; /* size of address */ } msg.msg_iov = omsg->msg_iov; /* scatter/gather array */ msg.msg_iovlen = omsg->msg_iovlen; /* # elements in msg_iov */ + + SWRAP_UNLOCK_SI(si); + #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL if (msg.msg_controllen > 0 && msg.msg_control != NULL) { /* omsg is a const so use a local buffer for modifications */ @@ -5552,9 +5726,13 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) libc_sendmsg(s, &msg, flags); } + SWRAP_LOCK_SI(si); + swrap_pcap_dump_packet(si, to, SWRAP_SENDTO, buf, len); free(buf); + SWRAP_UNLOCK_SI(si); + return len; } @@ -5696,6 +5874,9 @@ static int swrap_close(int fd) si_index = fi->si_index; si = swrap_get_socket_info(si_index); + SWRAP_LOCK(first_free); + SWRAP_LOCK_SI(si); + SWRAP_DLIST_REMOVE(socket_fds, fi); ret = libc_close(fd); @@ -5709,7 +5890,7 @@ static int swrap_close(int fd) if (swrap_get_refcount(si) > 0) { /* there are still references left */ - return ret; + goto out; } if (si->myname.sa_socklen > 0 && si->peername.sa_socklen > 0) { @@ -5725,6 +5906,10 @@ static int swrap_close(int fd) unlink(si->un_addr.sun_path); } +out: + SWRAP_UNLOCK_SI(si); + SWRAP_UNLOCK(first_free); + return ret; } @@ -5763,9 +5948,13 @@ static int swrap_dup(int fd) return -1; } + SWRAP_LOCK_SI(si); + swrap_inc_refcount(si); fi->si_index = src_fi->si_index; + SWRAP_UNLOCK_SI(si); + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); @@ -5824,9 +6013,13 @@ static int swrap_dup2(int fd, int newfd) return -1; } + SWRAP_LOCK_SI(si); + swrap_inc_refcount(si); fi->si_index = src_fi->si_index; + SWRAP_UNLOCK_SI(si); + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); @@ -5872,9 +6065,13 @@ static int swrap_vfcntl(int fd, int cmd, va_list va) return -1; } + SWRAP_LOCK_SI(si); + swrap_inc_refcount(si); fi->si_index = src_fi->si_index; + SWRAP_UNLOCK_SI(si); + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); -- cgit v1.2.3