pal-sgx-get-token 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #!/usr/bin/env python3
  2. # pylint: disable=invalid-name
  3. import argparse
  4. import os
  5. import socket
  6. import struct
  7. import sys
  8. import aesm_pb2
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
  10. import generated_offsets as offs # pylint: disable=import-error,wrong-import-position
  11. # pylint: enable=invalid-name
  12. def set_optional_sgx_features(attr):
  13. """Set optional SGX features if they are available on this machine."""
  14. optional_sgx_features = {
  15. offs.SGX_XFRM_AVX: "avx",
  16. offs.SGX_XFRM_AVX512: "avx512f",
  17. offs.SGX_XFRM_MPX: "mpx",
  18. }
  19. cpu_features = ""
  20. with open("/proc/cpuinfo", "r") as file:
  21. for line in file:
  22. if line.startswith("flags"):
  23. cpu_features = line.split(":")[1].strip().split()
  24. break
  25. else:
  26. raise Exception("Failed to parse CPU flags")
  27. xfrms = int.from_bytes(attr['xfrms'], byteorder='little')
  28. xfrmmask = int.from_bytes(attr['xfrm_mask'], byteorder='little')
  29. new_xfrms = 0
  30. for (bits, feature) in optional_sgx_features.items():
  31. # Check if SIGSTRUCT allows enabling an optional CPU feature.
  32. # If all the xfrm bits for a feature, after applying xfrmmask, are set in xfrms,
  33. # we can set the remaining bits if the feature is available.
  34. # If the xfrmmask includes all the required xfrm bits, then these bits cannot be
  35. # changed in xfrm (need to stay the same as signed).
  36. if xfrms & (bits & xfrmmask) == (bits & xfrmmask) and feature in cpu_features:
  37. new_xfrms |= xfrms | bits
  38. attr['xfrms'] = new_xfrms.to_bytes(length=8, byteorder='little')
  39. def read_sigstruct(sig):
  40. """Reading Sigstruct."""
  41. # field format: (offset, type, value)
  42. # SGX_ARCH_ENCLAVE_CSS_
  43. fields = {
  44. 'date': (offs.SGX_ARCH_ENCLAVE_CSS_DATE, "<HBB", 'year', 'month', 'day'),
  45. 'modulus': (offs.SGX_ARCH_ENCLAVE_CSS_MODULUS, "384s", 'modulus'),
  46. 'exponent': (offs.SGX_ARCH_ENCLAVE_CSS_EXPONENT, "<L", 'exponent'),
  47. 'signature': (offs.SGX_ARCH_ENCLAVE_CSS_SIGNATURE, "384s", 'signature'),
  48. 'misc_select': (offs.SGX_ARCH_ENCLAVE_CSS_MISC_SELECT, "4s", 'misc_select'),
  49. 'misc_mask': (offs.SGX_ARCH_ENCLAVE_CSS_MISC_MASK, "4s", 'misc_mask'),
  50. 'attributes': (offs.SGX_ARCH_ENCLAVE_CSS_ATTRIBUTES, "8s8s", 'flags', 'xfrms'),
  51. 'attribute_mask': (offs.SGX_ARCH_ENCLAVE_CSS_ATTRIBUTE_MASK, "8s8s",
  52. 'flag_mask', 'xfrm_mask'),
  53. 'enclave_hash': (offs.SGX_ARCH_ENCLAVE_CSS_ENCLAVE_HASH, "32s", 'enclave_hash'),
  54. 'isv_prod_id': (offs.SGX_ARCH_ENCLAVE_CSS_ISV_PROD_ID, "<H", 'isv_prod_id'),
  55. 'isv_svn': (offs.SGX_ARCH_ENCLAVE_CSS_ISV_SVN, "<H", 'isv_svn'),
  56. }
  57. attr = dict()
  58. for field in fields.values():
  59. values = struct.unpack_from(field[1], sig, field[0])
  60. for i, value in enumerate(values):
  61. attr[field[i + 2]] = value
  62. return attr
  63. def connect_aesmd(attr):
  64. """Connect with AESMD."""
  65. req_msg = aesm_pb2.GetTokenReq()
  66. req_msg.req.signature = attr['enclave_hash']
  67. req_msg.req.key = attr['modulus']
  68. req_msg.req.attributes = attr['flags'] + attr['xfrms']
  69. req_msg.req.timeout = 10000
  70. req_msg_raw = req_msg.SerializeToString()
  71. aesm_service = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  72. # Try to connect to all possible interfaces exposed by aesm service
  73. connections = (
  74. "/var/run/aesmd/aesm.socket", # named socket (for PSW 1.8+)
  75. "\0sgx_aesm_socket_base" + "\0" * 87 # unnamed socket (for PSW 1.6/1.7)
  76. )
  77. for conn in connections:
  78. try:
  79. aesm_service.connect(conn)
  80. except socket.error:
  81. continue
  82. break
  83. else:
  84. raise socket.error("Cannot connect to the AESMD service")
  85. aesm_service.send(struct.pack("<I", len(req_msg_raw)))
  86. aesm_service.send(req_msg_raw)
  87. ret_msg_size = struct.unpack("<I", aesm_service.recv(4))[0]
  88. ret_msg = aesm_pb2.GetTokenRet()
  89. ret_msg_raw = aesm_service.recv(ret_msg_size)
  90. ret_msg.ParseFromString(ret_msg_raw)
  91. if ret_msg.ret.error != 0:
  92. raise Exception("Failed. (Error Code = %d)" % (ret_msg.ret.error))
  93. return ret_msg.ret.token
  94. argparser = argparse.ArgumentParser()
  95. argparser.add_argument('--sig', '-sig', metavar='SIGNATURE',
  96. type=argparse.FileType('rb'), required=True,
  97. help='Input .sig file (contains SIGSTRUCT)')
  98. argparser.add_argument('--output', '-output', metavar='OUTPUT',
  99. type=argparse.FileType('wb'), required=True,
  100. help='Output .token file (contains EINITTOKEN)')
  101. def main(args=None):
  102. """Main Program."""
  103. args = argparser.parse_args(args)
  104. attr = read_sigstruct(args.sig.read())
  105. set_optional_sgx_features(attr)
  106. print("Attributes:")
  107. print(" mr_enclave: %s" % attr['enclave_hash'].hex())
  108. print(" isv_prod_id: %d" % attr['isv_prod_id'])
  109. print(" isv_svn: %d" % attr['isv_svn'])
  110. print(" attr.flags: %016x" % int.from_bytes(attr['flags'], byteorder='big'))
  111. print(" attr.xfrm: %016x" % int.from_bytes(attr['xfrms'], byteorder='big'))
  112. print(" misc_select: %08x" % int.from_bytes(attr['misc_select'], byteorder='big'))
  113. print(" misc_mask: %08x" % int.from_bytes(attr['misc_mask'], byteorder='big'))
  114. print(" modulus: %s..." % attr['modulus'].hex()[:32])
  115. print(" exponent: %d" % attr['exponent'])
  116. print(" signature: %s..." % attr['signature'].hex()[:32])
  117. print(" date: %d-%02d-%02d" % (attr['year'], attr['month'], attr['day']))
  118. token = connect_aesmd(attr)
  119. args.output.write(token)
  120. return 0
  121. if __name__ == "__main__":
  122. sys.exit(main())