瀏覽代碼

[LibOS] Allow all getsockopt() before bind()

Previously in Graphene, some getsockopt() syscalls, e.g. TCP_NODELAY,
failed because an underlying PAL handle was't created for the LibOS
handle until bind() was called. Thus, a sequence of accept() and
getsockopt() failed. This commit fixes this by returning default
socket options (possibly augmented with setsockopt values). Test
case is also provided.
Dmitrii Kuvaiskii 4 年之前
父節點
當前提交
eb4849d23a
共有 3 個文件被更改,包括 170 次插入135 次删除
  1. 133 103
      LibOS/shim/src/sys/shim_socket.c
  2. 35 31
      LibOS/shim/test/regression/getsockopt.c
  3. 2 1
      LibOS/shim/test/regression/test_libos.py

+ 133 - 103
LibOS/shim/src/sys/shim_socket.c

@@ -1537,62 +1537,33 @@ struct __kernel_linger {
     int l_linger;
 };
 
-static int __do_setsockopt(struct shim_handle* hdl, int level, int optname, char* optval,
-                           int optlen, PAL_STREAM_ATTR* attr) {
-    // Issue 754 - https://github.com/oscarlab/graphene/issues/754
-    __UNUSED(optlen);
-
-    int intval     = *((int*)optval);
-    PAL_BOL bolval = intval ? PAL_TRUE : PAL_FALSE;
-
-    if (level == SOL_SOCKET) {
-        switch (optname) {
-            case SO_ACCEPTCONN:
-            case SO_DOMAIN:
-            case SO_ERROR:
-            case SO_PROTOCOL:
-            case SO_TYPE:
-                return -EPERM;
-            case SO_KEEPALIVE:
-            case SO_LINGER:
-            case SO_RCVBUF:
-            case SO_SNDBUF:
-            case SO_RCVTIMEO:
-            case SO_SNDTIMEO:
-            case SO_REUSEADDR:
-                goto query;
-            default:
-                goto unknown;
-        }
-    }
-
-    if (level == SOL_TCP) {
-        switch (optname) {
-            case TCP_CORK:
-            case TCP_NODELAY:
-                goto query;
-            default:
-                goto unknown;
-        }
-    }
-
-unknown:
-    return -ENOPROTOOPT;
+static void __populate_addr_with_defaults(PAL_STREAM_ATTR* attr) {
+    /* Linux default recv/send buffer sizes for new sockets */
+    attr->socket.receivebuf     = 212992;
+    attr->socket.sendbuf        = 212992;
+
+    attr->socket.linger         = 0;
+    attr->socket.receivetimeout = 0;
+    attr->socket.sendtimeout    = 0;
+    attr->socket.tcp_cork       = PAL_FALSE;
+    attr->socket.tcp_keepalive  = PAL_FALSE;
+    attr->socket.tcp_nodelay    = PAL_FALSE;
+}
 
-query:
-    if (!attr) {
-        attr = __alloca(sizeof(PAL_STREAM_ATTR));
+static bool __update_attr(PAL_STREAM_ATTR* attr, int level, int optname, char* optval, int optlen) {
+    __UNUSED(optlen);
+    assert(attr);
 
-        if (!DkStreamAttributesQueryByHandle(hdl->pal_handle, attr))
-            return -PAL_ERRNO;
-    }
+    bool need_set_attr = false;
+    int intval         = *((int*)optval);
+    PAL_BOL bolval     = intval ? PAL_TRUE : PAL_FALSE;
 
     if (level == SOL_SOCKET) {
         switch (optname) {
             case SO_KEEPALIVE:
                 if (bolval != attr->socket.tcp_keepalive) {
                     attr->socket.tcp_keepalive = bolval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             case SO_LINGER: {
@@ -1600,35 +1571,36 @@ query:
                 int linger                = l->l_onoff ? l->l_linger : 0;
                 if (linger != (int)attr->socket.linger) {
                     attr->socket.linger = linger;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             }
             case SO_RCVBUF:
                 if (intval != (int)attr->socket.receivebuf) {
                     attr->socket.receivebuf = intval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             case SO_SNDBUF:
                 if (intval != (int)attr->socket.sendbuf) {
                     attr->socket.sendbuf = intval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             case SO_RCVTIMEO:
                 if (intval != (int)attr->socket.receivetimeout) {
                     attr->socket.receivetimeout = intval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             case SO_SNDTIMEO:
                 if (intval != (int)attr->socket.sendtimeout) {
                     attr->socket.sendtimeout = intval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             case SO_REUSEADDR:
+                /* PAL always does REUSEADDR, no need to check or update */
                 break;
         }
     }
@@ -1638,23 +1610,65 @@ query:
             case TCP_CORK:
                 if (bolval != attr->socket.tcp_cork) {
                     attr->socket.tcp_cork = bolval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
             case TCP_NODELAY:
                 if (bolval != attr->socket.tcp_nodelay) {
                     attr->socket.tcp_nodelay = bolval;
-                    goto set;
+                    need_set_attr = true;
                 }
                 break;
         }
     }
 
-    return 0;
+    return need_set_attr;
+}
 
-set:
-    if (!DkStreamAttributesSetByHandle(hdl->pal_handle, attr))
-        return -PAL_ERRNO;
+static int __do_setsockopt(struct shim_handle* hdl, int level, int optname, char* optval,
+                           int optlen, PAL_STREAM_ATTR* attr) {
+    // Issue 754 - https://github.com/oscarlab/graphene/issues/754
+    __UNUSED(optlen);
+
+    if (level != SOL_SOCKET && level != SOL_TCP)
+        return -ENOPROTOOPT;
+
+    if (level == SOL_SOCKET) {
+        switch (optname) {
+            case SO_ACCEPTCONN:
+            case SO_DOMAIN:
+            case SO_ERROR:
+            case SO_PROTOCOL:
+            case SO_TYPE:
+                return -EPERM;
+            case SO_KEEPALIVE:
+            case SO_LINGER:
+            case SO_RCVBUF:
+            case SO_SNDBUF:
+            case SO_RCVTIMEO:
+            case SO_SNDTIMEO:
+            case SO_REUSEADDR:
+                break;
+            default:
+                return -ENOPROTOOPT;
+        }
+    }
+
+    if (level == SOL_TCP && optname != TCP_CORK && optname != TCP_NODELAY)
+        return -ENOPROTOOPT;
+
+    PAL_STREAM_ATTR local_attr;
+    if (!attr) {
+        attr = &local_attr;
+        if (!DkStreamAttributesQueryByHandle(hdl->pal_handle, attr))
+            return -PAL_ERRNO;
+    }
+
+    bool need_set_attr = __update_attr(attr, level, optname, optval, optlen);
+    if (need_set_attr) {
+        if (!DkStreamAttributesSetByHandle(hdl->pal_handle, attr))
+            return -PAL_ERRNO;
+    }
 
     return 0;
 }
@@ -1756,6 +1770,9 @@ int shim_do_getsockopt(int fd, int level, int optname, char* optval, int* optlen
 
     int* intval = (int*)optval;
 
+    if (level != SOL_SOCKET && level != SOL_TCP)
+        goto unknown;
+
     if (level == SOL_SOCKET) {
         switch (optname) {
             case SO_ACCEPTCONN:
@@ -1789,7 +1806,7 @@ int shim_do_getsockopt(int fd, int level, int optname, char* optval, int* optlen
             case SO_RCVTIMEO:
             case SO_SNDTIMEO:
             case SO_REUSEADDR:
-                goto query;
+                break;
             default:
                 goto unknown;
         }
@@ -1799,68 +1816,81 @@ int shim_do_getsockopt(int fd, int level, int optname, char* optval, int* optlen
         switch (optname) {
             case TCP_CORK:
             case TCP_NODELAY:
-                goto query;
+                break;
             default:
                 goto unknown;
         }
     }
 
-unknown:
-    ret = -ENOPROTOOPT;
-    goto out;
-
-query:
-    {
-        PAL_STREAM_ATTR attr;
+    /* at this point, we need to query PAL to get current attributes of hdl */
+    PAL_STREAM_ATTR attr;
 
+    if (!hdl->pal_handle) {
+        /* it is possible that there is no underlying PAL handle for hdl, e.g., socket() before
+         * bind(); in this case, augment default attrs with pending_options and skip quering PAL */
+        __populate_addr_with_defaults(&attr);
+
+        struct shim_sock_option* o = sock->pending_options;
+        while (o) {
+            __update_attr(&attr, o->level, o->optname, o->optval, o->optlen);
+            o = o->next;
+        }
+    } else {
+        /* query PAL to get current attributes */
         if (!DkStreamAttributesQueryByHandle(hdl->pal_handle, &attr)) {
             ret = -PAL_ERRNO;
             goto out;
         }
+    }
 
-        if (level == SOL_SOCKET) {
-            switch (optname) {
-                case SO_KEEPALIVE:
-                    *intval = attr.socket.tcp_keepalive ? 1 : 0;
-                    break;
-                case SO_LINGER: {
-                    struct __kernel_linger* l = (struct __kernel_linger*)optval;
-                    l->l_onoff                = attr.socket.linger ? 1 : 0;
-                    l->l_linger               = attr.socket.linger;
-                    break;
-                }
-                case SO_RCVBUF:
-                    *intval = attr.socket.receivebuf;
-                    break;
-                case SO_SNDBUF:
-                    *intval = attr.socket.sendbuf;
-                    break;
-                case SO_RCVTIMEO:
-                    *intval = attr.socket.receivetimeout;
-                    break;
-                case SO_SNDTIMEO:
-                    *intval = attr.socket.sendtimeout;
-                    break;
-                case SO_REUSEADDR:
-                    *intval = 1;
-                    break;
+    if (level == SOL_SOCKET) {
+        switch (optname) {
+            case SO_KEEPALIVE:
+                *intval = attr.socket.tcp_keepalive ? 1 : 0;
+                break;
+            case SO_LINGER: {
+                struct __kernel_linger* l = (struct __kernel_linger*)optval;
+                l->l_onoff                = attr.socket.linger ? 1 : 0;
+                l->l_linger               = attr.socket.linger;
+                break;
             }
+            case SO_RCVBUF:
+                *intval = attr.socket.receivebuf;
+                break;
+            case SO_SNDBUF:
+                *intval = attr.socket.sendbuf;
+                break;
+            case SO_RCVTIMEO:
+                *intval = attr.socket.receivetimeout;
+                break;
+            case SO_SNDTIMEO:
+                *intval = attr.socket.sendtimeout;
+                break;
+            case SO_REUSEADDR:
+                *intval = 1;
+                break;
         }
+    }
 
-        if (level == SOL_TCP) {
-            switch (optname) {
-                case TCP_CORK:
-                    *intval = attr.socket.tcp_cork ? 1 : 0;
-                    break;
-                case TCP_NODELAY:
-                    *intval = attr.socket.tcp_nodelay ? 1 : 0;
-                    break;
-            }
+    if (level == SOL_TCP) {
+        switch (optname) {
+            case TCP_CORK:
+                *intval = attr.socket.tcp_cork ? 1 : 0;
+                break;
+            case TCP_NODELAY:
+                *intval = attr.socket.tcp_nodelay ? 1 : 0;
+                break;
         }
     }
 
+    ret = 0;
+
 out:
     unlock(&hdl->lock);
     put_handle(hdl);
     return ret;
+
+unknown:
+    ret = -ENOPROTOOPT;
+    goto out;
 }

+ 35 - 31
LibOS/shim/test/regression/getsockopt.c

@@ -1,46 +1,50 @@
-/* Unit test for issue #92.
- * Example for use of getsockopt with SO_TYPE
- * taken from here: http://alas.matf.bg.ac.rs/manuals/lspe/snode=103.html
- */
+/* Unit test for issues #92 and #644 */
+
 #include <assert.h>
 #include <errno.h>
+#include <netinet/tcp.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <sys/socket.h>
 
 int main(int argc, char** argv) {
-    int z;
-    int s       = -1; /* Socket */
-    int so_type = -1; /* Socket type */
+    int ret;
     socklen_t optlen; /* Option length */
-    int rv = 0;
-
-    /*
-     * Create a TCP/IP socket to use:
-     */
-    s = socket(PF_INET, SOCK_STREAM, 0);
-    if (s == -1) {
-        printf("socket(2) error %d", errno);
-        exit(-1);
+
+    int fd = socket(PF_INET, SOCK_STREAM, 0);
+    if (fd < 0) {
+        perror("socket failed");
+        return 1;
+    }
+
+    int so_type;
+    optlen = sizeof(so_type);
+    ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &so_type, &optlen);
+    if (ret < 0) {
+        perror("getsockopt(SOL_SOCKET, SO_TYPE) failed");
+        return 1;
     }
 
-    /*
-     * Get socket option SO_SNDBUF:
-     */
-    optlen = sizeof so_type;
-    z      = getsockopt(s, SOL_SOCKET, SO_TYPE, &so_type, &optlen);
-    if (z) {
-        printf("getsockopt(s,SOL_SOCKET,SO_TYPE) %d", errno);
-        exit(-1);
+    if (optlen != sizeof(so_type) || so_type != SOCK_STREAM) {
+        fprintf(stderr, "getsockopt(SOL_SOCKET, SO_TYPE) failed\n");
+        return 1;
+    }
+
+    printf("getsockopt: Got socket type OK\n");
+
+    int so_flags = 1;
+    optlen = sizeof(so_flags);
+    ret = getsockopt(fd, SOL_TCP, TCP_NODELAY, (void*)&so_flags, &optlen);
+    if (ret < 0) {
+        perror("getsockopt(SOL_TCP, TCP_NODELAY) failed");
+        return 1;
     }
 
-    assert(optlen == sizeof so_type);
-    if (so_type == SOCK_STREAM) {
-        printf("getsockopt: Got socket type OK\n");
-    } else {
-        printf("getsockopt: Got socket type failed\n");
-        rv = -1;
+    if (optlen != sizeof(so_flags) || (so_flags != 0 && so_flags != 1)) {
+        fprintf(stderr, "getsockopt(SOL_TCP, TCP_NODELAY) failed\n");
+        return 1;
     }
 
-    return rv;
+    printf("getsockopt: Got TCP_NODELAY flag OK\n");
+    return 0;
 }

+ 2 - 1
LibOS/shim/test/regression/test_libos.py

@@ -414,8 +414,9 @@ class TC_40_FileSystem(RegressionTestCase):
 
 class TC_80_Socket(RegressionTestCase):
     def test_000_getsockopt(self):
-        stdout, stderr = self.run_binary(['getsockopt'])
+        stdout, _ = self.run_binary(['getsockopt'])
         self.assertIn('getsockopt: Got socket type OK', stdout)
+        self.assertIn('getsockopt: Got TCP_NODELAY flag OK', stdout)
 
     def test_010_epoll_wait_timeout(self):
         stdout, stderr = self.run_binary(['epoll_wait_timeout', '8000'],