comms.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. #include <vector>
  2. #include <functional>
  3. #include <cstring>
  4. #include <pthread.h>
  5. #include "sgx_tcrypto.h"
  6. #include "sgx_tseal.h"
  7. #include "Enclave_t.h"
  8. #include "utils.hpp"
  9. #include "config.hpp"
  10. #include "route.hpp"
  11. #include "comms.hpp"
  12. // Our public and private identity keys
  13. static sgx_ec256_private_t g_privkey;
  14. static sgx_ec256_public_t g_pubkey;
  15. // The communication states for all the nodes. There's an entry for
  16. // ourselves in here, but it is unused.
  17. std::vector<NodeCommState> g_commstates;
  18. static nodenum_t tot_nodes, my_node_num;
  19. static class CompletedHandshakeCounter {
  20. // Mutex around completed_handshakes
  21. pthread_mutex_t mutex;
  22. // The number of completed handshakes
  23. nodenum_t completed_handshakes;
  24. // The callback pointer to use when all handshakes complete
  25. void *complete_handshake_cbpointer;
  26. public:
  27. CompletedHandshakeCounter() {
  28. pthread_mutex_init(&mutex, NULL);
  29. completed_handshakes = 0;
  30. complete_handshake_cbpointer = NULL;
  31. if (tot_nodes == 1) {
  32. // There's no one to handshake with, so we're already done
  33. pthread_mutex_lock(&mutex);
  34. void *cbpointer = complete_handshake_cbpointer;
  35. complete_handshake_cbpointer = NULL;
  36. pthread_mutex_unlock(&mutex);
  37. ocall_comms_ready(cbpointer);
  38. }
  39. }
  40. void reset(void *cbpointer) {
  41. pthread_mutex_lock(&mutex);
  42. completed_handshakes = 0;
  43. complete_handshake_cbpointer = cbpointer;
  44. if (tot_nodes == 1) {
  45. // There's no one to handshake with, so we're already done
  46. complete_handshake_cbpointer = NULL;
  47. pthread_mutex_unlock(&mutex);
  48. ocall_comms_ready(cbpointer);
  49. } else {
  50. pthread_mutex_unlock(&mutex);
  51. }
  52. }
  53. void inc() {
  54. pthread_mutex_lock(&mutex);
  55. ++completed_handshakes;
  56. nodenum_t num_completed = completed_handshakes;
  57. pthread_mutex_unlock(&mutex);
  58. if (num_completed == tot_nodes - 1) {
  59. pthread_mutex_lock(&mutex);
  60. void *cbpointer = complete_handshake_cbpointer;
  61. complete_handshake_cbpointer = NULL;
  62. completed_handshakes = 0;
  63. pthread_mutex_unlock(&mutex);
  64. ocall_comms_ready(cbpointer);
  65. }
  66. }
  67. } completed_handshake_counter;
  68. // A typical default in_msg_get_buf handler. It computes the maximum
  69. // possible size of the decrypted data, allocates that much memory, and
  70. // returns a pointer to it.
  71. uint8_t* default_in_msg_get_buf(NodeCommState &commst,
  72. uint32_t tot_enc_chunk_size)
  73. {
  74. uint32_t max_plaintext_bytes = tot_enc_chunk_size;
  75. // If the handshake is complete, chunks will be encrypted and have a
  76. // MAC tag attached which will not correspond to plaintext bytes, so
  77. // we can trim them.
  78. if (commst.handshake_step == HANDSHAKE_COMPLETE) {
  79. // The minimum number of chunks needed to transmit this message
  80. uint32_t min_num_chunks =
  81. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  82. // The maximum number of plaintext bytes this message could contain
  83. max_plaintext_bytes = tot_enc_chunk_size -
  84. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  85. }
  86. return new uint8_t[max_plaintext_bytes];
  87. }
  88. // An in_msg_received handler when we don't actually expect a message
  89. // from a given node at a given time.
  90. void unknown_in_msg_received(NodeCommState &nodest,
  91. uint8_t *data, uint32_t plaintext_len, uint32_t)
  92. {
  93. printf("Received unknown message of %u bytes from node %lu:\n",
  94. plaintext_len, nodest.node_num);
  95. for (uint32_t i=0;i<plaintext_len;++i) {
  96. printf("%02x", data[i]);
  97. }
  98. printf("\n");
  99. delete[] data;
  100. }
  101. static void handshake_1_msg_received(NodeCommState &nodest,
  102. uint8_t *data, uint32_t plaintext_len, uint32_t);
  103. static void handshake_2_msg_received(NodeCommState &nodest,
  104. uint8_t *data, uint32_t plaintext_len, uint32_t);
  105. static void handshake_3_msg_received(NodeCommState &nodest,
  106. uint8_t *data, uint32_t plaintext_len, uint32_t);
  107. // Receive (at the server) the first handshake message
  108. static void handshake_1_msg_received(NodeCommState &nodest,
  109. uint8_t *data, uint32_t plaintext_len, uint32_t)
  110. {
  111. /*
  112. printf("Received handshake_1 message of %u bytes:\n", plaintext_len);
  113. for (uint32_t i=0;i<plaintext_len;++i) {
  114. printf("%02x", data[i]);
  115. }
  116. printf("\n");
  117. */
  118. if (plaintext_len != sizeof(sgx_ec256_public_t)) {
  119. printf("Received handshake_1 message of incorrect size %u\n",
  120. plaintext_len);
  121. return;
  122. }
  123. sgx_ecc_state_handle_t ecc_handle;
  124. sgx_ec256_public_t peer_dh_pubkey;
  125. memmove(&peer_dh_pubkey, data, sizeof(peer_dh_pubkey));
  126. delete[] data;
  127. sgx_ecc256_open_context(&ecc_handle);
  128. int valid;
  129. if (sgx_ecc256_check_point(&peer_dh_pubkey, ecc_handle, &valid)
  130. || !valid) {
  131. printf("Invalid public key received from node %hu\n",
  132. nodest.node_num);
  133. sgx_ecc256_close_context(ecc_handle);
  134. return;
  135. }
  136. printf("Valid public key received from node %hu\n", nodest.node_num);
  137. // Create our own DH key pair
  138. sgx_ec256_public_t our_dh_pubkey;
  139. sgx_ec256_private_t our_dh_privkey;
  140. sgx_ecc256_create_key_pair(&our_dh_privkey, &our_dh_pubkey, ecc_handle);
  141. // Construct the shared secret
  142. sgx_ec256_dh_shared_t sharedsecret;
  143. sgx_ecc256_compute_shared_dhkey(&our_dh_privkey, &peer_dh_pubkey,
  144. &sharedsecret, ecc_handle);
  145. memset(&our_dh_privkey, 0, sizeof(our_dh_privkey));
  146. // Compute H1(sharedsecret) and H2(sharedsecret)
  147. sgx_sha_state_handle_t sha_handle;
  148. sgx_sha256_hash_t h1, h2;
  149. sgx_sha256_init(&sha_handle);
  150. sgx_sha256_update((const uint8_t*)"\x01", 1, sha_handle);
  151. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  152. sha_handle);
  153. sgx_sha256_get_hash(sha_handle, &h1);
  154. sgx_sha256_close(sha_handle);
  155. sgx_sha256_init(&sha_handle);
  156. sgx_sha256_update((const uint8_t*)"\x02", 1, sha_handle);
  157. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  158. sha_handle);
  159. sgx_sha256_get_hash(sha_handle, &h2);
  160. sgx_sha256_close(sha_handle);
  161. // Compute the server-to-client MAC
  162. sgx_hmac_state_handle_t hmac_handle;
  163. uint8_t srv_cli_mac[16];
  164. sgx_hmac256_init(h1, 16, &hmac_handle);
  165. sgx_hmac256_update((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
  166. hmac_handle);
  167. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  168. hmac_handle);
  169. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  170. hmac_handle);
  171. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  172. hmac_handle);
  173. sgx_hmac256_final(srv_cli_mac, 16, hmac_handle);
  174. sgx_hmac256_close(hmac_handle);
  175. // Compute the client-to-server MAC
  176. uint8_t cli_srv_mac[16];
  177. sgx_hmac256_init(((uint8_t*)h1)+16, 16, &hmac_handle);
  178. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  179. hmac_handle);
  180. sgx_hmac256_update((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
  181. hmac_handle);
  182. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  183. hmac_handle);
  184. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  185. hmac_handle);
  186. sgx_hmac256_final(cli_srv_mac, 16, hmac_handle);
  187. sgx_hmac256_close(hmac_handle);
  188. // Sign the server-to-client MAC
  189. sgx_ec256_signature_t srv_cli_sig;
  190. sgx_ecdsa_sign(srv_cli_mac, 16, &g_privkey, &srv_cli_sig, ecc_handle);
  191. sgx_ecc256_close_context(ecc_handle);
  192. // Save the state we'll need to process handshake message 3
  193. memmove(&nodest.in_aes_key, h2, 16);
  194. memmove(&nodest.out_aes_key, ((uint8_t*)h2)+16, 16);
  195. memmove(&nodest.handshake_cli_srv_mac, cli_srv_mac, 16);
  196. // Get us ready to receive handshake message 3
  197. nodest.in_msg_get_buf = default_in_msg_get_buf;
  198. nodest.in_msg_received = handshake_3_msg_received;
  199. nodest.handshake_step = HANDSHAKE_S_SENT_2;
  200. // Send handshake message 2
  201. nodest.message_start(sizeof(our_dh_pubkey) + sizeof(srv_cli_sig),
  202. false);
  203. nodest.message_data((const uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
  204. false);
  205. nodest.message_data((const uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
  206. false);
  207. }
  208. // Receive (at the client) the second handshake message
  209. static void handshake_2_msg_received(NodeCommState &nodest,
  210. uint8_t *data, uint32_t plaintext_len, uint32_t)
  211. {
  212. /*
  213. printf("Received handshake_2 message of %u bytes:\n", plaintext_len);
  214. for (uint32_t i=0;i<plaintext_len;++i) {
  215. printf("%02x", data[i]);
  216. }
  217. printf("\n");
  218. */
  219. if (plaintext_len != sizeof(sgx_ec256_public_t) +
  220. sizeof(sgx_ec256_signature_t)) {
  221. printf("Received handshake_2 message of incorrect size %u\n",
  222. plaintext_len);
  223. return;
  224. }
  225. sgx_ecc_state_handle_t ecc_handle;
  226. sgx_ec256_public_t peer_dh_pubkey;
  227. sgx_ec256_signature_t peer_sig;
  228. memmove(&peer_dh_pubkey, data, sizeof(peer_dh_pubkey));
  229. memmove(&peer_sig, data+sizeof(peer_dh_pubkey), sizeof(peer_sig));
  230. delete[] data;
  231. sgx_ecc256_open_context(&ecc_handle);
  232. int valid;
  233. if (sgx_ecc256_check_point(&peer_dh_pubkey, ecc_handle, &valid)
  234. || !valid) {
  235. printf("Invalid public key received from node %hu\n",
  236. nodest.node_num);
  237. sgx_ecc256_close_context(ecc_handle);
  238. return;
  239. }
  240. // Construct the shared secret
  241. sgx_ec256_dh_shared_t sharedsecret;
  242. sgx_ecc256_compute_shared_dhkey(&nodest.handshake_dh_privkey,
  243. &peer_dh_pubkey, &sharedsecret, ecc_handle);
  244. memset(&nodest.handshake_dh_privkey, 0,
  245. sizeof(nodest.handshake_dh_privkey));
  246. // Compute H1(sharedsecret) and H2(sharedsecret)
  247. sgx_sha_state_handle_t sha_handle;
  248. sgx_sha256_hash_t h1, h2;
  249. sgx_sha256_init(&sha_handle);
  250. sgx_sha256_update((const uint8_t*)"\x01", 1, sha_handle);
  251. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  252. sha_handle);
  253. sgx_sha256_get_hash(sha_handle, &h1);
  254. sgx_sha256_close(sha_handle);
  255. sgx_sha256_init(&sha_handle);
  256. sgx_sha256_update((const uint8_t*)"\x02", 1, sha_handle);
  257. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  258. sha_handle);
  259. sgx_sha256_get_hash(sha_handle, &h2);
  260. sgx_sha256_close(sha_handle);
  261. // Compute the server-to-client MAC
  262. sgx_hmac_state_handle_t hmac_handle;
  263. uint8_t srv_cli_mac[16];
  264. sgx_hmac256_init(h1, 16, &hmac_handle);
  265. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  266. hmac_handle);
  267. sgx_hmac256_update((uint8_t*)&nodest.handshake_dh_pubkey,
  268. sizeof(nodest.handshake_dh_pubkey), hmac_handle);
  269. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  270. hmac_handle);
  271. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  272. hmac_handle);
  273. sgx_hmac256_final(srv_cli_mac, 16, hmac_handle);
  274. sgx_hmac256_close(hmac_handle);
  275. // Compute the client-to-server MAC
  276. uint8_t cli_srv_mac[16];
  277. sgx_hmac256_init(((uint8_t*)h1)+16, 16, &hmac_handle);
  278. sgx_hmac256_update((uint8_t*)&nodest.handshake_dh_pubkey,
  279. sizeof(nodest.handshake_dh_pubkey), hmac_handle);
  280. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  281. hmac_handle);
  282. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  283. hmac_handle);
  284. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  285. hmac_handle);
  286. sgx_hmac256_final(cli_srv_mac, 16, hmac_handle);
  287. sgx_hmac256_close(hmac_handle);
  288. // Verify the signature on the server-to-client MAC
  289. uint8_t result;
  290. if (sgx_ecdsa_verify(srv_cli_mac, 16, &nodest.pubkey, &peer_sig,
  291. &result, ecc_handle) || result != SGX_EC_VALID) {
  292. printf("Invalid signature received from node %hu\n",
  293. nodest.node_num);
  294. sgx_ecc256_close_context(ecc_handle);
  295. return;
  296. }
  297. printf("Valid signature received from node %hu\n", nodest.node_num);
  298. // Sign the client-to-server MAC
  299. sgx_ec256_signature_t cli_srv_sig;
  300. sgx_ecdsa_sign(cli_srv_mac, 16, &g_privkey, &cli_srv_sig, ecc_handle);
  301. sgx_ecc256_close_context(ecc_handle);
  302. // Our side of the handshake is complete
  303. memmove(&nodest.out_aes_key, h2, 16);
  304. memmove(&nodest.in_aes_key, ((uint8_t*)h2)+16, 16);
  305. memset(&nodest.out_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  306. memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  307. nodest.handshake_step = HANDSHAKE_COMPLETE;
  308. nodest.in_msg_get_buf = default_in_msg_get_buf;
  309. nodest.in_msg_received = unknown_in_msg_received;
  310. // Send handshake message 3
  311. nodest.message_start(sizeof(cli_srv_sig), false);
  312. nodest.message_data((const uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
  313. false);
  314. // Set the received message handler for routing
  315. route_init_msg_handler(nodest.node_num);
  316. // Mark the handshake as complete
  317. completed_handshake_counter.inc();
  318. }
  319. static void handshake_3_msg_received(NodeCommState &nodest,
  320. uint8_t *data, uint32_t plaintext_len, uint32_t)
  321. {
  322. /*
  323. printf("Received handshake_3 message of %u bytes:\n", plaintext_len);
  324. for (uint32_t i=0;i<plaintext_len;++i) {
  325. printf("%02x", data[i]);
  326. }
  327. printf("\n");
  328. */
  329. if (plaintext_len != sizeof(sgx_ec256_signature_t)) {
  330. printf("Received handshake_3 message of incorrect size %u\n",
  331. plaintext_len);
  332. return;
  333. }
  334. sgx_ecc_state_handle_t ecc_handle;
  335. sgx_ec256_signature_t peer_sig;
  336. memmove(&peer_sig, data, sizeof(peer_sig));
  337. delete[] data;
  338. sgx_ecc256_open_context(&ecc_handle);
  339. // Verify the signature on the client-to-server MAC
  340. uint8_t result;
  341. if (sgx_ecdsa_verify(nodest.handshake_cli_srv_mac, 16,
  342. &nodest.pubkey, &peer_sig, &result, ecc_handle)
  343. || result != SGX_EC_VALID) {
  344. printf("Invalid signature received from node %hu\n",
  345. nodest.node_num);
  346. sgx_ecc256_close_context(ecc_handle);
  347. return;
  348. }
  349. printf("Valid signature received from node %hu\n", nodest.node_num);
  350. // Our side of the handshake is complete
  351. memset(&nodest.out_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  352. memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  353. nodest.handshake_step = HANDSHAKE_COMPLETE;
  354. nodest.in_msg_get_buf = default_in_msg_get_buf;
  355. nodest.in_msg_received = unknown_in_msg_received;
  356. // Set the received message handler for routing
  357. route_init_msg_handler(nodest.node_num);
  358. // Mark the handshake as complete
  359. completed_handshake_counter.inc();
  360. }
  361. // Start a new outgoing message. Pass the number of _plaintext_ bytes
  362. // the message will be.
  363. void NodeCommState::message_start(uint32_t plaintext_len, bool encrypt)
  364. {
  365. uint32_t ciphertext_len = plaintext_len;
  366. // If the handshake is complete, add SGX_AESGCM_MAC_SIZE bytes for
  367. // every FRAME_SIZE-SGX_AESGCM_MAC_SIZE bytes of plaintext.
  368. if (encrypt) {
  369. uint32_t num_chunks = (plaintext_len +
  370. FRAME_SIZE - SGX_AESGCM_MAC_SIZE - 1) /
  371. (FRAME_SIZE - SGX_AESGCM_MAC_SIZE);
  372. ciphertext_len = plaintext_len +
  373. num_chunks * SGX_AESGCM_MAC_SIZE;
  374. }
  375. ocall_message(&frame, node_num, ciphertext_len);
  376. frame_offset = 0;
  377. msg_size = ciphertext_len;
  378. msg_frame_offset = 0;
  379. msg_plaintext_size = plaintext_len;
  380. msg_plaintext_processed = 0;
  381. if (plaintext_len < FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
  382. msg_plaintext_chunk_remain = plaintext_len;
  383. } else {
  384. msg_plaintext_chunk_remain = FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
  385. }
  386. if (!frame) {
  387. printf("Received NULL back from ocall_message\n");
  388. }
  389. if (msg_plaintext_chunk_remain > 0) {
  390. if (encrypt) {
  391. *(size_t*)out_aes_iv += 1;
  392. sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
  393. SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
  394. }
  395. }
  396. }
  397. // Process len bytes of plaintext data into the current message.
  398. void NodeCommState::message_data(const uint8_t *data, uint32_t len, bool encrypt)
  399. {
  400. while (len > 0) {
  401. if (msg_plaintext_chunk_remain == 0) {
  402. printf("Attempt to queue too much message data\n");
  403. return;
  404. }
  405. uint32_t bytes_to_process = len;
  406. if (bytes_to_process > msg_plaintext_chunk_remain) {
  407. bytes_to_process = msg_plaintext_chunk_remain;
  408. }
  409. if (frame == NULL) {
  410. printf("frame is NULL when queueing message data\n");
  411. return;
  412. }
  413. if (encrypt) {
  414. // Encrypt the data
  415. sgx_aes_gcm128_enc_update(const_cast<uint8_t*>(data),
  416. bytes_to_process, frame+frame_offset, out_aes_gcm_state);
  417. } else {
  418. // Just copy the plaintext data during the handshake
  419. memmove(frame+frame_offset, data, bytes_to_process);
  420. }
  421. frame_offset += bytes_to_process;
  422. msg_plaintext_processed += bytes_to_process;
  423. msg_plaintext_chunk_remain -= bytes_to_process;
  424. len -= bytes_to_process;
  425. data += bytes_to_process;
  426. if (msg_plaintext_chunk_remain == 0) {
  427. // Complete and send this chunk
  428. if (encrypt) {
  429. sgx_aes_gcm128_enc_get_mac(frame+frame_offset,
  430. out_aes_gcm_state);
  431. frame_offset += SGX_AESGCM_MAC_SIZE;
  432. }
  433. uint8_t *nextframe = NULL;
  434. ocall_chunk(&nextframe, node_num, frame, frame_offset);
  435. frame = nextframe;
  436. msg_frame_offset += frame_offset;
  437. frame_offset = 0;
  438. msg_plaintext_chunk_remain =
  439. msg_plaintext_size - msg_plaintext_processed;
  440. if (msg_plaintext_chunk_remain >
  441. FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
  442. msg_plaintext_chunk_remain =
  443. FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
  444. }
  445. if (encrypt) {
  446. sgx_aes_gcm_close(out_aes_gcm_state);
  447. if (msg_plaintext_chunk_remain > 0) {
  448. *(size_t*)out_aes_iv += 1;
  449. sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
  450. SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
  451. }
  452. }
  453. }
  454. }
  455. }
  456. // Generate a new identity signature key. Output the public key and the
  457. // sealed private key. outsealedpriv must point to SEALEDPRIVKEY_SIZE =
  458. // sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of
  459. // memory.
  460. void ecall_identity_key_new(sgx_ec256_public_t *outpub,
  461. sgx_sealed_data_t *outsealedpriv)
  462. {
  463. sgx_ecc_state_handle_t ecc_handle;
  464. sgx_ecc256_open_context(&ecc_handle);
  465. sgx_ecc256_create_key_pair(&g_privkey, &g_pubkey, ecc_handle);
  466. memmove(outpub, &g_pubkey, sizeof(g_pubkey));
  467. sgx_ecc256_close_context(ecc_handle);
  468. sgx_seal_data(18, (const uint8_t*)"TEEMS Identity key",
  469. sizeof(g_privkey), (const uint8_t*)&g_privkey,
  470. SEALED_PRIVKEY_SIZE, outsealedpriv);
  471. }
  472. // Load an identity key from a sealed privkey. Output the resulting
  473. // public key. insealedpriv must point to sizeof(sgx_sealed_data_t) +
  474. // sizeof(sgx_ec256_private_t) bytes of memory. Returns true for
  475. // success, false for failure.
  476. bool ecall_identity_key_load(sgx_ec256_public_t *outpub,
  477. const sgx_sealed_data_t *insealedpriv)
  478. {
  479. sgx_ecc_state_handle_t ecc_handle;
  480. char aad[18];
  481. uint32_t aadsize = sizeof(aad);
  482. sgx_ec256_private_t privkey;
  483. uint32_t privkeysize = sizeof(privkey);
  484. sgx_status_t res = sgx_unseal_data(
  485. insealedpriv, (uint8_t*)aad, &aadsize,
  486. (uint8_t*)&privkey, &privkeysize);
  487. if (res || aadsize != sizeof(aad) || privkeysize != sizeof(privkey)
  488. || memcmp(aad, "TEEMS Identity key", sizeof(aad))) {
  489. return false;
  490. }
  491. sgx_ecc256_open_context(&ecc_handle);
  492. sgx_ec256_public_t pubkey;
  493. int valid;
  494. if (sgx_ecc256_calculate_pub_from_priv(&privkey, &pubkey) ||
  495. sgx_ecc256_check_point(&pubkey, ecc_handle, &valid) ||
  496. !valid) {
  497. sgx_ecc256_close_context(ecc_handle);
  498. return false;
  499. }
  500. sgx_ecc256_close_context(ecc_handle);
  501. memmove(&g_pubkey, &pubkey, sizeof(pubkey));
  502. memmove(&g_privkey, &privkey, sizeof(privkey));
  503. memmove(outpub, &pubkey, sizeof(pubkey));
  504. return true;
  505. }
  506. bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
  507. nodenum_t num_nodes, nodenum_t me)
  508. {
  509. sgx_ecc_state_handle_t ecc_handle;
  510. sgx_ecc256_open_context(&ecc_handle);
  511. g_commstates.clear();
  512. tot_nodes = 0;
  513. g_commstates.reserve(num_nodes);
  514. for (nodenum_t i=0; i<num_nodes; ++i) {
  515. // Check that the pubkey is valid
  516. int valid;
  517. if (sgx_ecc256_check_point(&apinodeconfigs[i].pubkey,
  518. ecc_handle, &valid) ||
  519. !valid) {
  520. printf("Pubkey for node %hu invalid\n", i);
  521. g_commstates.clear();
  522. sgx_ecc256_close_context(ecc_handle);
  523. return false;
  524. }
  525. g_commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
  526. }
  527. sgx_ecc256_close_context(ecc_handle);
  528. my_node_num = me;
  529. // Check that no one other than us has our pubkey (deals with
  530. // reflection attacks)
  531. for (nodenum_t i=0; i<num_nodes; ++i) {
  532. if (i == my_node_num) continue;
  533. if (!memcmp(&g_commstates[i].pubkey,
  534. &g_commstates[my_node_num].pubkey,
  535. sizeof(g_commstates[i].pubkey))) {
  536. printf("Pubkey %hu matches our own; possible reflection attack?\n",
  537. i);
  538. g_commstates.clear();
  539. return false;
  540. }
  541. }
  542. tot_nodes = num_nodes;
  543. // There will be an enclave-to-enclave channel between us and each
  544. // other node's enclave. For the node numbers smaller than ours, we
  545. // will be the server for the handshake for that channel. Prepare
  546. // to receive the first handshake message from those nodes'
  547. // enclaves.
  548. for (nodenum_t i=0; i<my_node_num; ++i) {
  549. g_commstates[i].in_msg_get_buf = default_in_msg_get_buf;
  550. g_commstates[i].in_msg_received = handshake_1_msg_received;
  551. }
  552. return true;
  553. }
  554. bool ecall_message(nodenum_t node_num, uint32_t message_len)
  555. {
  556. if (node_num >= tot_nodes) {
  557. printf("Out-of-range node_num %hu received in ecall_message\n",
  558. node_num);
  559. return false;
  560. }
  561. NodeCommState &nodest = g_commstates[node_num];
  562. if (nodest.in_msg_size != nodest.in_msg_offset) {
  563. printf("Received ecall_message without completing previous message\n");
  564. return false;
  565. }
  566. if (!nodest.in_msg_get_buf) {
  567. printf("No message header handler registered\n");
  568. return false;
  569. }
  570. uint8_t *buf = nodest.in_msg_get_buf(nodest, message_len);
  571. if (!buf) {
  572. printf("Message header handler returned NULL\n");
  573. return false;
  574. }
  575. nodest.in_msg_size = message_len;
  576. nodest.in_msg_offset = 0;
  577. nodest.in_msg_plaintext_processed = 0;
  578. nodest.in_msg_buf = buf;
  579. // Just in case message_len == 0
  580. if (nodest.in_msg_offset == nodest.in_msg_size) {
  581. // This was the last chunk; handle the received message
  582. uint32_t plaintext_processed = nodest.in_msg_plaintext_processed;
  583. uint32_t msg_size = nodest.in_msg_size;
  584. nodest.in_msg_buf = NULL;
  585. nodest.in_msg_size = 0;
  586. nodest.in_msg_offset = 0;
  587. nodest.in_msg_plaintext_processed = 0;
  588. nodest.in_msg_received(nodest, buf, plaintext_processed, msg_size);
  589. }
  590. return true;
  591. }
  592. bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
  593. uint32_t chunklen)
  594. {
  595. if (node_num >= tot_nodes) {
  596. printf("Out-of-range node_num %hu received in ecall_chunk\n",
  597. node_num);
  598. return false;
  599. }
  600. NodeCommState &nodest = g_commstates[node_num];
  601. if (nodest.in_msg_size == nodest.in_msg_offset) {
  602. printf("Received ecall_chunk after completing message\n");
  603. return false;
  604. }
  605. if (!nodest.in_msg_buf) {
  606. printf("No incoming message buffer allocated\n");
  607. return false;
  608. }
  609. if (!nodest.in_msg_received) {
  610. printf("No message received handler registered\n");
  611. return false;
  612. }
  613. if (nodest.in_msg_offset + chunklen > nodest.in_msg_size) {
  614. printf("Chunk larger than remaining message size\n");
  615. return false;
  616. }
  617. if (nodest.handshake_step == HANDSHAKE_COMPLETE) {
  618. // Decrypt the incoming data
  619. *(size_t*)(nodest.in_aes_iv) += 1;
  620. if (sgx_rijndael128GCM_decrypt(&nodest.in_aes_key, chunkdata,
  621. chunklen - SGX_AESGCM_MAC_SIZE,
  622. nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
  623. nodest.in_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
  624. (const sgx_aes_gcm_128bit_tag_t *)
  625. (chunkdata + chunklen - SGX_AESGCM_MAC_SIZE))) {
  626. printf("Decryption failed\n");
  627. return false;
  628. }
  629. nodest.in_msg_plaintext_processed +=
  630. chunklen - SGX_AESGCM_MAC_SIZE;
  631. } else {
  632. // Just copy the handshake data
  633. memmove(nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
  634. chunkdata, chunklen);
  635. nodest.in_msg_plaintext_processed += chunklen;
  636. }
  637. nodest.in_msg_offset += chunklen;
  638. if (nodest.in_msg_offset == nodest.in_msg_size) {
  639. // This was the last chunk; handle the received message
  640. uint8_t* buf = nodest.in_msg_buf;
  641. uint32_t plaintext_processed = nodest.in_msg_plaintext_processed;
  642. uint32_t msg_size = nodest.in_msg_size;
  643. nodest.in_msg_buf = NULL;
  644. nodest.in_msg_size = 0;
  645. nodest.in_msg_offset = 0;
  646. nodest.in_msg_plaintext_processed = 0;
  647. nodest.in_msg_received(nodest, buf, plaintext_processed, msg_size);
  648. }
  649. return true;
  650. }
  651. // Start the handshake (as the client)
  652. void NodeCommState::handshake_start()
  653. {
  654. sgx_ecc_state_handle_t ecc_handle;
  655. sgx_ecc256_open_context(&ecc_handle);
  656. // Create a DH keypair
  657. sgx_ecc256_create_key_pair(&handshake_dh_privkey, &handshake_dh_pubkey,
  658. ecc_handle);
  659. sgx_ecc256_close_context(ecc_handle);
  660. // Get us ready to receive handshake message 2
  661. in_msg_get_buf = default_in_msg_get_buf;
  662. in_msg_received = handshake_2_msg_received;
  663. handshake_step = HANDSHAKE_C_SENT_1;
  664. // Send the public key as the first message
  665. message_start(sizeof(handshake_dh_pubkey), false);
  666. message_data((const uint8_t*)&handshake_dh_pubkey,
  667. sizeof(handshake_dh_pubkey), false);
  668. }
  669. // Start all handshakes for which we are the client. Call
  670. // ocall_comms_ready(cbpointer) when the handshakes with all other nodes
  671. // (for which we are client or server) are complete.
  672. bool ecall_comms_start(void *cbpointer)
  673. {
  674. completed_handshake_counter.reset(cbpointer);
  675. for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
  676. g_commstates[t].handshake_start();
  677. }
  678. return true;
  679. }