#!/usr/bin/env python3

import argparse
import os
import socket
import struct
import sys

import aesm_pb2

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from generated_offsets import *


def read_sigstruct(sig):
    """Reading Sigstruct."""
    # field format: (offset, type, value)
    # SGX_ARCH_ENCLAVE_CSS_
    fields = {
        'date': (SGX_ARCH_ENCLAVE_CSS_DATE, "<HBB", 'year', 'month', 'day'),
        'modulus': (SGX_ARCH_ENCLAVE_CSS_MODULUS, "384s", 'modulus'),
        'exponent': (SGX_ARCH_ENCLAVE_CSS_EXPONENT, "<L", 'exponent'),
        'signature': (SGX_ARCH_ENCLAVE_CSS_SIGNATURE, "384s", 'signature'),

        'misc_select': (SGX_ARCH_ENCLAVE_CSS_MISC_SELECT, "4s", 'misc_select'),
        'misc_mask': (SGX_ARCH_ENCLAVE_CSS_MISC_MASK, "4s", 'misc_mask'),
        'attributes': (SGX_ARCH_ENCLAVE_CSS_ATTRIBUTES, "8s8s", 'flags', 'xfrms'),
        'attribute_mask': (SGX_ARCH_ENCLAVE_CSS_ATTRIBUTE_MASK, "8s8s", 'flag_mask', 'xfrm_mask'),
        'enclave_hash': (SGX_ARCH_ENCLAVE_CSS_ENCLAVE_HASH, "32s", 'enclave_hash'),
        'isv_prod_id': (SGX_ARCH_ENCLAVE_CSS_ISV_PROD_ID, "<H", 'isv_prod_id'),
        'isv_svn': (SGX_ARCH_ENCLAVE_CSS_ISV_SVN, "<H", 'isv_svn'),
    }

    attr = dict()
    for field in fields.values():
        values = struct.unpack_from(field[1], sig, field[0])

        for i, value in enumerate(values):
            attr[field[i + 2]] = value

    return attr


def connect_aesmd(attr):
    """Connect with AESMD."""

    req_msg = aesm_pb2.GetTokenReq()
    req_msg.req.signature = attr['enclave_hash']
    req_msg.req.key = attr['modulus']
    req_msg.req.attributes = attr['flags'] + attr['xfrms']
    req_msg.req.timeout = 10000

    req_msg_raw = req_msg.SerializeToString()

    aesm_service = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

    # Try to connect to all possible interfaces exposed by aesm service
    connections = (
        "/var/run/aesmd/aesm.socket",         # named socket (for PSW 1.8+)
        "\0sgx_aesm_socket_base" + "\0" * 87  # unnamed socket (for PSW 1.6/1.7)
    )

    for conn in connections:
        try:
            aesm_service.connect(conn)
        except socket.error:
            continue
        break
    else:
        raise socket.error("Cannot connect to the AESMD service")

    aesm_service.send(struct.pack("<I", len(req_msg_raw)))
    aesm_service.send(req_msg_raw)

    ret_msg_size = struct.unpack("<I", aesm_service.recv(4))[0]
    ret_msg = aesm_pb2.GetTokenRet()
    ret_msg_raw = aesm_service.recv(ret_msg_size)
    ret_msg.ParseFromString(ret_msg_raw)

    if ret_msg.ret.error != 0:
        raise Exception("Failed. (Error Code = %d)" % (ret_msg.ret.error))

    return ret_msg.ret.token


argparser = argparse.ArgumentParser()
argparser.add_argument('--sig', '-sig', metavar='SIGNATURE',
                       type=argparse.FileType('rb'), required=True,
                       help='Input .sig file (contains SIGSTRUCT)')
argparser.add_argument('--output', '-output', metavar='OUTPUT',
                       type=argparse.FileType('wb'), required=True,
                       help='Output .token file (contains EINITTOKEN)')


def main(args=None):
    """Main Program."""
    args = argparser.parse_args(args)

    attr = read_sigstruct(args.sig.read())

    print("Attributes:")
    print("    mr_enclave:  %s" % attr['enclave_hash'].hex())
    print("    isv_prod_id: %d" % attr['isv_prod_id'])
    print("    isv_svn:     %d" % attr['isv_svn'])
    print("    attr.flags:  %016x" % int.from_bytes(attr['flags'], byteorder='big'))
    print("    attr.xfrm:   %016x" % int.from_bytes(attr['xfrms'], byteorder='big'))
    print("    misc_select: %08x" % int.from_bytes(attr['misc_select'], byteorder='big'))
    print("    misc_mask:   %08x" % int.from_bytes(attr['misc_mask'], byteorder='big'))
    print("    modulus:     %s..." % attr['modulus'].hex()[:32])
    print("    exponent:    %d" % attr['exponent'])
    print("    signature:   %s..." % attr['signature'].hex()[:32])
    print("    date:        %d-%02d-%02d" % (attr['year'], attr['month'], attr['day']))

    token = connect_aesmd(attr)
    args.output.write(token)
    return 0


if __name__ == "__main__":
    sys.exit(main())
