#!/usr/bin/env python2

import os
import sys
import struct
import socket
from google.protobuf import message as _message
from Crypto.PublicKey import RSA
import aesm_pb2

""" Utilities """

def int_to_bytes(i):
    b = ""
    l = 0
    while i > 0:
        b = b + chr(i % 256)
        i = i // 256
        l = l + 1
    return b

def bytes_to_int(b):
    i = 0
    for c in b:
        i = i * 256 + ord(c)
    return i


""" Reading Sigstruct """

def read_sigstruct(sig):
    # field format: (offset, type, value)
    fields = dict()

    fields['date']      = (  20, "<HBB", 'year', 'month', 'day')

    fields['modulus']   = ( 128, "384s", 'modulus')
    fields['exponent']  = ( 512, "<L",   'exponent')
    fields['signature'] = ( 516, "384s", 'signature')

    fields['miscs']     = ( 900, "4s",   'miscs')
    fields['miscmask']  = ( 904, "4s",   'miscmask')
    fields['attrs']     = ( 928, "8s8s", 'flags', 'xfrms')
    fields['attrmask']  = ( 944, "8s8s", 'flagmask', 'xfrmmask')
    fields['mrenclave'] = ( 960, "32s",  'mrenclave')
    fields['isvprodid'] = (1024, "<H",   'isvprodid')
    fields['isvsvn']    = (1026, "<H",   'isvsvn')

    attr = dict()
    for key, field in fields.items():
        values = struct.unpack_from(field[1], sig, field[0])

        for i in range(len(values)):
            attr[field[i + 2]] = values[i]

    return attr

""" Connect with AESMD """

def connect_aesmd(attr):
    req_msg = aesm_pb2.GetTokenReq()
    req_msg.req.signature = attr['mrenclave']
    req_msg.req.key = attr['modulus']
    req_msg.req.attributes = attr['flags'] + attr['xfrms']
    req_msg.req.timeout = 10000

    req_msg_raw = req_msg.SerializeToString()

    aesm_service = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    connected = False

    # try to connect to the unnamed socket (for PSW 1.6 and 1.7)
    if not connected:
        try:
            aesm_service.connect("\0sgx_aesm_socket_base" + "\0" * 87)
            connected = True
        except socket.error:
            pass

    # try to connect to the named socket (for PSW 1.8+)
    if not connected:
        try:
            aesm_service.connect("/var/run/aesmd/aesm.socket")
            connected = True
        except socket.error:
            pass

    if not connected:
        raise socket.error("Cannot connect to the AESMD service")

    aesm_service.send(struct.pack("<I", len(req_msg_raw)))
    aesm_service.send(req_msg_raw)

    ret_msg_size = struct.unpack("<I", aesm_service.recv(4))[0]
    ret_msg = aesm_pb2.GetTokenRet()
    ret_msg_raw = aesm_service.recv(ret_msg_size)
    ret_msg.ParseFromString(ret_msg_raw)

    if ret_msg.ret.error != 0:
        raise Exception("Failed. (Error Code = %d)" % (ret_msg.ret.error))

    return ret_msg.ret.token

""" Main Program """

options = {
#       Option name : (Required  Value)
        'output':    (True,    'output'),
        'sig':       (True,    'sigstruct file'),
    }

def usage():
    usage_message = 'USAGE: ' + sys.argv[0] + ' -help|-h'

    for opt, optval in options.items():
        if not optval[0]:
            usage_message += '['
        usage_message += '|-' + opt
        if optval[1]:
            usage_message += ' <' + optval[1] + '>'
        if not optval[0]:
            usage_message += ']'

    print >> sys.stderr, usage_message
    os._exit(-1)

def parse_args():
    args = dict()
    for opt, optval in options.items():
        if optval[1] is None:
            args[opt] = False

    i = 1
    while i < len(sys.argv):
        got = sys.argv[i]

        if got == '-help' or got == '-h':
            usage()

        invalid = True
        for opt, optval in options.items():
            if got != '-' + opt:
                continue

            if optval[1] is not None:
                i += 1
                if i == len(sys.argv):
                    print >>sys.stderr, "Option %s needs a value." % (opt)
                    usage()
                args[opt] = sys.argv[i]
            else:
                args[opt] = True

            invalid = False
            break

        if invalid:
            print >>sys.stderr, "Unknown option: %s." % (got[1:])
            usage()
        i += 1

    for opt, optval in options.items():
        if optval[0] and opt not in args:
            print >>sys.stderr, "Must specify %s <%s>." % (opt, optval[1])
            usage()

    return args

if __name__ == "__main__":

    # Parse arguments
    args = parse_args()

    attr = read_sigstruct(open(args['sig'], 'rb').read())

    print >>sys.stderr, "Attributes:"
    print >>sys.stderr, "    mrenclave: %s" % (attr['mrenclave'].encode('hex'))
    print >>sys.stderr, "    isvprodid: %d" % (attr['isvprodid'])
    print >>sys.stderr, "    isvsvn:    %d" % (attr['isvsvn'])
    print >>sys.stderr, "    flags:     %016x" % (bytes_to_int(attr['flags']))
    print >>sys.stderr, "    xfrms:     %016x" % (bytes_to_int(attr['xfrms']))
    print >>sys.stderr, "    miscs:     %08x"  % (bytes_to_int(attr['miscs']))
    print >>sys.stderr, "    miscmask:  %08x"  % (bytes_to_int(attr['miscmask']))
    print >>sys.stderr, "    modulus:   %s..." % (attr['modulus'].encode('hex')[:32])
    print >>sys.stderr, "    exponent:  %d" % (attr['exponent'])
    print >>sys.stderr, "    signature: %s..." % (attr['signature'].encode('hex')[:32])

    token = connect_aesmd(attr)
    open(args['output'], 'wb').write(token)
