Browse Source

The enclave side of the untrusted app to enclave I/O

Ian Goldberg 1 year ago
parent
commit
9e22973195
3 changed files with 258 additions and 9 deletions
  1. 4 0
      App/net.cpp
  2. 1 1
      Enclave/Enclave.config.xml
  3. 253 8
      Enclave/comms.cpp

+ 4 - 0
App/net.cpp

@@ -140,6 +140,8 @@ void NodeIO::recv_commands(
                 recv_chunksize_inflight = 0;
                 if (ecall_message(node_num, recv_msgsize_inflight)) {
                     recv_commands(error_cb, epoch_cb);
+                } else {
+                    printf("ecall_message failed\n");
                 }
             } else if ((receive_header & 0xff) == COMMAND_CHUNK) {
                 uint32_t this_chunk_size = uint32_t(receive_header >> 8);
@@ -157,6 +159,8 @@ void NodeIO::recv_commands(
                         if (ecall_chunk(node_num, receive_frame,
                                 this_chunk_size)) {
                             recv_commands(error_cb, epoch_cb);
+                        } else {
+                            printf("ecall_chunk failed\n");
                         }
                     });
             } else {

+ 1 - 1
Enclave/Enclave.config.xml

@@ -3,7 +3,7 @@
   <ProdID>0</ProdID>
   <ISVSVN>0</ISVSVN>
   <StackMaxSize>0x40000</StackMaxSize>
-  <HeapMaxSize>0x100000</HeapMaxSize>
+  <HeapMaxSize>0x2000000</HeapMaxSize>
   <TCSNum>10</TCSNum>
   <TCSPolicy>1</TCSPolicy>
   <DisableDebug>0</DisableDebug>

+ 253 - 8
Enclave/comms.cpp

@@ -1,4 +1,5 @@
 #include <vector>
+#include <functional>
 #include <cstring>
 
 #include "sgx_tcrypto.h"
@@ -39,30 +40,204 @@ struct NodeCommState {
     uint8_t out_aes_iv[SGX_AESGCM_IV_SIZE];
     uint8_t in_aes_iv[SGX_AESGCM_IV_SIZE];
 
+    // The GCM state for incrementally building each outgoing chunk
+    sgx_aes_state_handle_t out_aes_gcm_state;
+
     // The current outgoing frame and the current offset into it
     uint8_t *frame;
     uint32_t frame_offset;
 
-    // The current outgoing message size and the offset into it of the
-    // start of the current frame
+    // The current outgoing message ciphertext size and the offset into
+    // it of the start of the current frame
     uint32_t msg_size;
     uint32_t msg_frame_offset;
 
-    // The current incoming message size and the offset into it of all
-    // previous chunks of this message
+    // The current outgoing message plaintext size, how many plaintext
+    // bytes we've already processed with message_data, and how many
+    // plaintext bytes remain for the current chunk
+    uint32_t msg_plaintext_size;
+    uint32_t msg_plaintext_processed;
+    uint32_t msg_plaintext_chunk_remain;
+
+    // The current incoming message ciphertext size and the offset into
+    // it of all previous chunks of this message
     uint32_t in_msg_size;
     uint32_t in_msg_offset;
+    // The current incoming message number of plaintext bytes processed
+    uint32_t in_msg_plaintext_processed;
     // The internal buffer where we're storing the (decrypted) message
     uint8_t *in_msg_buf;
 
+    // The function to call when a new incoming message header arrives.
+    // This function should return a pointer to enough memory to hold
+    // the (decrypted) chunks of the message.  Remember that the length
+    // passed here is the total size of the _encrypted_ chunks.  This
+    // function should not itself modify the in_msg_size, in_msg_offset,
+    // or in_msg_buf members.  This function will usually allocate an
+    // appropriate amount of memory and return the pointer to it, but
+    // may do other things, like return a pointer to the middle of a
+    // previously allocated region of memory.
+    std::function<uint8_t*(NodeCommState&,uint32_t)> in_msg_get_buf;
+
+    // The function to call after the last chunk of a message has been
+    // received.  If in_msg_get_buf allocated memory, this function
+    // should deallocate it.  in_msg_size, in_msg_offset, and in_msg_buf
+    // will already have been reset when this function is called.  The
+    // uint32_t that is passed are the total size of the _decrypted_
+    // data and the original total size of the _encrypted_ chunks that
+    // was passed to in_msg_get_buf.
+    std::function<void(NodeCommState&,uint8_t*,uint32_t,uint32_t)>
+        in_msg_received;
+
     NodeCommState(const sgx_ec256_public_t* conf_pubkey, nodenum_t i) :
-            node_num(i), handshake_step(HANDSHAKE_NONE), frame(NULL),
+            node_num(i), handshake_step(HANDSHAKE_NONE),
+            out_aes_gcm_state(NULL), frame(NULL),
             frame_offset(0), msg_size(0), msg_frame_offset(0),
-            in_msg_size(0), in_msg_offset(0), in_msg_buf(NULL) {
+            msg_plaintext_size(0), msg_plaintext_processed(0),
+            msg_plaintext_chunk_remain(0),
+            in_msg_size(0), in_msg_offset(0),
+            in_msg_plaintext_processed(0), in_msg_buf(NULL),
+            in_msg_get_buf(NULL), in_msg_received(NULL) {
         memmove(&pubkey, conf_pubkey, sizeof(pubkey));
     }
+
+    void message_start(uint32_t plaintext_len);
+
+    void message_data(uint8_t *data, uint32_t len);
 };
 
+// A typical default in_msg_get_buf handler.  It computes the maximum
+// possible size of the decrypted data, allocates that much memory, and
+// returns a pointer to it.
+static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
+    uint32_t tot_enc_chunk_size)
+{
+    uint32_t max_plaintext_bytes = tot_enc_chunk_size;
+
+    // If the handshake is complete, chunks will be encrypted and have a
+    // MAC tag attached which will not correspond to plaintext bytes, so
+    // we can trim them.
+
+    if (commst.handshake_step == HANDSHAKE_COMPLETE) {
+        // The minimum number of chunks needed to transmit this message
+        uint32_t min_num_chunks =
+            (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
+        // The maximum number of plaintext bytes this message could contain
+        max_plaintext_bytes = tot_enc_chunk_size -
+            SGX_AESGCM_MAC_SIZE * min_num_chunks;
+    }
+
+    return new uint8_t[max_plaintext_bytes];
+}
+
+// Receive (at the server) the first handshake message
+static void handshake_1_msg_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t message_len)
+{
+    printf("Received handshake_1 message of %u bytes:\n", plaintext_len);
+    for (uint32_t i=0;i<plaintext_len;++i) {
+        printf("%02x", data[i]);
+    }
+    printf("\n");
+    delete[] data;
+}
+
+// Start a new outgoing message.  Pass the number of _plaintext_ bytes
+// the message will be.
+void NodeCommState::message_start(uint32_t plaintext_len)
+{
+    uint32_t ciphertext_len = plaintext_len;
+
+    // If the handshake is complete, add SGX_AESGCM_MAC_SIZE bytes for
+    // every FRAME_SIZE-SGX_AESGCM_MAC_SIZE bytes of plaintext.
+    if (handshake_step == HANDSHAKE_COMPLETE) {
+        uint32_t num_chunks = (plaintext_len +
+            FRAME_SIZE - SGX_AESGCM_MAC_SIZE - 1) /
+            (FRAME_SIZE - SGX_AESGCM_MAC_SIZE);
+        ciphertext_len = plaintext_len +
+            num_chunks * SGX_AESGCM_MAC_SIZE;
+    }
+    ocall_message(&frame, node_num, ciphertext_len);
+    frame_offset = 0;
+    msg_size = ciphertext_len;
+    msg_frame_offset = 0;
+    msg_plaintext_size = plaintext_len;
+    msg_plaintext_processed = 0;
+    if (plaintext_len < FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
+        msg_plaintext_chunk_remain = plaintext_len;
+    } else {
+        msg_plaintext_chunk_remain = FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
+    }
+    if (!frame) {
+        printf("Received NULL back from ocall_message\n");
+    }
+    if (msg_plaintext_chunk_remain > 0) {
+        *(size_t*)out_aes_iv += 1;
+        sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv, SGX_AESGCM_IV_SIZE,
+            NULL, 0, &out_aes_gcm_state);
+    }
+}
+
+// Process len bytes of plaintext data into the current message.
+void NodeCommState::message_data(uint8_t *data, uint32_t len)
+{
+    while (len > 0) {
+        if (msg_plaintext_chunk_remain == 0) {
+            printf("Attempt to queue too much message data\n");
+            return;
+        }
+        uint32_t bytes_to_process = len;
+        if (bytes_to_process > msg_plaintext_chunk_remain) {
+            bytes_to_process = msg_plaintext_chunk_remain;
+        }
+        if (frame == NULL) {
+            printf("frame is NULL when queueing message data\n");
+            return;
+        }
+        if (handshake_step == HANDSHAKE_COMPLETE) {
+            // Encrypt the data
+            sgx_aes_gcm128_enc_update(data, bytes_to_process,
+                frame+frame_offset, out_aes_gcm_state);
+        } else {
+            // Just copy the plaintext data during the handshake
+            memmove(frame+frame_offset, data, bytes_to_process);
+        }
+        frame_offset += bytes_to_process;
+        msg_plaintext_processed += bytes_to_process;
+        msg_plaintext_chunk_remain -= bytes_to_process;
+        len -= bytes_to_process;
+        data += bytes_to_process;
+        if (msg_plaintext_chunk_remain == 0) {
+            // Complete and send this chunk
+            if (handshake_step == HANDSHAKE_COMPLETE) {
+                sgx_aes_gcm128_enc_get_mac(frame+frame_offset,
+                    out_aes_gcm_state);
+                frame_offset += SGX_AESGCM_MAC_SIZE;
+            }
+            uint8_t *nextframe = NULL;
+            ocall_chunk(&nextframe, node_num, frame, frame_offset);
+            frame = nextframe;
+            msg_frame_offset += frame_offset;
+            frame_offset = 0;
+            msg_plaintext_chunk_remain =
+                msg_plaintext_size - msg_plaintext_processed;
+            if (msg_plaintext_chunk_remain >
+                    FRAME_SIZE - SGX_AESGCM_MAC_SIZE) {
+                msg_plaintext_chunk_remain =
+                    FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
+            }
+            if (handshake_step == HANDSHAKE_COMPLETE) {
+                sgx_aes_gcm_close(out_aes_gcm_state);
+                if (msg_plaintext_chunk_remain > 0) {
+                    *(size_t*)out_aes_iv += 1;
+                    sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
+                        SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
+                }
+            }
+        }
+    }
+}
+
 // The communication states for all the nodes.  There's an entry for
 // ourselves in here, but it is unused.
 static std::vector<NodeCommState> commstates;
@@ -170,6 +345,16 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     }
 
     tot_nodes = num_nodes;
+
+    // There will be an enclave-to-enclave channel between us and each
+    // other node's enclave.  For the node numbers smaller than ours, we
+    // will be the server for the handshake for that channel.  Prepare
+    // to receive the first handshake message from those nodes'
+    // enclaves.
+    for (nodenum_t i=0; i<my_node_num; ++i) {
+        commstates[i].in_msg_get_buf = default_in_msg_get_buf;
+        commstates[i].in_msg_received = handshake_1_msg_received;
+    }
     return true;
 }
 
@@ -186,13 +371,73 @@ bool ecall_message(nodenum_t node_num, uint32_t message_len)
         printf("Received ecall_message without completing previous message\n");
         return false;
     }
-    printf("ecall_message called\n");
+    if (!nodest.in_msg_get_buf) {
+        printf("No message header handler registered\n");
+        return false;
+    }
+    uint8_t *buf = nodest.in_msg_get_buf(nodest, message_len);
+    if (!buf) {
+        printf("Message header handler returned NULL\n");
+        return false;
+    }
+    nodest.in_msg_size = message_len;
+    nodest.in_msg_offset = 0;
+    nodest.in_msg_plaintext_processed = 0;
+    nodest.in_msg_buf = buf;
     return true;
 }
 
 bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
     uint32_t chunklen)
 {
-    printf("ecall_chunk called\n");
+    if (node_num >= tot_nodes) {
+        printf("Out-of-range node_num %hu received in ecall_chunk\n",
+            node_num);
+        return false;
+    }
+    NodeCommState &nodest = commstates[node_num];
+
+    if (nodest.in_msg_size == nodest.in_msg_offset) {
+        printf("Received ecall_chunk after completing message\n");
+        return false;
+    }
+    if (!nodest.in_msg_buf) {
+        printf("No incoming message buffer allocated\n");
+        return false;
+    }
+    if (!nodest.in_msg_received) {
+        printf("No message received handler registered\n");
+        return false;
+    }
+    if (nodest.in_msg_offset + chunklen > nodest.in_msg_size) {
+        printf("Chunk larger than remaining message size\n");
+        return false;
+    }
+    if (nodest.handshake_step == HANDSHAKE_COMPLETE) {
+        // Decrypt the incoming data
+        *(size_t*)(nodest.in_aes_iv) += 1;
+        if (sgx_rijndael128GCM_decrypt(&nodest.in_aes_key, chunkdata,
+                chunklen - SGX_AESGCM_MAC_SIZE,
+                nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
+                nodest.in_aes_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
+                (const sgx_aes_gcm_128bit_tag_t *)
+                    (chunkdata + chunklen - SGX_AESGCM_MAC_SIZE))) {
+            printf("Decryption failed\n");
+            return false;
+        }
+        nodest.in_msg_plaintext_processed +=
+            chunklen - SGX_AESGCM_MAC_SIZE;
+    } else {
+        // Just copy the handshake data
+        memmove(nodest.in_msg_buf + nodest.in_msg_plaintext_processed,
+            chunkdata, chunklen);
+        nodest.in_msg_plaintext_processed += chunklen;
+    }
+    nodest.in_msg_offset += chunklen;
+    if (nodest.in_msg_offset == nodest.in_msg_size) {
+        // This was the last chunk; handle the received message
+        nodest.in_msg_received(nodest, nodest.in_msg_buf,
+            nodest.in_msg_plaintext_processed, nodest.in_msg_size);
+    }
     return true;
 }