#!/usr/bin/env python3 import argparse import datetime import functools import hashlib import os import struct import subprocess import sys sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) import generated_offsets as offs # pylint: disable=import-error,wrong-import-position # pylint: enable=invalid-name # Default / Architectural Options ARCHITECTURE = "amd64" SSAFRAMESIZE = offs.PAGESIZE DEFAULT_ENCLAVE_SIZE = '256M' DEFAULT_THREAD_NUM = 4 ENCLAVE_HEAP_MIN = offs.DEFAULT_HEAP_MIN # Utilities ZERO_PAGE = bytes(offs.PAGESIZE) def roundup(addr): remaining = addr % offs.PAGESIZE if remaining: return addr + (offs.PAGESIZE - remaining) return addr def rounddown(addr): return addr - addr % offs.PAGESIZE def parse_size(value): scale = 1 if value.endswith("K"): scale = 1024 if value.endswith("M"): scale = 1024 * 1024 if value.endswith("G"): scale = 1024 * 1024 * 1024 if scale != 1: value = value[:-1] return int(value, 0) * scale # Reading / Writing Manifests def read_manifest(filename): manifest = dict() manifest_layout = [] with open(filename, "r") as file: for line in file: if line == "": manifest_layout.append((None, None)) break pound = line.find("#") if pound != -1: comment = line[pound:].strip() line = line[:pound] else: comment = None line = line.strip() equal = line.find("=") if equal != -1: key = line[:equal].strip() manifest[key] = line[equal + 1:].strip() else: key = None manifest_layout.append((key, comment)) return (manifest, manifest_layout) def exec_sig_manifest(args, manifest): if 'exec' not in args or args.get('depend'): if 'loader.exec' in manifest: args['exec'] = resolve_manifest_uri(args['manifest'], manifest['loader.exec']) if 'sgx.sigfile' in manifest: args['sigfile'] = resolve_uri(manifest['sgx.sigfile'], check_exist=False) else: sigfile = args['output'] for ext in ['.manifest.sgx.d', '.manifest.sgx', '.manifest']: if sigfile.endswith(ext): sigfile = sigfile[:-len(ext)] break args['sigfile'] = sigfile + '.sig' manifest['sgx.sigfile'] = 'file:' + os.path.basename(args['sigfile']) if args.get('libpal', None) is None: if 'sgx.enclave_pal_file' in manifest: args['libpal'] = resolve_manifest_uri(args['manifest'], manifest['sgx.enclave_pal_file']) else: print("Either --libpal or sgx.enclave_pal_file must be given", file=sys.stderr) return 1 return 0 def output_manifest(filename, manifest, manifest_layout): with open(filename, 'w') as file: written = [] file.write('# DO NOT MODIFY. THIS FILE WAS AUTO-GENERATED.\n\n') for (key, comment) in manifest_layout: line = '' if key is not None: line += key + ' = ' + manifest[key] written.append(key) if comment is not None: if line != '': line += ' ' line += comment file.write(line) file.write('\n') file.write('\n') file.write('# Generated by Graphene\n') file.write('\n') for key in sorted(manifest): if key not in written: file.write("%s = %s\n" % (key, manifest[key])) # Loading Enclave Attributes def get_enclave_attributes(manifest): sgx_flags = { 'FLAG_DEBUG': struct.pack("= 7 and tokens[7] == "E": tokens[6] += tokens[7] prot = 0 for token in tokens[6]: if token == "R": prot = prot | 4 if token == "W": prot = prot | 2 if token == "E": prot = prot | 1 loadcmds.append((int(tokens[1][2:], 16), # offset int(tokens[2][2:], 16), # addr int(tokens[4][2:], 16), # filesize int(tokens[5][2:], 16), # memsize prot)) proc.wait() if proc.returncode != 0: return None return loadcmds class MemoryArea: # pylint: disable=too-few-public-methods,too-many-instance-attributes def __init__(self, desc, file=None, content=None, addr=None, size=None, flags=None, measure=True): # pylint: disable=too-many-arguments self.desc = desc self.file = file self.content = content self.addr = addr self.size = size self.flags = flags self.is_binary = False self.measure = measure if file: loadcmds = get_loadcmds(file) if loadcmds: mapaddr = 0xffffffffffffffff mapaddr_end = 0 for (_, addr_, _, memsize, _) in loadcmds: if rounddown(addr_) < mapaddr: mapaddr = rounddown(addr_) if roundup(addr_ + memsize) > mapaddr_end: mapaddr_end = roundup(addr_ + memsize) self.is_binary = True self.size = mapaddr_end - mapaddr if mapaddr > 0: self.addr = mapaddr else: self.size = os.stat(file).st_size if self.addr is not None: self.addr = rounddown(self.addr) if self.size is not None: self.size = roundup(self.size) def get_memory_areas(attr, args): areas = [] areas.append( MemoryArea('ssa', size=attr['thread_num'] * SSAFRAMESIZE * offs.SSAFRAMENUM, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) areas.append(MemoryArea('tcs', size=attr['thread_num'] * offs.TCS_SIZE, flags=PAGEINFO_TCS)) areas.append(MemoryArea('tls', size=attr['thread_num'] * offs.PAGESIZE, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) for _ in range(attr['thread_num']): areas.append(MemoryArea('stack', size=offs.ENCLAVE_STACK_SIZE, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) for _ in range(attr['thread_num']): areas.append(MemoryArea('sig_stack', size=offs.ENCLAVE_SIG_STACK_SIZE, flags=PAGEINFO_R | PAGEINFO_W | PAGEINFO_REG)) areas.append(MemoryArea('pal', file=args['libpal'], flags=PAGEINFO_REG)) if 'exec' in args: areas.append(MemoryArea('exec', file=args['exec'], flags=PAGEINFO_W | PAGEINFO_REG)) return areas def find_areas(areas, desc): return [area for area in areas if area.desc == desc] def find_area(areas, desc, allow_none=False): matching = find_areas(areas, desc) if not matching and allow_none: return None if len(matching) != 1: raise KeyError( "Could not find exactly one MemoryArea '{}'".format(desc)) return matching[0] def entry_point(elf_path): env = os.environ env['LC_ALL'] = 'C' out = subprocess.check_output( ['readelf', '-l', '--', elf_path], env=env) for line in out.splitlines(): line = line.decode() if line.startswith("Entry point "): return int(line[12:], 0) raise ValueError("Could not find entry point of elf file") def baseaddr(): if ENCLAVE_HEAP_MIN == 0: return offs.ENCLAVE_HIGH_ADDRESS return 0 def gen_area_content(attr, areas): # pylint: disable=too-many-locals manifest_area = find_area(areas, 'manifest') exec_area = find_area(areas, 'exec', True) pal_area = find_area(areas, 'pal') ssa_area = find_area(areas, 'ssa') tcs_area = find_area(areas, 'tcs') tls_area = find_area(areas, 'tls') stacks = find_areas(areas, 'stack') sig_stacks = find_areas(areas, 'sig_stack') tcs_data = bytearray(tcs_area.size) def set_tcs_field(t, offset, pack_fmt, value): struct.pack_into(pack_fmt, tcs_data, t * offs.TCS_SIZE + offset, value) tls_data = bytearray(tls_area.size) def set_tls_field(t, offset, value): struct.pack_into('= enclave_heap_max or area is exec_area): if not area.measure: raise ValueError("Memory area, which is not the heap, " "is not measured") elif area.desc != 'free': raise ValueError("Unexpected memory area is in heap range") for t in range(0, attr['thread_num']): ssa_offset = ssa_area.addr + SSAFRAMESIZE * offs.SSAFRAMENUM * t ssa = baseaddr() + ssa_offset set_tcs_field(t, offs.TCS_OSSA, ' ENCLAVE_HEAP_MIN: flags = PAGEINFO_R | PAGEINFO_W | PAGEINFO_X | PAGEINFO_REG free_areas.append( MemoryArea('free', addr=ENCLAVE_HEAP_MIN, size=populating - ENCLAVE_HEAP_MIN, flags=flags, measure=False)) gen_area_content(attr, areas) return areas + free_areas def generate_measurement(attr, areas): # pylint: disable=too-many-statements,too-many-branches,too-many-locals def do_ecreate(digest, size): data = struct.pack("<8sLQ44s", b"ECREATE", SSAFRAMESIZE // offs.PAGESIZE, size, b"") digest.update(data) def do_eadd(digest, offset, flags): data = struct.pack("<8sQQ40s", b"EADD", offset, flags, b"") digest.update(data) def do_eextend(digest, offset, content): if len(content) != 256: raise ValueError("Exactly 256 bytes expected") data = struct.pack("<8sQ48s", b"EEXTEND", offset, b"") digest.update(data) digest.update(content) def include_page(digest, offset, flags, content, measure): if len(content) != offs.PAGESIZE: raise ValueError("Exactly one page expected") do_eadd(digest, offset, flags) if measure: for i in range(0, offs.PAGESIZE, 256): do_eextend(digest, offset + i, content[i:i + 256]) mrenclave = hashlib.sha256() do_ecreate(mrenclave, attr['enclave_size']) def print_area(addr, size, flags, desc, measured): if flags & PAGEINFO_REG: type_ = 'REG' if flags & PAGEINFO_TCS: type_ = 'TCS' prot = ['-', '-', '-'] if flags & PAGEINFO_R: prot[0] = 'R' if flags & PAGEINFO_W: prot[1] = 'W' if flags & PAGEINFO_X: prot[2] = 'X' prot = ''.join(prot) desc = '(' + desc + ')' if measured: desc += ' measured' if size == offs.PAGESIZE: print(" %016x [%s:%s] %s" % (addr, type_, prot, desc)) else: print(" %016x-%016lx [%s:%s] %s" % (addr, addr + size, type_, prot, desc)) def load_file(digest, file, offset, addr, filesize, memsize, desc, flags): # pylint: disable=too-many-arguments f_addr = rounddown(offset) m_addr = rounddown(addr) m_size = roundup(addr + memsize) - m_addr print_area(m_addr, m_size, flags, desc, True) for page in range(m_addr, m_addr + m_size, offs.PAGESIZE): start = page - m_addr + f_addr end = start + offs.PAGESIZE start_zero = b"" if start < offset: if offset - start >= offs.PAGESIZE: start_zero = ZERO_PAGE else: start_zero = bytes(offset - start) end_zero = b"" if end > offset + filesize: if end - offset - filesize >= offs.PAGESIZE: end_zero = ZERO_PAGE else: end_zero = bytes(end - offset - filesize) start += len(start_zero) end -= len(end_zero) if start < end: file.seek(start) data = file.read(end - start) else: data = b"" if len(start_zero + data + end_zero) != offs.PAGESIZE: raise Exception("wrong calculation") include_page(digest, page, flags, start_zero + data + end_zero, True) for area in areas: if area.file: with open(area.file, 'rb') as file: if area.is_binary: loadcmds = get_loadcmds(area.file) if loadcmds: mapaddr = 0xffffffffffffffff for (offset, addr, filesize, memsize, prot) in loadcmds: if rounddown(addr) < mapaddr: mapaddr = rounddown(addr) baseaddr_ = area.addr - mapaddr for (offset, addr, filesize, memsize, prot) in loadcmds: flags = area.flags if prot & 4: flags = flags | PAGEINFO_R if prot & 2: flags = flags | PAGEINFO_W if prot & 1: flags = flags | PAGEINFO_X if flags & PAGEINFO_X: desc = 'code' else: desc = 'data' load_file(mrenclave, file, offset, baseaddr_ + addr, filesize, memsize, desc, flags) else: load_file(mrenclave, file, 0, area.addr, os.stat(area.file).st_size, area.size, area.desc, area.flags) else: for addr in range(area.addr, area.addr + area.size, offs.PAGESIZE): data = ZERO_PAGE if area.content is not None: start = addr - area.addr end = start + offs.PAGESIZE data = area.content[start:end] include_page(mrenclave, addr, area.flags, data, area.measure) print_area(area.addr, area.size, area.flags, area.desc, area.measure) return mrenclave.digest() def generate_sigstruct(attr, args, mrenclave): '''Generate Sigstruct. field format: (offset, type, value) ''' # pylint: disable=too-many-locals fields = { 'header': (offs.SGX_ARCH_ENCLAVE_CSS_HEADER, "<4L", 0x00000006, 0x000000e1, 0x00010000, 0x00000000), 'module_vendor': (offs.SGX_ARCH_ENCLAVE_CSS_MODULE_VENDOR, "= offs.SGX_ARCH_ENCLAVE_CSS_MISC_SELECT: struct.pack_into(field[1], sign_buffer, field[0] - offs.SGX_ARCH_ENCLAVE_CSS_MISC_SELECT + 128, *field[2:]) else: struct.pack_into(field[1], sign_buffer, field[0], *field[2:]) proc = subprocess.Popen( ['openssl', 'rsa', '-modulus', '-in', args['key'], '-noout'], stdout=subprocess.PIPE) modulus_out, _ = proc.communicate() modulus = bytes.fromhex(modulus_out[8:8+offs.SE_KEY_SIZE*2].decode()) modulus = bytes(reversed(modulus)) proc = subprocess.Popen( ['openssl', 'sha256', '-binary', '-sign', args['key']], stdin=subprocess.PIPE, stdout=subprocess.PIPE) signature, _ = proc.communicate(sign_buffer) signature = signature[::-1] modulus_int = int.from_bytes(modulus, byteorder='little') signature_int = int.from_bytes(signature, byteorder='little') tmp1 = signature_int * signature_int q1_int = tmp1 // modulus_int tmp2 = tmp1 % modulus_int q2_int = tmp2 * signature_int // modulus_int q1 = q1_int.to_bytes(384, byteorder='little') # pylint: disable=invalid-name q2 = q2_int.to_bytes(384, byteorder='little') # pylint: disable=invalid-name fields.update({ 'modulus': (offs.SGX_ARCH_ENCLAVE_CSS_MODULUS, "384s", modulus), 'exponent': (offs.SGX_ARCH_ENCLAVE_CSS_EXPONENT, "