comms.cpp 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. #include <vector>
  2. #include <functional>
  3. #include <cstring>
  4. #include <pthread.h>
  5. #include "sgx_tcrypto.h"
  6. #include "sgx_tseal.h"
  7. #include "Enclave_t.h"
  8. #include "utils.hpp"
  9. #include "config.hpp"
  10. // Our public and private identity keys
  11. static sgx_ec256_private_t g_privkey;
  12. static sgx_ec256_public_t g_pubkey;
  13. // What step of the handshake are we on?
  14. enum HandshakeStep {
  15. HANDSHAKE_NONE,
  16. HANDSHAKE_C_SENT_1,
  17. HANDSHAKE_S_SENT_2,
  18. HANDSHAKE_COMPLETE
  19. };
  20. // Communication state for a node
  21. struct NodeCommState {
  22. sgx_ec256_public_t pubkey;
  23. nodenum_t node_num;
  24. HandshakeStep handshake_step;
  25. // Our DH keypair during the handshake
  26. sgx_ec256_private_t handshake_dh_privkey;
  27. sgx_ec256_public_t handshake_dh_pubkey;
  28. // The server keeps this state between handshake messages 1 and 3
  29. uint8_t handshake_cli_srv_mac[16];
  30. // The outgoing and incoming AES keys after the handshake
  31. sgx_aes_gcm_128bit_key_t out_aes_key, in_aes_key;
  32. // The outgoing and incoming IV counters
  33. uint8_t out_aes_iv[SGX_AESGCM_IV_SIZE];
  34. uint8_t in_aes_iv[SGX_AESGCM_IV_SIZE];
  35. // The GCM state for incrementally building each outgoing chunk
  36. sgx_aes_state_handle_t out_aes_gcm_state;
  37. // The current outgoing frame and the current offset into it
  38. uint8_t *frame;
  39. uint32_t frame_offset;
  40. // The current outgoing message ciphertext size and the offset into
  41. // it of the start of the current frame
  42. uint32_t msg_size;
  43. uint32_t msg_frame_offset;
  44. // The current outgoing message plaintext size, how many plaintext
  45. // bytes we've already processed with message_data, and how many
  46. // plaintext bytes remain for the current chunk
  47. uint32_t msg_plaintext_size;
  48. uint32_t msg_plaintext_processed;
  49. uint32_t msg_plaintext_chunk_remain;
  50. // The current incoming message ciphertext size and the offset into
  51. // it of all previous chunks of this message
  52. uint32_t in_msg_size;
  53. uint32_t in_msg_offset;
  54. // The current incoming message number of plaintext bytes processed
  55. uint32_t in_msg_plaintext_processed;
  56. // The internal buffer where we're storing the (decrypted) message
  57. uint8_t *in_msg_buf;
  58. // The function to call when a new incoming message header arrives.
  59. // This function should return a pointer to enough memory to hold
  60. // the (decrypted) chunks of the message. Remember that the length
  61. // passed here is the total size of the _encrypted_ chunks. This
  62. // function should not itself modify the in_msg_size, in_msg_offset,
  63. // or in_msg_buf members. This function will usually allocate an
  64. // appropriate amount of memory and return the pointer to it, but
  65. // may do other things, like return a pointer to the middle of a
  66. // previously allocated region of memory.
  67. std::function<uint8_t*(NodeCommState&,uint32_t)> in_msg_get_buf;
  68. // The function to call after the last chunk of a message has been
  69. // received. If in_msg_get_buf allocated memory, this function
  70. // should deallocate it. in_msg_size, in_msg_offset,
  71. // in_msg_plaintext_processed, and in_msg_buf will already have been
  72. // reset when this function is called. The uint32_t that is passed
  73. // are the total size of the _decrypted_ data and the original total
  74. // size of the _encrypted_ chunks that was passed to in_msg_get_buf.
  75. std::function<void(NodeCommState&,uint8_t*,uint32_t,uint32_t)>
  76. in_msg_received;
  77. NodeCommState(const sgx_ec256_public_t* conf_pubkey, nodenum_t i) :
  78. node_num(i), handshake_step(HANDSHAKE_NONE),
  79. out_aes_gcm_state(NULL), frame(NULL),
  80. frame_offset(0), msg_size(0), msg_frame_offset(0),
  81. msg_plaintext_size(0), msg_plaintext_processed(0),
  82. msg_plaintext_chunk_remain(0),
  83. in_msg_size(0), in_msg_offset(0),
  84. in_msg_plaintext_processed(0), in_msg_buf(NULL),
  85. in_msg_get_buf(NULL), in_msg_received(NULL) {
  86. memmove(&pubkey, conf_pubkey, sizeof(pubkey));
  87. }
  88. void message_start(uint32_t plaintext_len, bool encrypt=true);
  89. void message_data(uint8_t *data, uint32_t len, bool encrypt=true);
  90. // Start the handshake (as the client)
  91. void handshake_start();
  92. };
  93. // The communication states for all the nodes. There's an entry for
  94. // ourselves in here, but it is unused.
  95. static std::vector<NodeCommState> commstates;
  96. static nodenum_t tot_nodes, my_node_num;
  97. static class CompletedHandshakeCounter {
  98. // Mutex around completed_handshakes
  99. pthread_mutex_t mutex;
  100. // The number of completed handshakes
  101. nodenum_t completed_handshakes;
  102. // The callback pointer to use when all handshakes complete
  103. void *complete_handshake_cbpointer;
  104. public:
  105. CompletedHandshakeCounter() {
  106. pthread_mutex_init(&mutex, NULL);
  107. completed_handshakes = 0;
  108. complete_handshake_cbpointer = NULL;
  109. }
  110. void reset(void *cbpointer) {
  111. pthread_mutex_lock(&mutex);
  112. completed_handshakes = 0;
  113. complete_handshake_cbpointer = cbpointer;
  114. pthread_mutex_unlock(&mutex);
  115. }
  116. void inc() {
  117. pthread_mutex_lock(&mutex);
  118. ++completed_handshakes;
  119. nodenum_t num_completed = completed_handshakes;
  120. pthread_mutex_unlock(&mutex);
  121. if (num_completed == tot_nodes - 1) {
  122. pthread_mutex_lock(&mutex);
  123. void *cbpointer = complete_handshake_cbpointer;
  124. complete_handshake_cbpointer = NULL;
  125. completed_handshakes = 0;
  126. pthread_mutex_unlock(&mutex);
  127. ocall_comms_ready(cbpointer);
  128. }
  129. }
  130. } completed_handshake_counter;
  131. // A typical default in_msg_get_buf handler. It computes the maximum
  132. // possible size of the decrypted data, allocates that much memory, and
  133. // returns a pointer to it.
  134. static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
  135. uint32_t tot_enc_chunk_size)
  136. {
  137. uint32_t max_plaintext_bytes = tot_enc_chunk_size;
  138. // If the handshake is complete, chunks will be encrypted and have a
  139. // MAC tag attached which will not correspond to plaintext bytes, so
  140. // we can trim them.
  141. if (commst.handshake_step == HANDSHAKE_COMPLETE) {
  142. // The minimum number of chunks needed to transmit this message
  143. uint32_t min_num_chunks =
  144. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  145. // The maximum number of plaintext bytes this message could contain
  146. max_plaintext_bytes = tot_enc_chunk_size -
  147. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  148. }
  149. return new uint8_t[max_plaintext_bytes];
  150. }
  151. static void default_in_msg_received(NodeCommState &nodest,
  152. uint8_t *data, uint32_t plaintext_len, uint32_t)
  153. {
  154. printf("Received message of %u bytes from node %lu:\n",
  155. plaintext_len, nodest.node_num);
  156. for (uint32_t i=0;i<plaintext_len;++i) {
  157. printf("%02x", data[i]);
  158. }
  159. printf("\n");
  160. delete[] data;
  161. }
  162. static void handshake_1_msg_received(NodeCommState &nodest,
  163. uint8_t *data, uint32_t plaintext_len, uint32_t);
  164. static void handshake_2_msg_received(NodeCommState &nodest,
  165. uint8_t *data, uint32_t plaintext_len, uint32_t);
  166. static void handshake_3_msg_received(NodeCommState &nodest,
  167. uint8_t *data, uint32_t plaintext_len, uint32_t);
  168. // Receive (at the server) the first handshake message
  169. static void handshake_1_msg_received(NodeCommState &nodest,
  170. uint8_t *data, uint32_t plaintext_len, uint32_t)
  171. {
  172. /*
  173. printf("Received handshake_1 message of %u bytes:\n", plaintext_len);
  174. for (uint32_t i=0;i<plaintext_len;++i) {
  175. printf("%02x", data[i]);
  176. }
  177. printf("\n");
  178. */
  179. if (plaintext_len != sizeof(sgx_ec256_public_t)) {
  180. printf("Received handshake_1 message of incorrect size %u\n",
  181. plaintext_len);
  182. return;
  183. }
  184. sgx_ecc_state_handle_t ecc_handle;
  185. sgx_ec256_public_t peer_dh_pubkey;
  186. memmove(&peer_dh_pubkey, data, sizeof(peer_dh_pubkey));
  187. delete[] data;
  188. sgx_ecc256_open_context(&ecc_handle);
  189. int valid;
  190. if (sgx_ecc256_check_point(&peer_dh_pubkey, ecc_handle, &valid)
  191. || !valid) {
  192. printf("Invalid public key received from node %hu\n",
  193. nodest.node_num);
  194. sgx_ecc256_close_context(ecc_handle);
  195. return;
  196. }
  197. printf("Valid public key received from node %hu\n", nodest.node_num);
  198. // Create our own DH key pair
  199. sgx_ec256_public_t our_dh_pubkey;
  200. sgx_ec256_private_t our_dh_privkey;
  201. sgx_ecc256_create_key_pair(&our_dh_privkey, &our_dh_pubkey, ecc_handle);
  202. // Construct the shared secret
  203. sgx_ec256_dh_shared_t sharedsecret;
  204. sgx_ecc256_compute_shared_dhkey(&our_dh_privkey, &peer_dh_pubkey,
  205. &sharedsecret, ecc_handle);
  206. memset(&our_dh_privkey, 0, sizeof(our_dh_privkey));
  207. // Compute H1(sharedsecret) and H2(sharedsecret)
  208. sgx_sha_state_handle_t sha_handle;
  209. sgx_sha256_hash_t h1, h2;
  210. sgx_sha256_init(&sha_handle);
  211. sgx_sha256_update((const uint8_t*)"\x01", 1, sha_handle);
  212. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  213. sha_handle);
  214. sgx_sha256_get_hash(sha_handle, &h1);
  215. sgx_sha256_close(sha_handle);
  216. sgx_sha256_init(&sha_handle);
  217. sgx_sha256_update((const uint8_t*)"\x02", 1, sha_handle);
  218. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  219. sha_handle);
  220. sgx_sha256_get_hash(sha_handle, &h2);
  221. sgx_sha256_close(sha_handle);
  222. // Compute the server-to-client MAC
  223. sgx_hmac_state_handle_t hmac_handle;
  224. uint8_t srv_cli_mac[16];
  225. sgx_hmac256_init(h1, 16, &hmac_handle);
  226. sgx_hmac256_update((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
  227. hmac_handle);
  228. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  229. hmac_handle);
  230. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  231. hmac_handle);
  232. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  233. hmac_handle);
  234. sgx_hmac256_final(srv_cli_mac, 16, hmac_handle);
  235. sgx_hmac256_close(hmac_handle);
  236. // Compute the client-to-server MAC
  237. uint8_t cli_srv_mac[16];
  238. sgx_hmac256_init(((uint8_t*)h1)+16, 16, &hmac_handle);
  239. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  240. hmac_handle);
  241. sgx_hmac256_update((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
  242. hmac_handle);
  243. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  244. hmac_handle);
  245. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  246. hmac_handle);
  247. sgx_hmac256_final(cli_srv_mac, 16, hmac_handle);
  248. sgx_hmac256_close(hmac_handle);
  249. // Sign the server-to-client MAC
  250. sgx_ec256_signature_t srv_cli_sig;
  251. sgx_ecdsa_sign(srv_cli_mac, 16, &g_privkey, &srv_cli_sig, ecc_handle);
  252. sgx_ecc256_close_context(ecc_handle);
  253. // Save the state we'll need to process handshake message 3
  254. memmove(&nodest.in_aes_key, h2, 16);
  255. memmove(&nodest.out_aes_key, ((uint8_t*)h2)+16, 16);
  256. memmove(&nodest.handshake_cli_srv_mac, cli_srv_mac, 16);
  257. // Get us ready to receive handshake message 3
  258. nodest.in_msg_get_buf = default_in_msg_get_buf;
  259. nodest.in_msg_received = handshake_3_msg_received;
  260. nodest.handshake_step = HANDSHAKE_S_SENT_2;
  261. // Send handshake message 2
  262. nodest.message_start(sizeof(our_dh_pubkey) + sizeof(srv_cli_sig),
  263. false);
  264. nodest.message_data((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
  265. false);
  266. nodest.message_data((uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
  267. false);
  268. }
  269. // Receive (at the client) the secong handshake message
  270. static void handshake_2_msg_received(NodeCommState &nodest,
  271. uint8_t *data, uint32_t plaintext_len, uint32_t)
  272. {
  273. /*
  274. printf("Received handshake_2 message of %u bytes:\n", plaintext_len);
  275. for (uint32_t i=0;i<plaintext_len;++i) {
  276. printf("%02x", data[i]);
  277. }
  278. printf("\n");
  279. */
  280. if (plaintext_len != sizeof(sgx_ec256_public_t) +
  281. sizeof(sgx_ec256_signature_t)) {
  282. printf("Received handshake_2 message of incorrect size %u\n",
  283. plaintext_len);
  284. return;
  285. }
  286. sgx_ecc_state_handle_t ecc_handle;
  287. sgx_ec256_public_t peer_dh_pubkey;
  288. sgx_ec256_signature_t peer_sig;
  289. memmove(&peer_dh_pubkey, data, sizeof(peer_dh_pubkey));
  290. memmove(&peer_sig, data+sizeof(peer_dh_pubkey), sizeof(peer_sig));
  291. delete[] data;
  292. sgx_ecc256_open_context(&ecc_handle);
  293. int valid;
  294. if (sgx_ecc256_check_point(&peer_dh_pubkey, ecc_handle, &valid)
  295. || !valid) {
  296. printf("Invalid public key received from node %hu\n",
  297. nodest.node_num);
  298. sgx_ecc256_close_context(ecc_handle);
  299. return;
  300. }
  301. // Construct the shared secret
  302. sgx_ec256_dh_shared_t sharedsecret;
  303. sgx_ecc256_compute_shared_dhkey(&nodest.handshake_dh_privkey,
  304. &peer_dh_pubkey, &sharedsecret, ecc_handle);
  305. memset(&nodest.handshake_dh_privkey, 0,
  306. sizeof(nodest.handshake_dh_privkey));
  307. // Compute H1(sharedsecret) and H2(sharedsecret)
  308. sgx_sha_state_handle_t sha_handle;
  309. sgx_sha256_hash_t h1, h2;
  310. sgx_sha256_init(&sha_handle);
  311. sgx_sha256_update((const uint8_t*)"\x01", 1, sha_handle);
  312. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  313. sha_handle);
  314. sgx_sha256_get_hash(sha_handle, &h1);
  315. sgx_sha256_close(sha_handle);
  316. sgx_sha256_init(&sha_handle);
  317. sgx_sha256_update((const uint8_t*)"\x02", 1, sha_handle);
  318. sgx_sha256_update((uint8_t*)&sharedsecret, sizeof(sharedsecret),
  319. sha_handle);
  320. sgx_sha256_get_hash(sha_handle, &h2);
  321. sgx_sha256_close(sha_handle);
  322. // Compute the server-to-client MAC
  323. sgx_hmac_state_handle_t hmac_handle;
  324. uint8_t srv_cli_mac[16];
  325. sgx_hmac256_init(h1, 16, &hmac_handle);
  326. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  327. hmac_handle);
  328. sgx_hmac256_update((uint8_t*)&nodest.handshake_dh_pubkey,
  329. sizeof(nodest.handshake_dh_pubkey), hmac_handle);
  330. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  331. hmac_handle);
  332. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  333. hmac_handle);
  334. sgx_hmac256_final(srv_cli_mac, 16, hmac_handle);
  335. sgx_hmac256_close(hmac_handle);
  336. // Compute the client-to-server MAC
  337. uint8_t cli_srv_mac[16];
  338. sgx_hmac256_init(((uint8_t*)h1)+16, 16, &hmac_handle);
  339. sgx_hmac256_update((uint8_t*)&nodest.handshake_dh_pubkey,
  340. sizeof(nodest.handshake_dh_pubkey), hmac_handle);
  341. sgx_hmac256_update((uint8_t*)&peer_dh_pubkey, sizeof(peer_dh_pubkey),
  342. hmac_handle);
  343. sgx_hmac256_update((uint8_t*)&g_pubkey, sizeof(g_pubkey),
  344. hmac_handle);
  345. sgx_hmac256_update((uint8_t*)&nodest.pubkey, sizeof(nodest.pubkey),
  346. hmac_handle);
  347. sgx_hmac256_final(cli_srv_mac, 16, hmac_handle);
  348. sgx_hmac256_close(hmac_handle);
  349. // Verify the signature on the server-to-client MAC
  350. uint8_t result;
  351. if (sgx_ecdsa_verify(srv_cli_mac, 16, &nodest.pubkey, &peer_sig,
  352. &result, ecc_handle) || result != SGX_EC_VALID) {
  353. printf("Invalid signature received from node %hu\n",
  354. nodest.node_num);
  355. sgx_ecc256_close_context(ecc_handle);
  356. return;
  357. }
  358. printf("Valid signature received from node %hu\n", nodest.node_num);
  359. // Sign the client-to-server MAC
  360. sgx_ec256_signature_t cli_srv_sig;
  361. sgx_ecdsa_sign(cli_srv_mac, 16, &g_privkey, &cli_srv_sig, ecc_handle);
  362. sgx_ecc256_close_context(ecc_handle);
  363. // Our side of the handshake is complete
  364. memmove(&nodest.out_aes_key, h2, 16);
  365. memmove(&nodest.in_aes_key, ((uint8_t*)h2)+16, 16);
  366. memset(&nodest.out_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  367. memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  368. nodest.handshake_step = HANDSHAKE_COMPLETE;
  369. nodest.in_msg_get_buf = default_in_msg_get_buf;
  370. nodest.in_msg_received = default_in_msg_received;
  371. // Send handshake message 3
  372. nodest.message_start(sizeof(cli_srv_sig), false);
  373. nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
  374. false);
  375. // Mark the handshake as complete
  376. completed_handshake_counter.inc();
  377. }
  378. static void handshake_3_msg_received(NodeCommState &nodest,
  379. uint8_t *data, uint32_t plaintext_len, uint32_t)
  380. {
  381. /*
  382. printf("Received handshake_3 message of %u bytes:\n", plaintext_len);
  383. for (uint32_t i=0;i<plaintext_len;++i) {
  384. printf("%02x", data[i]);
  385. }
  386. printf("\n");
  387. */
  388. if (plaintext_len != sizeof(sgx_ec256_signature_t)) {
  389. printf("Received handshake_3 message of incorrect size %u\n",
  390. plaintext_len);
  391. return;
  392. }
  393. sgx_ecc_state_handle_t ecc_handle;
  394. sgx_ec256_signature_t peer_sig;
  395. memmove(&peer_sig, data, sizeof(peer_sig));
  396. delete[] data;
  397. sgx_ecc256_open_context(&ecc_handle);
  398. // Verify the signature on the client-to-server MAC
  399. uint8_t result;
  400. if (sgx_ecdsa_verify(nodest.handshake_cli_srv_mac, 16,
  401. &nodest.pubkey, &peer_sig, &result, ecc_handle)
  402. || result != SGX_EC_VALID) {
  403. printf("Invalid signature received from node %hu\n",
  404. nodest.node_num);
  405. sgx_ecc256_close_context(ecc_handle);
  406. return;
  407. }
  408. printf("Valid signature received from node %hu\n", nodest.node_num);
  409. // Our side of the handshake is complete
  410. memset(&nodest.out_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  411. memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
  412. nodest.handshake_step = HANDSHAKE_COMPLETE;
  413. nodest.in_msg_get_buf = default_in_msg_get_buf;
  414. nodest.in_msg_received = default_in_msg_received;
  415. // Mark the handshake as complete
  416. completed_handshake_counter.inc();
  417. }
  418. // Start a new outgoing message. Pass the number of _plaintext_ bytes
  419. // the message will be.
  420. void NodeCommState::message_start(uint32_t plaintext_len, bool encrypt)
  421. {
  422. uint32_t ciphertext_len = plaintext_len;
  423. // If the handshake is complete, add SGX_AESGCM_MAC_SIZE bytes for
  424. // every FRAME_SIZE-SGX_AESGCM_MAC_SIZE bytes of plaintext.
  425. if (encrypt) {
  426. uint32_t num_chunks = (plaintext_len +
  427. FRAME_SIZE - SGX_AESGCM_MAC_SIZE - 1) /
  428. (FRAME_SIZE - SGX_AESGCM_MAC_SIZE);
  429. ciphertext_len = plaintext_len +
  430. num_chunks * SGX_AESGCM_MAC_SIZE;
  431. }
  432. ocall_message(&frame, node_num, ciphertext_len);
  433. frame_offset = 0;
  434. msg_size = ciphertext_len;
  435. msg_frame_offset = 0;
  436. msg_plaintext_size = plaintext_len;
  437. msg_plaintext_processed = 0;
  438. if (plaintext_len < FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
  439. msg_plaintext_chunk_remain = plaintext_len;
  440. } else {
  441. msg_plaintext_chunk_remain = FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
  442. }
  443. if (!frame) {
  444. printf("Received NULL back from ocall_message\n");
  445. }
  446. if (msg_plaintext_chunk_remain > 0) {
  447. if (encrypt) {
  448. *(size_t*)out_aes_iv += 1;
  449. sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
  450. SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
  451. }
  452. }
  453. }
  454. // Process len bytes of plaintext data into the current message.
  455. void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
  456. {
  457. while (len > 0) {
  458. if (msg_plaintext_chunk_remain == 0) {
  459. printf("Attempt to queue too much message data\n");
  460. return;
  461. }
  462. uint32_t bytes_to_process = len;
  463. if (bytes_to_process > msg_plaintext_chunk_remain) {
  464. bytes_to_process = msg_plaintext_chunk_remain;
  465. }
  466. if (frame == NULL) {
  467. printf("frame is NULL when queueing message data\n");
  468. return;
  469. }
  470. if (encrypt) {
  471. // Encrypt the data
  472. sgx_aes_gcm128_enc_update(data, bytes_to_process,
  473. frame+frame_offset, out_aes_gcm_state);
  474. } else {
  475. // Just copy the plaintext data during the handshake
  476. memmove(frame+frame_offset, data, bytes_to_process);
  477. }
  478. frame_offset += bytes_to_process;
  479. msg_plaintext_processed += bytes_to_process;
  480. msg_plaintext_chunk_remain -= bytes_to_process;
  481. len -= bytes_to_process;
  482. data += bytes_to_process;
  483. if (msg_plaintext_chunk_remain == 0) {
  484. // Complete and send this chunk
  485. if (encrypt) {
  486. sgx_aes_gcm128_enc_get_mac(frame+frame_offset,
  487. out_aes_gcm_state);
  488. frame_offset += SGX_AESGCM_MAC_SIZE;
  489. }
  490. uint8_t *nextframe = NULL;
  491. ocall_chunk(&nextframe, node_num, frame, frame_offset);
  492. frame = nextframe;
  493. msg_frame_offset += frame_offset;
  494. frame_offset = 0;
  495. msg_plaintext_chunk_remain =
  496. msg_plaintext_size - msg_plaintext_processed;
  497. if (msg_plaintext_chunk_remain >
  498. FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
  499. msg_plaintext_chunk_remain =
  500. FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
  501. }
  502. if (encrypt) {
  503. sgx_aes_gcm_close(out_aes_gcm_state);
  504. if (msg_plaintext_chunk_remain > 0) {
  505. *(size_t*)out_aes_iv += 1;
  506. sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
  507. SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
  508. }
  509. }
  510. }
  511. }
  512. }
  513. // Generate a new identity signature key. Output the public key and the
  514. // sealed private key. outsealedpriv must point to SEALEDPRIVKEY_SIZE =
  515. // sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of
  516. // memory.
  517. void ecall_identity_key_new(sgx_ec256_public_t *outpub,
  518. sgx_sealed_data_t *outsealedpriv)
  519. {
  520. sgx_ecc_state_handle_t ecc_handle;
  521. sgx_ecc256_open_context(&ecc_handle);
  522. sgx_ecc256_create_key_pair(&g_privkey, &g_pubkey, ecc_handle);
  523. memmove(outpub, &g_pubkey, sizeof(g_pubkey));
  524. sgx_ecc256_close_context(ecc_handle);
  525. sgx_seal_data(18, (const uint8_t*)"TEEMS Identity key",
  526. sizeof(g_privkey), (const uint8_t*)&g_privkey,
  527. SEALED_PRIVKEY_SIZE, outsealedpriv);
  528. }
  529. // Load an identity key from a sealed privkey. Output the resulting
  530. // public key. insealedpriv must point to sizeof(sgx_sealed_data_t) +
  531. // sizeof(sgx_ec256_private_t) bytes of memory. Returns true for
  532. // success, false for failure.
  533. bool ecall_identity_key_load(sgx_ec256_public_t *outpub,
  534. const sgx_sealed_data_t *insealedpriv)
  535. {
  536. sgx_ecc_state_handle_t ecc_handle;
  537. char aad[18];
  538. uint32_t aadsize = sizeof(aad);
  539. sgx_ec256_private_t privkey;
  540. uint32_t privkeysize = sizeof(privkey);
  541. sgx_status_t res = sgx_unseal_data(
  542. insealedpriv, (uint8_t*)aad, &aadsize,
  543. (uint8_t*)&privkey, &privkeysize);
  544. if (res || aadsize != sizeof(aad) || privkeysize != sizeof(privkey)
  545. || memcmp(aad, "TEEMS Identity key", sizeof(aad))) {
  546. return false;
  547. }
  548. sgx_ecc256_open_context(&ecc_handle);
  549. sgx_ec256_public_t pubkey;
  550. int valid;
  551. if (sgx_ecc256_calculate_pub_from_priv(&privkey, &pubkey) ||
  552. sgx_ecc256_check_point(&pubkey, ecc_handle, &valid) ||
  553. !valid) {
  554. sgx_ecc256_close_context(ecc_handle);
  555. return false;
  556. }
  557. sgx_ecc256_close_context(ecc_handle);
  558. memmove(&g_pubkey, &pubkey, sizeof(pubkey));
  559. memmove(&g_privkey, &privkey, sizeof(privkey));
  560. memmove(outpub, &pubkey, sizeof(pubkey));
  561. return true;
  562. }
  563. bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
  564. nodenum_t num_nodes, nodenum_t me)
  565. {
  566. sgx_ecc_state_handle_t ecc_handle;
  567. sgx_ecc256_open_context(&ecc_handle);
  568. commstates.clear();
  569. tot_nodes = 0;
  570. commstates.reserve(num_nodes);
  571. for (nodenum_t i=0; i<num_nodes; ++i) {
  572. // Check that the pubkey is valid
  573. int valid;
  574. if (sgx_ecc256_check_point(&apinodeconfigs[i].pubkey,
  575. ecc_handle, &valid) ||
  576. !valid) {
  577. printf("Pubkey for node %hu invalid\n", i);
  578. commstates.clear();
  579. sgx_ecc256_close_context(ecc_handle);
  580. return false;
  581. }
  582. commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
  583. }
  584. sgx_ecc256_close_context(ecc_handle);
  585. my_node_num = me;
  586. // Check that no one other than us has our pubkey (deals with
  587. // reflection attacks)
  588. for (nodenum_t i=0; i<num_nodes; ++i) {
  589. if (i == my_node_num) continue;
  590. if (!memcmp(&commstates[i].pubkey,
  591. &commstates[my_node_num].pubkey,
  592. sizeof(commstates[i].pubkey))) {
  593. printf("Pubkey %hu matches our own; possible reflection attack?\n",
  594. i);
  595. commstates.clear();
  596. return false;
  597. }
  598. }
  599. tot_nodes = num_nodes;
  600. // There will be an enclave-to-enclave channel between us and each
  601. // other node's enclave. For the node numbers smaller than ours, we
  602. // will be the server for the handshake for that channel. Prepare
  603. // to receive the first handshake message from those nodes'
  604. // enclaves.
  605. for (nodenum_t i=0; i<my_node_num; ++i) {
  606. commstates[i].in_msg_get_buf = default_in_msg_get_buf;
  607. commstates[i].in_msg_received = handshake_1_msg_received;
  608. }
  609. return true;
  610. }
  611. bool ecall_message(nodenum_t node_num, uint32_t message_len)
  612. {
  613. if (node_num >= tot_nodes) {
  614. printf("Out-of-range node_num %hu received in ecall_message\n",
  615. node_num);
  616. return false;
  617. }
  618. NodeCommState &nodest = commstates[node_num];
  619. if (nodest.in_msg_size != nodest.in_msg_offset) {
  620. printf("Received ecall_message without completing previous message\n");
  621. return false;
  622. }
  623. if (!nodest.in_msg_get_buf) {
  624. printf("No message header handler registered\n");
  625. return false;
  626. }
  627. uint8_t *buf = nodest.in_msg_get_buf(nodest, message_len);
  628. if (!buf) {
  629. printf("Message header handler returned NULL\n");
  630. return false;
  631. }
  632. nodest.in_msg_size = message_len;
  633. nodest.in_msg_offset = 0;
  634. nodest.in_msg_plaintext_processed = 0;
  635. nodest.in_msg_buf = buf;
  636. return true;
  637. }
  638. bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
  639. uint32_t chunklen)
  640. {
  641. if (node_num >= tot_nodes) {
  642. printf("Out-of-range node_num %hu received in ecall_chunk\n",
  643. node_num);
  644. return false;
  645. }
  646. NodeCommState &nodest = commstates[node_num];
  647. if (nodest.in_msg_size == nodest.in_msg_offset) {
  648. printf("Received ecall_chunk after completing message\n");
  649. return false;
  650. }
  651. if (!nodest.in_msg_buf) {
  652. printf("No incoming message buffer allocated\n");
  653. return false;
  654. }
  655. if (!nodest.in_msg_received) {
  656. printf("No message received handler registered\n");
  657. return false;
  658. }
  659. if (nodest.in_msg_offset + chunklen > nodest.in_msg_size) {
  660. printf("Chunk larger than remaining message size\n");
  661. return false;
  662. }
  663. if (nodest.handshake_step == HANDSHAKE_COMPLETE) {
  664. // Decrypt the incoming data
  665. *(size_t*)(nodest.in_aes_iv) += 1;
  666. if (sgx_rijndael128GCM_decrypt(&nodest.in_aes_key, chunkdata,
  667. chunklen - SGX_AESGCM_MAC_SIZE,
  668. nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
  669. nodest.in_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
  670. (const sgx_aes_gcm_128bit_tag_t *)
  671. (chunkdata + chunklen - SGX_AESGCM_MAC_SIZE))) {
  672. printf("Decryption failed\n");
  673. return false;
  674. }
  675. nodest.in_msg_plaintext_processed +=
  676. chunklen - SGX_AESGCM_MAC_SIZE;
  677. } else {
  678. // Just copy the handshake data
  679. memmove(nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
  680. chunkdata, chunklen);
  681. nodest.in_msg_plaintext_processed += chunklen;
  682. }
  683. nodest.in_msg_offset += chunklen;
  684. if (nodest.in_msg_offset == nodest.in_msg_size) {
  685. // This was the last chunk; handle the received message
  686. uint8_t* buf = nodest.in_msg_buf;
  687. uint32_t plaintext_processed = nodest.in_msg_plaintext_processed;
  688. uint32_t msg_size = nodest.in_msg_size;
  689. nodest.in_msg_buf = NULL;
  690. nodest.in_msg_size = 0;
  691. nodest.in_msg_offset = 0;
  692. nodest.in_msg_plaintext_processed = 0;
  693. nodest.in_msg_received(nodest, buf, plaintext_processed, msg_size);
  694. }
  695. return true;
  696. }
  697. // Start the handshake (as the client)
  698. void NodeCommState::handshake_start()
  699. {
  700. sgx_ecc_state_handle_t ecc_handle;
  701. sgx_ecc256_open_context(&ecc_handle);
  702. // Create a DH keypair
  703. sgx_ecc256_create_key_pair(&handshake_dh_privkey, &handshake_dh_pubkey,
  704. ecc_handle);
  705. sgx_ecc256_close_context(ecc_handle);
  706. // Get us ready to receive handshake message 2
  707. in_msg_get_buf = default_in_msg_get_buf;
  708. in_msg_received = handshake_2_msg_received;
  709. handshake_step = HANDSHAKE_C_SENT_1;
  710. // Send the public key as the first message
  711. message_start(sizeof(handshake_dh_pubkey), false);
  712. message_data((uint8_t*)&handshake_dh_pubkey,
  713. sizeof(handshake_dh_pubkey), false);
  714. }
  715. // Start all handshakes for which we are the client. Call
  716. // ocall_comms_ready(cbpointer) when the handshakes with all other nodes
  717. // (for which we are client or server) are complete.
  718. bool ecall_comms_start(void *cbpointer)
  719. {
  720. completed_handshake_counter.reset(cbpointer);
  721. for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
  722. commstates[t].handshake_start();
  723. }
  724. return true;
  725. }