Просмотр исходного кода

[LibOS] Allow MSG_PEEK on recv()

Previously, Graphene failed if recv() contained MSG_PEEK flag. This
resulted in many TLS-based applications failing, including Nginx,
Apache, and Lighttpd in SSL/TLS mode. This commit adds emulation of
MSG_PEEK at LibOS level. A simple TCP test case is provided.
Dmitrii Kuvaiskii 4 лет назад
Родитель
Сommit
5784d97375

+ 11 - 1
LibOS/shim/include/shim_handle.h

@@ -154,6 +154,8 @@ struct shim_pipe_handle {
 #define AF_INET  PF_INET
 #define AF_INET6 PF_INET6
 
+#define SOCK_URI_SIZE 108
+
 enum shim_sock_state {
     SOCK_CREATED,
     SOCK_BOUND,
@@ -202,7 +204,15 @@ struct shim_sock_handle {
         int optname;
         int optlen;
         char optval[];
-    } * pending_options;
+    }* pending_options;
+
+    struct shim_peek_buffer {
+        size_t size;             /* total size (capacity) of buffer `buf` */
+        size_t start;            /* beginning of buffered but yet unread data in `buf` */
+        size_t end;              /* end of buffered but yet unread data in `buf` */
+        char uri[SOCK_URI_SIZE]; /* cached URI for recvfrom(udp_socket) case */
+        char buf[];              /* peek buffer of size `size` */
+    }* peek_buffer;
 };
 
 struct shim_dirent {

+ 8 - 0
LibOS/shim/include/shim_types.h

@@ -332,6 +332,14 @@ struct __kernel_ustat
   };
 
 /* bits/socket.h */
+enum
+{
+    MSG_OOB  = 0x01, /* Process out-of-band data. */
+    MSG_PEEK = 0x02, /* Peek at incoming messages. */
+#define MSG_OOB MSG_OOB
+#define MSG_PEEK MSG_PEEK
+};
+
 struct msghdr {
     void *msg_name;         /* Address to send to/receive from.  */
     socklen_t msg_namelen;  /* Length of address data.  */

+ 11 - 0
LibOS/shim/src/bookkeep/shim_handle.c

@@ -466,6 +466,11 @@ void put_handle(struct shim_handle* hdl) {
         } else {
             if (hdl->fs && hdl->fs->fs_ops && hdl->fs->fs_ops->close)
                 hdl->fs->fs_ops->close(hdl);
+
+            if (hdl->type == TYPE_SOCK && hdl->info.sock.peek_buffer) {
+                free(hdl->info.sock.peek_buffer);
+                hdl->info.sock.peek_buffer = NULL;
+            }
         }
 
         delete_from_epoll_handles(hdl);
@@ -742,6 +747,12 @@ BEGIN_CP_FUNC(handle) {
         if (hdl->type == TYPE_EPOLL)
             DO_CP(epoll_item, &hdl->info.epoll.fds, &new_hdl->info.epoll.fds);
 
+        if (hdl->type == TYPE_SOCK) {
+            /* no support for multiple processes sharing options/peek buffer of the socket */
+            new_hdl->info.sock.pending_options = NULL;
+            new_hdl->info.sock.peek_buffer     = NULL;
+        }
+
         INIT_LISTP(&new_hdl->epolls);
 
         unlock(&hdl->lock);

+ 6 - 0
LibOS/shim/src/sys/shim_open.c

@@ -65,6 +65,12 @@ size_t shim_do_read (int fd, void * buf, size_t count)
     if (!hdl)
         return -EBADF;
 
+    /* sockets may read from LibOS buffer due to MSG_PEEK, so need to call socket-specific recv */
+    if (hdl->type == TYPE_SOCK) {
+        put_handle(hdl);
+        return shim_do_recvfrom(fd, buf, count, 0, NULL, NULL);
+    }
+
     int ret = do_handle_read(hdl, buf, count);
     put_handle(hdl);
     return ret;

+ 106 - 18
LibOS/shim/src/sys/shim_socket.c

@@ -56,8 +56,6 @@
 
 #define AF_UNSPEC 0
 
-#define SOCK_URI_SIZE 108
-
 static int rebase_on_lo __attribute_migratable = -1;
 
 static size_t minimal_addrlen(int domain) {
@@ -1127,9 +1125,8 @@ ssize_t shim_do_sendmmsg(int sockfd, struct mmsghdr* msg, size_t vlen, int flags
 
 static ssize_t do_recvmsg(int fd, struct iovec* bufs, int nbufs, int flags, struct sockaddr* addr,
                           socklen_t* addrlen) {
-    /* TODO handle flags properly. For now, explicitly return an error. */
-    if (flags) {
-        debug("recvmsg()/recvmmsg()/recvfrom(): flags parameter unsupported.\n");
+    if (flags & ~MSG_PEEK) {
+        debug("recvmsg()/recvmmsg()/recvfrom(): unknown flag (only MSG_PEEK is supported).\n");
         return -EOPNOTSUPP;
     }
 
@@ -1137,6 +1134,7 @@ static ssize_t do_recvmsg(int fd, struct iovec* bufs, int nbufs, int flags, stru
     if (!hdl)
         return -EBADF;
 
+    struct shim_peek_buffer* peek_buffer = NULL;
     int ret = -ENOTSOCK;
     if (hdl->type != TYPE_SOCK)
         goto out;
@@ -1159,13 +1157,16 @@ static ssize_t do_recvmsg(int fd, struct iovec* bufs, int nbufs, int flags, stru
     if (!bufs || test_user_memory(bufs, sizeof(*bufs) * nbufs, false))
         goto out;
 
+    size_t expected_size = 0;
     for (int i = 0; i < nbufs; i++) {
         if (!bufs[i].iov_base || test_user_memory(bufs[i].iov_base, bufs[i].iov_len, true))
             goto out;
+        expected_size += bufs[i].iov_len;
     }
 
     lock(&hdl->lock);
-
+    peek_buffer        = sock->peek_buffer;
+    sock->peek_buffer  = NULL;
     PAL_HANDLE pal_hdl = hdl->pal_handle;
     char* uri          = NULL;
 
@@ -1192,19 +1193,77 @@ static ssize_t do_recvmsg(int fd, struct iovec* bufs, int nbufs, int flags, stru
 
     unlock(&hdl->lock);
 
+    if (flags & MSG_PEEK) {
+        if (!peek_buffer) {
+            /* create new peek buffer with expected read size */
+            peek_buffer = malloc(sizeof(*peek_buffer) + expected_size);
+            if (!peek_buffer) {
+                ret = -ENOMEM;
+                lock(&hdl->lock);
+                goto out_locked;
+            }
+            peek_buffer->size  = expected_size;
+            peek_buffer->start = 0;
+            peek_buffer->end   = 0;
+        } else {
+            /* realloc peek buffer to accommodate expected read size */
+            if (expected_size > peek_buffer->size - peek_buffer->start) {
+                size_t expand = expected_size - (peek_buffer->size - peek_buffer->start);
+                struct shim_peek_buffer* old_peek_buffer = peek_buffer;
+                peek_buffer = malloc(sizeof(*peek_buffer) + old_peek_buffer->size + expand);
+                if (!peek_buffer) {
+                    ret = -ENOMEM;
+                    lock(&hdl->lock);
+                    goto out_locked;
+                }
+                memcpy(peek_buffer, old_peek_buffer, sizeof(*peek_buffer) + old_peek_buffer->size);
+                peek_buffer->size += expand;
+                free(old_peek_buffer);
+            }
+        }
+
+        if (expected_size > peek_buffer->end - peek_buffer->start) {
+            /* fill peek buffer if this MSG_PEEK read request cannot be satisfied with data already
+             * present in peek buffer; note that buffer can hold expected read size at this point */
+            size_t left_to_read = expected_size - (peek_buffer->end - peek_buffer->start);
+            PAL_NUM pal_ret = DkStreamRead(pal_hdl, /*offset=*/0, left_to_read,
+                                           &peek_buffer->buf[peek_buffer->end],
+                                           uri, uri ? SOCK_URI_SIZE : 0);
+            if (pal_ret == PAL_STREAM_ERROR) {
+                ret = (PAL_NATIVE_ERRNO == PAL_ERROR_STREAMNOTEXIST) ? -ECONNABORTED : -PAL_ERRNO;
+                lock(&hdl->lock);
+                goto out_locked;
+            }
+
+            peek_buffer->end += pal_ret;
+            if (uri)
+                memcpy(peek_buffer->uri, uri, SOCK_URI_SIZE);
+        }
+    }
+
+    ret = 0;
+
     bool address_received = false;
-    int bytes             = 0;
-    ret                   = 0;
+    size_t total_bytes    = 0;
 
     for (int i = 0; i < nbufs; i++) {
-        PAL_NUM pal_ret = DkStreamRead(pal_hdl, 0, bufs[i].iov_len, bufs[i].iov_base, uri, uri ? SOCK_URI_SIZE : 0);
-
-        if (pal_ret == PAL_STREAM_ERROR) {
-            ret = (PAL_NATIVE_ERRNO == PAL_ERROR_STREAMNOTEXIST) ? -ECONNABORTED : -PAL_ERRNO;
-            break;
+        size_t iov_bytes = 0;
+        if (peek_buffer) {
+            /* some data left to read from peek buffer */
+            assert(total_bytes < peek_buffer->end - peek_buffer->start);
+            iov_bytes = MIN(bufs[i].iov_len, peek_buffer->end - peek_buffer->start - total_bytes);
+            memcpy(bufs[i].iov_base, &peek_buffer->buf[peek_buffer->start + total_bytes], iov_bytes);
+            uri = peek_buffer->uri;
+        } else {
+            PAL_NUM pal_ret = DkStreamRead(pal_hdl, 0, bufs[i].iov_len, bufs[i].iov_base, uri, uri ? SOCK_URI_SIZE : 0);
+            if (pal_ret == PAL_STREAM_ERROR) {
+                ret = (PAL_NATIVE_ERRNO == PAL_ERROR_STREAMNOTEXIST) ? -ECONNABORTED : -PAL_ERRNO;
+                break;
+            }
+            iov_bytes = pal_ret;
         }
 
-        bytes += pal_ret;
+        total_bytes += iov_bytes;
 
         if (addr && !address_received) {
             if (sock->domain == AF_UNIX) {
@@ -1238,23 +1297,52 @@ static ssize_t do_recvmsg(int fd, struct iovec* bufs, int nbufs, int flags, stru
 
         /* gap in iovecs is not allowed, return a partial read to user; it is the responsibility of
          * user application to deal with partial reads */
-        if (pal_ret < bufs[i].iov_len)
+        if (iov_bytes < bufs[i].iov_len)
+            break;
+
+        /* we read from peek_buffer and exhausted it, return a partial read to user; it is the
+         * responsibility of user application to deal with partial reads */
+        if (peek_buffer && total_bytes == peek_buffer->end - peek_buffer->start)
             break;
     }
 
-    if (bytes)
-        ret = bytes;
+    if (total_bytes)
+        ret = total_bytes;
     if (ret < 0) {
         lock(&hdl->lock);
         goto out_locked;
     }
+
+    if (!(flags & MSG_PEEK) && peek_buffer) {
+        /* we read from peek buffer without MSG_PEEK, need to "remove" this read data */
+        peek_buffer->start += total_bytes;
+        if (peek_buffer->start == peek_buffer->end) {
+            /* we may have exhausted peek buffer, free it to not leak memory */
+            free(peek_buffer);
+            peek_buffer = NULL;
+        }
+    }
+
+    if (peek_buffer) {
+        /* there is non-exhausted peek buffer for this socket, update socket's data */
+        lock(&hdl->lock);
+
+        /* we assume it is impossible for other thread to update this socket's peek buffer (i.e.,
+         * only single thread works on a particular socket); if some real-world program actually has
+         * two threads working on one socket, then we need to fix "grab the lock twice" logic */
+        assert(!sock->peek_buffer);
+
+        sock->peek_buffer = peek_buffer;
+        unlock(&hdl->lock);
+    }
+
     goto out;
 
 out_locked:
     if (ret < 0)
         sock->error = -ret;
-
     unlock(&hdl->lock);
+    free(peek_buffer);
 out:
     put_handle(hdl);
     return ret;

+ 16 - 8
LibOS/shim/test/regression/poll_many_types.c

@@ -1,3 +1,4 @@
+#include <errno.h>
 #include <fcntl.h>
 #include <poll.h>
 #include <stdio.h>
@@ -14,27 +15,34 @@ int main(int argc, char** argv) {
     int pipefd[2];
     ret = pipe(pipefd);
     if (ret < 0) {
-        perror("pipe creation failed");
+        perror("pipe creation");
         return 1;
     }
+
     /* write something into write end of pipe so read end becomes pollable */
-    ret = write(pipefd[1], string, (strlen(string) + 1));
-    if (ret < 0) {
-        perror("pipe write failed");
-        return 1;
+    ssize_t written = 0;
+    while (written < sizeof(string)) {
+        ssize_t n;
+        if ((n = write(pipefd[1], string + written, sizeof(string) - written)) < 0) {
+            if (errno == EINTR || errno == EAGAIN)
+                continue;
+            perror("pipe write");
+            return 1;
+        }
+        written += n;
     }
 
     /* type 2: regular file */
     int filefd = open(argv[0], O_RDONLY);
     if (filefd < 0) {
-        perror("file open failed");
+        perror("file open");
         return 1;
     }
 
     /* type 3: dev file */
     int devfd = open("/dev/urandom", O_RDONLY);
     if (devfd < 0) {
-        perror("dev/urandom open failed");
+        perror("dev/urandom open");
         return 1;
     }
 
@@ -46,7 +54,7 @@ int main(int argc, char** argv) {
 
     ret = poll(infds, 3, -1);
     if (ret <= 0) {
-        perror("poll with POLLIN failed");
+        perror("poll with POLLIN");
         return 1;
     }
     printf("poll(POLLIN) returned %d file descriptors\n", ret);

+ 219 - 0
LibOS/shim/test/regression/tcp_msg_peek.c

@@ -0,0 +1,219 @@
+#include <arpa/inet.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#define SRV_IP "127.0.0.1"
+#define PORT 11111
+#define BUFLEN 512
+
+enum { SINGLE, PARALLEL } mode = PARALLEL;
+int pipefds[2];
+
+void server(void) {
+    int listening_socket, client_socket;
+    struct sockaddr_in address;
+    socklen_t addrlen;
+
+    if ((listening_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
+        perror("socket");
+        exit(1);
+    }
+
+    int enable = 1;
+    if (setsockopt(listening_socket, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < 0) {
+        perror("setsockopt");
+        exit(1);
+    }
+
+    memset(&address, 0, sizeof(address));
+    address.sin_family      = AF_INET;
+    address.sin_port        = htons(PORT);
+    address.sin_addr.s_addr = htonl(INADDR_ANY);
+
+    if (bind(listening_socket, (struct sockaddr*)&address, sizeof(address)) < 0) {
+        perror("bind");
+        exit(1);
+    }
+
+    if (listen(listening_socket, 3) < 0) {
+        perror("listen");
+        exit(1);
+    }
+
+    if (mode == PARALLEL) {
+        if (close(pipefds[0]) < 0) {
+            perror("close of pipe");
+            exit(1);
+        }
+
+        char byte = 0;
+
+        ssize_t written = 0;
+        while (written == 0) {
+            if ((written = write(pipefds[1], &byte, sizeof(byte))) < 0) {
+                if (errno == EINTR || errno == EAGAIN)
+                    continue;
+                perror("write on pipe");
+                exit(1);
+            }
+        }
+    }
+
+    addrlen       = sizeof(address);
+    client_socket = accept(listening_socket, (struct sockaddr*)&address, &addrlen);
+
+    if (client_socket < 0) {
+        perror("accept");
+        exit(1);
+    }
+
+    if (close(listening_socket) < 0) {
+        perror("close of listening socket");
+        exit(1);
+    }
+
+    puts("[server] client is connected...");
+
+    char buffer[] = "Hello from server!\n";
+
+    ssize_t written = 0;
+    while (written < sizeof(buffer)) {
+        ssize_t n;
+        if ((n = sendto(client_socket, buffer + written, sizeof(buffer) - written, 0, 0, 0)) < 0) {
+            if (errno == EINTR || errno == EAGAIN)
+                continue;
+            perror("sendto to client");
+            exit(1);
+        }
+        written += n;
+    }
+
+    if (close(client_socket) < 0) {
+        perror("close of client socket");
+        exit(1);
+    }
+
+    puts("[server] done");
+}
+
+static ssize_t client_recv(int server_socket, char* buf, size_t len, int flags) {
+    ssize_t read = 0;
+    while (1) {
+        ssize_t n;
+        if ((n = recv(server_socket, buf + read, len - read, flags)) < 0) {
+            if (errno == EINTR || errno == EAGAIN)
+                continue;
+            perror("client recv");
+            exit(1);
+        }
+
+        read += n;
+
+        if (!n || flags & MSG_PEEK) {
+            /* recv with MSG_PEEK flag should be done only once */
+            break;
+        }
+    }
+
+    return read;
+}
+
+void client(void) {
+    int server_socket;
+    struct sockaddr_in address;
+    char buffer[BUFLEN];
+    ssize_t count;
+
+    if (mode == PARALLEL) {
+        if (close(pipefds[1]) < 0) {
+            perror("close of pipe");
+            exit(1);
+        }
+
+        char byte = 0;
+
+        ssize_t received = 0;
+        while (received == 0) {
+            if ((received = read(pipefds[0], &byte, sizeof(byte))) < 0) {
+                if (errno == EINTR || errno == EAGAIN)
+                    continue;
+                perror("read on pipe");
+                exit(1);
+            }
+        }
+    }
+
+    if ((server_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
+        perror("socket");
+        exit(1);
+    }
+
+    memset(&address, 0, sizeof(address));
+    address.sin_family = AF_INET;
+    address.sin_port   = htons((PORT));
+    if (inet_aton(SRV_IP, &address.sin_addr) == 0) {
+        perror("inet_aton");
+        exit(1);
+    }
+
+    if (connect(server_socket, (struct sockaddr*)&address, sizeof(address)) < 0) {
+        perror("connect");
+        exit(1);
+    }
+
+    printf("[client] receiving with MSG_PEEK: ");
+    count = client_recv(server_socket, buffer, sizeof(buffer), MSG_PEEK);
+    fwrite(buffer, count, 1, stdout);
+
+    printf("[client] receiving without MSG_PEEK: ");
+    count = client_recv(server_socket, buffer, sizeof(buffer), 0);
+    fwrite(buffer, count, 1, stdout);
+
+    printf("[client] checking how many bytes are left unread: ");
+    count = client_recv(server_socket, buffer, sizeof(buffer), 0);
+    printf("%zu\n", count);
+
+    if (close(server_socket) < 0) {
+        perror("close of server socket");
+        exit(1);
+    }
+
+    puts("[client] done");
+}
+
+int main(int argc, char** argv) {
+    if (argc > 1) {
+        if (strcmp(argv[1], "client") == 0) {
+            mode = SINGLE;
+            client();
+            return 0;
+        }
+
+        if (strcmp(argv[1], "server") == 0) {
+            mode = SINGLE;
+            server();
+            return 0;
+        }
+    } else {
+        pipe(pipefds);
+
+        int pid = fork();
+
+        if (pid == 0) {
+            client();
+        } else {
+            server();
+        }
+    }
+
+    return 0;
+}

+ 8 - 0
LibOS/shim/test/regression/test_libos.py

@@ -473,3 +473,11 @@ class TC_80_Socket(RegressionTestCase):
         self.assertIn('Data: This is packet 7', stdout)
         self.assertIn('Data: This is packet 8', stdout)
         self.assertIn('Data: This is packet 9', stdout)
+
+    def test_300_socket_tcp_msg_peek(self):
+        stdout, _ = self.run_binary(['tcp_msg_peek'], timeout=50)
+        self.assertIn('[client] receiving with MSG_PEEK: Hello from server!', stdout)
+        self.assertIn('[client] receiving without MSG_PEEK: Hello from server!', stdout)
+        self.assertIn('[client] checking how many bytes are left unread: 0', stdout)
+        self.assertIn('[client] done', stdout)
+        self.assertIn('[server] done', stdout)