|
@@ -1,4 +1,5 @@
|
|
|
#include <vector>
|
|
|
+#include <functional>
|
|
|
#include <cstring>
|
|
|
|
|
|
#include "sgx_tcrypto.h"
|
|
@@ -39,30 +40,204 @@ struct NodeCommState {
|
|
|
uint8_t out_aes_iv[SGX_AESGCM_IV_SIZE];
|
|
|
uint8_t in_aes_iv[SGX_AESGCM_IV_SIZE];
|
|
|
|
|
|
+ // The GCM state for incrementally building each outgoing chunk
|
|
|
+ sgx_aes_state_handle_t out_aes_gcm_state;
|
|
|
+
|
|
|
// The current outgoing frame and the current offset into it
|
|
|
uint8_t *frame;
|
|
|
uint32_t frame_offset;
|
|
|
|
|
|
- // The current outgoing message size and the offset into it of the
|
|
|
- // start of the current frame
|
|
|
+ // The current outgoing message ciphertext size and the offset into
|
|
|
+ // it of the start of the current frame
|
|
|
uint32_t msg_size;
|
|
|
uint32_t msg_frame_offset;
|
|
|
|
|
|
- // The current incoming message size and the offset into it of all
|
|
|
- // previous chunks of this message
|
|
|
+ // The current outgoing message plaintext size, how many plaintext
|
|
|
+ // bytes we've already processed with message_data, and how many
|
|
|
+ // plaintext bytes remain for the current chunk
|
|
|
+ uint32_t msg_plaintext_size;
|
|
|
+ uint32_t msg_plaintext_processed;
|
|
|
+ uint32_t msg_plaintext_chunk_remain;
|
|
|
+
|
|
|
+ // The current incoming message ciphertext size and the offset into
|
|
|
+ // it of all previous chunks of this message
|
|
|
uint32_t in_msg_size;
|
|
|
uint32_t in_msg_offset;
|
|
|
+ // The current incoming message number of plaintext bytes processed
|
|
|
+ uint32_t in_msg_plaintext_processed;
|
|
|
// The internal buffer where we're storing the (decrypted) message
|
|
|
uint8_t *in_msg_buf;
|
|
|
|
|
|
+ // The function to call when a new incoming message header arrives.
|
|
|
+ // This function should return a pointer to enough memory to hold
|
|
|
+ // the (decrypted) chunks of the message. Remember that the length
|
|
|
+ // passed here is the total size of the _encrypted_ chunks. This
|
|
|
+ // function should not itself modify the in_msg_size, in_msg_offset,
|
|
|
+ // or in_msg_buf members. This function will usually allocate an
|
|
|
+ // appropriate amount of memory and return the pointer to it, but
|
|
|
+ // may do other things, like return a pointer to the middle of a
|
|
|
+ // previously allocated region of memory.
|
|
|
+ std::function<uint8_t*(NodeCommState&,uint32_t)> in_msg_get_buf;
|
|
|
+
|
|
|
+ // The function to call after the last chunk of a message has been
|
|
|
+ // received. If in_msg_get_buf allocated memory, this function
|
|
|
+ // should deallocate it. in_msg_size, in_msg_offset, and in_msg_buf
|
|
|
+ // will already have been reset when this function is called. The
|
|
|
+ // uint32_t that is passed are the total size of the _decrypted_
|
|
|
+ // data and the original total size of the _encrypted_ chunks that
|
|
|
+ // was passed to in_msg_get_buf.
|
|
|
+ std::function<void(NodeCommState&,uint8_t*,uint32_t,uint32_t)>
|
|
|
+ in_msg_received;
|
|
|
+
|
|
|
NodeCommState(const sgx_ec256_public_t* conf_pubkey, nodenum_t i) :
|
|
|
- node_num(i), handshake_step(HANDSHAKE_NONE), frame(NULL),
|
|
|
+ node_num(i), handshake_step(HANDSHAKE_NONE),
|
|
|
+ out_aes_gcm_state(NULL), frame(NULL),
|
|
|
frame_offset(0), msg_size(0), msg_frame_offset(0),
|
|
|
- in_msg_size(0), in_msg_offset(0), in_msg_buf(NULL) {
|
|
|
+ msg_plaintext_size(0), msg_plaintext_processed(0),
|
|
|
+ msg_plaintext_chunk_remain(0),
|
|
|
+ in_msg_size(0), in_msg_offset(0),
|
|
|
+ in_msg_plaintext_processed(0), in_msg_buf(NULL),
|
|
|
+ in_msg_get_buf(NULL), in_msg_received(NULL) {
|
|
|
memmove(&pubkey, conf_pubkey, sizeof(pubkey));
|
|
|
}
|
|
|
+
|
|
|
+ void message_start(uint32_t plaintext_len);
|
|
|
+
|
|
|
+ void message_data(uint8_t *data, uint32_t len);
|
|
|
};
|
|
|
|
|
|
+// A typical default in_msg_get_buf handler. It computes the maximum
|
|
|
+// possible size of the decrypted data, allocates that much memory, and
|
|
|
+// returns a pointer to it.
|
|
|
+static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size)
|
|
|
+{
|
|
|
+ uint32_t max_plaintext_bytes = tot_enc_chunk_size;
|
|
|
+
|
|
|
+ // If the handshake is complete, chunks will be encrypted and have a
|
|
|
+ // MAC tag attached which will not correspond to plaintext bytes, so
|
|
|
+ // we can trim them.
|
|
|
+
|
|
|
+ if (commst.handshake_step == HANDSHAKE_COMPLETE) {
|
|
|
+ // The minimum number of chunks needed to transmit this message
|
|
|
+ uint32_t min_num_chunks =
|
|
|
+ (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
|
|
|
+ // The maximum number of plaintext bytes this message could contain
|
|
|
+ max_plaintext_bytes = tot_enc_chunk_size -
|
|
|
+ SGX_AESGCM_MAC_SIZE * min_num_chunks;
|
|
|
+ }
|
|
|
+
|
|
|
+ return new uint8_t[max_plaintext_bytes];
|
|
|
+}
|
|
|
+
|
|
|
+// Receive (at the server) the first handshake message
|
|
|
+static void handshake_1_msg_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t message_len)
|
|
|
+{
|
|
|
+ printf("Received handshake_1 message of %u bytes:\n", plaintext_len);
|
|
|
+ for (uint32_t i=0;i<plaintext_len;++i) {
|
|
|
+ printf("%02x", data[i]);
|
|
|
+ }
|
|
|
+ printf("\n");
|
|
|
+ delete[] data;
|
|
|
+}
|
|
|
+
|
|
|
+// Start a new outgoing message. Pass the number of _plaintext_ bytes
|
|
|
+// the message will be.
|
|
|
+void NodeCommState::message_start(uint32_t plaintext_len)
|
|
|
+{
|
|
|
+ uint32_t ciphertext_len = plaintext_len;
|
|
|
+
|
|
|
+ // If the handshake is complete, add SGX_AESGCM_MAC_SIZE bytes for
|
|
|
+ // every FRAME_SIZE-SGX_AESGCM_MAC_SIZE bytes of plaintext.
|
|
|
+ if (handshake_step == HANDSHAKE_COMPLETE) {
|
|
|
+ uint32_t num_chunks = (plaintext_len +
|
|
|
+ FRAME_SIZE - SGX_AESGCM_MAC_SIZE - 1) /
|
|
|
+ (FRAME_SIZE - SGX_AESGCM_MAC_SIZE);
|
|
|
+ ciphertext_len = plaintext_len +
|
|
|
+ num_chunks * SGX_AESGCM_MAC_SIZE;
|
|
|
+ }
|
|
|
+ ocall_message(&frame, node_num, ciphertext_len);
|
|
|
+ frame_offset = 0;
|
|
|
+ msg_size = ciphertext_len;
|
|
|
+ msg_frame_offset = 0;
|
|
|
+ msg_plaintext_size = plaintext_len;
|
|
|
+ msg_plaintext_processed = 0;
|
|
|
+ if (plaintext_len < FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
|
|
|
+ msg_plaintext_chunk_remain = plaintext_len;
|
|
|
+ } else {
|
|
|
+ msg_plaintext_chunk_remain = FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
|
|
|
+ }
|
|
|
+ if (!frame) {
|
|
|
+ printf("Received NULL back from ocall_message\n");
|
|
|
+ }
|
|
|
+ if (msg_plaintext_chunk_remain > 0) {
|
|
|
+ *(size_t*)out_aes_iv += 1;
|
|
|
+ sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv, SGX_AESGCM_IV_SIZE,
|
|
|
+ NULL, 0, &out_aes_gcm_state);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Process len bytes of plaintext data into the current message.
|
|
|
+void NodeCommState::message_data(uint8_t *data, uint32_t len)
|
|
|
+{
|
|
|
+ while (len > 0) {
|
|
|
+ if (msg_plaintext_chunk_remain == 0) {
|
|
|
+ printf("Attempt to queue too much message data\n");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ uint32_t bytes_to_process = len;
|
|
|
+ if (bytes_to_process > msg_plaintext_chunk_remain) {
|
|
|
+ bytes_to_process = msg_plaintext_chunk_remain;
|
|
|
+ }
|
|
|
+ if (frame == NULL) {
|
|
|
+ printf("frame is NULL when queueing message data\n");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (handshake_step == HANDSHAKE_COMPLETE) {
|
|
|
+ // Encrypt the data
|
|
|
+ sgx_aes_gcm128_enc_update(data, bytes_to_process,
|
|
|
+ frame+frame_offset, out_aes_gcm_state);
|
|
|
+ } else {
|
|
|
+ // Just copy the plaintext data during the handshake
|
|
|
+ memmove(frame+frame_offset, data, bytes_to_process);
|
|
|
+ }
|
|
|
+ frame_offset += bytes_to_process;
|
|
|
+ msg_plaintext_processed += bytes_to_process;
|
|
|
+ msg_plaintext_chunk_remain -= bytes_to_process;
|
|
|
+ len -= bytes_to_process;
|
|
|
+ data += bytes_to_process;
|
|
|
+ if (msg_plaintext_chunk_remain == 0) {
|
|
|
+ // Complete and send this chunk
|
|
|
+ if (handshake_step == HANDSHAKE_COMPLETE) {
|
|
|
+ sgx_aes_gcm128_enc_get_mac(frame+frame_offset,
|
|
|
+ out_aes_gcm_state);
|
|
|
+ frame_offset += SGX_AESGCM_MAC_SIZE;
|
|
|
+ }
|
|
|
+ uint8_t *nextframe = NULL;
|
|
|
+ ocall_chunk(&nextframe, node_num, frame, frame_offset);
|
|
|
+ frame = nextframe;
|
|
|
+ msg_frame_offset += frame_offset;
|
|
|
+ frame_offset = 0;
|
|
|
+ msg_plaintext_chunk_remain =
|
|
|
+ msg_plaintext_size - msg_plaintext_processed;
|
|
|
+ if (msg_plaintext_chunk_remain >
|
|
|
+ FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
|
|
|
+ msg_plaintext_chunk_remain =
|
|
|
+ FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
|
|
|
+ }
|
|
|
+ if (handshake_step == HANDSHAKE_COMPLETE) {
|
|
|
+ sgx_aes_gcm_close(out_aes_gcm_state);
|
|
|
+ if (msg_plaintext_chunk_remain > 0) {
|
|
|
+ *(size_t*)out_aes_iv += 1;
|
|
|
+ sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
|
|
|
+ SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// The communication states for all the nodes. There's an entry for
|
|
|
// ourselves in here, but it is unused.
|
|
|
static std::vector<NodeCommState> commstates;
|
|
@@ -170,6 +345,16 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
|
|
|
}
|
|
|
|
|
|
tot_nodes = num_nodes;
|
|
|
+
|
|
|
+ // There will be an enclave-to-enclave channel between us and each
|
|
|
+ // other node's enclave. For the node numbers smaller than ours, we
|
|
|
+ // will be the server for the handshake for that channel. Prepare
|
|
|
+ // to receive the first handshake message from those nodes'
|
|
|
+ // enclaves.
|
|
|
+ for (nodenum_t i=0; i<my_node_num; ++i) {
|
|
|
+ commstates[i].in_msg_get_buf = default_in_msg_get_buf;
|
|
|
+ commstates[i].in_msg_received = handshake_1_msg_received;
|
|
|
+ }
|
|
|
return true;
|
|
|
}
|
|
|
|
|
@@ -186,13 +371,73 @@ bool ecall_message(nodenum_t node_num, uint32_t message_len)
|
|
|
printf("Received ecall_message without completing previous message\n");
|
|
|
return false;
|
|
|
}
|
|
|
- printf("ecall_message called\n");
|
|
|
+ if (!nodest.in_msg_get_buf) {
|
|
|
+ printf("No message header handler registered\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ uint8_t *buf = nodest.in_msg_get_buf(nodest, message_len);
|
|
|
+ if (!buf) {
|
|
|
+ printf("Message header handler returned NULL\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ nodest.in_msg_size = message_len;
|
|
|
+ nodest.in_msg_offset = 0;
|
|
|
+ nodest.in_msg_plaintext_processed = 0;
|
|
|
+ nodest.in_msg_buf = buf;
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
|
|
|
uint32_t chunklen)
|
|
|
{
|
|
|
- printf("ecall_chunk called\n");
|
|
|
+ if (node_num >= tot_nodes) {
|
|
|
+ printf("Out-of-range node_num %hu received in ecall_chunk\n",
|
|
|
+ node_num);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ NodeCommState &nodest = commstates[node_num];
|
|
|
+
|
|
|
+ if (nodest.in_msg_size == nodest.in_msg_offset) {
|
|
|
+ printf("Received ecall_chunk after completing message\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (!nodest.in_msg_buf) {
|
|
|
+ printf("No incoming message buffer allocated\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (!nodest.in_msg_received) {
|
|
|
+ printf("No message received handler registered\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (nodest.in_msg_offset + chunklen > nodest.in_msg_size) {
|
|
|
+ printf("Chunk larger than remaining message size\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (nodest.handshake_step == HANDSHAKE_COMPLETE) {
|
|
|
+ // Decrypt the incoming data
|
|
|
+ *(size_t*)(nodest.in_aes_iv) += 1;
|
|
|
+ if (sgx_rijndael128GCM_decrypt(&nodest.in_aes_key, chunkdata,
|
|
|
+ chunklen - SGX_AESGCM_MAC_SIZE,
|
|
|
+ nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
|
|
|
+ nodest.in_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
|
|
|
+ (const sgx_aes_gcm_128bit_tag_t *)
|
|
|
+ (chunkdata + chunklen - SGX_AESGCM_MAC_SIZE))) {
|
|
|
+ printf("Decryption failed\n");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ nodest.in_msg_plaintext_processed +=
|
|
|
+ chunklen - SGX_AESGCM_MAC_SIZE;
|
|
|
+ } else {
|
|
|
+ // Just copy the handshake data
|
|
|
+ memmove(nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
|
|
|
+ chunkdata, chunklen);
|
|
|
+ nodest.in_msg_plaintext_processed += chunklen;
|
|
|
+ }
|
|
|
+ nodest.in_msg_offset += chunklen;
|
|
|
+ if (nodest.in_msg_offset == nodest.in_msg_size) {
|
|
|
+ // This was the last chunk; handle the received message
|
|
|
+ nodest.in_msg_received(nodest, nodest.in_msg_buf,
|
|
|
+ nodest.in_msg_plaintext_processed, nodest.in_msg_size);
|
|
|
+ }
|
|
|
return true;
|
|
|
}
|