#!/usr/bin/env python3 import argparse import datetime import hashlib import os import shutil import struct import subprocess import sys sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from generated_offsets import * # Default / Architectural Options ARCHITECTURE = "amd64" SSAFRAMESIZE = PAGESIZE DEFAULT_ENCLAVE_SIZE = '256M' DEFAULT_THREAD_NUM = 4 enclave_heap_min = DEFAULT_HEAP_MIN # Utilities ZERO_PAGE = bytes(PAGESIZE) def roundup(addr): remaining = addr % PAGESIZE if remaining: return addr + (PAGESIZE - remaining) return addr def rounddown(addr): return addr - addr % PAGESIZE def parse_size(s): scale = 1 if s.endswith("K"): scale = 1024 if s.endswith("M"): scale = 1024 * 1024 if s.endswith("G"): scale = 1024 * 1024 * 1024 if scale != 1: s = s[:-1] return int(s, 0) * scale # Reading / Writing Manifests def read_manifest(filename): manifest = dict() manifest_layout = [] with open(filename, "r") as f: for line in f.readlines(): if line == "": manifest_layout.append((None, None)) break pound = line.find("#") if pound != -1: comment = line[pound:].strip() line = line[:pound] else: comment = None line = line.strip() equal = line.find("=") if equal != -1: key = line[:equal].strip() manifest[key] = line[equal + 1:].strip() else: key = None manifest_layout.append((key, comment)) return (manifest, manifest_layout) def output_manifest(filename, manifest, manifest_layout): with open(filename, 'w') as f: written = [] for (key, comment) in manifest_layout: line = '' if key is not None: line += key + ' = ' + manifest[key] written.append(key) if comment is not None: if line != '': line += ' ' line += comment f.write(line) f.write('\n') f.write('\n') f.write("# Generated by Graphene\n") f.write('\n') for key in sorted(manifest): if key not in written: f.write("%s = %s\n" % (key, manifest[key])) # Loading Enclave Attributes def get_enclave_attributes(manifest): sgx_flags = { 'FLAG_DEBUG': struct.pack("= 7 and tokens[7] == "E": tokens[6] += tokens[7] prot = 0 for t in tokens[6]: if t == "R": prot = prot | 4 if t == "W": prot = prot | 2 if t == "E": prot = prot | 1 loadcmds.append((int(tokens[1][2:], 16), # offset int(tokens[2][2:], 16), # addr int(tokens[4][2:], 16), # filesize int(tokens[5][2:], 16), # memsize prot)) p.wait() if p.returncode != 0: return None return loadcmds class MemoryArea: def __init__(self, desc, file=None, content=None, addr=None, size=None, flags=None, measure=True): self.desc = desc self.file = file self.content = content self.addr = addr self.size = size self.flags = flags self.is_binary = False self.measure = measure if file: loadcmds = get_loadcmds(file) if loadcmds: mapaddr = 0xffffffffffffffff mapaddr_end = 0 for (_, addr_, _, memsize, _) in loadcmds: if rounddown(addr_) < mapaddr: mapaddr = rounddown(addr_) if roundup(addr_ + memsize) > mapaddr_end: mapaddr_end = roundup(addr_ + memsize) self.is_binary = True self.size = mapaddr_end - mapaddr if mapaddr > 0: self.addr = mapaddr else: self.size = os.stat(file).st_size if self.addr is not None: self.addr = rounddown(self.addr) if self.size is not None: self.size = roundup(self.size) def get_memory_areas(attr, args): areas = [] areas.append( MemoryArea('ssa', size=attr['thread_num'] * SSAFRAMESIZE * SSAFRAMENUM, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) areas.append(MemoryArea('tcs', size=attr['thread_num'] * TCS_SIZE, flags=PAGEINFO_TCS)) areas.append(MemoryArea('tls', size=attr['thread_num'] * PAGESIZE, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) for _ in range(attr['thread_num']): areas.append(MemoryArea('stack', size=ENCLAVE_STACK_SIZE, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) areas.append(MemoryArea('pal', file=args['libpal'], flags=PAGEINFO_REG)) if 'exec' in args: areas.append(MemoryArea('exec', file=args['exec'], flags=PAGEINFO_W | PAGEINFO_REG)) return areas def find_areas(areas, desc): return [area for area in areas if area.desc == desc] def find_area(areas, desc, allow_none=False): matching = find_areas(areas, desc) if not matching and allow_none: return None if len(matching) != 1: raise KeyError( "Could not find exactly one MemoryArea '{}'".format(desc)) return matching[0] def entry_point(elf_path): env = os.environ env['LC_ALL'] = 'C' out = subprocess.check_output( ['readelf', '-l', '--', elf_path], env=env) for line in out.splitlines(): line = line.decode() if line.startswith("Entry point "): return int(line[12:], 0) raise ValueError("Could not find entry point of elf file") def baseaddr(): if enclave_heap_min == 0: return ENCLAVE_HIGH_ADDRESS return 0 def gen_area_content(attr, areas): manifest_area = find_area(areas, 'manifest') exec_area = find_area(areas, 'exec', True) pal_area = find_area(areas, 'pal') ssa_area = find_area(areas, 'ssa') tcs_area = find_area(areas, 'tcs') tls_area = find_area(areas, 'tls') stacks = find_areas(areas, 'stack') tcs_data = bytearray(tcs_area.size) def set_tcs_field(t, offset, pack_fmt, value): struct.pack_into(pack_fmt, tcs_data, t * TCS_SIZE + offset, value) tls_data = bytearray(tls_area.size) def set_tls_field(t, offset, value): struct.pack_into('= enclave_heap_max or area is exec_area): if not area.measure: raise ValueError("Memory area, which is not the heap, " "is not measured") elif area.desc != 'free': raise ValueError("Unexpected memory area is in heap range") for t in range(0, attr['thread_num']): ssa_offset = ssa_area.addr + SSAFRAMESIZE * SSAFRAMENUM * t ssa = baseaddr() + ssa_offset set_tcs_field(t, TCS_OSSA, ' enclave_heap_min: flags = PAGEINFO_R | PAGEINFO_W | PAGEINFO_X | PAGEINFO_REG free_areas.append( MemoryArea('free', addr=enclave_heap_min, size=populating - enclave_heap_min, flags=flags, measure=False)) gen_area_content(attr, areas) return areas + free_areas def generate_measurement(attr, areas): def do_ecreate(digest, size): data = struct.pack("<8sLQ44s", b"ECREATE", SSAFRAMESIZE // PAGESIZE, size, b"") digest.update(data) def do_eadd(digest, offset, flags): data = struct.pack("<8sQQ40s", b"EADD", offset, flags, b"") digest.update(data) def do_eextend(digest, offset, content): if len(content) != 256: raise ValueError("Exactly 256 bytes expected") data = struct.pack("<8sQ48s", b"EEXTEND", offset, b"") digest.update(data) digest.update(content) def include_page(digest, offset, flags, content, measure): if len(content) != PAGESIZE: raise ValueError("Exactly one page expected") do_eadd(digest, offset, flags) if measure: for i in range(0, PAGESIZE, 256): do_eextend(digest, offset + i, content[i:i + 256]) mrenclave = hashlib.sha256() do_ecreate(mrenclave, attr['enclave_size']) def print_area(addr, size, flags, desc, measured): if flags & PAGEINFO_REG: type_ = 'REG' if flags & PAGEINFO_TCS: type_ = 'TCS' prot = ['-', '-', '-'] if flags & PAGEINFO_R: prot[0] = 'R' if flags & PAGEINFO_W: prot[1] = 'W' if flags & PAGEINFO_X: prot[2] = 'X' prot = ''.join(prot) desc = '(' + desc + ')' if measured: desc += ' measured' if size == PAGESIZE: print(" %016x [%s:%s] %s" % (addr, type_, prot, desc)) else: print(" %016x-%016lx [%s:%s] %s" % (addr, addr + size, type_, prot, desc)) def load_file(digest, f, offset, addr, filesize, memsize, desc, flags): f_addr = rounddown(offset) m_addr = rounddown(addr) m_size = roundup(addr + memsize) - m_addr print_area(m_addr, m_size, flags, desc, True) for pg in range(m_addr, m_addr + m_size, PAGESIZE): start = pg - m_addr + f_addr end = start + PAGESIZE start_zero = b"" if start < offset: if offset - start >= PAGESIZE: start_zero = ZERO_PAGE else: start_zero = bytes(offset - start) end_zero = b"" if end > offset + filesize: if end - offset - filesize >= PAGESIZE: end_zero = ZERO_PAGE else: end_zero = bytes(end - offset - filesize) start += len(start_zero) end -= len(end_zero) if start < end: f.seek(start) data = f.read(end - start) else: data = b"" if len(start_zero + data + end_zero) != PAGESIZE: raise Exception("wrong calculation") include_page(digest, pg, flags, start_zero + data + end_zero, True) for area in areas: if area.file: with open(area.file, 'rb') as f: if area.is_binary: loadcmds = get_loadcmds(area.file) if loadcmds: mapaddr = 0xffffffffffffffff for (offset, addr, filesize, memsize, prot) in loadcmds: if rounddown(addr) < mapaddr: mapaddr = rounddown(addr) baseaddr_ = area.addr - mapaddr for (offset, addr, filesize, memsize, prot) in loadcmds: flags = area.flags if prot & 4: flags = flags | PAGEINFO_R if prot & 2: flags = flags | PAGEINFO_W if prot & 1: flags = flags | PAGEINFO_X if flags & PAGEINFO_X: desc = 'code' else: desc = 'data' load_file(mrenclave, f, offset, baseaddr_ + addr, filesize, memsize, desc, flags) else: load_file(mrenclave, f, 0, area.addr, os.stat(area.file).st_size, area.size, area.desc, area.flags) else: for a in range(area.addr, area.addr + area.size, PAGESIZE): data = ZERO_PAGE if area.content is not None: start = a - area.addr end = start + PAGESIZE data = area.content[start:end] include_page(mrenclave, a, area.flags, data, area.measure) print_area(area.addr, area.size, area.flags, area.desc, area.measure) return mrenclave.digest() def generate_sigstruct(attr, args, mrenclave): """Generate Sigstruct.""" today = datetime.date.today() # field format: (offset, type, value) fields = { 'header': (SGX_ARCH_SIGSTRUCT_HEADER, "<4L", 0x00000006, 0x000000e1, 0x00010000, 0x00000000), 'vendor': (SGX_ARCH_SIGSTRUCT_VENDOR, "= SGX_ARCH_SIGSTRUCT_MISCSELECT: struct.pack_into(field[1], sign_buffer, field[0] - SGX_ARCH_SIGSTRUCT_MISCSELECT + 128, *field[2:]) else: struct.pack_into(field[1], sign_buffer, field[0], *field[2:]) p = subprocess.Popen( ['openssl', 'rsa', '-modulus', '-in', args['key'], '-noout'], stdout=subprocess.PIPE) modulus_out = p.communicate()[0] modulus = bytes.fromhex(modulus_out[8:8+SGX_ARCH_KEY_SIZE*2].decode()) modulus = bytes(reversed(modulus)) p = subprocess.Popen( ['openssl', 'sha256', '-binary', '-sign', args['key']], stdin=subprocess.PIPE, stdout=subprocess.PIPE) signature = p.communicate(sign_buffer)[0] signature = signature[::-1] modulus_int = int.from_bytes(modulus, byteorder='little') signature_int = int.from_bytes(signature, byteorder='little') tmp1 = signature_int * signature_int q1_int = tmp1 // modulus_int tmp2 = tmp1 % modulus_int q2_int = tmp2 * signature_int // modulus_int q1 = q1_int.to_bytes(384, byteorder='little') q2 = q2_int.to_bytes(384, byteorder='little') fields.update({ 'modulus': (SGX_ARCH_SIGSTRUCT_MODULUS, "384s", modulus), 'exponent': (SGX_ARCH_SIGSTRUCT_EXPONENT, "