#!/usr/bin/env python import os import sys import re import datetime import struct import subprocess import hashlib import binascii import shutil """ Default / Architectural Options """ ARCHITECTURE = "amd64" PAGESIZE = 4096 MEMORY_GAP = PAGESIZE TCSSIZE = PAGESIZE SSAFRAMESIZE = PAGESIZE SSAFRAMENUM = 2 ENCLAVE_STACK_SIZE = PAGESIZE * 16 DEFAULT_ENCLAVE_SIZE = '256M' DEFAULT_THREAD_NUM = 4 ENCLAVE_HEAP_MIN = 0x10000 """ Utilities """ def roundup(addr): remaining = addr % PAGESIZE if remaining: return addr + (PAGESIZE - remaining) else: return addr def rounddown(addr): return addr - addr % PAGESIZE def roundup_data(data): return data + '\0' * (roundup(len(data)) - len(data)) def int_to_bytes(i): b = "" l = 0 while i > 0: b = b + chr(i % 256) i = i // 256 l = l + 1 return b def bytes_to_int(b): i = 0 for c in b: i = i * 256 + ord(c) return i def parse_int(s): if len(s) > 2 and s.startswith("0x"): return int(s[2:], 16) if len(s) > 1 and s.startswith("0"): return int(s[1:], 8) return int(s) def parse_size(s): scale = 1 if s.endswith("K"): scale = 1024 if s.endswith("M"): scale = 1024 * 1024 if s.endswith("G"): scale = 1024 * 1024 * 1024 if scale != 1: s = s[:-1] return parse_int(s) * scale """ Reading / Writing Manifests """ def read_manifest(filename): manifest = dict() manifest_layout = [] with open(filename, "r") as f: for line in f.readlines(): if line == "": manifest_layout.append((None, None)) break pound = line.find("#") if pound != -1: comment = line[pound:].strip() line = line[:pound] else: comment = None line = line.strip() equal = line.find("=") if equal != -1: key = line[:equal].strip() manifest[key] = line[equal + 1:].strip() else: key = None manifest_layout.append((key, comment)) return (manifest, manifest_layout) def output_manifest(filename, manifest, manifest_layout): with open(filename, 'w') as f: written = [] for (key, comment) in manifest_layout: line = '' if key is not None: line += key + ' = ' + manifest[key] written.append(key) if comment is not None: if line != '': line += ' ' line += comment print >>f, line print >>f print >>f, "# Generated by Graphene" print >>f for key in sorted(manifest.keys()): if key not in written: print >>f, key, '=', manifest[key] """ Loading Enclave Attributes """ def get_enclave_attributes(manifest): sgx_flags = { 'FLAG_DEBUG' : struct.pack("= 7 and tokens[7] == "E": tokens[6] += tokens[7] prot = 0 for t in tokens[6]: if t == "R": prot = prot | 4 if t == "W": prot = prot | 2 if t == "E": prot = prot | 1 loadcmds.append((int(tokens[1][2:], 16), # offset int(tokens[2][2:], 16), # addr int(tokens[4][2:], 16), # filesize int(tokens[5][2:], 16), # memsize prot)) p.wait() if p.returncode != 0: return None return loadcmds class MemoryArea: def __init__(self, desc, file=None, addr=None, size=None, flags=None): self.desc = desc self.file = file self.addr = addr self.size = size self.flags = flags self.is_binary = False if file: loadcmds = get_loadcmds(file) if loadcmds: mapaddr = 0xffffffffffffffff mapaddr_end = 0 for (offset, addr, filesize, memsize, prot) in loadcmds: if rounddown(addr) < mapaddr: mapaddr = rounddown(addr) if roundup(addr + memsize) > mapaddr_end: mapaddr_end = roundup(addr + memsize) self.is_binary = True self.size = mapaddr_end - mapaddr if mapaddr > 0: self.addr = mapaddr else: self.size = os.stat(file).st_size if self.addr is not None: self.addr = rounddown(self.addr) if self.size is not None: self.size = roundup(self.size) def get_memory_areas(manifest, attr, args): areas = [] areas.append(MemoryArea('ssa', size=attr['thread_num'] * SSAFRAMESIZE * SSAFRAMENUM, flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_REG)) areas.append(MemoryArea('tcs', size=attr['thread_num'] * TCSSIZE, flags=PAGEINFO_TCS)) areas.append(MemoryArea('tls', size=attr['thread_num'] * PAGESIZE, flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_REG)) for t in range(attr['thread_num']): areas.append(MemoryArea('stack', size=ENCLAVE_STACK_SIZE, flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_REG)) areas.append(MemoryArea('pal', file=args['libpal'], flags=PAGEINFO_REG)) if 'exec' in args: areas.append(MemoryArea('exec', file=args['exec'], flags=PAGEINFO_W|PAGEINFO_REG)) return areas def populate_memory_areas(manifest, attr, areas): populating = attr['enclave_size'] for area in areas: if area.addr is not None: continue area.addr = populating - area.size if area.addr < ENCLAVE_HEAP_MIN: raise Exception("Enclave size is not large enough") if area.desc == 'exec': populating = area.addr; else: populating = area.addr - MEMORY_GAP free_areas = [] for area in areas: if area.addr + area.size < populating: addr = area.addr + area.size free_areas.append(MemoryArea('free', addr=addr, size=populating - addr, flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG)) populating = area.addr if populating > ENCLAVE_HEAP_MIN: free_areas.append(MemoryArea('free', addr=ENCLAVE_HEAP_MIN, size=populating - ENCLAVE_HEAP_MIN, flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG)) return areas + free_areas def generate_measurement(attr, areas): def do_ecreate(digest, size): data = struct.pack("<8sLQ44s", "ECREATE", SSAFRAMESIZE / PAGESIZE, size, "") digest.update(data) def do_eadd(digest, offset, flags): data = struct.pack("<8sQQ40s", "EADD", offset, flags, "") digest.update(data) def do_eextend(digest, offset): data = struct.pack("<8sQ48s", "EEXTEND", offset, "") digest.update(data) class mrenclave_digest: def __init__(self): self.digest = hashlib.sha256() def update(self, payload): for er in range(0, len(payload), 64): self.digest.update(payload[er:er+64]) def finalize(self): return self.digest.digest() mrenclave = mrenclave_digest() do_ecreate(mrenclave, attr['enclave_size']) def print_area(addr, size, flags, desc, measured): if flags & PAGEINFO_REG: type = 'REG' if flags & PAGEINFO_TCS: type = 'TCS' prot = ['-', '-', '-'] if flags & PAGEINFO_R: prot[0] = 'R' if flags & PAGEINFO_W: prot[1] = 'W' if flags & PAGEINFO_X: prot[2] = 'X' prot = ''.join(prot) desc = '(' + desc + ')' if measured: desc += ' measured' if size == PAGESIZE: print >>sys.stderr, " %016x [%s:%s] %s" % (addr, type, prot, desc) else: print >>sys.stderr, " %016x-%016lx [%s:%s] %s" % (addr, addr + size, type, prot, desc) def load_file(digest, f, offset, addr, filesize, memsize, desc, flags): f_addr = rounddown(offset) m_addr = rounddown(addr) f_size = roundup(offset + filesize) - f_addr m_size = roundup(addr + memsize) - m_addr print_area(m_addr, f_size, flags, desc, True) if f_size < m_size: print_area(m_addr + f_size, m_size - f_size, flags, "bss", False) for pg in range(m_addr, m_addr + m_size, PAGESIZE): do_eadd(digest, pg, flags) if (pg >= m_addr + f_size): continue for er in range(pg, pg + PAGESIZE, 256): do_eextend(digest, er) start = er - m_addr + f_addr end = start + 256 start_zero = "" if start < offset: if offset - start >= 256: start_zero = chr(0) * 256 else: start_zero = chr(0) * (offset - start) end_zero = "" if end > offset + filesize: if end - offset - filesize >= 256: end_zero = chr(0) * 256 else: end_zero = chr(0) * (end - offset - filesize) start += len(start_zero) end -= len(end_zero) if start < end: f.seek(start) data = f.read(end - start) else: data = "" if len(start_zero + data + end_zero) != 256: raise Exception("wrong calculation") digest.update(start_zero + data + end_zero) for area in areas: if area.file: with open(area.file, 'rb') as f: if area.is_binary: loadcmds = get_loadcmds(area.file) if loadcmds: mapaddr = 0xffffffffffffffff for (offset, addr, filesize, memsize, prot) in loadcmds: if rounddown(addr) < mapaddr: mapaddr = rounddown(addr) baseaddr = area.addr - mapaddr for (offset, addr, filesize, memsize, prot) in loadcmds: flags = area.flags if prot & 4: flags = flags | PAGEINFO_R if prot & 2: flags = flags | PAGEINFO_W if prot & 1: flags = flags | PAGEINFO_X if flags & PAGEINFO_X: desc = 'code' else: desc = 'data' load_file(mrenclave, f, offset, baseaddr + addr, filesize, memsize, desc, flags) else: load_file(mrenclave, f, 0, area.addr, os.stat(area.file).st_size, area.size, area.desc, area.flags) else: for a in range(area.addr, area.addr + area.size, PAGESIZE): do_eadd(mrenclave, a, area.flags) print_area(area.addr, area.size, area.flags, area.desc, False) return mrenclave.finalize() """ Generate Sigstruct """ def generate_sigstruct(attr, args, mrenclave): today = datetime.date.today() # field format: (offset, type, value) fields = dict() fields['header'] = ( 0, "<4L", 0x00000006, 0x000000e1, 0x00010000, 0x00000000) fields['vendor'] = ( 16, "= 900: struct.pack_into(field[1], sign_buffer, field[0] - 900 + 128, *field[2:]) else: struct.pack_into(field[1], sign_buffer, field[0], *field[2:]) #TODO: Got to work on this part now - take in a key p = subprocess.Popen(['openssl', 'rsa', '-modulus', '-in', args['key'], '-noout'], stdout=subprocess.PIPE) modulus_out = p.communicate()[0] modulus = modulus_out[8:8+384*2].lower().decode('hex') modulus = modulus[::-1] #TODO: Sign the manifest. Prolly do this in C code? p = subprocess.Popen(['openssl', 'sha256', '-binary', '-sign', args['key']], stdin=subprocess.PIPE, stdout=subprocess.PIPE) signature = p.communicate(sign_buffer)[0] signature = signature[::-1] def bytes_to_int(bytes): i = 0 q = 1 for digit in bytes: if ord(digit) != 0: i = i + ord(digit) * q q = q * 256 return i def int_to_bytes(i): b = "" l = 0 while i > 0: b = b + chr(i % 256) i = i // 256 l = l + 1 return b modulus_int = bytes_to_int(modulus) signature_int = bytes_to_int(signature) tmp1 = signature_int * signature_int q1_int = tmp1 // modulus_int tmp2 = tmp1 % modulus_int q2_int = tmp2 * signature_int // modulus_int q1 = int_to_bytes(q1_int) q2 = int_to_bytes(q2_int) fields['modulus'] = ( 128, "384s", modulus) fields['exponent'] = ( 512, "' if not optval[0]: usage_message += ']' print >> sys.stderr, usage_message os._exit(-1) def parse_args(): args = dict() for opt, optval in options.items(): if not optval[1]: args[opt] = False i = 1 while i < len(sys.argv): got = sys.argv[i] if got == '-help' or got == '-h': usage() invalid = True for opt, optval in options.items(): if got != '-' + opt: continue if optval[1] is not None: i += 1 if i == len(sys.argv): print >>sys.stderr, "Option %s needs a value." % (opt) usage() args[opt] = sys.argv[i] else: args[opt] = True invalid = False break if invalid: print >>sys.stderr, "Unknown option: %s." % (got[1:]) usage() i += 1 for opt, optval in options.items(): if optval[0] and opt not in args: print >>sys.stderr, "Must specify %s <%s>." % (opt, optval[1]) usage() return args if __name__ == "__main__": # Parse arguments args = parse_args() (manifest, manifest_layout) = read_manifest(args['manifest']) if 'exec' not in args: if 'loader.exec' in manifest: exec_url = manifest['loader.exec'] if exec_url[:5] != 'file:': print "executable must be a local file" os._exit(-1) args['exec'] = os.path.join(os.path.dirname(args['manifest']), exec_url[5:]) args['root'] = os.path.dirname(os.path.abspath(args['output'])) if 'sgx.sigfile' in manifest: args['sigfile'] = resolve_uri(manifest['sgx.sigfile'], False) else: sigfile = args['output'] for ext in ['.manifest.sgx', '.manifest']: if sigfile.endswith(ext): sigfile = sigfile[:-len(ext)] break args['sigfile'] = sigfile + '.sig' manifest['sgx.sigfile'] = 'file:' + os.path.basename(args['sigfile']) # Get attributes from manifest attr = dict() for key, default, parse in [ ('enclave_size', DEFAULT_ENCLAVE_SIZE, parse_size), ('thread_num', str(DEFAULT_THREAD_NUM), parse_int), ('isvprodid', '0', parse_int), ('isvsvn', '0', parse_int), ]: if 'sgx.' + key not in manifest: manifest['sgx.' + key] = default attr[key] = parse(manifest['sgx.' + key]) (attr['flags'], attr['xfrms'], attr['miscs']) = get_enclave_attributes(manifest) print >>sys.stderr, "Attributes:" print >>sys.stderr, " size: %d" % (attr['enclave_size']) print >>sys.stderr, " threadnum: %d" % (attr['thread_num']) print >>sys.stderr, " isvprodid: %d" % (attr['isvprodid']) print >>sys.stderr, " isvsvn: %d" % (attr['isvsvn']) print >>sys.stderr, " flags: %016x" % (bytes_to_int(attr['flags'])) print >>sys.stderr, " xfrms: %016x" % (bytes_to_int(attr['xfrms'])) print >>sys.stderr, " miscs: %08x" % (bytes_to_int(attr['miscs'])) # Get trusted checksums and measurements #Fixed this to make it run without computing hashes. print >>sys.stderr, "Trusted files:" for key, val in get_trusted_files(manifest, args).items(): (uri, target, checksum) = val print >>sys.stderr, " ('%s', '%s')," % (target.split('/')[-1], checksum) #(checksum, uri) manifest['sgx.trusted_checksum.' + key] = checksum print >>sys.stderr, "Trusted children:" for key, val in get_trusted_children(manifest, args).items(): (uri, target, mrenclave) = val print >>sys.stderr, " %s %s" % (mrenclave, uri) manifest['sgx.trusted_mrenclave.' + key] = mrenclave #TODO: Uses readelf for .sgx manifest file, output file and libpal. #Try to hard-code results for libpal, exec in GET_LOADCMDS (called by the constructor of MemoryAreas). # Try populate memory areas memory_areas = get_memory_areas(manifest, attr, args) if len([a for a in memory_areas if a.addr is not None]) > 0: manifest['sgx.static_address'] = '1' else: ENCLAVE_HEAP_MIN = 0 # Add manifest at the top shutil.copy2(args['manifest'], args['output']) output_manifest(args['output'], manifest, manifest_layout) memory_areas = [ MemoryArea('manifest', file=args['output'], flags=PAGEINFO_R|PAGEINFO_REG) ] + memory_areas memory_areas = populate_memory_areas(manifest, attr, memory_areas) print >>sys.stderr, "Memory:" # Generate measurement #TODO:This should also tap into the hard-coded readelf results. mrenclave = generate_measurement(attr, memory_areas) print >>sys.stderr, "Measurement:" print >>sys.stderr, " " + mrenclave.encode('hex') #TODO:This should also be removed - prolly just return the sigstruct to the main verifier enclave. # Generate sigstruct open(args['sigfile'], 'wb').write(generate_sigstruct(attr, args, mrenclave))