diff options
Diffstat (limited to 'src/socket_wrapper.c')
-rw-r--r-- | src/socket_wrapper.c | 273 |
1 files changed, 235 insertions, 38 deletions
diff --git a/src/socket_wrapper.c b/src/socket_wrapper.c index 9c16e1a..6b67224 100644 --- a/src/socket_wrapper.c +++ b/src/socket_wrapper.c @@ -188,7 +188,17 @@ enum swrap_dbglvl_e { #define SOCKET_INFO_CONTAINER(si) \ (struct socket_info_container *)(si) -#define SWRAP_DLIST_ADD(list,item) do { \ +#define SWRAP_LOCK_SI(si) do { \ + struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \ + pthread_mutex_lock(&sic->mutex); \ +} while(0) + +#define SWRAP_UNLOCK_SI(si) do { \ + struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \ + pthread_mutex_unlock(&sic->mutex); \ +} while(0) + +#define DLIST_ADD(list, item) do { \ if (!(list)) { \ (item)->prev = NULL; \ (item)->next = NULL; \ @@ -201,7 +211,13 @@ enum swrap_dbglvl_e { } \ } while (0) -#define SWRAP_DLIST_REMOVE(list,item) do { \ +#define SWRAP_DLIST_ADD(list, item) do { \ + SWRAP_LOCK(list); \ + DLIST_ADD(list, item); \ + SWRAP_UNLOCK(list); \ +} while (0) + +#define DLIST_REMOVE(list, item) do { \ if ((list) == (item)) { \ (list) = (item)->next; \ if (list) { \ @@ -219,10 +235,15 @@ enum swrap_dbglvl_e { (item)->next = NULL; \ } while (0) -#define SWRAP_DLIST_ADD_AFTER(list, item, el) \ -do { \ +#define SWRAP_DLIST_REMOVE(list,item) do { \ + SWRAP_LOCK(list); \ + DLIST_REMOVE(list, item); \ + SWRAP_UNLOCK(list); \ +} while (0) + +#define DLIST_ADD_AFTER(list, item, el) do { \ if ((list) == NULL || (el) == NULL) { \ - SWRAP_DLIST_ADD(list, item); \ + DLIST_ADD(list, item); \ } else { \ (item)->prev = (el); \ (item)->next = (el)->next; \ @@ -233,6 +254,12 @@ do { \ } \ } while (0) +#define SWRAP_DLIST_ADD_AFTER(list, item, el) do { \ + SWRAP_LOCK(list); \ + DLIST_ADD_AFTER(list, item, el); \ + SWRAP_UNLOCK(list); \ +} while (0) + #if defined(HAVE_GETTIMEOFDAY_TZ) || defined(HAVE_GETTIMEOFDAY_TZ_VOID) #define swrapGetTimeOfDay(tval) gettimeofday(tval,NULL) #else @@ -333,6 +360,7 @@ struct socket_info_container struct socket_info info; unsigned int refcount; int next_free; + pthread_mutex_t mutex; }; static struct socket_info_container *sockets; @@ -348,6 +376,26 @@ static struct socket_info_fd *socket_fds; /* The mutex for accessing the global libc.symbols */ static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER; +/* The mutex for syncronizing the port selection during swrap_auto_bind() */ +static pthread_mutex_t autobind_start_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Global mutex to guard the initialization of array of socket_info structures. + */ +static pthread_mutex_t sockets_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Global mutex to protect modification of the socket_fds linked + * list structure by different threads within a process. + */ +static pthread_mutex_t socket_fds_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Global mutex to synchronize the query for first free index in array of + * socket_info structures by different threads within a process. + */ +static pthread_mutex_t first_free_mutex = PTHREAD_MUTEX_INITIALIZER; + /* Function prototypes */ bool socket_wrapper_enabled(void); @@ -1321,7 +1369,10 @@ static void socket_wrapper_init_sockets(void) { size_t i; + SWRAP_LOCK(sockets); + if (sockets != NULL) { + SWRAP_UNLOCK(sockets); return; } @@ -1333,17 +1384,24 @@ static void socket_wrapper_init_sockets(void) if (sockets == NULL) { SWRAP_LOG(SWRAP_LOG_ERROR, "Failed to allocate sockets array.\n"); + SWRAP_UNLOCK(sockets); exit(-1); } + SWRAP_LOCK(first_free); + first_free = 0; for (i = 0; i < max_sockets; i++) { swrap_set_next_free(&sockets[i].info, i+1); + sockets[i].mutex = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER; } /* mark the end of the free list */ swrap_set_next_free(&sockets[max_sockets-1].info, -1); + + SWRAP_UNLOCK(first_free); + SWRAP_UNLOCK(sockets); } bool socket_wrapper_enabled(void) @@ -1377,25 +1435,33 @@ static unsigned int socket_wrapper_default_iface(void) static int swrap_add_socket_info(struct socket_info *si_input) { struct socket_info *si = NULL; - int si_index; + int si_index = -1; if (si_input == NULL) { errno = EINVAL; return -1; } + SWRAP_LOCK(first_free); if (first_free == -1) { errno = ENFILE; - return -1; + goto out; } si_index = first_free; si = swrap_get_socket_info(si_index); + SWRAP_LOCK_SI(si); + first_free = swrap_get_next_free(si); *si = *si_input; swrap_inc_refcount(si); + SWRAP_UNLOCK_SI(si); + +out: + SWRAP_UNLOCK(first_free); + return si_index; } @@ -1795,13 +1861,17 @@ static struct socket_info_fd *find_socket_info_fd(int fd) { struct socket_info_fd *f; + SWRAP_LOCK(socket_fds); + for (f = socket_fds; f; f = f->next) { if (f->fd == fd) { - return f; + break; } } - return NULL; + SWRAP_UNLOCK(socket_fds); + + return f; } static int find_socket_info_index(int fd) @@ -1934,6 +2004,9 @@ static void swrap_remove_stale(int fd) si = swrap_get_socket_info(si_index); + SWRAP_LOCK(first_free); + SWRAP_LOCK_SI(si); + SWRAP_DLIST_REMOVE(socket_fds, fi); swrap_set_next_free(si, first_free); @@ -1943,12 +2016,16 @@ static void swrap_remove_stale(int fd) free(fi); if (swrap_get_refcount(si) > 0) { - return; + goto out; } if (si->un_addr.sun_path[0] != '\0') { unlink(si->un_addr.sun_path); } + +out: + SWRAP_UNLOCK_SI(si); + SWRAP_UNLOCK(first_free); } static int sockaddr_convert_to_un(struct socket_info *si, @@ -3101,16 +3178,26 @@ static int swrap_accept(int s, #endif } + + /* + * prevent parent_si from being altered / closed + * while we read it + */ + SWRAP_LOCK_SI(parent_si); + /* * assume out sockaddr have the same size as the in parent * socket family */ in_addr.sa_socklen = socket_length(parent_si->family); if (in_addr.sa_socklen <= 0) { + SWRAP_UNLOCK_SI(parent_si); errno = EINVAL; return -1; } + SWRAP_UNLOCK_SI(parent_si); + #ifdef HAVE_ACCEPT4 ret = libc_accept4(s, &un_addr.sa.s, &un_addr.sa_socklen, flags); #else @@ -3127,6 +3214,8 @@ static int swrap_accept(int s, fd = ret; + SWRAP_LOCK_SI(parent_si); + ret = sockaddr_convert_from_un(parent_si, &un_addr.sa.un, un_addr.sa_socklen, @@ -3134,6 +3223,7 @@ static int swrap_accept(int s, &in_addr.sa.s, &in_addr.sa_socklen); if (ret == -1) { + SWRAP_UNLOCK_SI(parent_si); close(fd); return ret; } @@ -3147,6 +3237,8 @@ static int swrap_accept(int s, child_si->is_server = 1; child_si->connected = 1; + SWRAP_UNLOCK_SI(parent_si); + child_si->peername = (struct swrap_address) { .sa_socklen = in_addr.sa_socklen, }; @@ -3197,9 +3289,11 @@ static int swrap_accept(int s, if (addr != NULL) { struct socket_info *si = swrap_get_socket_info(idx); + SWRAP_LOCK_SI(si); swrap_pcap_dump_packet(si, addr, SWRAP_ACCEPT_SEND, NULL, 0); swrap_pcap_dump_packet(si, addr, SWRAP_ACCEPT_RECV, NULL, 0); swrap_pcap_dump_packet(si, addr, SWRAP_ACCEPT_ACK, NULL, 0); + SWRAP_UNLOCK_SI(si); } return fd; @@ -3241,6 +3335,8 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) int port; struct stat st; + SWRAP_LOCK(autobind_start); + if (autobind_start_init != 1) { autobind_start_init = 1; autobind_start = getpid(); @@ -3359,6 +3455,7 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) ret = 0; done: + SWRAP_UNLOCK(autobind_start); return ret; } @@ -3380,6 +3477,8 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, return libc_connect(s, serv_addr, addrlen); } + SWRAP_LOCK_SI(si); + if (si->bound == 0) { ret = swrap_auto_bind(s, si, serv_addr->sa_family); if (ret == -1) { @@ -3463,6 +3562,7 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, } done: + SWRAP_UNLOCK_SI(si); return ret; } @@ -3491,6 +3591,8 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) return libc_bind(s, myaddr, addrlen); } + SWRAP_LOCK_SI(si); + switch (si->family) { case AF_INET: { const struct sockaddr_in *sin; @@ -3538,14 +3640,16 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) if (bind_error != 0) { errno = bind_error; - return -1; + ret = -1; + goto out; } #if 0 /* FIXME */ in_use = check_addr_port_in_use(myaddr, addrlen); if (in_use) { errno = EADDRINUSE; - return -1; + ret = -1; + goto out; } #endif @@ -3559,7 +3663,7 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) 1, &si->bcast); if (ret == -1) { - return -1; + goto out; } unlink(un_addr.sa.un.sun_path); @@ -3574,6 +3678,9 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) si->bound = 1; } +out: + SWRAP_UNLOCK_SI(si); + return ret; } @@ -3677,16 +3784,21 @@ static int swrap_listen(int s, int backlog) return libc_listen(s, backlog); } + SWRAP_LOCK_SI(si); + if (si->bound == 0) { ret = swrap_auto_bind(s, si, si->family); if (ret == -1) { errno = EADDRINUSE; - return ret; + goto out; } } ret = libc_listen(s, backlog); +out: + SWRAP_UNLOCK_SI(si); + return ret; } @@ -3853,26 +3965,34 @@ static int swrap_getpeername(int s, struct sockaddr *name, socklen_t *addrlen) { struct socket_info *si = find_socket_info(s); socklen_t len; + int ret = -1; if (!si) { return libc_getpeername(s, name, addrlen); } + SWRAP_LOCK_SI(si); + if (si->peername.sa_socklen == 0) { errno = ENOTCONN; - return -1; + goto out; } len = MIN(*addrlen, si->peername.sa_socklen); if (len == 0) { - return 0; + ret = 0; + goto out; } memcpy(name, &si->peername.sa.ss, len); *addrlen = si->peername.sa_socklen; - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } #ifdef HAVE_ACCEPT_PSOCKLEN_T @@ -3892,20 +4012,28 @@ static int swrap_getsockname(int s, struct sockaddr *name, socklen_t *addrlen) { struct socket_info *si = find_socket_info(s); socklen_t len; + int ret = -1; if (!si) { return libc_getsockname(s, name, addrlen); } + SWRAP_LOCK_SI(si); + len = MIN(*addrlen, si->myname.sa_socklen); if (len == 0) { - return 0; + ret = 0; + goto out; } memcpy(name, &si->myname.sa.ss, len); *addrlen = si->myname.sa_socklen; - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } #ifdef HAVE_ACCEPT_PSOCKLEN_T @@ -3941,6 +4069,8 @@ static int swrap_getsockopt(int s, int level, int optname, optlen); } + SWRAP_LOCK_SI(si); + if (level == SOL_SOCKET) { switch (optname) { #ifdef SO_DOMAIN @@ -4023,6 +4153,7 @@ static int swrap_getsockopt(int s, int level, int optname, ret = -1; done: + SWRAP_UNLOCK_SI(si); return ret; } @@ -4061,6 +4192,8 @@ static int swrap_setsockopt(int s, int level, int optname, optlen); } + SWRAP_LOCK_SI(si); + if (level == IPPROTO_TCP) { switch (optname) { #ifdef TCP_NODELAY @@ -4125,6 +4258,7 @@ static int swrap_setsockopt(int s, int level, int optname, } done: + SWRAP_UNLOCK_SI(si); return ret; } @@ -4149,6 +4283,8 @@ static int swrap_vioctl(int s, unsigned long int r, va_list va) return libc_vioctl(s, r, va); } + SWRAP_LOCK_SI(si); + va_copy(ap, va); rc = libc_vioctl(s, r, va); @@ -4167,6 +4303,7 @@ static int swrap_vioctl(int s, unsigned long int r, va_list va) va_end(ap); + SWRAP_UNLOCK_SI(si); return rc; } @@ -4472,7 +4609,7 @@ static ssize_t swrap_sendmsg_before(int fd, int *bcast) { size_t i, len = 0; - ssize_t ret; + ssize_t ret = -1; if (to_un) { *to_un = NULL; @@ -4484,13 +4621,15 @@ static ssize_t swrap_sendmsg_before(int fd, *bcast = 0; } + SWRAP_LOCK_SI(si); + switch (si->type) { case SOCK_STREAM: { unsigned long mtu; if (!si->connected) { errno = ENOTCONN; - return -1; + goto out; } if (msg->msg_iovlen == 0) { @@ -4533,14 +4672,14 @@ static ssize_t swrap_sendmsg_before(int fd, if (msg_name == NULL) { errno = ENOTCONN; - return -1; + goto out; } ret = sockaddr_convert_to_un(si, msg_name, msg->msg_namelen, tmp_un, 0, bcast); if (ret == -1) { - return -1; + goto out; } if (to_un) { @@ -4558,11 +4697,11 @@ static ssize_t swrap_sendmsg_before(int fd, if (ret == -1) { if (errno == ENOTSOCK) { swrap_remove_stale(fd); - return -ENOTSOCK; + ret = -ENOTSOCK; } else { SWRAP_LOG(SWRAP_LOG_ERROR, "swrap_sendmsg_before failed"); - return -1; } + goto out; } } @@ -4577,7 +4716,7 @@ static ssize_t swrap_sendmsg_before(int fd, 0, NULL); if (ret == -1) { - return -1; + goto out; } ret = libc_connect(fd, @@ -4590,14 +4729,14 @@ static ssize_t swrap_sendmsg_before(int fd, } if (ret == -1) { - return ret; + goto out; } si->defer_connect = 0; break; default: errno = EHOSTUNREACH; - return -1; + goto out; } #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL @@ -4608,7 +4747,7 @@ static ssize_t swrap_sendmsg_before(int fd, ret = swrap_sendmsg_filter_cmsghdr(msg, &cmbuf, &cmlen); if (ret < 0) { free(cmbuf); - return -1; + goto out; } if (cmlen == 0) { @@ -4622,7 +4761,11 @@ static ssize_t swrap_sendmsg_before(int fd, } #endif - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } static void swrap_sendmsg_after(int fd, @@ -4676,6 +4819,8 @@ static void swrap_sendmsg_after(int fd, } len = ofs; + SWRAP_LOCK_SI(si); + switch (si->type) { case SOCK_STREAM: if (ret == -1) { @@ -4699,6 +4844,8 @@ static void swrap_sendmsg_after(int fd, break; } + SWRAP_UNLOCK_SI(si); + free(buf); errno = saved_errno; } @@ -4709,7 +4856,9 @@ static int swrap_recvmsg_before(int fd, struct iovec *tmp_iov) { size_t i, len = 0; - ssize_t ret; + int ret = -1; + + SWRAP_LOCK_SI(si); (void)fd; /* unused */ @@ -4718,7 +4867,7 @@ static int swrap_recvmsg_before(int fd, unsigned int mtu; if (!si->connected) { errno = ENOTCONN; - return -1; + goto out; } if (msg->msg_iovlen == 0) { @@ -4746,7 +4895,7 @@ static int swrap_recvmsg_before(int fd, case SOCK_DGRAM: if (msg->msg_name == NULL) { errno = EINVAL; - return -1; + goto out; } if (msg->msg_iovlen == 0) { @@ -4764,21 +4913,25 @@ static int swrap_recvmsg_before(int fd, */ if (errno == ENOTSOCK) { swrap_remove_stale(fd); - return -ENOTSOCK; + ret = -ENOTSOCK; } else { SWRAP_LOG(SWRAP_LOG_ERROR, "swrap_recvmsg_before failed"); - return -1; } + goto out; } } break; default: errno = EHOSTUNREACH; - return -1; + goto out; } - return 0; + ret = 0; +out: + SWRAP_UNLOCK_SI(si); + + return ret; } static int swrap_recvmsg_after(int fd, @@ -4810,6 +4963,8 @@ static int swrap_recvmsg_after(int fd, avail += msg->msg_iov[i].iov_len; } + SWRAP_LOCK_SI(si); + /* Convert the socket address before we leave */ if (si->type == SOCK_DGRAM && un_addr != NULL) { rc = sockaddr_convert_from_un(si, @@ -4838,6 +4993,7 @@ static int swrap_recvmsg_after(int fd, buf = (uint8_t *)malloc(remain); if (buf == NULL) { /* we just not capture the packet */ + SWRAP_UNLOCK_SI(si); errno = saved_errno; return -1; } @@ -4895,11 +5051,13 @@ done: msg->msg_control != NULL) { rc = swrap_msghdr_add_socket_info(si, msg); if (rc < 0) { + SWRAP_UNLOCK_SI(si); return -1; } } #endif + SWRAP_UNLOCK_SI(si); return rc; } @@ -5071,11 +5229,16 @@ static ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, un_addr.sa_socklen); } + SWRAP_LOCK_SI(si); + swrap_pcap_dump_packet(si, to, SWRAP_SENDTO, buf, len); + SWRAP_UNLOCK_SI(si); + return len; } + SWRAP_LOCK_SI(si); /* * If it is a dgram socket and we are connected, don't include the * 'to' address. @@ -5096,6 +5259,8 @@ static ssize_t swrap_sendto(int s, const void *buf, size_t len, int flags, msg.msg_namelen); } + SWRAP_UNLOCK_SI(si); + swrap_sendmsg_after(s, si, &msg, to, ret); return ret; @@ -5427,6 +5592,8 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags) #endif omsg->msg_iovlen = msg.msg_iovlen; + SWRAP_LOCK_SI(si); + /* * From the manpage: * @@ -5446,6 +5613,8 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags) omsg->msg_namelen = msg.msg_namelen; } + SWRAP_UNLOCK_SI(si); + return ret; } @@ -5481,12 +5650,17 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) ZERO_STRUCT(msg); + SWRAP_LOCK_SI(si); + if (si->connected == 0) { msg.msg_name = omsg->msg_name; /* optional address */ msg.msg_namelen = omsg->msg_namelen; /* size of address */ } msg.msg_iov = omsg->msg_iov; /* scatter/gather array */ msg.msg_iovlen = omsg->msg_iovlen; /* # elements in msg_iov */ + + SWRAP_UNLOCK_SI(si); + #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL if (msg.msg_controllen > 0 && msg.msg_control != NULL) { /* omsg is a const so use a local buffer for modifications */ @@ -5552,9 +5726,13 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) libc_sendmsg(s, &msg, flags); } + SWRAP_LOCK_SI(si); + swrap_pcap_dump_packet(si, to, SWRAP_SENDTO, buf, len); free(buf); + SWRAP_UNLOCK_SI(si); + return len; } @@ -5696,6 +5874,9 @@ static int swrap_close(int fd) si_index = fi->si_index; si = swrap_get_socket_info(si_index); + SWRAP_LOCK(first_free); + SWRAP_LOCK_SI(si); + SWRAP_DLIST_REMOVE(socket_fds, fi); ret = libc_close(fd); @@ -5709,7 +5890,7 @@ static int swrap_close(int fd) if (swrap_get_refcount(si) > 0) { /* there are still references left */ - return ret; + goto out; } if (si->myname.sa_socklen > 0 && si->peername.sa_socklen > 0) { @@ -5725,6 +5906,10 @@ static int swrap_close(int fd) unlink(si->un_addr.sun_path); } +out: + SWRAP_UNLOCK_SI(si); + SWRAP_UNLOCK(first_free); + return ret; } @@ -5763,9 +5948,13 @@ static int swrap_dup(int fd) return -1; } + SWRAP_LOCK_SI(si); + swrap_inc_refcount(si); fi->si_index = src_fi->si_index; + SWRAP_UNLOCK_SI(si); + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); @@ -5824,9 +6013,13 @@ static int swrap_dup2(int fd, int newfd) return -1; } + SWRAP_LOCK_SI(si); + swrap_inc_refcount(si); fi->si_index = src_fi->si_index; + SWRAP_UNLOCK_SI(si); + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); @@ -5872,9 +6065,13 @@ static int swrap_vfcntl(int fd, int cmd, va_list va) return -1; } + SWRAP_LOCK_SI(si); + swrap_inc_refcount(si); fi->si_index = src_fi->si_index; + SWRAP_UNLOCK_SI(si); + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); |