瀏覽代碼

[LibOS] Add VMA_TAINTED in mprotect

borysp 4 年之前
父節點
當前提交
12efb484f3

+ 4 - 1
LibOS/shim/src/bookkeep/shim_vma.c

@@ -733,13 +733,16 @@ static int __bkeep_mprotect (struct shim_vma * prev,
             /* If [start, end) contains the VMA, just update its protection. */
             if (start <= cur->start && cur->end <= end) {
                 cur->prot = prot;
+                if (cur->file && (prot & PROT_WRITE)) {
+                    cur->flags |= VMA_TAINTED;
+                }
             } else {
                 /* Create a new VMA for the protected area */
                 new = __get_new_vma();
                 new->start = cur->start > start ? cur->start : start;
                 new->end   = cur->end < end ? cur->end : end;
                 new->prot  = prot;
-                new->flags = cur->flags;
+                new->flags = cur->flags | ((cur->file && (prot & PROT_WRITE)) ? VMA_TAINTED : 0);
                 new->file  = cur->file;
                 if (new->file) {
                     get_handle(new->file);

+ 3 - 3
LibOS/shim/src/sys/shim_mmap.c

@@ -48,12 +48,12 @@ void* shim_do_mmap(void* addr, size_t length, int prot, int flags, int fd, off_t
     if (fd >= 0 && !IS_ALLOC_ALIGNED(offset))
         return (void*)-EINVAL;
 
-    if (!length || !access_ok(addr, length))
-        return (void*)-EINVAL;
-
     if (!IS_ALLOC_ALIGNED(length))
         length = ALLOC_ALIGN_UP(length);
 
+    if (!length || !access_ok(addr, length))
+        return (void*)-EINVAL;
+
     /* ignore MAP_32BIT when MAP_FIXED is set */
     if ((flags & (MAP_32BIT | MAP_FIXED)) == (MAP_32BIT | MAP_FIXED))
         flags &= ~MAP_32BIT;

+ 80 - 0
LibOS/shim/test/regression/mprotect_file_fork.c

@@ -0,0 +1,80 @@
+#include <err.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#define FNAME "/tmp/test"
+
+#define VAL 0xff
+
+int main(void) {
+    int fd;
+    void *ptr;
+
+    if (mkdir("/tmp", S_IRWXU | S_IRWXG | S_IRWXO) < 0 && errno != EEXIST) {
+        err(1, "mkdir");
+    }
+    if (unlink(FNAME) < 0 && errno != ENOENT) {
+        err(1, "unlink");
+    }
+
+    fd = open(FNAME, O_CREAT | O_EXCL | O_RDWR, S_IRUSR | S_IWUSR);
+    if (fd < 0) {
+        err(1, "open");
+    }
+    if (ftruncate(fd, 0x10) < 0) {
+        err(1, "ftruncate");
+    }
+
+    ptr = mmap(NULL, 0x1000, PROT_READ, MAP_PRIVATE, fd, 0);
+    if (ptr == MAP_FAILED) {
+        err(1, "mmap");
+    }
+
+    if (close(fd) < 0) {
+        err(1, "close");
+    }
+
+    if (mprotect(ptr, 0x1000, PROT_READ | PROT_WRITE) < 0) {
+        err(1, "mprotect");
+    }
+
+    *(int*)ptr = VAL;
+
+    pid_t p = fork();
+    if (p < 0) {
+        err(1, "fork");
+    }
+
+    if (p == 0) {
+        // child
+        if (*(int*)ptr != VAL) {
+            printf("EXPECTED: 0x%x\nGOT     : 0x%x\n", VAL, *(int*)ptr);
+            return 1;
+        }
+        return 0;
+    }
+
+    // parent
+    int st = 0;
+    if (wait(&st) < 0) {
+        err(1, "wait");
+    }
+
+    if (unlink(FNAME) < 0) {
+        err(1, "unlink");
+    }
+
+    if (!WIFEXITED(st) || WEXITSTATUS(st) != 0) {
+        printf("abnormal child termination: %d\n", st);
+        return 1;
+    }
+
+    puts("Test successful!");
+    return 0;
+}

+ 6 - 1
LibOS/shim/test/regression/test_libos.py

@@ -316,7 +316,7 @@ class TC_30_Syscall(RegressionTestCase):
         self.assertIn('mmap test 5 passed', stdout)
         self.assertIn('mmap test 8 passed', stdout)
 
-    def test_52_large_mmap(self):
+    def test_052_large_mmap(self):
         stdout, stderr = self.run_binary(['large-mmap'], timeout=480)
 
         # Ftruncate
@@ -326,6 +326,11 @@ class TC_30_Syscall(RegressionTestCase):
         self.assertIn('large-mmap: mmap 1 completed OK', stdout)
         self.assertIn('large-mmap: mmap 2 completed OK', stdout)
 
+    def test_053_mprotect_file_fork(self):
+        stdout, _ = self.run_binary(['mprotect_file_fork'])
+
+        self.assertIn('Test successful!', stdout)
+
     @unittest.skip('sigaltstack isn\'t correctly implemented')
     def test_060_sigaltstack(self):
         stdout, stderr = self.run_binary(['sigaltstack'])