123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- #!/usr/bin/env python3
- # pylint: disable=invalid-name
- import argparse
- import os
- import socket
- import struct
- import sys
- import aesm_pb2
- sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
- import generated_offsets as offs # pylint: disable=import-error,wrong-import-position
- # pylint: enable=invalid-name
- def set_optional_sgx_features(attr):
- """Set optional SGX features if they are available on this machine."""
- optional_sgx_features = {
- offs.SGX_XFRM_AVX: "avx",
- offs.SGX_XFRM_AVX512: "avx512f",
- offs.SGX_XFRM_MPX: "mpx",
- }
- cpu_features = ""
- with open("/proc/cpuinfo", "r") as file:
- for line in file:
- if line.startswith("flags"):
- cpu_features = line.split(":")[1].strip().split()
- break
- else:
- raise Exception("Failed to parse CPU flags")
- xfrms = int.from_bytes(attr['xfrms'], byteorder='little')
- xfrmmask = int.from_bytes(attr['xfrm_mask'], byteorder='little')
- new_xfrms = 0
- for (bits, feature) in optional_sgx_features.items():
- # Check if SIGSTRUCT allows enabling an optional CPU feature.
- # If all the xfrm bits for a feature, after applying xfrmmask, are set in xfrms,
- # we can set the remaining bits if the feature is available.
- # If the xfrmmask includes all the required xfrm bits, then these bits cannot be
- # changed in xfrm (need to stay the same as signed).
- if xfrms & (bits & xfrmmask) == (bits & xfrmmask) and feature in cpu_features:
- new_xfrms |= xfrms | bits
- attr['xfrms'] = new_xfrms.to_bytes(length=8, byteorder='little')
- def read_sigstruct(sig):
- """Reading Sigstruct."""
- # field format: (offset, type, value)
- # SGX_ARCH_ENCLAVE_CSS_
- fields = {
- 'date': (offs.SGX_ARCH_ENCLAVE_CSS_DATE, "<HBB", 'year', 'month', 'day'),
- 'modulus': (offs.SGX_ARCH_ENCLAVE_CSS_MODULUS, "384s", 'modulus'),
- 'exponent': (offs.SGX_ARCH_ENCLAVE_CSS_EXPONENT, "<L", 'exponent'),
- 'signature': (offs.SGX_ARCH_ENCLAVE_CSS_SIGNATURE, "384s", 'signature'),
- 'misc_select': (offs.SGX_ARCH_ENCLAVE_CSS_MISC_SELECT, "4s", 'misc_select'),
- 'misc_mask': (offs.SGX_ARCH_ENCLAVE_CSS_MISC_MASK, "4s", 'misc_mask'),
- 'attributes': (offs.SGX_ARCH_ENCLAVE_CSS_ATTRIBUTES, "8s8s", 'flags', 'xfrms'),
- 'attribute_mask': (offs.SGX_ARCH_ENCLAVE_CSS_ATTRIBUTE_MASK, "8s8s",
- 'flag_mask', 'xfrm_mask'),
- 'enclave_hash': (offs.SGX_ARCH_ENCLAVE_CSS_ENCLAVE_HASH, "32s", 'enclave_hash'),
- 'isv_prod_id': (offs.SGX_ARCH_ENCLAVE_CSS_ISV_PROD_ID, "<H", 'isv_prod_id'),
- 'isv_svn': (offs.SGX_ARCH_ENCLAVE_CSS_ISV_SVN, "<H", 'isv_svn'),
- }
- attr = dict()
- for field in fields.values():
- values = struct.unpack_from(field[1], sig, field[0])
- for i, value in enumerate(values):
- attr[field[i + 2]] = value
- return attr
- def connect_aesmd(attr):
- """Connect with AESMD."""
- req_msg = aesm_pb2.GetTokenReq()
- req_msg.req.signature = attr['enclave_hash']
- req_msg.req.key = attr['modulus']
- req_msg.req.attributes = attr['flags'] + attr['xfrms']
- req_msg.req.timeout = 10000
- req_msg_raw = req_msg.SerializeToString()
- aesm_service = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- # Try to connect to all possible interfaces exposed by aesm service
- connections = (
- "/var/run/aesmd/aesm.socket", # named socket (for PSW 1.8+)
- "\0sgx_aesm_socket_base" + "\0" * 87 # unnamed socket (for PSW 1.6/1.7)
- )
- for conn in connections:
- try:
- aesm_service.connect(conn)
- except socket.error:
- continue
- break
- else:
- raise socket.error("Cannot connect to the AESMD service")
- aesm_service.send(struct.pack("<I", len(req_msg_raw)))
- aesm_service.send(req_msg_raw)
- ret_msg_size = struct.unpack("<I", aesm_service.recv(4))[0]
- ret_msg = aesm_pb2.GetTokenRet()
- ret_msg_raw = aesm_service.recv(ret_msg_size)
- ret_msg.ParseFromString(ret_msg_raw)
- if ret_msg.ret.error != 0:
- raise Exception("Failed. (Error Code = %d)" % (ret_msg.ret.error))
- return ret_msg.ret.token
- argparser = argparse.ArgumentParser()
- argparser.add_argument('--sig', '-sig', metavar='SIGNATURE',
- type=argparse.FileType('rb'), required=True,
- help='Input .sig file (contains SIGSTRUCT)')
- argparser.add_argument('--output', '-output', metavar='OUTPUT',
- type=argparse.FileType('wb'), required=True,
- help='Output .token file (contains EINITTOKEN)')
- def main(args=None):
- """Main Program."""
- args = argparser.parse_args(args)
- attr = read_sigstruct(args.sig.read())
- set_optional_sgx_features(attr)
- print("Attributes:")
- print(" mr_enclave: %s" % attr['enclave_hash'].hex())
- print(" isv_prod_id: %d" % attr['isv_prod_id'])
- print(" isv_svn: %d" % attr['isv_svn'])
- print(" attr.flags: %016x" % int.from_bytes(attr['flags'], byteorder='big'))
- print(" attr.xfrm: %016x" % int.from_bytes(attr['xfrms'], byteorder='big'))
- print(" misc_select: %08x" % int.from_bytes(attr['misc_select'], byteorder='big'))
- print(" misc_mask: %08x" % int.from_bytes(attr['misc_mask'], byteorder='big'))
- print(" modulus: %s..." % attr['modulus'].hex()[:32])
- print(" exponent: %d" % attr['exponent'])
- print(" signature: %s..." % attr['signature'].hex()[:32])
- print(" date: %d-%02d-%02d" % (attr['year'], attr['month'], attr['day']))
- token = connect_aesmd(attr)
- args.output.write(token)
- return 0
- if __name__ == "__main__":
- sys.exit(main())
|