#!/usr/bin/env python3
# pylint: disable=invalid-name

import argparse
import os
import socket
import struct
import sys

import aesm_pb2

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
import generated_offsets as offs # pylint: disable=import-error,wrong-import-position

# pylint: enable=invalid-name

def set_optional_sgx_features(attr):
    """Set optional SGX features if they are available on this machine."""
    optional_sgx_features = {
        offs.SGX_XFRM_AVX:      "avx",
        offs.SGX_XFRM_AVX512:   "avx512f",
        offs.SGX_XFRM_MPX:      "mpx",
    }

    cpu_features = ""
    with open("/proc/cpuinfo", "r") as file:
        for line in file:
            if line.startswith("flags"):
                cpu_features = line.split(":")[1].strip().split()
                break
        else:
            raise Exception("Failed to parse CPU flags")

    xfrms = int.from_bytes(attr['xfrms'], byteorder='little')
    xfrmmask = int.from_bytes(attr['xfrm_mask'], byteorder='little')

    new_xfrms = 0
    for (bits, feature) in optional_sgx_features.items():
        # Check if SIGSTRUCT allows enabling an optional CPU feature.
        # If all the xfrm bits for a feature, after applying xfrmmask, are set in xfrms,
        # we can set the remaining bits if the feature is available.
        # If the xfrmmask includes all the required xfrm bits, then these bits cannot be
        # changed in xfrm (need to stay the same as signed).
        if xfrms & (bits & xfrmmask) == (bits & xfrmmask) and feature in cpu_features:
            new_xfrms |= xfrms | bits

    attr['xfrms'] = new_xfrms.to_bytes(length=8, byteorder='little')


def read_sigstruct(sig):
    """Reading Sigstruct."""
    # field format: (offset, type, value)
    # SGX_ARCH_ENCLAVE_CSS_
    fields = {
        'date': (offs.SGX_ARCH_ENCLAVE_CSS_DATE, "<HBB", 'year', 'month', 'day'),
        'modulus': (offs.SGX_ARCH_ENCLAVE_CSS_MODULUS, "384s", 'modulus'),
        'exponent': (offs.SGX_ARCH_ENCLAVE_CSS_EXPONENT, "<L", 'exponent'),
        'signature': (offs.SGX_ARCH_ENCLAVE_CSS_SIGNATURE, "384s", 'signature'),

        'misc_select': (offs.SGX_ARCH_ENCLAVE_CSS_MISC_SELECT, "4s", 'misc_select'),
        'misc_mask': (offs.SGX_ARCH_ENCLAVE_CSS_MISC_MASK, "4s", 'misc_mask'),
        'attributes': (offs.SGX_ARCH_ENCLAVE_CSS_ATTRIBUTES, "8s8s", 'flags', 'xfrms'),
        'attribute_mask': (offs.SGX_ARCH_ENCLAVE_CSS_ATTRIBUTE_MASK, "8s8s",
            'flag_mask', 'xfrm_mask'),
        'enclave_hash': (offs.SGX_ARCH_ENCLAVE_CSS_ENCLAVE_HASH, "32s", 'enclave_hash'),
        'isv_prod_id': (offs.SGX_ARCH_ENCLAVE_CSS_ISV_PROD_ID, "<H", 'isv_prod_id'),
        'isv_svn': (offs.SGX_ARCH_ENCLAVE_CSS_ISV_SVN, "<H", 'isv_svn'),
    }

    attr = dict()
    for field in fields.values():
        values = struct.unpack_from(field[1], sig, field[0])

        for i, value in enumerate(values):
            attr[field[i + 2]] = value

    return attr


def connect_aesmd(attr):
    """Connect with AESMD."""

    req_msg = aesm_pb2.GetTokenReq()
    req_msg.req.signature = attr['enclave_hash']
    req_msg.req.key = attr['modulus']
    req_msg.req.attributes = attr['flags'] + attr['xfrms']
    req_msg.req.timeout = 10000

    req_msg_raw = req_msg.SerializeToString()

    aesm_service = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

    # Try to connect to all possible interfaces exposed by aesm service
    connections = (
        "/var/run/aesmd/aesm.socket",         # named socket (for PSW 1.8+)
        "\0sgx_aesm_socket_base" + "\0" * 87  # unnamed socket (for PSW 1.6/1.7)
    )

    for conn in connections:
        try:
            aesm_service.connect(conn)
        except socket.error:
            continue
        break
    else:
        raise socket.error("Cannot connect to the AESMD service")

    aesm_service.send(struct.pack("<I", len(req_msg_raw)))
    aesm_service.send(req_msg_raw)

    ret_msg_size = struct.unpack("<I", aesm_service.recv(4))[0]
    ret_msg = aesm_pb2.GetTokenRet()
    ret_msg_raw = aesm_service.recv(ret_msg_size)
    ret_msg.ParseFromString(ret_msg_raw)

    if ret_msg.ret.error != 0:
        raise Exception("Failed. (Error Code = %d)" % (ret_msg.ret.error))

    return ret_msg.ret.token


argparser = argparse.ArgumentParser()
argparser.add_argument('--sig', '-sig', metavar='SIGNATURE',
                       type=argparse.FileType('rb'), required=True,
                       help='Input .sig file (contains SIGSTRUCT)')
argparser.add_argument('--output', '-output', metavar='OUTPUT',
                       type=argparse.FileType('wb'), required=True,
                       help='Output .token file (contains EINITTOKEN)')


def main(args=None):
    """Main Program."""
    args = argparser.parse_args(args)

    attr = read_sigstruct(args.sig.read())
    set_optional_sgx_features(attr)

    print("Attributes:")
    print("    mr_enclave:  %s" % attr['enclave_hash'].hex())
    print("    isv_prod_id: %d" % attr['isv_prod_id'])
    print("    isv_svn:     %d" % attr['isv_svn'])
    print("    attr.flags:  %016x" % int.from_bytes(attr['flags'], byteorder='big'))
    print("    attr.xfrm:   %016x" % int.from_bytes(attr['xfrms'], byteorder='big'))
    print("    misc_select: %08x" % int.from_bytes(attr['misc_select'], byteorder='big'))
    print("    misc_mask:   %08x" % int.from_bytes(attr['misc_mask'], byteorder='big'))
    print("    modulus:     %s..." % attr['modulus'].hex()[:32])
    print("    exponent:    %d" % attr['exponent'])
    print("    signature:   %s..." % attr['signature'].hex()[:32])
    print("    date:        %d-%02d-%02d" % (attr['year'], attr['month'], attr['day']))

    token = connect_aesmd(attr)
    args.output.write(token)
    return 0


if __name__ == "__main__":
    sys.exit(main())
