Browse Source

Round 2 of private routing complete

Ian Goldberg 1 year ago
parent
commit
3811f012ef
2 changed files with 186 additions and 23 deletions
  1. 8 5
      App/start.cpp
  2. 178 18
      Enclave/route.cpp

+ 8 - 5
App/start.cpp

@@ -84,13 +84,16 @@ static void route_test(NetIO &netio, char **args)
 
 
     ecall_routing_proceed([&](uint32_t round_num) {
     ecall_routing_proceed([&](uint32_t round_num) {
         printf("Round %u complete\n", round_num);
         printf("Round %u complete\n", round_num);
-        if (round_num == 1) {
-            boost::asio::post(netio.io_context(), []{
-                ecall_routing_proceed([&](uint32_t round_num2) {
-                    printf("Round %u complete\n", round_num2);
+        boost::asio::post(netio.io_context(), [&]{
+            ecall_routing_proceed([&](uint32_t round_num2) {
+                printf("Round %u complete\n", round_num2);
+                boost::asio::post(netio.io_context(), []{
+                    ecall_routing_proceed([](uint32_t round_num3) {
+                        printf("Round %u complete\n", round_num3);
+                    });
                 });
                 });
             });
             });
-        }
+        });
     });
     });
 }
 }
 
 

+ 178 - 18
Enclave/route.cpp

@@ -21,9 +21,11 @@ struct MsgBuffer {
     uint32_t bufsize;
     uint32_t bufsize;
     // The number of nodes we've heard from
     // The number of nodes we've heard from
     nodenum_t nodes_received;
     nodenum_t nodes_received;
+    // Have we completed the previous round yet?
+    bool completed_prev_round;
 
 
     MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0),
     MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0),
-            nodes_received(0) {
+            nodes_received(0), completed_prev_round(false) {
         pthread_mutex_init(&mutex, NULL);
         pthread_mutex_init(&mutex, NULL);
     }
     }
 
 
@@ -42,6 +44,7 @@ struct MsgBuffer {
         memset(buf, 0, size_t(msgs) * g_teems_config.msg_size);
         memset(buf, 0, size_t(msgs) * g_teems_config.msg_size);
         bufsize = msgs;
         bufsize = msgs;
         nodes_received = 0;
         nodes_received = 0;
+        completed_prev_round = false;
     }
     }
 
 
     // Reset the contents of the buffer
     // Reset the contents of the buffer
@@ -50,6 +53,7 @@ struct MsgBuffer {
         reserved = 0;
         reserved = 0;
         inserted = 0;
         inserted = 0;
         nodes_received = 0;
         nodes_received = 0;
+        completed_prev_round = false;
     }
     }
 
 
     // You can't copy a MsgBuffer
     // You can't copy a MsgBuffer
@@ -249,6 +253,9 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
     return msgbuf.buf + start * msg_size;
     return msgbuf.buf + start * msg_size;
 }
 }
 
 
+static void round2_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t);
+
 // A round 1 message was received by a routing node from an ingestion
 // A round 1 message was received by a routing node from an ingestion
 // node; we put it into the round 2 buffer for processing in round 2
 // node; we put it into the round 2 buffer for processing in round 2
 static void round1_received(NodeCommState &nodest,
 static void round1_received(NodeCommState &nodest,
@@ -257,14 +264,30 @@ static void round1_received(NodeCommState &nodest,
     uint16_t msg_size = g_teems_config.msg_size;
     uint16_t msg_size = g_teems_config.msg_size;
     assert((plaintext_len % uint32_t(msg_size)) == 0);
     assert((plaintext_len % uint32_t(msg_size)) == 0);
     uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
     uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
 
 
     pthread_mutex_lock(&route_state.round1.mutex);
     pthread_mutex_lock(&route_state.round1.mutex);
     route_state.round1.inserted += num_msgs;
     route_state.round1.inserted += num_msgs;
     route_state.round1.nodes_received += 1;
     route_state.round1.nodes_received += 1;
     nodenum_t nodes_received = route_state.round1.nodes_received;
     nodenum_t nodes_received = route_state.round1.nodes_received;
+    bool completed_prev_round = route_state.round1.completed_prev_round;
     pthread_mutex_unlock(&route_state.round1.mutex);
     pthread_mutex_unlock(&route_state.round1.mutex);
 
 
-    if (nodes_received == g_teems_config.num_ingestion_nodes) {
+    // What is the next message we expect from this node?
+    if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round2, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round2_received;
+    }
+    // Otherwise, it's just the next round 1 message, so don't change
+    // the handlers.
+
+    if (nodes_received == g_teems_config.num_ingestion_nodes &&
+            completed_prev_round) {
         route_state.step = ROUTE_ROUND_1;
         route_state.step = ROUTE_ROUND_1;
         void *cbpointer = route_state.cbpointer;
         void *cbpointer = route_state.cbpointer;
         route_state.cbpointer = NULL;
         route_state.cbpointer = NULL;
@@ -278,6 +301,36 @@ static void round2_received(NodeCommState &nodest,
 {
 {
     uint16_t msg_size = g_teems_config.msg_size;
     uint16_t msg_size = g_teems_config.msg_size;
     assert((plaintext_len % uint32_t(msg_size)) == 0);
     assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
+
+    pthread_mutex_lock(&route_state.round2.mutex);
+    route_state.round2.inserted += num_msgs;
+    route_state.round2.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round2.nodes_received;
+    bool completed_prev_round = route_state.round2.completed_prev_round;
+    pthread_mutex_unlock(&route_state.round2.mutex);
+
+    // What is the next message we expect from this node?
+    if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1_received;
+    }
+    // Otherwise, it's just the next round 2 message, so don't change
+    // the handlers.
+
+    if (nodes_received == g_teems_config.num_routing_nodes &&
+            completed_prev_round) {
+        route_state.step = ROUTE_ROUND_2;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, 2);
+    }
 }
 }
 
 
 // For a given other node, set the received message handler to the first
 // For a given other node, set the received message handler to the first
@@ -307,7 +360,11 @@ void route_init_msg_handler(nodenum_t node_num)
     // and they are a routing node (possibly among other roles), a round
     // and they are a routing node (possibly among other roles), a round
     // 2 routing message is the first thing we expect from them
     // 2 routing message is the first thing we expect from them
     else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
     else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
-        nodest.in_msg_get_buf = default_in_msg_get_buf;
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round2, commst,
+                tot_enc_chunk_size);
+        };
         nodest.in_msg_received = round2_received;
         nodest.in_msg_received = round2_received;
     }
     }
     // Otherwise, we don't expect a message from this node. Set the
     // Otherwise, we don't expect a message from this node. Set the
@@ -416,15 +473,8 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
             pthread_mutex_lock(&round1.mutex);
             pthread_mutex_lock(&round1.mutex);
             round1.inserted += num_msgs;
             round1.inserted += num_msgs;
             round1.nodes_received += 1;
             round1.nodes_received += 1;
-            nodenum_t nodes_received = round1.nodes_received;
             pthread_mutex_unlock(&round1.mutex);
             pthread_mutex_unlock(&round1.mutex);
 
 
-            if (nodes_received == g_teems_config.num_ingestion_nodes) {
-                route_state.step = ROUTE_ROUND_1;
-                void *cbpointer = route_state.cbpointer;
-                route_state.cbpointer = NULL;
-                ocall_routing_round_complete(cbpointer, 1);
-            }
         } else {
         } else {
             NodeCommState &nodecom = g_commstates[routing_node];
             NodeCommState &nodecom = g_commstates[routing_node];
             nodecom.message_start(num_msgs * msg_size);
             nodecom.message_start(num_msgs * msg_size);
@@ -442,6 +492,66 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
     }
     }
 }
 }
 
 
+// Send the round 2 messages from the round 1 buffer, which are already
+// padded and shuffled, so this can be done non-obliviously.  tot_msgs
+// is the total number of messages in the input buffer, which may
+// include padding messages added by the shuffle.  Those messages are
+// not sent anywhere.  There are num_msgs_per_stg messages for each
+// storage node labelled for that node.
+static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    MsgBuffer &round1 = route_state.round1;
+    const uint8_t* buf = round1.buf;
+    nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    uint8_t *myself_buf = NULL;
+
+    for (nodenum_t i=0; i<num_storage_nodes; ++i) {
+        nodenum_t node = g_teems_config.storage_nodes[i];
+        if (node != my_node_num) {
+            g_commstates[node].message_start(msg_size * num_msgs_per_stg);
+        } else {
+            MsgBuffer &round2 = route_state.round2;
+            pthread_mutex_lock(&round2.mutex);
+            uint32_t start = round2.reserved;
+            if (start + num_msgs_per_stg > round2.bufsize) {
+                pthread_mutex_unlock(&round2.mutex);
+                printf("Max %u messages exceeded\n", round2.bufsize);
+                return;
+            }
+            round2.reserved += num_msgs_per_stg;
+            pthread_mutex_unlock(&round2.mutex);
+            myself_buf = round2.buf + start * msg_size;
+        }
+    }
+
+    while (tot_msgs) {
+        nodenum_t storage_node_id =
+            nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
+        if (storage_node_id < num_storage_nodes) {
+            nodenum_t node = g_teems_config.storage_map[storage_node_id];
+            if (node == my_node_num) {
+                memmove(myself_buf, buf, msg_size);
+                myself_buf += msg_size;
+            } else {
+                g_commstates[node].message_data(buf, msg_size);
+            }
+        }
+
+        buf += msg_size;
+        --tot_msgs;
+    }
+
+    if (myself_buf) {
+        MsgBuffer &round2 = route_state.round2;
+        pthread_mutex_lock(&round2.mutex);
+        round2.inserted += num_msgs_per_stg;
+        round2.nodes_received += 1;
+        pthread_mutex_unlock(&round2.mutex);
+    }
+}
+
 // Perform the next round of routing.  The callback pointer will be
 // Perform the next round of routing.  The callback pointer will be
 // passed to ocall_routing_round_complete when the round is complete.
 // passed to ocall_routing_round_complete when the round is complete.
 void ecall_routing_proceed(void *cbpointer)
 void ecall_routing_proceed(void *cbpointer)
@@ -452,6 +562,7 @@ void ecall_routing_proceed(void *cbpointer)
         if (my_roles & ROLE_INGESTION) {
         if (my_roles & ROLE_INGESTION) {
             route_state.cbpointer = cbpointer;
             route_state.cbpointer = cbpointer;
             MsgBuffer &ingbuf = route_state.ingbuf;
             MsgBuffer &ingbuf = route_state.ingbuf;
+            MsgBuffer &round1 = route_state.round1;
 
 
             pthread_mutex_lock(&ingbuf.mutex);
             pthread_mutex_lock(&ingbuf.mutex);
             // Ensure there are no pending messages currently being inserted
             // Ensure there are no pending messages currently being inserted
@@ -473,6 +584,17 @@ void ecall_routing_proceed(void *cbpointer)
 #endif
 #endif
             ingbuf.reset();
             ingbuf.reset();
             pthread_mutex_unlock(&ingbuf.mutex);
             pthread_mutex_unlock(&ingbuf.mutex);
+
+            pthread_mutex_lock(&round1.mutex);
+            round1.completed_prev_round = true;
+            nodenum_t nodes_received = round1.nodes_received;
+            pthread_mutex_unlock(&round1.mutex);
+
+            if (nodes_received == g_teems_config.num_ingestion_nodes) {
+                route_state.step = ROUTE_ROUND_1;
+                route_state.cbpointer = NULL;
+                ocall_routing_round_complete(cbpointer, 1);
+            }
         } else {
         } else {
             route_state.step = ROUTE_ROUND_1;
             route_state.step = ROUTE_ROUND_1;
             ocall_routing_round_complete(cbpointer, 1);
             ocall_routing_round_complete(cbpointer, 1);
@@ -481,6 +603,7 @@ void ecall_routing_proceed(void *cbpointer)
         if (my_roles & ROLE_ROUTING) {
         if (my_roles & ROLE_ROUTING) {
             route_state.cbpointer = cbpointer;
             route_state.cbpointer = cbpointer;
             MsgBuffer &round1 = route_state.round1;
             MsgBuffer &round1 = route_state.round1;
+            MsgBuffer &round2 = route_state.round2;
 
 
             pthread_mutex_lock(&round1.mutex);
             pthread_mutex_lock(&round1.mutex);
             // Ensure there are no pending messages currently being inserted
             // Ensure there are no pending messages currently being inserted
@@ -509,7 +632,7 @@ void ecall_routing_proceed(void *cbpointer)
             unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
             unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
             unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
             unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
 #endif
 #endif
-            uint32_t msg_size = g_teems_config.msg_size;
+            uint16_t msg_size = g_teems_config.msg_size;
             nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
             nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
             std::vector<uint32_t> tally = obliv_tally_stg(
             std::vector<uint32_t> tally = obliv_tally_stg(
                 round1.buf, msg_size, round1.inserted, num_storage_nodes);
                 round1.buf, msg_size, round1.inserted, num_storage_nodes);
@@ -555,7 +678,7 @@ void ecall_routing_proceed(void *cbpointer)
             uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
             uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
                 round1.buf, msg_size, round1.inserted, round1.bufsize);
                 round1.buf, msg_size, round1.inserted, round1.bufsize);
 #ifdef PROFILE_ROUTING
 #ifdef PROFILE_ROUTING
-            printf_with_rtclock_diff(start_pad, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
+            printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
             printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
             printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
 #endif
 #endif
 
 
@@ -564,18 +687,55 @@ void ecall_routing_proceed(void *cbpointer)
             // storage node, and the oblivious shuffle broke the
             // storage node, and the oblivious shuffle broke the
             // connection between where each message came from and where
             // connection between where each message came from and where
             // it's going.
             // it's going.
-
-            for(uint32_t i=0;i<num_shuffled;++i) {
-                uint32_t destaddr = *(uint32_t*)(round1.buf+i*msg_size);
-                printf("%08x\n", destaddr);
-            }
+            send_round2_msgs(num_shuffled, msgs_per_stg);
 
 
             round1.reset();
             round1.reset();
             pthread_mutex_unlock(&round1.mutex);
             pthread_mutex_unlock(&round1.mutex);
+
+            pthread_mutex_lock(&round2.mutex);
+            round2.completed_prev_round = true;
+            nodenum_t nodes_received = round2.nodes_received;
+            pthread_mutex_unlock(&round2.mutex);
+
+            if (nodes_received == g_teems_config.num_routing_nodes) {
+                route_state.step = ROUTE_ROUND_2;
+                route_state.cbpointer = NULL;
+                ocall_routing_round_complete(cbpointer, 2);
+            }
+        } else {
+            route_state.step = ROUTE_ROUND_2;
+            ocall_routing_round_complete(cbpointer, 2);
+        }
+    } else if (route_state.step == ROUTE_ROUND_2) {
+        if (my_roles & ROLE_STORAGE) {
+            uint16_t msg_size = g_teems_config.msg_size;
+            MsgBuffer &round2 = route_state.round2;
+
+            pthread_mutex_lock(&round2.mutex);
+            // Ensure there are no pending messages currently being inserted
+            // into the buffer
+            while (round2.reserved != round2.inserted) {
+                pthread_mutex_unlock(&round2.mutex);
+                pthread_mutex_lock(&round2.mutex);
+            }
+
+            printf("Storage server received %u messages:\n", round2.inserted);
+            const uint8_t *buf = round2.buf;
+            for (uint32_t i=0; i<round2.inserted; ++i) {
+                printf("%08x\n", *(const uint32_t*)buf);
+                buf += msg_size;
+            }
+
+            round2.reset();
+            pthread_mutex_unlock(&round2.mutex);
+
+            // We're done
+            route_state.step = ROUTE_NOT_STARTED;
+            ocall_routing_round_complete(cbpointer, 3);
         } else {
         } else {
             // We're done
             // We're done
             route_state.step = ROUTE_NOT_STARTED;
             route_state.step = ROUTE_NOT_STARTED;
-            ocall_routing_round_complete(cbpointer, 2);
+            ocall_routing_round_complete(cbpointer, 3);
         }
         }
     }
     }
 }
 }