comms.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. #include <vector>
  2. #include <functional>
  3. #include <cstring>
  4. #include "sgx_tcrypto.h"
  5. #include "sgx_tseal.h"
  6. #include "Enclave_t.h"
  7. #include "utils.hpp"
  8. #include "config.hpp"
  9. // Our public and private identity keys
  10. static sgx_ec256_private_t g_privkey;
  11. static sgx_ec256_public_t g_pubkey;
  12. // What step of the handshake are we on?
  13. enum HandshakeStep {
  14. HANDSHAKE_NONE,
  15. HANDSHAKE_C_SENT_1,
  16. HANDSHAKE_S_SENT_2,
  17. HANDSHAKE_COMPLETE
  18. };
  19. // Communication state for a node
  20. struct NodeCommState {
  21. sgx_ec256_public_t pubkey;
  22. nodenum_t node_num;
  23. HandshakeStep handshake_step;
  24. // Our DH keypair during the handshake
  25. sgx_ec256_private_t handshake_privkey;
  26. sgx_ec256_public_t handshake_pubkey;
  27. // The peer's DH public key during the handshake
  28. sgx_ec256_public_t handshake_peer_pubkey;
  29. // The outgoing and incoming AES keys after the handshake
  30. sgx_aes_gcm_128bit_key_t out_aes_key, in_aes_key;
  31. // The outgoing and incoming IV counters
  32. uint8_t out_aes_iv[SGX_AESGCM_IV_SIZE];
  33. uint8_t in_aes_iv[SGX_AESGCM_IV_SIZE];
  34. // The GCM state for incrementally building each outgoing chunk
  35. sgx_aes_state_handle_t out_aes_gcm_state;
  36. // The current outgoing frame and the current offset into it
  37. uint8_t *frame;
  38. uint32_t frame_offset;
  39. // The current outgoing message ciphertext size and the offset into
  40. // it of the start of the current frame
  41. uint32_t msg_size;
  42. uint32_t msg_frame_offset;
  43. // The current outgoing message plaintext size, how many plaintext
  44. // bytes we've already processed with message_data, and how many
  45. // plaintext bytes remain for the current chunk
  46. uint32_t msg_plaintext_size;
  47. uint32_t msg_plaintext_processed;
  48. uint32_t msg_plaintext_chunk_remain;
  49. // The current incoming message ciphertext size and the offset into
  50. // it of all previous chunks of this message
  51. uint32_t in_msg_size;
  52. uint32_t in_msg_offset;
  53. // The current incoming message number of plaintext bytes processed
  54. uint32_t in_msg_plaintext_processed;
  55. // The internal buffer where we're storing the (decrypted) message
  56. uint8_t *in_msg_buf;
  57. // The function to call when a new incoming message header arrives.
  58. // This function should return a pointer to enough memory to hold
  59. // the (decrypted) chunks of the message. Remember that the length
  60. // passed here is the total size of the _encrypted_ chunks. This
  61. // function should not itself modify the in_msg_size, in_msg_offset,
  62. // or in_msg_buf members. This function will usually allocate an
  63. // appropriate amount of memory and return the pointer to it, but
  64. // may do other things, like return a pointer to the middle of a
  65. // previously allocated region of memory.
  66. std::function<uint8_t*(NodeCommState&,uint32_t)> in_msg_get_buf;
  67. // The function to call after the last chunk of a message has been
  68. // received. If in_msg_get_buf allocated memory, this function
  69. // should deallocate it. in_msg_size, in_msg_offset, and in_msg_buf
  70. // will already have been reset when this function is called. The
  71. // uint32_t that is passed are the total size of the _decrypted_
  72. // data and the original total size of the _encrypted_ chunks that
  73. // was passed to in_msg_get_buf.
  74. std::function<void(NodeCommState&,uint8_t*,uint32_t,uint32_t)>
  75. in_msg_received;
  76. NodeCommState(const sgx_ec256_public_t* conf_pubkey, nodenum_t i) :
  77. node_num(i), handshake_step(HANDSHAKE_NONE),
  78. out_aes_gcm_state(NULL), frame(NULL),
  79. frame_offset(0), msg_size(0), msg_frame_offset(0),
  80. msg_plaintext_size(0), msg_plaintext_processed(0),
  81. msg_plaintext_chunk_remain(0),
  82. in_msg_size(0), in_msg_offset(0),
  83. in_msg_plaintext_processed(0), in_msg_buf(NULL),
  84. in_msg_get_buf(NULL), in_msg_received(NULL) {
  85. memmove(&pubkey, conf_pubkey, sizeof(pubkey));
  86. }
  87. void message_start(uint32_t plaintext_len);
  88. void message_data(uint8_t *data, uint32_t len);
  89. // Start the handshake (as the client)
  90. void handshake_start();
  91. };
  92. // A typical default in_msg_get_buf handler. It computes the maximum
  93. // possible size of the decrypted data, allocates that much memory, and
  94. // returns a pointer to it.
  95. static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
  96. uint32_t tot_enc_chunk_size)
  97. {
  98. uint32_t max_plaintext_bytes = tot_enc_chunk_size;
  99. // If the handshake is complete, chunks will be encrypted and have a
  100. // MAC tag attached which will not correspond to plaintext bytes, so
  101. // we can trim them.
  102. if (commst.handshake_step == HANDSHAKE_COMPLETE) {
  103. // The minimum number of chunks needed to transmit this message
  104. uint32_t min_num_chunks =
  105. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  106. // The maximum number of plaintext bytes this message could contain
  107. max_plaintext_bytes = tot_enc_chunk_size -
  108. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  109. }
  110. return new uint8_t[max_plaintext_bytes];
  111. }
  112. // Receive (at the server) the first handshake message
  113. static void handshake_1_msg_received(NodeCommState &nodest,
  114. uint8_t *data, uint32_t plaintext_len, uint32_t)
  115. {
  116. /*
  117. printf("Received handshake_1 message of %u bytes:\n", plaintext_len);
  118. for (uint32_t i=0;i<plaintext_len;++i) {
  119. printf("%02x", data[i]);
  120. }
  121. printf("\n");
  122. */
  123. if (plaintext_len != sizeof(sgx_ec256_public_t)) {
  124. printf("Received handshake_1 message of incorrect size %u\n",
  125. plaintext_len);
  126. return;
  127. }
  128. sgx_ecc_state_handle_t ecc_handle;
  129. sgx_ec256_public_t pubkey;
  130. memmove(&pubkey, data, sizeof(pubkey));
  131. sgx_ecc256_open_context(&ecc_handle);
  132. int valid;
  133. if (sgx_ecc256_check_point(&pubkey, ecc_handle, &valid) || !valid) {
  134. printf("Invalid public key received from node %hu\n",
  135. nodest.node_num);
  136. sgx_ecc256_close_context(ecc_handle);
  137. return;
  138. }
  139. delete[] data;
  140. printf("Valid public key received from node %hu\n", nodest.node_num);
  141. memmove(&nodest.handshake_peer_pubkey, &pubkey, sizeof(pubkey));
  142. // Create our own DH key pair
  143. sgx_ecc256_create_key_pair(&nodest.handshake_privkey,
  144. &nodest.handshake_pubkey, ecc_handle);
  145. sgx_ecc256_close_context(ecc_handle);
  146. }
  147. // Start a new outgoing message. Pass the number of _plaintext_ bytes
  148. // the message will be.
  149. void NodeCommState::message_start(uint32_t plaintext_len)
  150. {
  151. uint32_t ciphertext_len = plaintext_len;
  152. // If the handshake is complete, add SGX_AESGCM_MAC_SIZE bytes for
  153. // every FRAME_SIZE-SGX_AESGCM_MAC_SIZE bytes of plaintext.
  154. if (handshake_step == HANDSHAKE_COMPLETE) {
  155. uint32_t num_chunks = (plaintext_len +
  156. FRAME_SIZE - SGX_AESGCM_MAC_SIZE - 1) /
  157. (FRAME_SIZE - SGX_AESGCM_MAC_SIZE);
  158. ciphertext_len = plaintext_len +
  159. num_chunks * SGX_AESGCM_MAC_SIZE;
  160. }
  161. ocall_message(&frame, node_num, ciphertext_len);
  162. frame_offset = 0;
  163. msg_size = ciphertext_len;
  164. msg_frame_offset = 0;
  165. msg_plaintext_size = plaintext_len;
  166. msg_plaintext_processed = 0;
  167. if (plaintext_len < FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
  168. msg_plaintext_chunk_remain = plaintext_len;
  169. } else {
  170. msg_plaintext_chunk_remain = FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
  171. }
  172. if (!frame) {
  173. printf("Received NULL back from ocall_message\n");
  174. }
  175. if (msg_plaintext_chunk_remain > 0) {
  176. *(size_t*)out_aes_iv += 1;
  177. sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv, SGX_AESGCM_IV_SIZE,
  178. NULL, 0, &out_aes_gcm_state);
  179. }
  180. }
  181. // Process len bytes of plaintext data into the current message.
  182. void NodeCommState::message_data(uint8_t *data, uint32_t len)
  183. {
  184. while (len > 0) {
  185. if (msg_plaintext_chunk_remain == 0) {
  186. printf("Attempt to queue too much message data\n");
  187. return;
  188. }
  189. uint32_t bytes_to_process = len;
  190. if (bytes_to_process > msg_plaintext_chunk_remain) {
  191. bytes_to_process = msg_plaintext_chunk_remain;
  192. }
  193. if (frame == NULL) {
  194. printf("frame is NULL when queueing message data\n");
  195. return;
  196. }
  197. if (handshake_step == HANDSHAKE_COMPLETE) {
  198. // Encrypt the data
  199. sgx_aes_gcm128_enc_update(data, bytes_to_process,
  200. frame+frame_offset, out_aes_gcm_state);
  201. } else {
  202. // Just copy the plaintext data during the handshake
  203. memmove(frame+frame_offset, data, bytes_to_process);
  204. }
  205. frame_offset += bytes_to_process;
  206. msg_plaintext_processed += bytes_to_process;
  207. msg_plaintext_chunk_remain -= bytes_to_process;
  208. len -= bytes_to_process;
  209. data += bytes_to_process;
  210. if (msg_plaintext_chunk_remain == 0) {
  211. // Complete and send this chunk
  212. if (handshake_step == HANDSHAKE_COMPLETE) {
  213. sgx_aes_gcm128_enc_get_mac(frame+frame_offset,
  214. out_aes_gcm_state);
  215. frame_offset += SGX_AESGCM_MAC_SIZE;
  216. }
  217. uint8_t *nextframe = NULL;
  218. ocall_chunk(&nextframe, node_num, frame, frame_offset);
  219. frame = nextframe;
  220. msg_frame_offset += frame_offset;
  221. frame_offset = 0;
  222. msg_plaintext_chunk_remain =
  223. msg_plaintext_size - msg_plaintext_processed;
  224. if (msg_plaintext_chunk_remain >
  225. FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
  226. msg_plaintext_chunk_remain =
  227. FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
  228. }
  229. if (handshake_step == HANDSHAKE_COMPLETE) {
  230. sgx_aes_gcm_close(out_aes_gcm_state);
  231. if (msg_plaintext_chunk_remain > 0) {
  232. *(size_t*)out_aes_iv += 1;
  233. sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
  234. SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
  235. }
  236. }
  237. }
  238. }
  239. }
  240. // The communication states for all the nodes. There's an entry for
  241. // ourselves in here, but it is unused.
  242. static std::vector<NodeCommState> commstates;
  243. static nodenum_t tot_nodes, my_node_num;
  244. // Generate a new identity signature key. Output the public key and the
  245. // sealed private key. outsealedpriv must point to SEALEDPRIVKEY_SIZE =
  246. // sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of
  247. // memory.
  248. void ecall_identity_key_new(sgx_ec256_public_t *outpub,
  249. sgx_sealed_data_t *outsealedpriv)
  250. {
  251. sgx_ecc_state_handle_t ecc_handle;
  252. sgx_ecc256_open_context(&ecc_handle);
  253. sgx_ecc256_create_key_pair(&g_privkey, &g_pubkey, ecc_handle);
  254. memmove(outpub, &g_pubkey, sizeof(g_pubkey));
  255. sgx_ecc256_close_context(ecc_handle);
  256. sgx_seal_data(18, (const uint8_t*)"TEEMS Identity key",
  257. sizeof(g_privkey), (const uint8_t*)&g_privkey,
  258. SEALED_PRIVKEY_SIZE, outsealedpriv);
  259. }
  260. // Load an identity key from a sealed privkey. Output the resulting
  261. // public key. insealedpriv must point to sizeof(sgx_sealed_data_t) +
  262. // sizeof(sgx_ec256_private_t) bytes of memory. Returns true for
  263. // success, false for failure.
  264. bool ecall_identity_key_load(sgx_ec256_public_t *outpub,
  265. const sgx_sealed_data_t *insealedpriv)
  266. {
  267. sgx_ecc_state_handle_t ecc_handle;
  268. char aad[18];
  269. uint32_t aadsize = sizeof(aad);
  270. sgx_ec256_private_t privkey;
  271. uint32_t privkeysize = sizeof(privkey);
  272. sgx_status_t res = sgx_unseal_data(
  273. insealedpriv, (uint8_t*)aad, &aadsize,
  274. (uint8_t*)&privkey, &privkeysize);
  275. if (res || aadsize != sizeof(aad) || privkeysize != sizeof(privkey)
  276. || memcmp(aad, "TEEMS Identity key", sizeof(aad))) {
  277. return false;
  278. }
  279. sgx_ecc256_open_context(&ecc_handle);
  280. sgx_ec256_public_t pubkey;
  281. int valid;
  282. if (sgx_ecc256_calculate_pub_from_priv(&privkey, &pubkey) ||
  283. sgx_ecc256_check_point(&pubkey, ecc_handle, &valid) ||
  284. !valid) {
  285. sgx_ecc256_close_context(ecc_handle);
  286. return false;
  287. }
  288. sgx_ecc256_close_context(ecc_handle);
  289. memmove(&g_pubkey, &pubkey, sizeof(pubkey));
  290. memmove(&g_privkey, &privkey, sizeof(privkey));
  291. memmove(outpub, &pubkey, sizeof(pubkey));
  292. return true;
  293. }
  294. bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
  295. nodenum_t num_nodes, nodenum_t me)
  296. {
  297. sgx_ecc_state_handle_t ecc_handle;
  298. sgx_ecc256_open_context(&ecc_handle);
  299. commstates.clear();
  300. tot_nodes = 0;
  301. commstates.reserve(num_nodes);
  302. for (nodenum_t i=0; i<num_nodes; ++i) {
  303. // Check that the pubkey is valid
  304. int valid;
  305. if (sgx_ecc256_check_point(&apinodeconfigs[i].pubkey,
  306. ecc_handle, &valid) ||
  307. !valid) {
  308. printf("Pubkey for node %hu invalid\n", i);
  309. commstates.clear();
  310. sgx_ecc256_close_context(ecc_handle);
  311. return false;
  312. }
  313. commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
  314. }
  315. sgx_ecc256_close_context(ecc_handle);
  316. my_node_num = me;
  317. // Check that no one other than us has our pubkey (deals with
  318. // reflection attacks)
  319. for (nodenum_t i=0; i<num_nodes; ++i) {
  320. if (i == my_node_num) continue;
  321. if (!memcmp(&commstates[i].pubkey,
  322. &commstates[my_node_num].pubkey,
  323. sizeof(commstates[i].pubkey))) {
  324. printf("Pubkey %hu matches our own; possible reflection attack?\n",
  325. i);
  326. commstates.clear();
  327. return false;
  328. }
  329. }
  330. tot_nodes = num_nodes;
  331. // There will be an enclave-to-enclave channel between us and each
  332. // other node's enclave. For the node numbers smaller than ours, we
  333. // will be the server for the handshake for that channel. Prepare
  334. // to receive the first handshake message from those nodes'
  335. // enclaves.
  336. for (nodenum_t i=0; i<my_node_num; ++i) {
  337. commstates[i].in_msg_get_buf = default_in_msg_get_buf;
  338. commstates[i].in_msg_received = handshake_1_msg_received;
  339. }
  340. return true;
  341. }
  342. bool ecall_message(nodenum_t node_num, uint32_t message_len)
  343. {
  344. if (node_num >= tot_nodes) {
  345. printf("Out-of-range node_num %hu received in ecall_message\n",
  346. node_num);
  347. return false;
  348. }
  349. NodeCommState &nodest = commstates[node_num];
  350. if (nodest.in_msg_size != nodest.in_msg_offset) {
  351. printf("Received ecall_message without completing previous message\n");
  352. return false;
  353. }
  354. if (!nodest.in_msg_get_buf) {
  355. printf("No message header handler registered\n");
  356. return false;
  357. }
  358. uint8_t *buf = nodest.in_msg_get_buf(nodest, message_len);
  359. if (!buf) {
  360. printf("Message header handler returned NULL\n");
  361. return false;
  362. }
  363. nodest.in_msg_size = message_len;
  364. nodest.in_msg_offset = 0;
  365. nodest.in_msg_plaintext_processed = 0;
  366. nodest.in_msg_buf = buf;
  367. return true;
  368. }
  369. bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
  370. uint32_t chunklen)
  371. {
  372. if (node_num >= tot_nodes) {
  373. printf("Out-of-range node_num %hu received in ecall_chunk\n",
  374. node_num);
  375. return false;
  376. }
  377. NodeCommState &nodest = commstates[node_num];
  378. if (nodest.in_msg_size == nodest.in_msg_offset) {
  379. printf("Received ecall_chunk after completing message\n");
  380. return false;
  381. }
  382. if (!nodest.in_msg_buf) {
  383. printf("No incoming message buffer allocated\n");
  384. return false;
  385. }
  386. if (!nodest.in_msg_received) {
  387. printf("No message received handler registered\n");
  388. return false;
  389. }
  390. if (nodest.in_msg_offset + chunklen > nodest.in_msg_size) {
  391. printf("Chunk larger than remaining message size\n");
  392. return false;
  393. }
  394. if (nodest.handshake_step == HANDSHAKE_COMPLETE) {
  395. // Decrypt the incoming data
  396. *(size_t*)(nodest.in_aes_iv) += 1;
  397. if (sgx_rijndael128GCM_decrypt(&nodest.in_aes_key, chunkdata,
  398. chunklen - SGX_AESGCM_MAC_SIZE,
  399. nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
  400. nodest.in_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
  401. (const sgx_aes_gcm_128bit_tag_t *)
  402. (chunkdata + chunklen - SGX_AESGCM_MAC_SIZE))) {
  403. printf("Decryption failed\n");
  404. return false;
  405. }
  406. nodest.in_msg_plaintext_processed +=
  407. chunklen - SGX_AESGCM_MAC_SIZE;
  408. } else {
  409. // Just copy the handshake data
  410. memmove(nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
  411. chunkdata, chunklen);
  412. nodest.in_msg_plaintext_processed += chunklen;
  413. }
  414. nodest.in_msg_offset += chunklen;
  415. if (nodest.in_msg_offset == nodest.in_msg_size) {
  416. // This was the last chunk; handle the received message
  417. nodest.in_msg_received(nodest, nodest.in_msg_buf,
  418. nodest.in_msg_plaintext_processed, nodest.in_msg_size);
  419. }
  420. return true;
  421. }
  422. // Start the handshake (as the client)
  423. void NodeCommState::handshake_start()
  424. {
  425. sgx_ecc_state_handle_t ecc_handle;
  426. sgx_ecc256_open_context(&ecc_handle);
  427. // Create a DH keypair
  428. sgx_ecc256_create_key_pair(&handshake_privkey, &handshake_pubkey,
  429. ecc_handle);
  430. sgx_ecc256_close_context(ecc_handle);
  431. // Send the public key as the first message
  432. message_start(sizeof(handshake_pubkey));
  433. message_data((uint8_t*)&handshake_pubkey, sizeof(handshake_pubkey));
  434. }
  435. // Start all handshakes for which we are the client
  436. bool ecall_comms_start()
  437. {
  438. for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
  439. commstates[t].handshake_start();
  440. }
  441. return true;
  442. }