Forráskód Böngészése

Receive round 1 messages

Ian Goldberg 1 éve
szülő
commit
b5fcd9bea5
4 módosított fájl, 145 hozzáadás és 15 törlés
  1. 11 2
      Enclave/comms.cpp
  2. 11 0
      Enclave/comms.hpp
  3. 119 13
      Enclave/route.cpp
  4. 4 0
      Enclave/route.hpp

+ 11 - 2
Enclave/comms.cpp

@@ -8,6 +8,7 @@
 #include "Enclave_t.h"
 #include "utils.hpp"
 #include "config.hpp"
+#include "route.hpp"
 #include "comms.hpp"
 
 // Our public and private identity keys
@@ -75,7 +76,7 @@ public:
 // A typical default in_msg_get_buf handler.  It computes the maximum
 // possible size of the decrypted data, allocates that much memory, and
 // returns a pointer to it.
-static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
+uint8_t* default_in_msg_get_buf(NodeCommState &commst,
     uint32_t tot_enc_chunk_size)
 {
     uint32_t max_plaintext_bytes = tot_enc_chunk_size;
@@ -96,7 +97,9 @@ static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
     return new uint8_t[max_plaintext_bytes];
 }
 
-static void unknown_in_msg_received(NodeCommState &nodest,
+// An in_msg_received handler when we don't actually expect a message
+// from a given node at a given time.
+void unknown_in_msg_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
 {
     printf("Received unknown message of %u bytes from node %lu:\n",
@@ -352,6 +355,9 @@ static void handshake_2_msg_received(NodeCommState &nodest,
     nodest.message_data((const uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
         false);
 
+    // Set the received message handler for routing
+    route_init_msg_handler(nodest.node_num);
+
     // Mark the handshake as complete
     completed_handshake_counter.inc();
 }
@@ -398,6 +404,9 @@ static void handshake_3_msg_received(NodeCommState &nodest,
     nodest.in_msg_get_buf = default_in_msg_get_buf;
     nodest.in_msg_received = unknown_in_msg_received;
 
+    // Set the received message handler for routing
+    route_init_msg_handler(nodest.node_num);
+
     // Mark the handshake as complete
     completed_handshake_counter.inc();
 }

+ 11 - 0
Enclave/comms.hpp

@@ -106,6 +106,17 @@ struct NodeCommState {
 // ourselves in here, but it is unused.
 extern std::vector<NodeCommState> g_commstates;
 
+// A typical default in_msg_get_buf handler.  It computes the maximum
+// possible size of the decrypted data, allocates that much memory, and
+// returns a pointer to it.
+uint8_t* default_in_msg_get_buf(NodeCommState &commst,
+    uint32_t tot_enc_chunk_size);
+
+// An in_msg_received handler when we don't actually expect a message
+// from a given node at a given time.
+void unknown_in_msg_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t);
+
 // The enclave-to-enclave communication protocol is as follows.  It
 // probably could just be attested TLS in a production environment, but
 // we're not implementing remote attestation at this time.  This means

+ 119 - 13
Enclave/route.cpp

@@ -18,8 +18,11 @@ struct MsgBuffer {
     uint32_t inserted;
     // The number of messages that can fit in buf
     uint32_t bufsize;
+    // The number of nodes we've heard from
+    nodenum_t nodes_received;
 
-    MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0) {
+    MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0),
+            nodes_received(0) {
         pthread_mutex_init(&mutex, NULL);
     }
 
@@ -36,6 +39,7 @@ struct MsgBuffer {
         // This may throw bad_alloc, but we'll catch it higher up
         buf = new uint8_t[size_t(msgs) * g_teems_config.msg_size];
         bufsize = msgs;
+        nodes_received = 0;
     }
 
     // Reset the contents of the buffer
@@ -43,6 +47,7 @@ struct MsgBuffer {
         memset(buf, 0, bufsize * g_teems_config.msg_size);
         reserved = 0;
         inserted = 0;
+        nodes_received = 0;
     }
 
     // You can't copy a MsgBuffer
@@ -68,6 +73,7 @@ static struct RouteState {
     uint32_t tot_msg_per_ing;
     uint32_t max_msg_to_each_str;
     uint32_t max_round2_msgs;
+    void *cbpointer;
 } route_state;
 
 // Computes ceil(x/y) where x and y are integers, x>=0, y>0.
@@ -147,6 +153,7 @@ bool route_init()
     route_state.tot_msg_per_ing = tot_msg_per_ing;
     route_state.max_msg_to_each_str = max_msg_to_each_str;
     route_state.max_round2_msgs = max_round2_msgs;
+    route_state.cbpointer = NULL;
 
     threadid_t nthreads = g_teems_config.nthreads;
 #ifdef PROFILE_ROUTING
@@ -196,6 +203,107 @@ size_t ecall_precompute_sort(int sizeidx)
     return ret;
 }
 
+static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
+        NodeCommState &, uint32_t tot_enc_chunk_size)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+
+    // Chunks will be encrypted and have a MAC tag attached which will
+    // not correspond to plaintext bytes, so we can trim them.
+
+    // The minimum number of chunks needed to transmit this message
+    uint32_t min_num_chunks =
+        (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
+    // The number of plaintext bytes this message could contain
+    uint32_t plaintext_bytes = tot_enc_chunk_size -
+        SGX_AESGCM_MAC_SIZE * min_num_chunks;
+
+    assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
+
+    uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
+
+    pthread_mutex_lock(&msgbuf.mutex);
+    uint32_t start = msgbuf.reserved;
+    if (start + num_msgs > msgbuf.bufsize) {
+        pthread_mutex_unlock(&msgbuf.mutex);
+        printf("Max %u messages exceeded\n", msgbuf.bufsize);
+        return NULL;
+    }
+    msgbuf.reserved += num_msgs;
+    pthread_mutex_unlock(&msgbuf.mutex);
+
+    return msgbuf.buf + start * msg_size;
+}
+
+// A round 1 message was received by a routing node from an ingestion
+// node; we put it into the round 2 buffer for processing in round 2
+static void round1_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+
+    pthread_mutex_lock(&route_state.round2.mutex);
+    route_state.round2.inserted += num_msgs;
+    route_state.round2.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round2.nodes_received;
+    pthread_mutex_unlock(&route_state.round2.mutex);
+
+    if (nodes_received == g_teems_config.num_ingestion_nodes) {
+        route_state.step = ROUTE_ROUND_1;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, 1);
+    }
+}
+
+// A round 2 message was received by a storage node from a routing node
+static void round2_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+}
+
+// For a given other node, set the received message handler to the first
+// message we would expect from them, given their roles and our roles.
+void route_init_msg_handler(nodenum_t node_num)
+{
+    // Our roles and their roles
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[node_num];
+
+    // The node communication state
+    NodeCommState &nodest = g_commstates[node_num];
+
+    // If we are a routing node (possibly among other roles) and they
+    // are an ingestion node (possibly among other roles), a round 1
+    // routing message is the first thing we expect from them.  We put
+    // these messages into the round2 buffer for processing.
+    if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round2, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1_received;
+    }
+    // Otherwise, if we are a storage node (possibly among other roles)
+    // and they are a routing node (possibly among other roles), a round
+    // 2 routing message is the first thing we expect form them
+    else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
+        nodest.in_msg_get_buf = default_in_msg_get_buf;
+        nodest.in_msg_received = round2_received;
+    }
+    // Otherwise, we don't expect a message from this node. Set the
+    // unknown message handler.
+    else {
+        nodest.in_msg_get_buf = default_in_msg_get_buf;
+        nodest.in_msg_received = unknown_in_msg_received;
+    }
+}
+
 // Directly ingest a buffer of num_msgs messages into the round1 buffer.
 // Return true on success, false on failure.
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
@@ -293,7 +401,16 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
 
             pthread_mutex_lock(&round2.mutex);
             round2.inserted += num_msgs;
+            round2.nodes_received += 1;
+            nodenum_t nodes_received = round2.nodes_received;
             pthread_mutex_unlock(&round2.mutex);
+
+            if (nodes_received == g_teems_config.num_ingestion_nodes) {
+                route_state.step = ROUTE_ROUND_1;
+                void *cbpointer = route_state.cbpointer;
+                route_state.cbpointer = NULL;
+                ocall_routing_round_complete(cbpointer, 1);
+            }
         } else {
             NodeCommState &nodecom = g_commstates[routing_node];
             nodecom.message_start(num_msgs * msg_size);
@@ -309,16 +426,6 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
             }
         }
     }
-
-    /*
-    for (uint32_t i=0;i<N;++i) {
-        const uint8_t *msg = msgs + indices[i]*msg_size;
-        for (uint16_t j=0;j<msg_size/4;++j) {
-            printf("%08x ", ((const uint32_t*)msg)[j]);
-        }
-        printf("\n");
-    }
-    */
 }
 
 // Perform the next round of routing.  The callback pointer will be
@@ -327,6 +434,7 @@ void ecall_routing_proceed(void *cbpointer)
 {
     if (route_state.step == ROUTE_NOT_STARTED) {
 
+        route_state.cbpointer = cbpointer;
         MsgBuffer &round1 = route_state.round1;
 
         pthread_mutex_lock(&round1.mutex);
@@ -348,7 +456,5 @@ void ecall_routing_proceed(void *cbpointer)
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
-        route_state.step = ROUTE_ROUND_1;
-        ocall_routing_round_complete(cbpointer, 1);
     }
 }

+ 4 - 0
Enclave/route.hpp

@@ -5,4 +5,8 @@
 // comms_init_nodestate. Returns true on success, false on failure.
 bool route_init();
 
+// For a given other node, set the received message handler to the first
+// message we would expect from them, given their roles and our roles.
+void route_init_msg_handler(nodenum_t node_num);
+
 #endif