Bläddra i källkod

Send the round 1 messages

Ian Goldberg 1 år sedan
förälder
incheckning
444af7fc6d
5 ändrade filer med 110 tillägg och 37 borttagningar
  1. 0 10
      App/start.cpp
  2. 2 2
      App/teems.cpp
  3. 20 20
      Enclave/comms.cpp
  4. 2 2
      Enclave/comms.hpp
  5. 86 3
      Enclave/route.cpp

+ 0 - 10
App/start.cpp

@@ -75,16 +75,6 @@ static void route_test(NetIO &netio, char **args)
         ecall_precompute_sort(i);
     }
 
-    netio.recv_commands(
-        // error_cb
-        [](boost::system::error_code) {
-            printf("Error\n");
-        },
-        // epoch_cb
-        [](uint32_t epoch) {
-            printf("Epoch %u\n", epoch);
-        });
-
     if (!ecall_ingest_raw(msgs, tot_tokens)) {
         printf("Ingestion failed\n");
         return;

+ 2 - 2
App/teems.cpp

@@ -265,8 +265,8 @@ int main(int argc, char **argv)
             NodeIO &node = netio.node(node_num);
             node.recv_commands(
                 // error_cb
-                [](boost::system::error_code) {
-                    printf("Error\n");
+                [](boost::system::error_code ec) {
+                    printf("Error %d\n", ec.value());
                 },
                 // epoch_cb
                 [](uint32_t epoch) {

+ 20 - 20
Enclave/comms.cpp

@@ -16,7 +16,7 @@ static sgx_ec256_public_t g_pubkey;
 
 // The communication states for all the nodes.  There's an entry for
 // ourselves in here, but it is unused.
-std::vector<NodeCommState> commstates;
+std::vector<NodeCommState> g_commstates;
 
 static nodenum_t tot_nodes, my_node_num;
 static class CompletedHandshakeCounter {
@@ -211,9 +211,9 @@ static void handshake_1_msg_received(NodeCommState &nodest,
     // Send handshake message 2
     nodest.message_start(sizeof(our_dh_pubkey) + sizeof(srv_cli_sig),
         false);
-    nodest.message_data((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
+    nodest.message_data((const uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
         false);
-    nodest.message_data((uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
+    nodest.message_data((const uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
         false);
 }
 
@@ -334,7 +334,7 @@ static void handshake_2_msg_received(NodeCommState &nodest,
 
     // Send handshake message 3
     nodest.message_start(sizeof(cli_srv_sig), false);
-    nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
+    nodest.message_data((const uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
         false);
 
     // Mark the handshake as complete
@@ -426,7 +426,7 @@ void NodeCommState::message_start(uint32_t plaintext_len, bool encrypt)
 }
 
 // Process len bytes of plaintext data into the current message.
-void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
+void NodeCommState::message_data(const uint8_t *data, uint32_t len, bool encrypt)
 {
     while (len > 0) {
         if (msg_plaintext_chunk_remain == 0) {
@@ -443,7 +443,7 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
         }
         if (encrypt) {
             // Encrypt the data
-            sgx_aes_gcm128_enc_update(data, bytes_to_process,
+            sgx_aes_gcm128_enc_update((uint8_t*)data, bytes_to_process,
                 frame+frame_offset, out_aes_gcm_state);
         } else {
             // Just copy the plaintext data during the handshake
@@ -554,9 +554,9 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     sgx_ecc_state_handle_t ecc_handle;
     sgx_ecc256_open_context(&ecc_handle);
 
-    commstates.clear();
+    g_commstates.clear();
     tot_nodes = 0;
-    commstates.reserve(num_nodes);
+    g_commstates.reserve(num_nodes);
     for (nodenum_t i=0; i<num_nodes; ++i) {
         // Check that the pubkey is valid
         int valid;
@@ -564,11 +564,11 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
                 ecc_handle, &valid) ||
                 !valid) {
             printf("Pubkey for node %hu invalid\n", i);
-            commstates.clear();
+            g_commstates.clear();
             sgx_ecc256_close_context(ecc_handle);
             return false;
         }
-        commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
+        g_commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
     }
     sgx_ecc256_close_context(ecc_handle);
 
@@ -578,12 +578,12 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     // reflection attacks)
     for (nodenum_t i=0; i<num_nodes; ++i) {
         if (i == my_node_num) continue;
-        if (!memcmp(&commstates[i].pubkey,
-                &commstates[my_node_num].pubkey,
-                sizeof(commstates[i].pubkey))) {
+        if (!memcmp(&g_commstates[i].pubkey,
+                &g_commstates[my_node_num].pubkey,
+                sizeof(g_commstates[i].pubkey))) {
             printf("Pubkey %hu matches our own; possible reflection attack?\n",
                 i);
-            commstates.clear();
+            g_commstates.clear();
             return false;
         }
     }
@@ -596,8 +596,8 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     // to receive the first handshake message from those nodes'
     // enclaves.
     for (nodenum_t i=0; i<my_node_num; ++i) {
-        commstates[i].in_msg_get_buf = default_in_msg_get_buf;
-        commstates[i].in_msg_received = handshake_1_msg_received;
+        g_commstates[i].in_msg_get_buf = default_in_msg_get_buf;
+        g_commstates[i].in_msg_received = handshake_1_msg_received;
     }
     return true;
 }
@@ -609,7 +609,7 @@ bool ecall_message(nodenum_t node_num, uint32_t message_len)
             node_num);
         return false;
     }
-    NodeCommState &nodest = commstates[node_num];
+    NodeCommState &nodest = g_commstates[node_num];
 
     if (nodest.in_msg_size != nodest.in_msg_offset) {
         printf("Received ecall_message without completing previous message\n");
@@ -639,7 +639,7 @@ bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
             node_num);
         return false;
     }
-    NodeCommState &nodest = commstates[node_num];
+    NodeCommState &nodest = g_commstates[node_num];
 
     if (nodest.in_msg_size == nodest.in_msg_offset) {
         printf("Received ecall_chunk after completing message\n");
@@ -713,7 +713,7 @@ void NodeCommState::handshake_start()
     // Send the public key as the first message
     message_start(sizeof(handshake_dh_pubkey), false);
 
-    message_data((uint8_t*)&handshake_dh_pubkey,
+    message_data((const uint8_t*)&handshake_dh_pubkey,
         sizeof(handshake_dh_pubkey), false);
 }
 
@@ -725,7 +725,7 @@ bool ecall_comms_start(void *cbpointer)
     completed_handshake_counter.reset(cbpointer);
 
     for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
-        commstates[t].handshake_start();
+        g_commstates[t].handshake_start();
     }
     return true;
 }

+ 2 - 2
Enclave/comms.hpp

@@ -96,7 +96,7 @@ struct NodeCommState {
 
     void message_start(uint32_t plaintext_len, bool encrypt=true);
 
-    void message_data(uint8_t *data, uint32_t len, bool encrypt=true);
+    void message_data(const uint8_t *data, uint32_t len, bool encrypt=true);
 
     // Start the handshake (as the client)
     void handshake_start();
@@ -104,7 +104,7 @@ struct NodeCommState {
 
 // The communication states for all the nodes.  There's an entry for
 // ourselves in here, but it is unused.
-extern std::vector<NodeCommState> commstates;
+extern std::vector<NodeCommState> g_commstates;
 
 // The enclave-to-enclave communication protocol is as follows.  It
 // probably could just be attested TLS in a production environment, but

+ 86 - 3
Enclave/route.cpp

@@ -3,6 +3,7 @@
 #include "config.hpp"
 #include "utils.hpp"
 #include "sort.hpp"
+#include "comms.hpp"
 #include "route.hpp"
 
 #define PROFILE_ROUTING
@@ -15,8 +16,10 @@ struct MsgBuffer {
     uint32_t reserved;
     // The number of messages definitely in the buffer
     uint32_t inserted;
+    // The number of messages that can fit in buf
+    uint32_t bufsize;
 
-    MsgBuffer() : buf(NULL), reserved(0), inserted(0) {
+    MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0) {
         pthread_mutex_init(&mutex, NULL);
     }
 
@@ -32,11 +35,12 @@ struct MsgBuffer {
         inserted = 0;
         // This may throw bad_alloc, but we'll catch it higher up
         buf = new uint8_t[size_t(msgs) * g_teems_config.msg_size];
+        bufsize = msgs;
     }
 
     // Reset the contents of the buffer
     void reset() {
-        memset(buf, 0, inserted * g_teems_config.msg_size);
+        memset(buf, 0, bufsize * g_teems_config.msg_size);
         reserved = 0;
         inserted = 0;
     }
@@ -220,12 +224,91 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
     return true;
 }
 
-// Send the round 1 messages
+// Send the round 1 messages.  Note that N here is not private.
 static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
     uint32_t N)
 {
     uint16_t msg_size = g_teems_config.msg_size;
     uint16_t tot_weight = g_teems_config.tot_weight;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+
+    uint32_t full_rows = N / uint32_t(tot_weight);
+    uint32_t last_row = N % uint32_t(tot_weight);
+
+    for (auto &routing_node: g_teems_config.routing_nodes) {
+        uint8_t weight =
+            g_teems_config.weights[routing_node].weight;
+        if (weight == 0) {
+            // This shouldn't happen, but just in case
+            continue;
+        }
+        uint16_t start_weight =
+            g_teems_config.weights[routing_node].startweight;
+
+        // The number of messages headed for this routing node from the
+        // full rows
+        uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
+        // The number of messages headed for this routing node from the
+        // incomplete last row is:
+        // 0 if last_row < start_weight
+        // last_row-start_weight if start_weight <= last_row < start_weight + weight
+        // weight if start_weight + weight <= last_row
+        uint32_t num_msgs_last_row = 0;
+        if (start_weight <= last_row && last_row < start_weight + weight) {
+            num_msgs_last_row = last_row-start_weight;
+        } else if (start_weight + weight <= last_row) {
+            num_msgs_last_row = weight;
+        }
+        // The total number of messages headed for this routing node
+        uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
+
+        if (routing_node == my_node_num) {
+            // Special case: we're sending to ourselves; just put the
+            // messages in our own round2 buffer
+            MsgBuffer &round2 = route_state.round2;
+
+            pthread_mutex_lock(&round2.mutex);
+            uint32_t start = round2.reserved;
+            if (start + num_msgs > round2.bufsize) {
+                pthread_mutex_unlock(&round2.mutex);
+                printf("Max %u messages exceeded\n", round2.bufsize);
+                return;
+            }
+            round2.reserved += num_msgs;
+            pthread_mutex_unlock(&round2.mutex);
+            uint8_t *buf = round2.buf + start * msg_size;
+
+            for (uint32_t i=0; i<full_rows; ++i) {
+                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                for (uint32_t j=0; j<weight; ++j) {
+                    memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                    buf += msg_size;
+                }
+            }
+            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                buf += msg_size;
+            }
+
+            pthread_mutex_lock(&round2.mutex);
+            round2.inserted += num_msgs;
+            pthread_mutex_unlock(&round2.mutex);
+        } else {
+            NodeCommState &nodecom = g_commstates[routing_node];
+            nodecom.message_start(num_msgs * msg_size);
+            for (uint32_t i=0; i<full_rows; ++i) {
+                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                for (uint32_t j=0; j<weight; ++j) {
+                    nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+                }
+            }
+            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+            }
+        }
+    }
 
     /*
     for (uint32_t i=0;i<N;++i) {