#!/usr/bin/env python2

import os
import sys
import re
import datetime
import struct
import subprocess
import hashlib
import binascii
import shutil

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from generated_offsets import *

""" Default / Architectural Options """

ARCHITECTURE = "amd64"

SSAFRAMESIZE = PAGESIZE

DEFAULT_ENCLAVE_SIZE = '256M'
DEFAULT_THREAD_NUM = 4
enclave_heap_min = DEFAULT_HEAP_MIN

""" Utilities """

ZERO_PAGE = "\0" * PAGESIZE

def roundup(addr):
    remaining = addr % PAGESIZE
    if remaining:
        return addr + (PAGESIZE - remaining)
    else:
        return addr

def rounddown(addr):
    return addr - addr % PAGESIZE

def roundup_data(data):
    return data + '\0' * (roundup(len(data)) - len(data))

def int_to_bytes(i):
    b = ""
    l = 0
    while i > 0:
        b = b + chr(i % 256)
        i = i // 256
        l = l + 1
    return b

def bytes_to_int(b):
    i = 0
    for c in b:
        i = i * 256 + ord(c)
    return i

def parse_int(s):
    if len(s) > 2 and s.startswith("0x"):
        return int(s[2:], 16)
    if len(s) > 1 and s.startswith("0"):
        return int(s[1:], 8)
    return int(s)

def parse_size(s):
    scale = 1
    if s.endswith("K"):
        scale = 1024
    if s.endswith("M"):
        scale = 1024 * 1024
    if s.endswith("G"):
        scale = 1024 * 1024 * 1024
    if scale != 1:
        s = s[:-1]
    return parse_int(s) * scale


""" Reading / Writing Manifests """

def read_manifest(filename):
    manifest = dict()
    manifest_layout = []
    with open(filename, "r") as f:
        for line in f.readlines():
            if line == "":
                manifest_layout.append((None, None))
                break

            pound = line.find("#")
            if pound != -1:
                comment = line[pound:].strip()
                line = line[:pound]
            else:
                comment = None

            line = line.strip()
            equal = line.find("=")
            if equal != -1:
                key = line[:equal].strip()
                manifest[key] = line[equal + 1:].strip()
            else:
                key = None

            manifest_layout.append((key, comment))

    return (manifest, manifest_layout)

def output_manifest(filename, manifest, manifest_layout):
    with open(filename, 'w') as f:
        written = []

        for (key, comment) in manifest_layout:
            line = ''
            if key is not None:
                line += key + ' = ' + manifest[key]
                written.append(key)
            if comment is not None:
                if line != '':
                    line += ' '
                line += comment
            print >>f, line

        print >>f
        print >>f, "# Generated by Graphene"
        print >>f

        for key in sorted(manifest.keys()):
            if key not in written:
                print >>f, key, '=', manifest[key]


""" Loading Enclave Attributes """

def get_enclave_attributes(manifest):
    sgx_flags = {
        'FLAG_DEBUG'          : struct.pack("<Q", SGX_FLAGS_DEBUG),
        'FLAG_MODE64BIT'      : struct.pack("<Q", SGX_FLAGS_MODE64BIT),
    }

    sgx_xfrms = {
        'XFRM_LEGACY'         : struct.pack("<Q", SGX_XFRM_LEGACY),
        'XFRM_AVX'            : struct.pack("<Q", SGX_XFRM_AVX),
        'XFRM_AVX512'         : struct.pack("<Q", SGX_XFRM_AVX512),
        'XFRM_MPX'            : struct.pack("<Q", SGX_XFRM_MPX),
    }

    sgx_miscs = {
        'MISC_EXINFO'         : struct.pack("<L", SGX_MISCSELECT_EXINFO),
    }

    default_attributes = {
        'FLAG_DEBUG',
        'XFRM_LEGACY',
        'XFRM_AVX',
    }

    if ARCHITECTURE == 'amd64':
        default_attributes.add('FLAG_MODE64BIT')

    manifest_options = {
        'debug'          : 'FLAG_DEBUG',
        'enable_avx'     : 'XFRM_AVX',
        'enable_avx512'  : 'XFRM_AVX512',
        'enable_mpx'     : 'XFRM_MPX',
        'support_exinfo' : 'MISC_EXINFO',
    }

    attributes = default_attributes

    for opt in manifest_options.keys():
        key = 'sgx.' + opt
        if key in manifest:
            if manifest[key] == '1':
                attributes.add(manifest_options[opt])
            else:
                attributes.discard(manifest_options[opt])

    flags_raw = struct.pack("<Q", 0)
    xfrms_raw = struct.pack("<Q", 0)
    miscs_raw = struct.pack("<L", 0)

    for attr in attributes:
        if attr in sgx_flags:
            flags_raw = ''.join([chr(ord(a)|ord(b)) for a, b in zip(flags_raw, sgx_flags[attr])])
        if attr in sgx_xfrms:
            xfrms_raw = ''.join([chr(ord(a)|ord(b)) for a, b in zip(xfrms_raw, sgx_xfrms[attr])])
        if attr in sgx_miscs:
            miscs_raw = ''.join([chr(ord(a)|ord(b)) for a, b in zip(miscs_raw, sgx_miscs[attr])])

    return flags_raw, xfrms_raw, miscs_raw


""" Generate Checksums / Measurement """

def resolve_uri(uri, check_exist=True):
    orig_uri = uri
    if uri.startswith('file:'):
        target = os.path.normpath(uri[5:])
    else:
        target = os.path.normpath(uri)
    if check_exist and not os.path.exists(target):
        raise Exception('Cannot resolve ' + orig_uri + ' or the file does not exist.')
    return target

def get_checksum(file):
    digest = hashlib.sha256()
    with open(file, 'rb') as f:
        digest.update(f.read())
    return digest.digest()

def get_trusted_files(manifest, args):
    targets = dict()

    if 'exec' in args:
        targets['exec'] = (args['exec'], resolve_uri(args['exec']))

    if 'loader.preload' in manifest:
        i = 0
        preloads = []
        for uri in str.split(manifest['loader.preload'], ','):
            targets['preload' + str(i)] = (uri, resolve_uri(uri))
            preloads.append(uri)
            i += 1

    for (key, val) in manifest.items():
        if not key.startswith('sgx.trusted_files.'):
            continue
        key = key[len('sgx.trusted_files.'):]
        if key in targets:
            raise Exception('repeated key in manifest: sgx.trusted_files.' + key)
        targets[key] = (val, resolve_uri(val))

    for (key, val) in targets.items():
        (uri, target) = val
        checksum = get_checksum(target).encode('hex')
        targets[key] = (uri, target, checksum)

    return targets

def get_trusted_children(manifest, args):
    targets = dict()

    for (key, val) in manifest.items():
        if not key.startswith('sgx.trusted_children.'):
            continue
        key = key[len('sgx.trusted_children.'):]
        if key in targets:
            raise Exception('repeated key in manifest: sgx.trusted_children.' + key)

        target = resolve_uri(val)
        if not target.endswith('.sig'):
            target += '.sig'
        sig = open(target, 'rb').read()[960:992].encode('hex')
        targets[key] = (val, target, sig)

    return targets

""" Populate Enclave Memory """

PAGEINFO_R = 0x1
PAGEINFO_W = 0x2
PAGEINFO_X = 0x4
PAGEINFO_TCS = 0x100
PAGEINFO_REG = 0x200

def get_loadcmds(filename):
    loadcmds = []
    p = subprocess.Popen(['readelf', '-l', '-W', filename],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE)
    while True:
        line = p.stdout.readline()
        if line == '':
            break
        stripped = line.strip()
        if not stripped.startswith('LOAD'):
            continue
        tokens = stripped.split()
        if len(tokens) < 6:
            continue
        if len(tokens) >= 7 and tokens[7] == "E":
            tokens[6] += tokens[7]
        prot = 0
        for t in tokens[6]:
            if t == "R":
                prot = prot | 4
            if t == "W":
                prot = prot | 2
            if t == "E":
                prot = prot | 1

        loadcmds.append((int(tokens[1][2:], 16),  # offset
                         int(tokens[2][2:], 16),  # addr
                         int(tokens[4][2:], 16),  # filesize
                         int(tokens[5][2:], 16),  # memsize
                         prot))
    p.wait()
    if p.returncode != 0:
        return None
    return loadcmds

class MemoryArea:
    def __init__(self, desc, file=None, content=None, addr=None, size=None, flags=None, measure=True):
        self.desc = desc
        self.file = file
        self.content = content
        self.addr = addr
        self.size = size
        self.flags = flags
        self.is_binary = False
        self.measure = measure

        if file:
            loadcmds = get_loadcmds(file)
            if loadcmds:
                mapaddr = 0xffffffffffffffff
                mapaddr_end = 0
                for (offset, addr, filesize, memsize, prot) in loadcmds:
                    if rounddown(addr) < mapaddr:
                        mapaddr = rounddown(addr)
                    if roundup(addr + memsize) > mapaddr_end:
                        mapaddr_end = roundup(addr + memsize)

                self.is_binary = True
                self.size = mapaddr_end - mapaddr
                if mapaddr > 0:
                    self.addr = mapaddr
            else:
                self.size = os.stat(file).st_size

        if self.addr is not None:
            self.addr = rounddown(self.addr)
        if self.size is not None:
            self.size = roundup(self.size)

def get_memory_areas(manifest, attr, args):
    areas = []
    areas.append(MemoryArea('ssa', size=attr['thread_num'] * SSAFRAMESIZE * SSAFRAMENUM,
                            flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_REG))
    areas.append(MemoryArea('tcs', size=attr['thread_num'] * TCS_SIZE,
                            flags=PAGEINFO_TCS))
    areas.append(MemoryArea('tls', size=attr['thread_num'] * PAGESIZE,
                            flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_REG))

    for t in range(attr['thread_num']):
        areas.append(MemoryArea('stack', size=ENCLAVE_STACK_SIZE,
                                flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_REG))

    areas.append(MemoryArea('pal', file=args['libpal'], flags=PAGEINFO_REG))

    if 'exec' in args:
        areas.append(MemoryArea('exec', file=args['exec'],
                                flags=PAGEINFO_W|PAGEINFO_REG))
    return areas

def find_areas(areas, desc):
    return filter(lambda area: area.desc == desc, areas)

def find_area(areas, desc, allow_none=False):
    matching = find_areas(areas, desc)

    if len(matching) == 0 and allow_none:
        return None

    if len(matching) != 1:
        raise KeyError("Could not find exactly one MemoryArea '{}'".format(desc))

    return matching[0]

def entry_point(elf_path):
    env = os.environ
    env['LC_ALL'] = 'C'
    out = subprocess.check_output(['readelf', '-l', '--', elf_path], env = env)
    for line in out.splitlines():
        if line.startswith("Entry point "):
            return parse_int(line[12:])
    raise ValueError("Could not find entry point of elf file")

def baseaddr():
    if enclave_heap_min == 0:
        return ENCLAVE_HIGH_ADDRESS
    else:
        return 0

def gen_area_content(attr, areas):
    manifest_area = find_area(areas, 'manifest')
    exec_area = find_area(areas, 'exec', True)
    pal_area = find_area(areas, 'pal')
    ssa_area = find_area(areas, 'ssa')
    tcs_area = find_area(areas, 'tcs')
    tls_area = find_area(areas, 'tls')
    stacks = find_areas(areas, 'stack')

    tcs_data = bytearray(tcs_area.size)
    def set_tcs_field(t, offset, pack_fmt, value):
        struct.pack_into(pack_fmt, tcs_data, t * TCS_SIZE + offset, value)

    tls_data = bytearray(tls_area.size)
    def set_tls_field(t, offset, value):
        struct.pack_into('<Q', tls_data, t * PAGESIZE + offset, value)

    enclave_heap_max = pal_area.addr - MEMORY_GAP

    # Sanity check that we measure everything except the heap which is zeroed
    # on enclave startup.
    for area in areas:
        if area.addr + area.size <= enclave_heap_min or area.addr >= enclave_heap_max or area is exec_area:
            if not area.measure:
                raise ValueError("Memory area, which is not the heap, is not measured")
        elif area.desc != 'free':
            raise ValueError("Unexpected memory area is in heap range")

    for t in range(0, attr['thread_num']):
        ssa_offset = ssa_area.addr + SSAFRAMESIZE * SSAFRAMENUM * t;
        ssa = baseaddr() + ssa_offset
        set_tcs_field(t, TCS_OSSA, '<Q', ssa_offset)
        set_tcs_field(t, TCS_NSSA, '<L', SSAFRAMENUM)
        set_tcs_field(t, TCS_OENTRY, '<Q', pal_area.addr + entry_point(pal_area.file))
        set_tcs_field(t, TCS_OGSBASGX, '<Q', tls_area.addr + PAGESIZE * t)
        set_tcs_field(t, TCS_FSLIMIT, '<L', 0xfff)
        set_tcs_field(t, TCS_GSLIMIT, '<L', 0xfff)

        set_tls_field(t, SGX_ENCLAVE_SIZE, attr['enclave_size'])
        set_tls_field(t, SGX_TCS_OFFSET, tcs_area.addr + TCS_SIZE * t)
        set_tls_field(t, SGX_INITIAL_STACK_OFFSET, stacks[t].addr + stacks[t].size)
        set_tls_field(t, SGX_SSA, ssa)
        set_tls_field(t, SGX_GPR, ssa + SSAFRAMESIZE - SGX_GPR_SIZE)
        set_tls_field(t, SGX_MANIFEST_SIZE, os.stat(manifest_area.file).st_size)
        set_tls_field(t, SGX_HEAP_MIN, baseaddr() + enclave_heap_min)
        set_tls_field(t, SGX_HEAP_MAX, baseaddr() + enclave_heap_max)
        if exec_area is not None:
            set_tls_field(t, SGX_EXEC_ADDR, baseaddr() + exec_area.addr)
            set_tls_field(t, SGX_EXEC_SIZE, exec_area.size)

    tcs_area.content = tcs_data
    tls_area.content = tls_data

def populate_memory_areas(manifest, attr, areas):
    populating = attr['enclave_size']

    for area in areas:
        if area.addr is not None:
            continue

        area.addr = populating - area.size
        if area.addr < enclave_heap_min:
            raise Exception("Enclave size is not large enough")
        if area.desc == 'exec':
            populating = area.addr;
        else:
            populating = area.addr - MEMORY_GAP

    free_areas = []
    for area in areas:
        if area.addr + area.size < populating:
            addr = area.addr + area.size
            free_areas.append(MemoryArea('free', addr=addr, size=populating - addr,
                                flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG,
                                measure=False))
            populating = area.addr

    if populating > enclave_heap_min:
        free_areas.append(MemoryArea('free', addr=enclave_heap_min,
                                     size=populating - enclave_heap_min,
                                     flags=PAGEINFO_R|PAGEINFO_W|PAGEINFO_X|PAGEINFO_REG,
                                     measure=False))

    gen_area_content(attr, areas)

    return areas + free_areas

def generate_measurement(attr, areas):

    def do_ecreate(digest, size):
        data = struct.pack("<8sLQ44s", "ECREATE", SSAFRAMESIZE / PAGESIZE, size, "")
        digest.update(data)

    def do_eadd(digest, offset, flags):
        data = struct.pack("<8sQQ40s", "EADD", offset, flags, "")
        digest.update(data)

    def do_eextend(digest, offset, content):
        if len(content) != 256:
            raise ValueError("Exactly 256 bytes expected")

        data = struct.pack("<8sQ48s", "EEXTEND", offset, "")
        digest.update(data)
        digest.update(content)

    def include_page(digest, offset, flags, content, measure):
        if len(content) != PAGESIZE:
            raise ValueError("Exactly one page expected")

        do_eadd(digest, offset, flags)
        if measure:
            for i in range(0, PAGESIZE, 256):
                do_eextend(digest, offset + i, content[i:i + 256])

    mrenclave = hashlib.sha256()
    do_ecreate(mrenclave, attr['enclave_size'])

    def print_area(addr, size, flags, desc, measured):
        if flags & PAGEINFO_REG:
            type = 'REG'
        if flags & PAGEINFO_TCS:
            type = 'TCS'
        prot = ['-', '-', '-']
        if flags & PAGEINFO_R:
            prot[0] = 'R'
        if flags & PAGEINFO_W:
            prot[1] = 'W'
        if flags & PAGEINFO_X:
            prot[2] = 'X'
        prot = ''.join(prot)

        desc = '(' + desc + ')'
        if measured:
            desc += ' measured'

        if size == PAGESIZE:
            print >>sys.stderr, "    %016x [%s:%s] %s" % (addr, type, prot, desc)
        else:
            print >>sys.stderr, "    %016x-%016lx [%s:%s] %s" % (addr, addr + size, type, prot, desc)

    def load_file(digest, f, offset, addr, filesize, memsize, desc, flags):
        f_addr = rounddown(offset)
        m_addr = rounddown(addr)
        f_size = roundup(offset + filesize) - f_addr
        m_size = roundup(addr + memsize) - m_addr

        print_area(m_addr, m_size, flags, desc, True)

        for pg in range(m_addr, m_addr + m_size, PAGESIZE):
            start = pg - m_addr + f_addr
            end = start + PAGESIZE
            start_zero = ""
            if start < offset:
                if offset - start >= PAGESIZE:
                    start_zero = ZERO_PAGE
                else:
                    start_zero = chr(0) * (offset - start)
            end_zero = ""
            if end > offset + filesize:
                if end - offset - filesize >= PAGESIZE:
                    end_zero = ZERO_PAGE
                else:
                    end_zero = chr(0) * (end - offset - filesize)
            start += len(start_zero)
            end -= len(end_zero)
            if start < end:
                f.seek(start)
                data = f.read(end - start)
            else:
                data = ""
            if len(start_zero + data + end_zero) != PAGESIZE:
                    raise Exception("wrong calculation")

            include_page(digest, pg, flags, start_zero + data + end_zero, True)

    for area in areas:
        if area.file:
            with open(area.file, 'rb') as f:
                if area.is_binary:
                    loadcmds = get_loadcmds(area.file)
                    if loadcmds:
                        mapaddr = 0xffffffffffffffff
                        for (offset, addr, filesize, memsize, prot) in loadcmds:
                            if rounddown(addr) < mapaddr:
                                mapaddr = rounddown(addr)
                    baseaddr = area.addr - mapaddr
                    for (offset, addr, filesize, memsize, prot) in loadcmds:
                        flags = area.flags
                        if prot & 4:
                            flags = flags | PAGEINFO_R
                        if prot & 2:
                            flags = flags | PAGEINFO_W
                        if prot & 1:
                            flags = flags | PAGEINFO_X

                        if flags & PAGEINFO_X:
                            desc = 'code'
                        else:
                            desc = 'data'
                        load_file(mrenclave, f, offset, baseaddr + addr,
                                  filesize, memsize, desc, flags)
                else:
                    load_file(mrenclave, f, 0, area.addr,
                              os.stat(area.file).st_size, area.size,
                              area.desc, area.flags)
        else:
            for a in range(area.addr, area.addr + area.size, PAGESIZE):
                data = ZERO_PAGE
                if area.content is not None:
                    start = a - area.addr
                    end = start + PAGESIZE
                    data = area.content[start:end]

                include_page(mrenclave, a, area.flags, data, area.measure)

            print_area(area.addr, area.size, area.flags, area.desc, area.measure)

    return mrenclave.digest()

""" Generate Sigstruct """

def generate_sigstruct(attr, args, mrenclave):
    today = datetime.date.today()

    # field format: (offset, type, value)
    fields = dict()

    fields['header']    = (SGX_ARCH_SIGSTRUCT_HEADER,
                           "<4L",  0x00000006, 0x000000e1, 0x00010000, 0x00000000)
    fields['vendor']    = (SGX_ARCH_SIGSTRUCT_VENDOR,
                           "<L",   0x00000000)
    fields['date']      = (SGX_ARCH_SIGSTRUCT_DATE,
                           "<HBB", today.year, today.month, today.day)
    fields['header2']   = (SGX_ARCH_SIGSTRUCT_HEADER2,
                           "<4L",  0x00000101, 0x00000060, 0x00000060, 0x00000001)
    fields['swdefined'] = (SGX_ARCH_SIGSTRUCT_SWDEFINED,
                           "<L",   0x00000000)

    fields['miscs']     = (SGX_ARCH_SIGSTRUCT_MISCSELECT,
                           "4s",   attr['miscs'])
    fields['miscmask']  = (SGX_ARCH_SIGSTRUCT_MISCSELECT_MASK,
                           "4s",   attr['miscs'])
    fields['attrs']     = (SGX_ARCH_SIGSTRUCT_ATTRIBUTES,
                           "8s8s", attr['flags'], attr['xfrms'])
    fields['attrmask']  = (SGX_ARCH_SIGSTRUCT_ATTRIBUTES_MASK,
                           "8s8s", attr['flags'], attr['xfrms'])
    fields['mrenclave'] = (SGX_ARCH_SIGSTRUCT_ENCLAVE_HASH,
                           "32s",  mrenclave)
    fields['isvprodid'] = (SGX_ARCH_SIGSTRUCT_ISVPRODID,
                           "<H",   attr['isvprodid'])
    fields['isvsvn']    = (SGX_ARCH_SIGSTRUCT_ISVSVN,
                           "<H",   attr['isvsvn'])

    sign_buffer = bytearray(128 + 128)

    for key, field in fields.items():
        if field[0] >= 900:
            struct.pack_into(field[1], sign_buffer, field[0] - 900 + 128, *field[2:])
        else:
            struct.pack_into(field[1], sign_buffer, field[0], *field[2:])

    p = subprocess.Popen(['openssl', 'rsa', '-modulus', '-in', args['key'], '-noout'], stdout=subprocess.PIPE)
    modulus_out = p.communicate()[0]
    modulus = modulus_out[8:8+384*2].lower().decode('hex')
    modulus = modulus[::-1]

    p = subprocess.Popen(['openssl', 'sha256', '-binary', '-sign', args['key']],
            stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    signature = p.communicate(sign_buffer)[0]
    signature = signature[::-1]

    def bytes_to_int(bytes):
        i = 0
        q = 1
        for digit in bytes:
            if ord(digit) != 0:
                i = i + ord(digit) * q
            q = q * 256
        return i

    def int_to_bytes(i):
        b = ""
        l = 0
        while i > 0:
            b = b + chr(i % 256)
            i = i // 256
            l = l + 1
        return b

    modulus_int = bytes_to_int(modulus)
    signature_int = bytes_to_int(signature)

    tmp1   = signature_int * signature_int
    q1_int = tmp1 // modulus_int
    tmp2   = tmp1 % modulus_int
    q2_int = tmp2 * signature_int // modulus_int

    q1 = int_to_bytes(q1_int)
    q2 = int_to_bytes(q2_int)

    fields['modulus']   = (SGX_ARCH_SIGSTRUCT_MODULUS, "384s", modulus)
    fields['exponent']  = (SGX_ARCH_SIGSTRUCT_EXPONENT, "<L",   3)
    fields['signature'] = (SGX_ARCH_SIGSTRUCT_SIGNATURE, "384s", signature)

    fields['q1']        = (SGX_ARCH_SIGSTRUCT_Q1, "384s", q1)
    fields['q2']        = (SGX_ARCH_SIGSTRUCT_Q2, "384s", q2)

    buffer = bytearray(SGX_ARCH_SIGSTRUCT_SIZE)

    for key, field in fields.items():
        struct.pack_into(field[1], buffer, field[0], *field[2:])

    return buffer

""" Main Program """

options = {
#       Option name : (Required  Value)
        'output':    (True,    'output'),
        'libpal':    (True,    'libpal path'),
        'key':       (True,    'signing key'),
        'manifest':  (True,    'manifest'),
        'exec':      (False,   'executable'),
    }

def usage():
    usage_message = 'USAGE: ' + sys.argv[0] + ' -help|-h'

    for opt, optval in options.items():
        if not optval[0]:
            usage_message += '['
        usage_message += '|-' + opt
        if optval[1]:
            usage_message += ' <' + optval[1] + '>'
        if not optval[0]:
            usage_message += ']'

    print >> sys.stderr, usage_message
    os._exit(-1)

def parse_args():
    args = dict()
    for opt, optval in options.items():
        if not optval[1]:
            args[opt] = False

    i = 1
    while i < len(sys.argv):
        got = sys.argv[i]

        if got == '-help' or got == '-h':
            usage()

        invalid = True
        for opt, optval in options.items():
            if got != '-' + opt:
                continue

            if optval[1] is not None:
                i += 1
                if i == len(sys.argv):
                    print >>sys.stderr, "Option %s needs a value." % (opt)
                    usage()
                args[opt] = sys.argv[i]
            else:
                args[opt] = True

            invalid = False
            break

        if invalid:
            print >>sys.stderr, "Unknown option: %s." % (got[1:])
            usage()
        i += 1

    for opt, optval in options.items():
        if optval[0] and opt not in args:
            print >>sys.stderr, "Must specify %s <%s>." % (opt, optval[1])
            usage()

    return args

if __name__ == "__main__":

    # Parse arguments
    args = parse_args()

    (manifest, manifest_layout) = read_manifest(args['manifest'])

    if 'exec' not in args:
        if 'loader.exec' in manifest:
            exec_url = manifest['loader.exec']
            if exec_url[:5] != 'file:':
                print "executable must be a local file"
                os._exit(-1)

            args['exec'] = os.path.join(os.path.dirname(args['manifest']), exec_url[5:])

    args['root'] = os.path.dirname(os.path.abspath(args['output']))

    if 'sgx.sigfile' in manifest:
        args['sigfile'] = resolve_uri(manifest['sgx.sigfile'], False)
    else:
        sigfile = args['output']
        for ext in ['.manifest.sgx', '.manifest']:
            if sigfile.endswith(ext):
                sigfile = sigfile[:-len(ext)]
                break
        args['sigfile'] = sigfile + '.sig'
        manifest['sgx.sigfile'] = 'file:' + os.path.basename(args['sigfile'])

    # Get attributes from manifest
    attr = dict()

    for key, default, parse in [
        ('enclave_size', DEFAULT_ENCLAVE_SIZE,    parse_size),
        ('thread_num',   str(DEFAULT_THREAD_NUM), parse_int),
        ('isvprodid',    '0',                     parse_int),
        ('isvsvn',       '0',                     parse_int),
    ]:
        if 'sgx.' + key not in manifest:
            manifest['sgx.' + key] = default
        attr[key] = parse(manifest['sgx.' + key])

    (attr['flags'], attr['xfrms'], attr['miscs']) = get_enclave_attributes(manifest)

    print >>sys.stderr, "Attributes:"
    print >>sys.stderr, "    size:      %d" % (attr['enclave_size'])
    print >>sys.stderr, "    threadnum: %d" % (attr['thread_num'])
    print >>sys.stderr, "    isvprodid: %d" % (attr['isvprodid'])
    print >>sys.stderr, "    isvsvn:    %d" % (attr['isvsvn'])
    print >>sys.stderr, "    flags:     %016x" % (bytes_to_int(attr['flags']))
    print >>sys.stderr, "    xfrms:     %016x" % (bytes_to_int(attr['xfrms']))
    print >>sys.stderr, "    miscs:     %08x"  % (bytes_to_int(attr['miscs']))

    # Get trusted checksums and measurements
    print >>sys.stderr, "Trusted files:"
    for key, val in get_trusted_files(manifest, args).items():
        (uri, target, checksum) = val
        print >>sys.stderr, "    %s %s" % (checksum, uri)
        manifest['sgx.trusted_checksum.' + key] = checksum

    print >>sys.stderr, "Trusted children:"
    for key, val in get_trusted_children(manifest, args).items():
        (uri, target, mrenclave) = val
        print >>sys.stderr, "    %s %s" % (mrenclave, uri)
        manifest['sgx.trusted_mrenclave.' + key] = mrenclave

    # Try populate memory areas
    memory_areas = get_memory_areas(manifest, attr, args)

    if len([a for a in memory_areas if a.addr is not None]) > 0:
        manifest['sgx.static_address'] = '1'
    else:
        enclave_heap_min = 0

    # Add manifest at the top
    shutil.copy2(args['manifest'], args['output'])
    output_manifest(args['output'], manifest, manifest_layout)

    memory_areas = [
            MemoryArea('manifest', file=args['output'],
                       flags=PAGEINFO_R|PAGEINFO_REG)
            ] + memory_areas

    memory_areas = populate_memory_areas(manifest, attr, memory_areas)

    print >>sys.stderr, "Memory:"
    # Generate measurement
    mrenclave = generate_measurement(attr, memory_areas)

    print >>sys.stderr, "Measurement:"
    print >>sys.stderr, "    " + mrenclave.encode('hex')

    # Generate sigstruct
    open(args['sigfile'], 'wb').write(generate_sigstruct(attr, args, mrenclave))
