#include #include #include #include #include "sgx_tcrypto.h" #include "sgx_tseal.h" #include "Enclave_t.h" #include "utils.hpp" #include "config.hpp" #include "comms.hpp" // Our public and private identity keys static sgx_ec256_private_t g_privkey; static sgx_ec256_public_t g_pubkey; // The communication states for all the nodes. There's an entry for // ourselves in here, but it is unused. std::vector commstates; static nodenum_t tot_nodes, my_node_num; static class CompletedHandshakeCounter { // Mutex around completed_handshakes pthread_mutex_t mutex; // The number of completed handshakes nodenum_t completed_handshakes; // The callback pointer to use when all handshakes complete void *complete_handshake_cbpointer; public: CompletedHandshakeCounter() { pthread_mutex_init(&mutex, NULL); completed_handshakes = 0; complete_handshake_cbpointer = NULL; } void reset(void *cbpointer) { pthread_mutex_lock(&mutex); completed_handshakes = 0; complete_handshake_cbpointer = cbpointer; pthread_mutex_unlock(&mutex); } void inc() { pthread_mutex_lock(&mutex); ++completed_handshakes; nodenum_t num_completed = completed_handshakes; pthread_mutex_unlock(&mutex); if (num_completed == tot_nodes - 1) { pthread_mutex_lock(&mutex); void *cbpointer = complete_handshake_cbpointer; complete_handshake_cbpointer = NULL; completed_handshakes = 0; pthread_mutex_unlock(&mutex); ocall_comms_ready(cbpointer); } } } completed_handshake_counter; // A typical default in_msg_get_buf handler. It computes the maximum // possible size of the decrypted data, allocates that much memory, and // returns a pointer to it. static uint8_t* default_in_msg_get_buf(NodeCommState &commst, uint32_t tot_enc_chunk_size) { uint32_t max_plaintext_bytes = tot_enc_chunk_size; // If the handshake is complete, chunks will be encrypted and have a // MAC tag attached which will not correspond to plaintext bytes, so // we can trim them. if (commst.handshake_step == HANDSHAKE_COMPLETE) { // The minimum number of chunks needed to transmit this message uint32_t min_num_chunks = (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE; // The maximum number of plaintext bytes this message could contain max_plaintext_bytes = tot_enc_chunk_size - SGX_AESGCM_MAC_SIZE * min_num_chunks; } return new uint8_t[max_plaintext_bytes]; } static void default_in_msg_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { printf("Received message of %u bytes from node %lu:\n", plaintext_len, nodest.node_num); for (uint32_t i=0;i 0) { if (encrypt) { *(size_t*)out_aes_iv += 1; sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state); } } } // Process len bytes of plaintext data into the current message. void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt) { while (len > 0) { if (msg_plaintext_chunk_remain == 0) { printf("Attempt to queue too much message data\n"); return; } uint32_t bytes_to_process = len; if (bytes_to_process > msg_plaintext_chunk_remain) { bytes_to_process = msg_plaintext_chunk_remain; } if (frame == NULL) { printf("frame is NULL when queueing message data\n"); return; } if (encrypt) { // Encrypt the data sgx_aes_gcm128_enc_update(data, bytes_to_process, frame+frame_offset, out_aes_gcm_state); } else { // Just copy the plaintext data during the handshake memmove(frame+frame_offset, data, bytes_to_process); } frame_offset += bytes_to_process; msg_plaintext_processed += bytes_to_process; msg_plaintext_chunk_remain -= bytes_to_process; len -= bytes_to_process; data += bytes_to_process; if (msg_plaintext_chunk_remain == 0) { // Complete and send this chunk if (encrypt) { sgx_aes_gcm128_enc_get_mac(frame+frame_offset, out_aes_gcm_state); frame_offset += SGX_AESGCM_MAC_SIZE; } uint8_t *nextframe = NULL; ocall_chunk(&nextframe, node_num, frame, frame_offset); frame = nextframe; msg_frame_offset += frame_offset; frame_offset = 0; msg_plaintext_chunk_remain = msg_plaintext_size - msg_plaintext_processed; if (msg_plaintext_chunk_remain > FRAME_SIZE - SGX_AESGCM_MAC_SIZE) { msg_plaintext_chunk_remain = FRAME_SIZE - SGX_AESGCM_MAC_SIZE; } if (encrypt) { sgx_aes_gcm_close(out_aes_gcm_state); if (msg_plaintext_chunk_remain > 0) { *(size_t*)out_aes_iv += 1; sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state); } } } } } // Generate a new identity signature key. Output the public key and the // sealed private key. outsealedpriv must point to SEALEDPRIVKEY_SIZE = // sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of // memory. void ecall_identity_key_new(sgx_ec256_public_t *outpub, sgx_sealed_data_t *outsealedpriv) { sgx_ecc_state_handle_t ecc_handle; sgx_ecc256_open_context(&ecc_handle); sgx_ecc256_create_key_pair(&g_privkey, &g_pubkey, ecc_handle); memmove(outpub, &g_pubkey, sizeof(g_pubkey)); sgx_ecc256_close_context(ecc_handle); sgx_seal_data(18, (const uint8_t*)"TEEMS Identity key", sizeof(g_privkey), (const uint8_t*)&g_privkey, SEALED_PRIVKEY_SIZE, outsealedpriv); } // Load an identity key from a sealed privkey. Output the resulting // public key. insealedpriv must point to sizeof(sgx_sealed_data_t) + // sizeof(sgx_ec256_private_t) bytes of memory. Returns true for // success, false for failure. bool ecall_identity_key_load(sgx_ec256_public_t *outpub, const sgx_sealed_data_t *insealedpriv) { sgx_ecc_state_handle_t ecc_handle; char aad[18]; uint32_t aadsize = sizeof(aad); sgx_ec256_private_t privkey; uint32_t privkeysize = sizeof(privkey); sgx_status_t res = sgx_unseal_data( insealedpriv, (uint8_t*)aad, &aadsize, (uint8_t*)&privkey, &privkeysize); if (res || aadsize != sizeof(aad) || privkeysize != sizeof(privkey) || memcmp(aad, "TEEMS Identity key", sizeof(aad))) { return false; } sgx_ecc256_open_context(&ecc_handle); sgx_ec256_public_t pubkey; int valid; if (sgx_ecc256_calculate_pub_from_priv(&privkey, &pubkey) || sgx_ecc256_check_point(&pubkey, ecc_handle, &valid) || !valid) { sgx_ecc256_close_context(ecc_handle); return false; } sgx_ecc256_close_context(ecc_handle); memmove(&g_pubkey, &pubkey, sizeof(pubkey)); memmove(&g_privkey, &privkey, sizeof(privkey)); memmove(outpub, &pubkey, sizeof(pubkey)); return true; } bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs, nodenum_t num_nodes, nodenum_t me) { sgx_ecc_state_handle_t ecc_handle; sgx_ecc256_open_context(&ecc_handle); commstates.clear(); tot_nodes = 0; commstates.reserve(num_nodes); for (nodenum_t i=0; i= tot_nodes) { printf("Out-of-range node_num %hu received in ecall_message\n", node_num); return false; } NodeCommState &nodest = commstates[node_num]; if (nodest.in_msg_size != nodest.in_msg_offset) { printf("Received ecall_message without completing previous message\n"); return false; } if (!nodest.in_msg_get_buf) { printf("No message header handler registered\n"); return false; } uint8_t *buf = nodest.in_msg_get_buf(nodest, message_len); if (!buf) { printf("Message header handler returned NULL\n"); return false; } nodest.in_msg_size = message_len; nodest.in_msg_offset = 0; nodest.in_msg_plaintext_processed = 0; nodest.in_msg_buf = buf; return true; } bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata, uint32_t chunklen) { if (node_num >= tot_nodes) { printf("Out-of-range node_num %hu received in ecall_chunk\n", node_num); return false; } NodeCommState &nodest = commstates[node_num]; if (nodest.in_msg_size == nodest.in_msg_offset) { printf("Received ecall_chunk after completing message\n"); return false; } if (!nodest.in_msg_buf) { printf("No incoming message buffer allocated\n"); return false; } if (!nodest.in_msg_received) { printf("No message received handler registered\n"); return false; } if (nodest.in_msg_offset + chunklen > nodest.in_msg_size) { printf("Chunk larger than remaining message size\n"); return false; } if (nodest.handshake_step == HANDSHAKE_COMPLETE) { // Decrypt the incoming data *(size_t*)(nodest.in_aes_iv) += 1; if (sgx_rijndael128GCM_decrypt(&nodest.in_aes_key, chunkdata, chunklen - SGX_AESGCM_MAC_SIZE, nodest.in_msg_buf + nodest.in_msg_plaintext_processed, nodest.in_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0, (const sgx_aes_gcm_128bit_tag_t *) (chunkdata + chunklen - SGX_AESGCM_MAC_SIZE))) { printf("Decryption failed\n"); return false; } nodest.in_msg_plaintext_processed += chunklen - SGX_AESGCM_MAC_SIZE; } else { // Just copy the handshake data memmove(nodest.in_msg_buf + nodest.in_msg_plaintext_processed, chunkdata, chunklen); nodest.in_msg_plaintext_processed += chunklen; } nodest.in_msg_offset += chunklen; if (nodest.in_msg_offset == nodest.in_msg_size) { // This was the last chunk; handle the received message uint8_t* buf = nodest.in_msg_buf; uint32_t plaintext_processed = nodest.in_msg_plaintext_processed; uint32_t msg_size = nodest.in_msg_size; nodest.in_msg_buf = NULL; nodest.in_msg_size = 0; nodest.in_msg_offset = 0; nodest.in_msg_plaintext_processed = 0; nodest.in_msg_received(nodest, buf, plaintext_processed, msg_size); } return true; } // Start the handshake (as the client) void NodeCommState::handshake_start() { sgx_ecc_state_handle_t ecc_handle; sgx_ecc256_open_context(&ecc_handle); // Create a DH keypair sgx_ecc256_create_key_pair(&handshake_dh_privkey, &handshake_dh_pubkey, ecc_handle); sgx_ecc256_close_context(ecc_handle); // Get us ready to receive handshake message 2 in_msg_get_buf = default_in_msg_get_buf; in_msg_received = handshake_2_msg_received; handshake_step = HANDSHAKE_C_SENT_1; // Send the public key as the first message message_start(sizeof(handshake_dh_pubkey), false); message_data((uint8_t*)&handshake_dh_pubkey, sizeof(handshake_dh_pubkey), false); } // Start all handshakes for which we are the client. Call // ocall_comms_ready(cbpointer) when the handshakes with all other nodes // (for which we are client or server) are complete. bool ecall_comms_start(void *cbpointer) { completed_handshake_counter.reset(cbpointer); for (nodenum_t t = my_node_num+1; t