ソースを参照

[LibOS] Support of test_user_{memory/string} in SGX PAL

Previously, test_user_{memory/string} failed on Linux-SGX PAL.
These functions check if user-supplied buffer/string is accessible
inside LibOS. There was only one option: setup a segfault handler
to catch memory errors and access one byte of each page in user-supplied
buffer/string. If some byte is inaccessible, then exception is raised
and caught by segfault handler. The handler checks the faulting
address and redirects back to function and it returns with failure.

This option doesn't work under SGX because there is no trusted field
that contains the faulting address. SGX v1 doesn't have such field at
all, SGX v2 introduces a field in SSA.MISC region but only at the
4K-page granularity.

This commit adds a second option used for SGX PAL: functions consult
LibOS's internal VMA bookkeeping. If buffer/string is not in any VMA,
then these functions fail. This option is slightly slower than first one
since it requires locking and traversing a list of VMAs.
Dmitrii Kuvaiskii 5 年 前
コミット
e307eca021

+ 30 - 9
LibOS/shim/include/shim_internal.h

@@ -764,6 +764,19 @@ extern const char ** initial_envp;
 #define ALIGN_DOWN(addr)    \
     ((__typeof__(addr)) (((unsigned long) addr) & allocmask))
 
+void get_brk_region (void ** start, void ** end, void ** current);
+
+int reset_brk (void);
+int init_brk_region (void * brk_region);
+int init_heap (void);
+int init_internal_map (void);
+int init_loader (void);
+int init_manifest (PAL_HANDLE manifest_handle);
+
+bool test_user_memory (void * addr, size_t size, bool write);
+bool test_user_string (const char * addr);
+
+#ifdef __x86_64__
 #define switch_stack(stack_top)                                         \
     ({                                                                  \
         void * _rsp, * _rbp;                                            \
@@ -785,16 +798,24 @@ static_always_inline void * current_stack(void)
     return _rsp;
 }
 
-void get_brk_region (void ** start, void ** end, void ** current);
+static_always_inline bool __range_not_ok(unsigned long addr, unsigned long size) {
+    addr += size;
+    if (addr < size) {
+        /* pointer arithmetic overflow, this check is x86-64 specific */
+        return true;
+    }
+    return false;
+}
 
-int reset_brk (void);
-int init_brk_region (void * brk_region);
-int init_heap (void);
-int init_internal_map (void);
-int init_loader (void);
-int init_manifest (PAL_HANDLE manifest_handle);
+/* Check if pointer to memory region is valid. Return true if the memory
+ * region may be valid, false if it is definitely invalid. */
+#define access_ok(addr, size)                                       \
+    ({                                                              \
+        !__range_not_ok((unsigned long)addr, (unsigned long)size);  \
+    })
 
-bool test_user_memory (void * addr, size_t size, bool write);
-bool test_user_string (const char * addr);
+#else
+# error "Unsupported architecture"
+#endif /* __x86_64__ */
 
 #endif /* _PAL_INTERNAL_H_ */

+ 3 - 0
LibOS/shim/include/shim_vma.h

@@ -133,6 +133,9 @@ int lookup_vma (void * addr, struct shim_vma_val * vma);
 int lookup_overlap_vma (void * addr, uint64_t length,
                         struct shim_vma_val * vma);
 
+/* True if [addr, addr+length) is found in one VMA (valid memory region) */
+bool is_in_one_vma (void * addr, size_t length);
+
 /*
  * Looking for an unmapped space and then adding the corresponding bookkeeping
  * (more info in bookkeep/shim_vma.c).

+ 77 - 14
LibOS/shim/src/bookkeep/shim_signal.c

@@ -300,6 +300,38 @@ ret_exception:
     DkExceptionReturn(event);
 }
 
+/*
+ * Helper function for test_user_memory / test_user_string; they behave
+ * differently for different PALs:
+ *
+ * - For Linux-SGX, the faulting address is not propagated in memfault
+ *   exception (SGX v1 does not write address in SSA frame, SGX v2 writes
+ *   it only at a granularity of 4K pages). Thus, we cannot rely on
+ *   exception handling to compare against tcb.test_range.start/end.
+ *   Instead, traverse VMAs to see if [addr, addr+size) is addressable;
+ *   before traversing VMAs, grab a VMA lock.
+ *
+ * - For other PALs, we touch one byte of each page in [addr, addr+size).
+ *   If some byte is not addressable, exception is raised. memfault_upcall
+ *   handles this exception and resumes execution from ret_fault.
+ *
+ * The second option is faster in fault-free case but cannot be used under
+ * SGX PAL. We use the best option for each PAL for now. */
+static bool is_sgx_pal(void) {
+    static struct atomic_int sgx_pal = { .counter = 0 };
+    static struct atomic_int inited  = { .counter = 0 };
+
+    if (!atomic_read(&inited)) {
+        /* Ensure that is_sgx_pal is updated before initialized */
+        atomic_set(&sgx_pal, strcmp_static(PAL_CB(host_type), "Linux-SGX"));
+        mb();
+        atomic_set(&inited, 1);
+    }
+    mb();
+
+    return atomic_read(&sgx_pal) != 0;
+}
+
 /*
  * 'test_user_memory' and 'test_user_string' are helper functions for testing
  * if a user-given buffer or data structure is readable / writable (according
@@ -308,28 +340,33 @@ ret_exception:
  * guarantee further corruption of the buffer, or if the buffer is unmapped
  * with a concurrent system call. The purpose of these functions is simply for
  * the compatibility with programs that rely on the error numbers, such as the
- * LTP test suite.
- */
+ * LTP test suite. */
 bool test_user_memory (void * addr, size_t size, bool write)
 {
     if (!size)
         return false;
 
+    if (!access_ok(addr, size))
+        return true;
+
+    /* SGX path: check if [addr, addr+size) is addressable (in some VMA) */
+    if (is_sgx_pal())
+        return !is_in_one_vma(addr, size);
+
+    /* Non-SGX path: check if [addr, addr+size) is addressable by touching
+     * a byte of each page; invalid access will be caught in memfault_upcall */
     shim_tcb_t * tcb = shim_get_tls();
     assert(tcb && tcb->tp);
     __disable_preempt(tcb);
 
-    if (addr + size - 1 < addr)
-        size = (void *) 0x0 - addr;
-
-    bool has_fault = true;
+    bool  has_fault = true;
 
     /* Add the memory region to the watch list. This is not racy because
      * each thread has its own record. */
     assert(!tcb->test_range.cont_addr);
     tcb->test_range.cont_addr = &&ret_fault;
     tcb->test_range.start = addr;
-    tcb->test_range.end = addr + size - 1;
+    tcb->test_range.end   = addr + size - 1;
 
     /* Try to read or write into one byte inside each page */
     void * tmp = addr;
@@ -359,6 +396,32 @@ ret_fault:
  */
 bool test_user_string (const char * addr)
 {
+    if (!access_ok(addr, 1))
+        return true;
+
+    size_t size, maxlen;
+    const char * next = ALIGN_UP(addr + 1);
+
+    /* SGX path: check if [addr, addr+size) is addressable (in some VMA). */
+    if (is_sgx_pal()) {
+        /* We don't know length but using unprotected strlen() is dangerous
+         * so we check string in chunks of 4K pages. */
+        do {
+            maxlen = next - addr;
+
+            if (!access_ok(addr, maxlen) || !is_in_one_vma((void*) addr, maxlen))
+                return true;
+
+            size = strnlen(addr, maxlen);
+            addr = next;
+            next = ALIGN_UP(addr + 1);
+        } while (size == maxlen);
+
+        return false;
+    }
+
+    /* Non-SGX path: check if [addr, addr+size) is addressable by touching
+     * a byte of each page; invalid access will be caught in memfault_upcall. */
     shim_tcb_t * tcb = shim_get_tls();
     assert(tcb && tcb->tp);
     __disable_preempt(tcb);
@@ -368,22 +431,22 @@ bool test_user_string (const char * addr)
     assert(!tcb->test_range.cont_addr);
     tcb->test_range.cont_addr = &&ret_fault;
 
-    /* Test one page at a time. */
-    const char * next = ALIGN_UP(addr + 1);
     do {
         /* Add the memory region to the watch list. This is not racy because
          * each thread has its own record. */
         tcb->test_range.start = (void *) addr;
         tcb->test_range.end = (void *) (next - 1);
-        *(volatile char *) addr; /* try to read one byte from the page */
 
-        /* If the string ends in this page, exit the loop. */
-        if (strnlen(addr, next - addr) < next - addr)
-            break;
+        maxlen = next - addr;
+
+        if (!access_ok(addr, maxlen))
+            return true;
+        *(volatile char *) addr; /* try to read one byte from the page */
 
+        size = strnlen(addr, maxlen);
         addr = next;
         next = ALIGN_UP(addr + 1);
-    } while (addr < next);
+    } while (size == maxlen);
 
     has_fault = false; /* All accesses have passed. Nothing wrong. */
 

+ 20 - 0
LibOS/shim/src/bookkeep/shim_vma.c

@@ -121,6 +121,7 @@ static LOCKTYPE vma_list_lock;
 static inline bool test_vma_equal (struct shim_vma * vma,
                                    void * s, void * e)
 {
+    assert(s < e);
     return vma->start == s && vma->end == e;
 }
 
@@ -130,6 +131,7 @@ static inline bool test_vma_equal (struct shim_vma * vma,
 static inline bool test_vma_contain (struct shim_vma * vma,
                                      void * s, void * e)
 {
+    assert(s < e);
     return vma->start <= s && vma->end >= e;
 }
 
@@ -139,6 +141,7 @@ static inline bool test_vma_contain (struct shim_vma * vma,
 static inline bool test_vma_startin (struct shim_vma * vma,
                                      void * s, void * e)
 {
+    assert(s < e);
     return vma->start >= s && vma->start < e;
 }
 
@@ -148,6 +151,7 @@ static inline bool test_vma_startin (struct shim_vma * vma,
 static inline bool test_vma_endin (struct shim_vma * vma,
                                    void * s, void * e)
 {
+    assert(s < e);
     return vma->end > s && vma->end <= e;
 }
 
@@ -157,6 +161,7 @@ static inline bool test_vma_endin (struct shim_vma * vma,
 static inline bool test_vma_overlap (struct shim_vma * vma,
                                      void * s, void * e)
 {
+    assert(s < e);
     return test_vma_contain(vma, s, s + 1) ||
            test_vma_contain(vma, e - 1, e) ||
            test_vma_startin(vma, s, e);
@@ -984,6 +989,21 @@ int lookup_overlap_vma (void * addr, uint64_t length,
     return 0;
 }
 
+bool is_in_one_vma (void * addr, size_t length)
+{
+    struct shim_vma* vma;
+
+    lock(vma_list_lock);
+    listp_for_each_entry(vma, &vma_list, list)
+        if (test_vma_contain(vma, addr, addr + length)) {
+            unlock(vma_list_lock);
+            return true;
+        }
+
+    unlock(vma_list_lock);
+    return false;
+}
+
 int dump_all_vmas (struct shim_vma_val * vmas, size_t max_count)
 {
     struct shim_vma_val * val = vmas;

+ 16 - 0
LibOS/shim/test/regression/30_stat.py

@@ -0,0 +1,16 @@
+import os, sys, mmap
+from regression import Regression
+
+loader = sys.argv[1]
+
+# Running stat
+regression = Regression(loader, "stat_invalid_args")
+
+regression.add_check(name="Stat with invalid arguments",
+    check=lambda res: "stat(invalid-path-ptr) correctly returns error" in res[0].out and \
+                      "stat(invalid-buf-ptr) correctly returns error" in res[0].out and \
+                      "lstat(invalid-path-ptr) correctly returns error" in res[0].out and \
+                      "lstat(invalid-buf-ptr) correctly returns error" in res[0].out)
+
+rv = regression.run_checks()
+if rv: sys.exit(rv)

+ 37 - 0
LibOS/shim/test/regression/stat_invalid_args.c

@@ -0,0 +1,37 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#include <errno.h>
+
+int main (int argc, char** argv) {
+    int r;
+    struct stat buf;
+
+    char* goodpath = argv[0];
+    char* badpath  = (void*)-1;
+
+    struct stat* goodbuf = &buf;
+    struct stat* badbuf  = (void*)-1;
+
+    /* check stat() */
+    r = stat(badpath, goodbuf);
+    if (r == -1 && errno == EFAULT)
+        printf("stat(invalid-path-ptr) correctly returns error\n");
+
+    r = stat(goodpath, badbuf);
+    if (r == -1 && errno == EFAULT)
+        printf("stat(invalid-buf-ptr) correctly returns error\n");
+
+    /* check lstat() */
+    r = lstat(badpath, goodbuf);
+    if (r == -1 && errno == EFAULT)
+        printf("lstat(invalid-path-ptr) correctly returns error\n");
+
+    r = lstat(goodpath, badbuf);
+    if (r == -1 && errno == EFAULT)
+        printf("lstat(invalid-buf-ptr) correctly returns error\n");
+
+    return 0;
+}