Просмотр исходного кода

[Pal/Linux-SGX] Measure all memory except the heap

Before this change some important memory areas, for example TCS and TLS
were not measured. With this change all mapped enclave memory is
measured with one exception. Since the EEXTEND hashing is rather slow
the heap is not measured. Instead it gets zeroed on enclave startup.

Closes #505.
Simon Gaiser 6 лет назад
Родитель
Сommit
014a938767

+ 16 - 0
Pal/src/host/Linux-SGX/db_main.c

@@ -214,6 +214,22 @@ void pal_linux_main(char * uptr_args, uint64_t args_size,
         return;
     }
 
+    /* Zero the heap. We need to take care to not zero the exec area. */
+
+    void* zero1_start = sec_info.heap_min;
+    void* zero1_end = sec_info.heap_max;
+
+    void* zero2_start = sec_info.heap_max;
+    void* zero2_end = sec_info.heap_max;
+
+    if (sec_info.exec_addr != NULL) {
+        zero1_end = MIN(zero1_end, sec_info.exec_addr);
+        zero2_start = MIN(zero2_start, sec_info.exec_addr + sec_info.exec_size);
+    }
+
+    memset(zero1_start, 0, zero1_end - zero1_start);
+    memset(zero2_start, 0, zero2_end - zero2_start);
+
     /* relocate PAL itself */
     pal_map.l_addr = elf_machine_load_address();
     pal_map.l_name = ENCLAVE_FILENAME;

+ 7 - 0
Pal/src/host/Linux-SGX/generated-offsets.c

@@ -30,6 +30,7 @@ void dummy(void)
     OFFSET_T(SGX_GPR_RFLAGS, sgx_arch_gpr_t, rflags);
     OFFSET_T(SGX_GPR_RIP, sgx_arch_gpr_t, rip);
     OFFSET_T(SGX_GPR_EXITINFO, sgx_arch_gpr_t, exitinfo);
+    DEFINE(SGX_GPR_SIZE, sizeof(sgx_arch_gpr_t));
 
     /* sgx_context_t */
     OFFSET_T(SGX_CONTEXT_RAX, sgx_context_t, rax);
@@ -70,6 +71,12 @@ void dummy(void)
     OFFSET(SGX_READY_FOR_EXCEPTIONS, enclave_tls, ready_for_exceptions);
 
     /* sgx_arch_tcs_t */
+    OFFSET_T(TCS_OSSA, sgx_arch_tcs_t, ossa);
+    OFFSET_T(TCS_NSSA, sgx_arch_tcs_t, nssa);
+    OFFSET_T(TCS_OENTRY, sgx_arch_tcs_t, oentry);
+    OFFSET_T(TCS_OGSBASGX, sgx_arch_tcs_t, ogsbasgx);
+    OFFSET_T(TCS_FSLIMIT, sgx_arch_tcs_t, fslimit);
+    OFFSET_T(TCS_GSLIMIT, sgx_arch_tcs_t, gslimit);
     DEFINE(TCS_SIZE, sizeof(sgx_arch_tcs_t));
 
     /* sgx_arch_attributes_t */

+ 5 - 8
Pal/src/host/Linux-SGX/sgx_main.c

@@ -217,7 +217,7 @@ int load_enclave_binary (sgx_arch_secs_t * secs, int fd,
         if (zeroend > zeropage) {
             ret = add_pages_to_enclave(secs, (void *) base + zeropage, NULL,
                                        zeroend - zeropage,
-                                       SGX_PAGE_REG, c->prot, 1, "bss");
+                                       SGX_PAGE_REG, c->prot, false, "bss");
             if (ret < 0)
                 return ret;
         }
@@ -332,22 +332,19 @@ int initialize_enclave (struct pal_enclave * enclave)
                  0, ALLOC_ALIGNUP(manifest_size),
                  PROT_READ, SGX_PAGE_REG);
     struct mem_area * ssa_area =
-        set_area("ssa", true, false, -1, 0,
+        set_area("ssa", false, false, -1, 0,
                  enclave->thread_num * enclave->ssaframesize * SSAFRAMENUM,
                  PROT_READ|PROT_WRITE, SGX_PAGE_REG);
-    /* XXX: TCS should be part of measurement */
     struct mem_area * tcs_area =
-        set_area("tcs", true, false, -1, 0, enclave->thread_num * pagesize,
+        set_area("tcs", false, false, -1, 0, enclave->thread_num * pagesize,
                  0, SGX_PAGE_TCS);
-    /* XXX: TLS should be part of measurement */
     struct mem_area * tls_area =
-        set_area("tls", true, false, -1, 0, enclave->thread_num * pagesize,
+        set_area("tls", false, false, -1, 0, enclave->thread_num * pagesize,
                  PROT_READ|PROT_WRITE, SGX_PAGE_REG);
 
-    /* XXX: the enclave stack should be part of measurement */
     struct mem_area * stack_areas = &areas[area_num];
     for (int t = 0 ; t < enclave->thread_num ; t++)
-        set_area("stack", true, false, -1, 0, ENCLAVE_STACK_SIZE,
+        set_area("stack", false, false, -1, 0, ENCLAVE_STACK_SIZE,
                  PROT_READ|PROT_WRITE, SGX_PAGE_REG);
 
     struct mem_area * pal_area =

+ 133 - 38
Pal/src/host/Linux-SGX/signer/pal-sgx-sign

@@ -25,6 +25,8 @@ enclave_heap_min = DEFAULT_HEAP_MIN
 
 """ Utilities """
 
+ZERO_PAGE = "\0" * PAGESIZE
+
 def roundup(addr):
     remaining = addr % PAGESIZE
     if remaining:
@@ -301,13 +303,15 @@ def get_loadcmds(filename):
     return loadcmds
 
 class MemoryArea:
-    def __init__(self, desc, file=None, addr=None, size=None, flags=None):
+    def __init__(self, desc, file=None, content=None, addr=None, size=None, flags=None, measure=True):
         self.desc = desc
         self.file = file
+        self.content = content
         self.addr = addr
         self.size = size
         self.flags = flags
         self.is_binary = False
+        self.measure = measure
 
         if file:
             loadcmds = get_loadcmds(file)
@@ -352,6 +356,81 @@ def get_memory_areas(manifest, attr, args):
                                 flags=PAGEINFO_W|PAGEINFO_REG))
     return areas
 
+def find_areas(areas, desc):
+    return filter(lambda area: area.desc == desc, areas)
+
+def find_area(areas, desc, allow_none=False):
+    matching = find_areas(areas, desc)
+
+    if len(matching) == 0 and allow_none:
+        return None
+
+    if len(matching) != 1:
+        raise KeyError("Could not find exactly one MemoryArea '{}'".format(desc))
+
+    return matching[0]
+
+def entry_point(elf_path):
+    env = os.environ
+    env['LC_ALL'] = 'C'
+    out = subprocess.check_output(['readelf', '-l', '--', elf_path], env = env)
+    for line in out.splitlines():
+        if line.startswith("Entry point "):
+            return parse_int(line[12:])
+    raise ValueError("Could not find entry point of elf file")
+
+def baseaddr():
+    if enclave_heap_min == 0:
+        return ENCLAVE_HIGH_ADDRESS
+    else:
+        return 0
+
+def gen_area_content(attr, areas):
+    exec_area = find_area(areas, 'exec', True)
+    pal_area = find_area(areas, 'pal')
+    ssa_area = find_area(areas, 'ssa')
+    tcs_area = find_area(areas, 'tcs')
+    tls_area = find_area(areas, 'tls')
+    stacks = find_areas(areas, 'stack')
+
+    tcs_data = bytearray(tcs_area.size)
+    def set_tcs_field(t, offset, pack_fmt, value):
+        struct.pack_into(pack_fmt, tcs_data, t * TCS_SIZE + offset, value)
+
+    tls_data = bytearray(tls_area.size)
+    def set_tls_field(t, offset, value):
+        struct.pack_into('<Q', tls_data, t * PAGESIZE + offset, value)
+
+    enclave_heap_max = pal_area.addr - MEMORY_GAP
+
+    # Sanity check that we measure everything except the heap which is zeroed
+    # on enclave startup.
+    for area in areas:
+        if area.addr + area.size <= enclave_heap_min or area.addr >= enclave_heap_max or area is exec_area:
+            if not area.measure:
+                raise ValueError("Memory area, which is not the heap, is not measured")
+        elif area.desc != 'free':
+            raise ValueError("Unexpected memory area is in heap range")
+
+    for t in range(0, attr['thread_num']):
+        ssa_offset = ssa_area.addr + SSAFRAMESIZE * SSAFRAMENUM * t;
+        ssa = baseaddr() + ssa_offset
+        set_tcs_field(t, TCS_OSSA, '<Q', ssa_offset)
+        set_tcs_field(t, TCS_NSSA, '<L', SSAFRAMENUM)
+        set_tcs_field(t, TCS_OENTRY, '<Q', pal_area.addr + entry_point(pal_area.file))
+        set_tcs_field(t, TCS_OGSBASGX, '<Q', tls_area.addr + PAGESIZE * t)
+        set_tcs_field(t, TCS_FSLIMIT, '<L', 0xfff)
+        set_tcs_field(t, TCS_GSLIMIT, '<L', 0xfff)
+
+        set_tls_field(t, SGX_ENCLAVE_SIZE, attr['enclave_size'])
+        set_tls_field(t, SGX_TCS_OFFSET, tcs_area.addr + TCS_SIZE * t)
+        set_tls_field(t, SGX_INITIAL_STACK_OFFSET, stacks[t].addr + stacks[t].size)
+        set_tls_field(t, SGX_SSA, ssa)
+        set_tls_field(t, SGX_GPR, ssa + SSAFRAMESIZE - SGX_GPR_SIZE)
+
+    tcs_area.content = tcs_data
+    tls_area.content = tls_data
+
 def populate_memory_areas(manifest, attr, areas):
     populating = attr['enclave_size']
 
@@ -372,13 +451,17 @@ def populate_memory_areas(manifest, attr, areas):
         if area.addr + area.size < populating:
             addr = area.addr + area.size
             free_areas.append(MemoryArea('free', addr=addr, size=populating - addr,
-                                flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG))
+                                flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG,
+                                measure=False))
             populating = area.addr
 
     if populating > enclave_heap_min:
         free_areas.append(MemoryArea('free', addr=enclave_heap_min,
                                      size=populating - enclave_heap_min,
-                                     flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG))
+                                     flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG,
+                                     measure=False))
+
+    gen_area_content(attr, areas)
 
     return areas + free_areas
 
@@ -392,9 +475,22 @@ def generate_measurement(attr, areas):
         data = struct.pack("<8sQQ40s", "EADD", offset, flags, "")
         digest.update(data)
 
-    def do_eextend(digest, offset):
+    def do_eextend(digest, offset, content):
+        if len(content) != 256:
+            raise ValueError("Exactly 256 bytes expected")
+
         data = struct.pack("<8sQ48s", "EEXTEND", offset, "")
         digest.update(data)
+        digest.update(content)
+
+    def include_page(digest, offset, flags, content, measure):
+        if len(content) != PAGESIZE:
+            raise ValueError("Exactly one page expected")
+
+        do_eadd(digest, offset, flags)
+        if measure:
+            for i in range(0, PAGESIZE, 256):
+                do_eextend(digest, offset + i, content[i:i + 256])
 
     mrenclave = hashlib.sha256()
     do_ecreate(mrenclave, attr['enclave_size'])
@@ -428,42 +524,34 @@ def generate_measurement(attr, areas):
         f_size = roundup(offset + filesize) - f_addr
         m_size = roundup(addr + memsize) - m_addr
 
-        print_area(m_addr, f_size, flags, desc, True)
-        if f_size < m_size:
-            print_area(m_addr + f_size, m_size - f_size, flags, "bss", False)
+        print_area(m_addr, m_size, flags, desc, True)
 
         for pg in range(m_addr, m_addr + m_size, PAGESIZE):
-            do_eadd(digest, pg, flags)
-
-            if (pg >= m_addr + f_size):
-                continue
-
-            for er in range(pg, pg + PAGESIZE, 256):
-                do_eextend(digest, er)
-                start = er - m_addr + f_addr
-                end = start + 256
-                start_zero = ""
-                if start < offset:
-                    if offset - start >= 256:
-                        start_zero = chr(0) * 256
-                    else:
-                        start_zero = chr(0) * (offset - start)
-                end_zero = ""
-                if end > offset + filesize:
-                    if end - offset - filesize >= 256:
-                        end_zero = chr(0) * 256
-                    else:
-                        end_zero = chr(0) * (end - offset - filesize)
-                start += len(start_zero)
-                end -= len(end_zero)
-                if start < end:
-                    f.seek(start)
-                    data = f.read(end - start)
+            start = pg - m_addr + f_addr
+            end = start + PAGESIZE
+            start_zero = ""
+            if start < offset:
+                if offset - start >= PAGESIZE:
+                    start_zero = ZERO_PAGE
                 else:
-                    data = ""
-                if len(start_zero + data + end_zero) != 256:
+                    start_zero = chr(0) * (offset - start)
+            end_zero = ""
+            if end > offset + filesize:
+                if end - offset - filesize >= PAGESIZE:
+                    end_zero = ZERO_PAGE
+                else:
+                    end_zero = chr(0) * (end - offset - filesize)
+            start += len(start_zero)
+            end -= len(end_zero)
+            if start < end:
+                f.seek(start)
+                data = f.read(end - start)
+            else:
+                data = ""
+            if len(start_zero + data + end_zero) != PAGESIZE:
                     raise Exception("wrong calculation")
-                digest.update(start_zero + data + end_zero)
+
+            include_page(digest, pg, flags, start_zero + data + end_zero, True)
 
     for area in areas:
         if area.file:
@@ -497,8 +585,15 @@ def generate_measurement(attr, areas):
                               area.desc, area.flags)
         else:
             for a in range(area.addr, area.addr + area.size, PAGESIZE):
-                do_eadd(mrenclave, a, area.flags)
-            print_area(area.addr, area.size, area.flags, area.desc, False)
+                data = ZERO_PAGE
+                if area.content is not None:
+                    start = a - area.addr
+                    end = start + PAGESIZE
+                    data = area.content[start:end]
+
+                include_page(mrenclave, a, area.flags, data, area.measure)
+
+            print_area(area.addr, area.size, area.flags, area.desc, area.measure)
 
     return mrenclave.digest()