Browse Source

Rename some internal buffers

The round1, etc., buffer is now where messages in round 1 _arrive_, not
where they're sent _from_.
Ian Goldberg 1 year ago
parent
commit
17a9e45e67
1 changed files with 63 additions and 60 deletions
  1. 63 60
      Enclave/route.cpp

+ 63 - 60
Enclave/route.cpp

@@ -62,12 +62,15 @@ enum RouteStep {
     ROUTE_ROUND_2
 };
 
-// The round1 MsgBuffer stores messages we ingest while waiting for
-// round 1 to start, which will be sorted and sent out in round 1.  The
-// round2 MsgBuffer stores messages we receive in round 1, which will be
-// padded, sorted, and sent out in round 2.
+// The ingbuf MsgBuffer stores messages an ingestion node ingests while
+// waiting for round 1 to start, which will be sorted and sent out in
+// round 1.  The round1 MsgBuffer stores messages a routing node
+// receives in round 1, which will be padded, sorted, and sent out in
+// round 2.  The round2 MsgBuffer stores messages a storage node
+// receives in round 2.
 
 static struct RouteState {
+    MsgBuffer ingbuf;
     MsgBuffer round1;
     MsgBuffer round2;
     RouteStep step;
@@ -144,8 +147,8 @@ bool route_init()
 
     // Create the route state
     try {
-        route_state.round1.alloc(tot_msg_per_ing);
-        route_state.round2.alloc(max_round2_msgs);
+        route_state.ingbuf.alloc(tot_msg_per_ing);
+        route_state.round1.alloc(max_round2_msgs);
     } catch (std::bad_alloc&) {
         printf("Memory allocation failed in route_init\n");
         return false;
@@ -245,11 +248,11 @@ static void round1_received(NodeCommState &nodest,
     assert((plaintext_len % uint32_t(msg_size)) == 0);
     uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
 
-    pthread_mutex_lock(&route_state.round2.mutex);
-    route_state.round2.inserted += num_msgs;
-    route_state.round2.nodes_received += 1;
-    nodenum_t nodes_received = route_state.round2.nodes_received;
-    pthread_mutex_unlock(&route_state.round2.mutex);
+    pthread_mutex_lock(&route_state.round1.mutex);
+    route_state.round1.inserted += num_msgs;
+    route_state.round1.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round1.nodes_received;
+    pthread_mutex_unlock(&route_state.round1.mutex);
 
     if (nodes_received == g_teems_config.num_ingestion_nodes) {
         route_state.step = ROUTE_ROUND_1;
@@ -281,18 +284,18 @@ void route_init_msg_handler(nodenum_t node_num)
     // If we are a routing node (possibly among other roles) and they
     // are an ingestion node (possibly among other roles), a round 1
     // routing message is the first thing we expect from them.  We put
-    // these messages into the round2 buffer for processing.
+    // these messages into the round1 buffer for processing.
     if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
         nodest.in_msg_get_buf = [&](NodeCommState &commst,
                 uint32_t tot_enc_chunk_size) {
-            return msgbuffer_get_buf(route_state.round2, commst,
+            return msgbuffer_get_buf(route_state.round1, commst,
                 tot_enc_chunk_size);
         };
         nodest.in_msg_received = round1_received;
     }
     // Otherwise, if we are a storage node (possibly among other roles)
     // and they are a routing node (possibly among other roles), a round
-    // 2 routing message is the first thing we expect form them
+    // 2 routing message is the first thing we expect from them
     else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
         nodest.in_msg_get_buf = default_in_msg_get_buf;
         nodest.in_msg_received = round2_received;
@@ -305,30 +308,30 @@ void route_init_msg_handler(nodenum_t node_num)
     }
 }
 
-// Directly ingest a buffer of num_msgs messages into the round1 buffer.
+// Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
 // Return true on success, false on failure.
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 {
     uint16_t msg_size = g_teems_config.msg_size;
-    MsgBuffer &round1 = route_state.round1;
+    MsgBuffer &ingbuf = route_state.ingbuf;
 
-    pthread_mutex_lock(&round1.mutex);
-    uint32_t start = round1.reserved;
+    pthread_mutex_lock(&ingbuf.mutex);
+    uint32_t start = ingbuf.reserved;
     if (start + num_msgs > route_state.tot_msg_per_ing) {
-        pthread_mutex_unlock(&round1.mutex);
+        pthread_mutex_unlock(&ingbuf.mutex);
         printf("Max %u messages exceeded\n",
             route_state.tot_msg_per_ing);
         return false;
     }
-    round1.reserved += num_msgs;
-    pthread_mutex_unlock(&round1.mutex);
+    ingbuf.reserved += num_msgs;
+    pthread_mutex_unlock(&ingbuf.mutex);
 
-    memmove(round1.buf + start * msg_size,
+    memmove(ingbuf.buf + start * msg_size,
         msgs, num_msgs * msg_size);
 
-    pthread_mutex_lock(&round1.mutex);
-    round1.inserted += num_msgs;
-    pthread_mutex_unlock(&round1.mutex);
+    pthread_mutex_lock(&ingbuf.mutex);
+    ingbuf.inserted += num_msgs;
+    pthread_mutex_unlock(&ingbuf.mutex);
 
     return true;
 }
@@ -373,19 +376,19 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
 
         if (routing_node == my_node_num) {
             // Special case: we're sending to ourselves; just put the
-            // messages in our own round2 buffer
-            MsgBuffer &round2 = route_state.round2;
-
-            pthread_mutex_lock(&round2.mutex);
-            uint32_t start = round2.reserved;
-            if (start + num_msgs > round2.bufsize) {
-                pthread_mutex_unlock(&round2.mutex);
-                printf("Max %u messages exceeded\n", round2.bufsize);
+            // messages in our own round1 buffer
+            MsgBuffer &round1 = route_state.round1;
+
+            pthread_mutex_lock(&round1.mutex);
+            uint32_t start = round1.reserved;
+            if (start + num_msgs > round1.bufsize) {
+                pthread_mutex_unlock(&round1.mutex);
+                printf("Max %u messages exceeded\n", round1.bufsize);
                 return;
             }
-            round2.reserved += num_msgs;
-            pthread_mutex_unlock(&round2.mutex);
-            uint8_t *buf = round2.buf + start * msg_size;
+            round1.reserved += num_msgs;
+            pthread_mutex_unlock(&round1.mutex);
+            uint8_t *buf = round1.buf + start * msg_size;
 
             for (uint32_t i=0; i<full_rows; ++i) {
                 const uint64_t *idxp = indices + i*tot_weight + start_weight;
@@ -400,11 +403,11 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
                 buf += msg_size;
             }
 
-            pthread_mutex_lock(&round2.mutex);
-            round2.inserted += num_msgs;
-            round2.nodes_received += 1;
-            nodenum_t nodes_received = round2.nodes_received;
-            pthread_mutex_unlock(&round2.mutex);
+            pthread_mutex_lock(&round1.mutex);
+            round1.inserted += num_msgs;
+            round1.nodes_received += 1;
+            nodenum_t nodes_received = round1.nodes_received;
+            pthread_mutex_unlock(&round1.mutex);
 
             if (nodes_received == g_teems_config.num_ingestion_nodes) {
                 route_state.step = ROUTE_ROUND_1;
@@ -435,47 +438,47 @@ void ecall_routing_proceed(void *cbpointer)
 {
     if (route_state.step == ROUTE_NOT_STARTED) {
         route_state.cbpointer = cbpointer;
-        MsgBuffer &round1 = route_state.round1;
+        MsgBuffer &ingbuf = route_state.ingbuf;
 
-        pthread_mutex_lock(&round1.mutex);
+        pthread_mutex_lock(&ingbuf.mutex);
         // Ensure there are no pending messages currently being inserted
         // into the buffer
-        while (round1.reserved != round1.inserted) {
-            pthread_mutex_unlock(&round1.mutex);
-            pthread_mutex_lock(&round1.mutex);
+        while (ingbuf.reserved != ingbuf.inserted) {
+            pthread_mutex_unlock(&ingbuf.mutex);
+            pthread_mutex_lock(&ingbuf.mutex);
         }
         // Sort the messages we've received
 #ifdef PROFILE_ROUTING
-        uint32_t inserted = round1.inserted;
+        uint32_t inserted = ingbuf.inserted;
         unsigned long start = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
-        sort_mtobliv(g_teems_config.nthreads, round1.buf,
-            g_teems_config.msg_size, round1.inserted,
+        sort_mtobliv(g_teems_config.nthreads, ingbuf.buf,
+            g_teems_config.msg_size, ingbuf.inserted,
             route_state.tot_msg_per_ing, send_round1_msgs);
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
-        round1.reset();
-        pthread_mutex_unlock(&round1.mutex);
+        ingbuf.reset();
+        pthread_mutex_unlock(&ingbuf.mutex);
     } else if (route_state.step == ROUTE_ROUND_1) {
         route_state.cbpointer = cbpointer;
-        MsgBuffer &round2 = route_state.round2;
+        MsgBuffer &round1 = route_state.round1;
 
-        pthread_mutex_lock(&round2.mutex);
+        pthread_mutex_lock(&round1.mutex);
         // Ensure there are no pending messages currently being inserted
         // into the buffer
-        while (round2.reserved != round2.inserted) {
-            pthread_mutex_unlock(&round2.mutex);
-            pthread_mutex_lock(&round2.mutex);
+        while (round1.reserved != round1.inserted) {
+            pthread_mutex_unlock(&round1.mutex);
+            pthread_mutex_lock(&round1.mutex);
         }
 
         uint32_t msg_size = g_teems_config.msg_size;
-        for(uint32_t i=0;i<round2.inserted;++i) {
-            uint32_t destaddr = *(uint32_t*)(round2.buf+i*msg_size);
+        for(uint32_t i=0;i<round1.inserted;++i) {
+            uint32_t destaddr = *(uint32_t*)(round1.buf+i*msg_size);
             printf("%08x\n", destaddr);
         }
 
-        round2.reset();
-        pthread_mutex_unlock(&round2.mutex);
+        round1.reset();
+        pthread_mutex_unlock(&round1.mutex);
     }
 }