Browse Source

bugfixes: TCP/UDP sockets fail to return addresses when running in enclaves

Chia-Che Tsai 7 years ago
parent
commit
8cc0920bf4

+ 5 - 2
LibOS/shim/src/sys/shim_socket.c

@@ -61,7 +61,7 @@
 #define TCP_CONGESTION      13  /* Congestion control algorithm.  */
 #define TCP_MD5SIG          14  /* TCP MD5 Signature (RFC2385) */
 
-#define SOCK_URI_SIZE   64
+#define SOCK_URI_SIZE   108
 
 static int rebase_on_lo __attribute_migratable = -1;
 
@@ -207,7 +207,8 @@ static int inet_translate_addr (int domain, char * uri, int count,
 
     if (domain == AF_INET6) {
         unsigned short * ad = (void *) &addr->addr.v6.s6_addr;
-        int bytes = snprintf(uri, count, "[%x:%x:%x:%x:%x:%x:%x:%x]:%u",
+        int bytes = snprintf(uri, count,
+                             "[%04x:%04x:%x:%04x:%04x:%04x:%04x:%04x]:%u",
                              ad[0], ad[1], ad[2], ad[3],
                              ad[4], ad[5], ad[6], ad[7], addr->ext_port);
         return bytes == count ? -ENAMETOOLONG : bytes;
@@ -557,6 +558,8 @@ static int inet_parse_addr (int domain, int type, const char * uri,
 
         port_str++;
         next_str = strchr(port_str, ':');
+        if (next_str)
+            next_str++;
 
         struct addr_inet * addr = round ? conn : bind;
 

+ 37 - 25
Pal/src/host/Linux-SGX/db_sockets.c

@@ -72,7 +72,7 @@ static int inet_parse_uri (char ** uri, struct sockaddr * addr, unsigned int * a
 
     if (tmp[0] == '[') {
         /* for IPv6, the address will be in the form of
-           "[xx:xx:xx:xx:xx:xx:xx:xx]:port". */
+           "[xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx]:port". */
         struct sockaddr_in6 * addr_in6 = (struct sockaddr_in6 *) addr;
 
         slen = sizeof(struct sockaddr_in6);
@@ -154,11 +154,11 @@ static int inet_create_uri (char * uri, int count, struct sockaddr * addr,
             return PAL_ERROR_INVAL;
 
         struct sockaddr_in6 * addr_in6 = (struct sockaddr_in6 *) addr;
-        short * addr = (short *) &addr_in6->sin6_addr.s6_addr;
+        unsigned short * addr = (unsigned short *) &addr_in6->sin6_addr.s6_addr;
 
         /* for IPv6, the address will be in the form of
-           "[xx:xx:xx:xx:xx:xx:xx:xx]:port". */
-        len = snprintf(uri, count, "[%x:%x:%x:%x:%x:%x:%x:%x]:%u",
+           "[xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx]:port". */
+        len = snprintf(uri, count, "[%04x:%04x:%04x:%04x:%04x:%04x:%04x:%04x]:%u",
                        addr[0], addr[1], addr[2], addr[3],
                        addr[4], addr[5], addr[6], addr[7],
                        __ntohs(addr_in6->sin6_port));
@@ -166,6 +166,9 @@ static int inet_create_uri (char * uri, int count, struct sockaddr * addr,
         return -PAL_ERROR_INVAL;
     }
 
+    if (len >= count)
+        return -PAL_ERROR_TOOLONG;
+
     return len;
 }
 
@@ -607,11 +610,14 @@ static int udp_receivebyaddr (PAL_HANDLE handle, int offset, int len,
     if (bytes < 0)
         return bytes;
 
-    if (addrlen < 5)
+    char * addr_uri = strcpy_static(addr, "udp:", addrlen);
+    if (!addr_uri)
         return -PAL_ERROR_OVERFLOW;
 
-    memcpy(addr, "udp:", 5);
-    inet_create_uri(addr + 4, addrlen - 4, &conn_addr, conn_addrlen);
+    int ret = inet_create_uri(addr_uri, addr + addrlen - addr_uri, &conn_addr,
+                              conn_addrlen);
+    if (ret < 0)
+        return ret;
 
     return bytes;
 }
@@ -652,8 +658,11 @@ static int udp_sendbyaddr (PAL_HANDLE handle, int offset, int len,
     if (!strpartcmp_static(addr, "udp:"))
         return -PAL_ERROR_INVAL;
 
-    char * addrbuf = __alloca(addrlen - 3);
-    memcpy(addrbuf, addr + 4, addrlen - 3);
+    addr    += static_strlen("udp:");
+    addrlen -= static_strlen("udp:");
+
+    char * addrbuf = __alloca(addrlen);
+    memcpy(addrbuf, addr, addrlen);
 
     struct sockaddr conn_addr;
     unsigned int conn_addrlen = sizeof(struct sockaddr);
@@ -730,23 +739,11 @@ static int socket_attrquerybyhdl (PAL_HANDLE handle, PAL_STREAM_ATTR  * attr)
     if (handle->sock.fd == PAL_IDX_POISON)
         return -PAL_ERROR_BADHANDLE;
 
-    int fd = handle->sock.fd, ret;
-
-    memset(attr, 0, sizeof(PAL_STREAM_ATTR));
-
-    attr->disconnected = HANDLE_HDR(handle)->flags & ERROR(0);
-
-    if (handle->sock.conn) {
-        /* try use ioctl FIONEAD to get the size of socket */
-        ret = ocall_fionread(fd);
-        if (ret >= 0)
-            attr->pending_size = ret;
-    }
-
-    attr->readable  = (attr->pending_size > 0);
-    attr->writeable = HANDLE_HDR(handle)->flags & WRITEABLE(0);
-
+    attr->handle_type           = HANDLE_HDR(handle)->type;
+    attr->disconnected          = HANDLE_HDR(handle)->flags & ERROR(0);
     attr->nonblocking           = handle->sock.nonblocking;
+    attr->writeable             = HANDLE_HDR(handle)->flags & WRITEABLE(0);
+    attr->pending_size          = 0; /* fill in later */
     attr->socket.linger         = handle->sock.linger;
     attr->socket.receivebuf     = handle->sock.receivebuf;
     attr->socket.sendbuf        = handle->sock.sendbuf;
@@ -755,6 +752,21 @@ static int socket_attrquerybyhdl (PAL_HANDLE handle, PAL_STREAM_ATTR  * attr)
     attr->socket.tcp_cork       = handle->sock.tcp_cork;
     attr->socket.tcp_keepalive  = handle->sock.tcp_keepalive;
     attr->socket.tcp_nodelay    = handle->sock.tcp_nodelay;
+
+    int fd = handle->sock.fd, ret;
+
+    if (handle->sock.conn) {
+        /* try use ioctl FIONEAD to get the size of socket */
+        ret = ocall_fionread(fd);
+        if (ret < 0)
+            return ret;
+
+        attr->pending_size = ret;
+        attr->readable = !!attr->pending_size > 0;
+    } else {
+        attr->readable = !attr->disconnected;
+    }
+
     return 0;
 }
 

+ 2 - 0
Pal/src/host/Linux-SGX/enclave_ocalls.c

@@ -611,6 +611,8 @@ int ocall_sock_recv (int sockfd, void * buf, unsigned int count,
 
         COPY_FROM_USER(buf, ms->ms_buf, retval);
         COPY_FROM_USER(addr, ms->ms_addr, ms->ms_addrlen);
+        if (addrlen)
+            *addrlen = ms->ms_addrlen;
     }
     OCALL_EXIT();
     return retval;

+ 1 - 1
Pal/src/host/Linux-SGX/linux_types.h

@@ -41,7 +41,7 @@ typedef unsigned short int sa_family_t;
 
 struct sockaddr {
     sa_family_t sa_family;
-    char sa_data[14];
+    char sa_data[128 - sizeof(unsigned short)];
 };
 
 #ifndef AF_UNIX

+ 3 - 0
Pal/src/host/Linux-SGX/sgx_enclave.c

@@ -421,6 +421,9 @@ static int sgx_ocall_sock_recv(void * pms)
                          ms->ms_sockfd, ms->ms_buf, ms->ms_count, 0,
                          addr, addr ? &addrlen : NULL);
 
+    if (!IS_ERR(ret) && addr)
+        ms->ms_addrlen = addrlen;
+
     return IS_ERR(ret) ? unix_to_pal_error(ERRNO(ret)) : ret;
 }
 

+ 1 - 1
Pal/src/host/Linux-SGX/sgx_framework.c

@@ -205,7 +205,7 @@ int add_pages_to_enclave(sgx_arch_secs_t * secs,
             break;
     }
 
-    param.addr = (uint64_t) addr;
+    param.addr = secs->baseaddr + (uint64_t) addr;
     param.user_addr = (uint64_t) user_addr;
     param.size = size;
     param.secinfo = &secinfo;

+ 30 - 21
Pal/src/host/Linux/db_sockets.c

@@ -75,7 +75,7 @@ static int inet_parse_uri (char ** uri, struct sockaddr * addr, int * addrlen)
 
     if (tmp[0] == '[') {
         /* for IPv6, the address will be in the form of
-           "[xx:xx:xx:xx:xx:xx:xx:xx]:port". */
+           "[xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx]:port". */
         struct sockaddr_in6 * addr_in6 = (struct sockaddr_in6 *) addr;
 
         slen = sizeof(struct sockaddr_in6);
@@ -157,11 +157,11 @@ static int inet_create_uri (char * uri, int count, struct sockaddr * addr,
             return PAL_ERROR_INVAL;
 
         struct sockaddr_in6 * addr_in6 = (struct sockaddr_in6 *) addr;
-        short * addr = (short *) &addr_in6->sin6_addr.s6_addr;
+        unsigned short * addr = (unsigned short *) &addr_in6->sin6_addr.s6_addr;
 
         /* for IPv6, the address will be in the form of
-           "[xx:xx:xx:xx:xx:xx:xx:xx]:port". */
-        len = snprintf(uri, count, "[%x:%x:%x:%x:%x:%x:%x:%x]:%u",
+           "[xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx]:port". */
+        len = snprintf(uri, count, "[%04x:%04x:%04x:%04x:%04x:%04x:%04x:%04x]:%u",
                        addr[0], addr[1], addr[2], addr[3],
                        addr[4], addr[5], addr[6], addr[7],
                        __ntohs(addr_in6->sin6_port));
@@ -169,6 +169,9 @@ static int inet_create_uri (char * uri, int count, struct sockaddr * addr,
         return -PAL_ERROR_INVAL;
     }
 
+    if (len >= count)
+        return -PAL_ERROR_TOOLONG;
+
     return len;
 }
 
@@ -808,8 +811,10 @@ static int udp_receivebyaddr (PAL_HANDLE handle, int offset, int len,
     if (!addr_uri)
         return -PAL_ERROR_OVERFLOW;
 
-    inet_create_uri(addr_uri, addr + addrlen - addr_uri, &conn_addr,
-                    hdr.msg_namelen);
+    int ret = inet_create_uri(addr_uri, addr + addrlen - addr_uri, &conn_addr,
+                              hdr.msg_namelen);
+    if (ret < 0)
+        return ret;
 
     return bytes;
 }
@@ -868,8 +873,11 @@ static int udp_sendbyaddr (PAL_HANDLE handle, int offset, int len,
     if (!strpartcmp_static(addr, "udp:"))
         return -PAL_ERROR_INVAL;
 
-    char * addrbuf = __alloca(addrlen - 3);
-    memcpy(addrbuf, addr + 4, addrlen - 3);
+    addr    += static_strlen("udp:");
+    addrlen -= static_strlen("udp:");
+
+    char * addrbuf = __alloca(addrlen);
+    memcpy(addrbuf, addr, addrlen);
 
     struct sockaddr conn_addr;
     int conn_addrlen;
@@ -976,6 +984,20 @@ static int socket_attrquerybyhdl (PAL_HANDLE handle, PAL_STREAM_ATTR  * attr)
     if (handle->sock.fd == PAL_IDX_POISON)
         return -PAL_ERROR_BADHANDLE;
 
+    attr->handle_type           = HANDLE_HDR(handle)->type;
+    attr->disconnected          = HANDLE_HDR(handle)->flags & ERROR(0);
+    attr->nonblocking           = handle->sock.nonblocking;
+    attr->writeable             = HANDLE_HDR(handle)->flags & WRITEABLE(0);
+    attr->pending_size          = 0; /* fill in later */
+    attr->socket.linger         = handle->sock.linger;
+    attr->socket.receivebuf     = handle->sock.receivebuf;
+    attr->socket.sendbuf        = handle->sock.sendbuf;
+    attr->socket.receivetimeout = handle->sock.receivetimeout;
+    attr->socket.sendtimeout    = handle->sock.sendtimeout;
+    attr->socket.tcp_cork       = handle->sock.tcp_cork;
+    attr->socket.tcp_keepalive  = handle->sock.tcp_keepalive;
+    attr->socket.tcp_nodelay    = handle->sock.tcp_nodelay;
+
     int fd = handle->sock.fd, ret, val;
 
     if (handle->sock.conn) {
@@ -990,19 +1012,6 @@ static int socket_attrquerybyhdl (PAL_HANDLE handle, PAL_STREAM_ATTR  * attr)
         attr->readable = !attr->disconnected;
     }
 
-    attr->handle_type           = HANDLE_HDR(handle)->type;
-    attr->disconnected          = HANDLE_HDR(handle)->flags & ERROR(0);
-    attr->nonblocking           = handle->sock.nonblocking;
-    attr->writeable             = HANDLE_HDR(handle)->flags & WRITEABLE(0);
-    attr->socket.linger         = handle->sock.linger;
-    attr->socket.receivebuf     = handle->sock.receivebuf;
-    attr->socket.sendbuf        = handle->sock.sendbuf;
-    attr->socket.receivetimeout = handle->sock.receivetimeout;
-    attr->socket.sendtimeout    = handle->sock.sendtimeout;
-    attr->socket.tcp_cork       = handle->sock.tcp_cork;
-    attr->socket.tcp_keepalive  = handle->sock.tcp_keepalive;
-    attr->socket.tcp_nodelay    = handle->sock.tcp_nodelay;
-
     return 0;
 }