pal-sgx-get-token 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import socket
  5. import struct
  6. import sys
  7. import aesm_pb2
  8. sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
  9. from generated_offsets import *
  10. def set_optional_sgx_features(attr):
  11. """Set optional SGX features if they are available on this machine."""
  12. optional_sgx_features = {
  13. SGX_XFRM_AVX: "avx",
  14. SGX_XFRM_AVX512: "avx512f",
  15. SGX_XFRM_MPX: "mpx",
  16. }
  17. cpu_features = ""
  18. with open("/proc/cpuinfo", "r") as f:
  19. for line in f:
  20. if line.startswith("flags"):
  21. cpu_features = line.split(":")[1].strip().split()
  22. break
  23. else:
  24. raise Exception("Failed to parse CPU flags")
  25. xfrms = int.from_bytes(attr['xfrms'], byteorder='little')
  26. xfrmmask = int.from_bytes(attr['xfrm_mask'], byteorder='little')
  27. new_xfrms = 0
  28. for (bits, feature) in optional_sgx_features.items():
  29. # Check if SIGSTRUCT allows enabling an optional CPU feature.
  30. # If all the xfrm bits for a feature, after applying xfrmmask, are set in xfrms,
  31. # we can set the remaining bits if the feature is available.
  32. # If the xfrmmask includes all the required xfrm bits, then these bits cannot be
  33. # changed in xfrm (need to stay the same as signed).
  34. if xfrms & (bits & xfrmmask) == (bits & xfrmmask) and feature in cpu_features:
  35. new_xfrms |= xfrms | bits
  36. attr['xfrms'] = new_xfrms.to_bytes(length=8, byteorder='little')
  37. def read_sigstruct(sig):
  38. """Reading Sigstruct."""
  39. # field format: (offset, type, value)
  40. # SGX_ARCH_ENCLAVE_CSS_
  41. fields = {
  42. 'date': (SGX_ARCH_ENCLAVE_CSS_DATE, "<HBB", 'year', 'month', 'day'),
  43. 'modulus': (SGX_ARCH_ENCLAVE_CSS_MODULUS, "384s", 'modulus'),
  44. 'exponent': (SGX_ARCH_ENCLAVE_CSS_EXPONENT, "<L", 'exponent'),
  45. 'signature': (SGX_ARCH_ENCLAVE_CSS_SIGNATURE, "384s", 'signature'),
  46. 'misc_select': (SGX_ARCH_ENCLAVE_CSS_MISC_SELECT, "4s", 'misc_select'),
  47. 'misc_mask': (SGX_ARCH_ENCLAVE_CSS_MISC_MASK, "4s", 'misc_mask'),
  48. 'attributes': (SGX_ARCH_ENCLAVE_CSS_ATTRIBUTES, "8s8s", 'flags', 'xfrms'),
  49. 'attribute_mask': (SGX_ARCH_ENCLAVE_CSS_ATTRIBUTE_MASK, "8s8s", 'flag_mask', 'xfrm_mask'),
  50. 'enclave_hash': (SGX_ARCH_ENCLAVE_CSS_ENCLAVE_HASH, "32s", 'enclave_hash'),
  51. 'isv_prod_id': (SGX_ARCH_ENCLAVE_CSS_ISV_PROD_ID, "<H", 'isv_prod_id'),
  52. 'isv_svn': (SGX_ARCH_ENCLAVE_CSS_ISV_SVN, "<H", 'isv_svn'),
  53. }
  54. attr = dict()
  55. for field in fields.values():
  56. values = struct.unpack_from(field[1], sig, field[0])
  57. for i, value in enumerate(values):
  58. attr[field[i + 2]] = value
  59. return attr
  60. def connect_aesmd(attr):
  61. """Connect with AESMD."""
  62. req_msg = aesm_pb2.GetTokenReq()
  63. req_msg.req.signature = attr['enclave_hash']
  64. req_msg.req.key = attr['modulus']
  65. req_msg.req.attributes = attr['flags'] + attr['xfrms']
  66. req_msg.req.timeout = 10000
  67. req_msg_raw = req_msg.SerializeToString()
  68. aesm_service = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  69. # Try to connect to all possible interfaces exposed by aesm service
  70. connections = (
  71. "/var/run/aesmd/aesm.socket", # named socket (for PSW 1.8+)
  72. "\0sgx_aesm_socket_base" + "\0" * 87 # unnamed socket (for PSW 1.6/1.7)
  73. )
  74. for conn in connections:
  75. try:
  76. aesm_service.connect(conn)
  77. except socket.error:
  78. continue
  79. break
  80. else:
  81. raise socket.error("Cannot connect to the AESMD service")
  82. aesm_service.send(struct.pack("<I", len(req_msg_raw)))
  83. aesm_service.send(req_msg_raw)
  84. ret_msg_size = struct.unpack("<I", aesm_service.recv(4))[0]
  85. ret_msg = aesm_pb2.GetTokenRet()
  86. ret_msg_raw = aesm_service.recv(ret_msg_size)
  87. ret_msg.ParseFromString(ret_msg_raw)
  88. if ret_msg.ret.error != 0:
  89. raise Exception("Failed. (Error Code = %d)" % (ret_msg.ret.error))
  90. return ret_msg.ret.token
  91. argparser = argparse.ArgumentParser()
  92. argparser.add_argument('--sig', '-sig', metavar='SIGNATURE',
  93. type=argparse.FileType('rb'), required=True,
  94. help='Input .sig file (contains SIGSTRUCT)')
  95. argparser.add_argument('--output', '-output', metavar='OUTPUT',
  96. type=argparse.FileType('wb'), required=True,
  97. help='Output .token file (contains EINITTOKEN)')
  98. def main(args=None):
  99. """Main Program."""
  100. args = argparser.parse_args(args)
  101. attr = read_sigstruct(args.sig.read())
  102. set_optional_sgx_features(attr)
  103. print("Attributes:")
  104. print(" mr_enclave: %s" % attr['enclave_hash'].hex())
  105. print(" isv_prod_id: %d" % attr['isv_prod_id'])
  106. print(" isv_svn: %d" % attr['isv_svn'])
  107. print(" attr.flags: %016x" % int.from_bytes(attr['flags'], byteorder='big'))
  108. print(" attr.xfrm: %016x" % int.from_bytes(attr['xfrms'], byteorder='big'))
  109. print(" misc_select: %08x" % int.from_bytes(attr['misc_select'], byteorder='big'))
  110. print(" misc_mask: %08x" % int.from_bytes(attr['misc_mask'], byteorder='big'))
  111. print(" modulus: %s..." % attr['modulus'].hex()[:32])
  112. print(" exponent: %d" % attr['exponent'])
  113. print(" signature: %s..." % attr['signature'].hex()[:32])
  114. print(" date: %d-%02d-%02d" % (attr['year'], attr['month'], attr['day']))
  115. token = connect_aesmd(attr)
  116. args.output.write(token)
  117. return 0
  118. if __name__ == "__main__":
  119. sys.exit(main())