Quellcode durchsuchen

5-round public routing implementation

Aaron Johnson vor 1 Jahr
Ursprung
Commit
cb3d54a627
6 geänderte Dateien mit 726 neuen und 166 gelöschten Zeilen
  1. 2 1
      Client/clientlaunch
  2. 2 1
      Enclave/Enclave.config.xml
  3. 6 1
      Enclave/config.cpp
  4. 700 162
      Enclave/route.cpp
  5. 15 0
      Enclave/route.hpp
  6. 1 1
      Enclave/storage.cpp

+ 2 - 1
Client/clientlaunch

@@ -21,7 +21,8 @@ PUBKEYS = "./../App/pubkeys.yaml"
 CLIENTS = "./clients"
 
 # Client thread allocation
-prefix = "numactl -C36-39,76-79 "
+#prefix = "numactl -C36-39,76-79 "
+prefix = ""
 
 def launch(config, cmd, threads, lgfile):
     cmdline = ''

+ 2 - 1
Enclave/Enclave.config.xml

@@ -3,10 +3,11 @@
   <ProdID>0</ProdID>
   <ISVSVN>0</ISVSVN>
   <StackMaxSize>0x40000</StackMaxSize>
-  <HeapMaxSize>0x10000000</HeapMaxSize>
+  <HeapMaxSize>0x244e9000</HeapMaxSize>
   <TCSNum>32</TCSNum>
   <TCSPolicy>1</TCSPolicy>
   <DisableDebug>0</DisableDebug>
   <MiscSelect>0</MiscSelect>
   <MiscMask>0xFFFFFFFF</MiscMask>
 </EnclaveConfiguration>
+    

+ 6 - 1
Enclave/config.cpp

@@ -98,7 +98,12 @@ bool ecall_config_load(threadid_t nthreads,
             }
         }
         if (apinodeconfigs[i].roles & ROLE_ROUTING) {
-            nw.weight = apinodeconfigs[i].weight;
+            // Only use weights in private routing
+            if (g_teems_config.private_routing) {
+                nw.weight = apinodeconfigs[i].weight;
+            } else  {
+                nw.weight = 1;
+            }
             g_teems_config.num_routing_nodes += 1;
             if (i < my_node_num) {
                 rte_smaller.push_back(i);

+ 700 - 162
Enclave/route.cpp

@@ -35,10 +35,11 @@ bool route_init()
     }
 
     // Compute the maximum number of messages we could receive in round 1
-    // Each ingestion node will send us an our_weight/tot_weight
-    // fraction of the messages they hold
-    uint32_t max_msg_from_each_ing = CEILDIV(tot_msg_per_ing,
-        g_teems_config.tot_weight) * g_teems_config.my_weight;
+    // In private routing, each ingestion node will send us an
+    // our_weight/tot_weight fraction of the messages they hold
+    uint32_t max_msg_from_each_ing;
+    max_msg_from_each_ing = CEILDIV(tot_msg_per_ing, g_teems_config.tot_weight) *
+        g_teems_config.my_weight;
 
     // And the maximum number we can receive in total is that times the
     // number of ingestion nodes
@@ -60,8 +61,9 @@ bool route_init()
     }
 
     // Which will be at most this many from us
-    uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
-        g_teems_config.tot_weight) * g_teems_config.my_weight;
+    uint32_t max_msg_to_each_stg;
+    max_msg_to_each_stg = CEILDIV(tot_msg_per_stg, g_teems_config.tot_weight) *
+        g_teems_config.my_weight;
 
     // But we can't send more messages to each storage server than we
     // could receive in total
@@ -81,7 +83,16 @@ bool route_init()
     }
 
     // The max number of messages that can arrive at a storage server
-    uint32_t max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
+    uint32_t max_stg_msgs;
+    max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
+
+    // Calculating public-routing buffer sizes
+    // Weights are not used in public routing
+    uint32_t max_round1a_msgs = max_round1_msgs;
+    uint32_t max_round1b_msgs_to_adj_rtr =
+        (g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1);
+    uint32_t max_round1b_msgs = 2*max_round1b_msgs_to_adj_rtr;
+    uint32_t max_round1c_msgs = max_round1a_msgs;
 
     /*
     printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
@@ -95,6 +106,9 @@ bool route_init()
         }
         if (my_roles & ROLE_ROUTING) {
             route_state.round1.alloc(max_round2_msgs);
+            route_state.round1a.alloc(max_round1a_msgs);
+            route_state.round1b.alloc(2*max_round1b_msgs); // double space for sorting with 1a msgs
+            route_state.round1c.alloc(max_round1c_msgs);
         }
         if (my_roles & ROLE_STORAGE) {
             route_state.round2.alloc(max_stg_msgs);
@@ -108,6 +122,10 @@ bool route_init()
     }
     route_state.step = ROUTE_NOT_STARTED;
     route_state.tot_msg_per_ing = tot_msg_per_ing;
+    route_state.max_round1_msgs = max_round1_msgs;
+    route_state.max_round1a_msgs = max_round1a_msgs;
+    route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
+    route_state.max_round1c_msgs = max_round1c_msgs;
     route_state.max_msg_to_each_stg = max_msg_to_each_stg;
     route_state.max_round2_msgs = max_round2_msgs;
     route_state.max_stg_msgs = max_stg_msgs;
@@ -123,14 +141,12 @@ bool route_init()
     if (my_roles & ROLE_ROUTING) {
         sort_precompute_evalplan(max_round2_msgs, nthreads);
         if(!g_teems_config.private_routing) {
-            sort_precompute_evalplan(max_round2_msgs, nthreads);
+            sort_precompute_evalplan(max_round1a_msgs, nthreads);
+            sort_precompute_evalplan(max_round1b_msgs, nthreads);
         }
     }
     if (my_roles & ROLE_STORAGE) {
         sort_precompute_evalplan(max_stg_msgs, nthreads);
-        if(!g_teems_config.private_routing) {
-            sort_precompute_evalplan(max_stg_msgs, nthreads);
-        }
     }
 #ifdef PROFILE_ROUTING
     printf_with_rtclock_diff(start, "end precompute evalplans\n");
@@ -181,7 +197,10 @@ size_t ecall_precompute_sort(int sizeidx)
         if (my_roles & ROLE_ROUTING) {
             used_sizes.push_back(route_state.max_round2_msgs);
             if(!g_teems_config.private_routing) {
-                used_sizes.push_back(route_state.max_round2_msgs);
+                used_sizes.push_back(route_state.max_round1a_msgs);
+                used_sizes.push_back(route_state.max_round1b_msgs);
+                used_sizes.push_back(route_state.max_round1b_msgs);
+                used_sizes.push_back(route_state.max_round1c_msgs);
             }
         }
         if (my_roles & ROLE_STORAGE) {
@@ -227,6 +246,15 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
     return msgbuf.buf + start * msg_size;
 }
 
+static void round1a_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t);
+
+static void round1b_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t);
+
+static void round1c_received(NodeCommState &nodest, uint8_t *data,
+    uint32_t plaintext_len, uint32_t);
+
 static void round2_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t);
 
@@ -249,23 +277,169 @@ static void round1_received(NodeCommState &nodest,
     pthread_mutex_unlock(&route_state.round1.mutex);
 
     // What is the next message we expect from this node?
-    if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
+    if (g_teems_config.private_routing) {
+        if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
+            nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                    uint32_t tot_enc_chunk_size) {
+                return msgbuffer_get_buf(route_state.round2, commst,
+                    tot_enc_chunk_size);
+            };
+            nodest.in_msg_received = round2_received;
+        }
+        // Otherwise, it's just the next round 1 message, so don't change
+        // the handlers.
+    } else {
+        if (their_roles & ROLE_ROUTING) {
+            nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                    uint32_t tot_enc_chunk_size) {
+                return msgbuffer_get_buf(route_state.round1a, commst,
+                    tot_enc_chunk_size);
+            };
+            nodest.in_msg_received = round1a_received;
+        }
+        // Otherwise, it's just the next round 1 message, so don't change
+        // the handlers.
+    }
+
+    if (nodes_received == g_teems_config.num_ingestion_nodes &&
+            completed_prev_round) {
+        route_state.step = ROUTE_ROUND_1;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, 1);
+    }
+}
+
+// A round 1a message was received by a routing node
+static void round1a_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    pthread_mutex_lock(&route_state.round1a.mutex);
+    route_state.round1a.inserted += num_msgs;
+    route_state.round1a.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round1a.nodes_received;
+    bool completed_prev_round = route_state.round1a.completed_prev_round;
+    pthread_mutex_unlock(&route_state.round1a.mutex);
+
+    // Both are routing nodes
+    // We only expect a message from the previous and next nodes (if they exist)
+    //FIX: replace handlers for previous and next nodes, to put messages in correct locations
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
+    uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
+    if ((prev_nodes > 0) &&
+        (nodest.node_num == g_teems_config.routing_nodes[num_routing_nodes-1])) {
+        // Node is previous routing node
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1b, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1b_received;
+    } else if ((prev_nodes < num_routing_nodes-1) &&
+        (nodest.node_num == g_teems_config.routing_nodes[1])) {
+        // Node is next routing node
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1b, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1b_received;
+    } else {
+        // other routing nodes will not send to this node until round 1c
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1c, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1c_received;
+    }
+
+    if (nodes_received == g_teems_config.num_routing_nodes &&
+            completed_prev_round) {
+        route_state.step = ROUTE_ROUND_1A;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, ROUND_1A);
+    }
+}
+
+// A round 1b message was received by a routing node
+static void round1b_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
+    pthread_mutex_lock(&route_state.round1b.mutex);
+    route_state.round1b.inserted += num_msgs;
+    route_state.round1b.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round1b.nodes_received;
+    bool completed_prev_round = route_state.round1b.completed_prev_round;
+    pthread_mutex_unlock(&route_state.round1b.mutex);
+    // Set handler back to standard encrypted message handler
+    nodest.in_msg_get_buf = [&](NodeCommState &commst,
+            uint32_t tot_enc_chunk_size) {
+        return msgbuffer_get_buf(route_state.round1c, commst,
+            tot_enc_chunk_size);
+    };
+    nodest.in_msg_received = round1c_received;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
+    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
+    nodenum_t adjacent_nodes =
+        (((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) ? 1 : 2);        
+    if (nodes_received == adjacent_nodes && completed_prev_round) {
+        route_state.step = ROUTE_ROUND_1B;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, ROUND_1B);
+    }
+}
+
+// Message received in round 1c
+static void round1c_received(NodeCommState &nodest, uint8_t *data,
+    uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
+    pthread_mutex_lock(&route_state.round1c.mutex);
+    route_state.round1c.inserted += num_msgs;
+    route_state.round1c.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round1c.nodes_received;
+    bool completed_prev_round = route_state.round1c.completed_prev_round;
+    pthread_mutex_unlock(&route_state.round1c.mutex);
+
+    // What is the next message we expect from this node?
+    if (our_roles & ROLE_STORAGE) {
         nodest.in_msg_get_buf = [&](NodeCommState &commst,
                 uint32_t tot_enc_chunk_size) {
             return msgbuffer_get_buf(route_state.round2, commst,
                 tot_enc_chunk_size);
         };
         nodest.in_msg_received = round2_received;
+    } else if (their_roles & ROLE_INGESTION) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1_received;
     }
-    // Otherwise, it's just the next round 1 message, so don't change
-    // the handlers.
-
-    if (nodes_received == g_teems_config.num_ingestion_nodes &&
+    if (nodes_received == g_teems_config.num_routing_nodes &&
             completed_prev_round) {
-        route_state.step = ROUTE_ROUND_1;
+        route_state.step = ROUTE_ROUND_1C;
         void *cbpointer = route_state.cbpointer;
         route_state.cbpointer = NULL;
-        ocall_routing_round_complete(cbpointer, 1);
+        ocall_routing_round_complete(cbpointer, ROUND_1C);
     }
 }
 
@@ -378,19 +552,23 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 }
 
 // Send the round 1 messages.  Note that N here is not private.
-static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
+template<typename T>
+static void send_round1_msgs(const uint8_t *msgs, const T *indices,
     uint32_t N)
 {
     uint16_t msg_size = g_teems_config.msg_size;
-    uint16_t tot_weight = g_teems_config.tot_weight;
+    uint16_t tot_weight;
+    tot_weight = g_teems_config.tot_weight;
     nodenum_t my_node_num = g_teems_config.my_node_num;
 
-    uint32_t full_rows = N / uint32_t(tot_weight);
-    uint32_t last_row = N % uint32_t(tot_weight);
+    uint32_t full_rows;
+    uint32_t last_row;
+    full_rows = N / uint32_t(tot_weight);
+    last_row = N % uint32_t(tot_weight);
 
     for (auto &routing_node: g_teems_config.routing_nodes) {
-        uint8_t weight =
-            g_teems_config.weights[routing_node].weight;
+        uint8_t weight = g_teems_config.weights[routing_node].weight;
+        
         if (weight == 0) {
             // This shouldn't happen, but just in case
             continue;
@@ -432,13 +610,13 @@ static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
             uint8_t *buf = round1.buf + start * msg_size;
 
             for (uint32_t i=0; i<full_rows; ++i) {
-                const UidKey *idxp = indices + i*tot_weight + start_weight;
+                const T *idxp = indices + i*tot_weight + start_weight;
                 for (uint32_t j=0; j<weight; ++j) {
                     memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                     buf += msg_size;
                 }
             }
-            const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
+            const T *idxp = indices + full_rows*tot_weight + start_weight;
             for (uint32_t j=0; j<num_msgs_last_row; ++j) {
                 memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                 buf += msg_size;
@@ -453,12 +631,209 @@ static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
             NodeCommState &nodecom = g_commstates[routing_node];
             nodecom.message_start(num_msgs * msg_size);
             for (uint32_t i=0; i<full_rows; ++i) {
-                const UidKey *idxp = indices + i*tot_weight + start_weight;
+                const T *idxp = indices + i*tot_weight + start_weight;
                 for (uint32_t j=0; j<weight; ++j) {
                     nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
                 }
             }
-            const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
+            const T *idxp = indices + full_rows*tot_weight + start_weight;
+            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
+            }
+        }
+    }
+}
+
+// Send the round 1a messages from the round 1 buffer, which only occurs in public-channel routing.
+// msgs points to the message buffer, indices points to the the sorted indices, and N is the number
+// of non-padding items.
+static void send_round1a_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
+    uint16_t msg_size = g_teems_config.msg_size;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
+
+    uint32_t min_msgs_per_node = route_state.max_round1a_msgs / num_routing_nodes;
+    uint32_t extra_msgs = route_state.max_round1a_msgs % num_routing_nodes;
+    for (auto &routing_node: g_teems_config.routing_nodes) {
+        // In this unweighted setting, start_weight represents the position among routing nodes
+        uint16_t prev_nodes = g_teems_config.weights[routing_node].startweight;
+        uint32_t start_msg, num_msgs;
+        if (prev_nodes >= extra_msgs) {
+            start_msg = min_msgs_per_node * prev_nodes + extra_msgs;
+            num_msgs = min_msgs_per_node;
+        } else {
+            start_msg = min_msgs_per_node * prev_nodes + prev_nodes;
+            num_msgs = min_msgs_per_node + 1;
+        }
+        // take number of messages into account
+        if (start_msg >= N) {
+            num_msgs = 0;
+        } else if (start_msg + num_msgs > N) {
+            num_msgs = N - start_msg;
+        }
+
+        if (routing_node == my_node_num) {
+            // Special case: we're sending to ourselves; just put the
+            // messages in our own buffer
+            MsgBuffer &round1a = route_state.round1a;
+            pthread_mutex_lock(&round1a.mutex);
+            uint32_t start = round1a.reserved;
+            if (start + num_msgs > round1a.bufsize) {
+                pthread_mutex_unlock(&round1a.mutex);
+                printf("Max %u messages exceeded in round 1a\n", round1a.bufsize);
+                return;
+            }
+            round1a.reserved += num_msgs;
+            pthread_mutex_unlock(&round1a.mutex);
+            uint8_t *buf = round1a.buf + start * msg_size;
+            for (uint32_t i=0; i<num_msgs; ++i) {
+                const UidPriorityKey *idxp = indices + start_msg + i;
+                memmove(buf, msgs + idxp->index()*msg_size, msg_size);
+                buf += msg_size;
+            }
+            pthread_mutex_lock(&round1a.mutex);
+            round1a.inserted += num_msgs;
+            round1a.nodes_received += 1;
+            pthread_mutex_unlock(&round1a.mutex);
+        } else {
+            NodeCommState &nodecom = g_commstates[routing_node];
+            nodecom.message_start(num_msgs * msg_size);
+            for (uint32_t i=0; i<num_msgs; ++i) {
+                const UidPriorityKey *idxp = indices + start_msg + i;
+                nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
+            }
+        }
+    }
+}
+
+// Send the round 1b messages from the round 1a buffer, which only occurs in public-channel routing.
+// msgs points to the message buffer, indices points to the the sorted indices, and N is the number
+// of non-padding items.
+static void send_round1b_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
+    uint16_t msg_size = g_teems_config.msg_size;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
+    uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
+
+    // send to previous node
+    if (prev_nodes > 0) {
+        nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1];
+        NodeCommState &nodecom = g_commstates[prev_node];
+        uint32_t num_msgs = min(route_state.max_round1a_msgs,
+            route_state.max_round1b_msgs_to_adj_rtr);
+        nodecom.message_start(num_msgs * msg_size);
+        for (uint32_t i=0; i<num_msgs; ++i) {
+            const UidPriorityKey *idxp = indices + i;
+            nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
+        }
+    }
+    // send to next node
+    if (prev_nodes < num_routing_nodes-1) {
+        nodenum_t next_node = g_teems_config.routing_nodes[1];
+        NodeCommState &nodecom = g_commstates[next_node];
+        if (N <= route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr) {
+            // No messages to exchange with next node
+            nodecom.message_start(0);
+            // No need to call message_data()
+        } else {
+            uint32_t start_msg =
+                route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr;
+            uint32_t num_msgs = N - start_msg;
+            nodecom.message_start(num_msgs * msg_size);
+            for (uint32_t i=0; i<num_msgs; ++i) {
+                const UidPriorityKey *idxp = indices + start_msg + i;
+                nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
+            }
+        }
+    }
+}
+
+// Send the round 1c messages.  Note that N here is not private.
+// FIX: combine with send_round1_msgs(), which is similar
+static void send_round1c_msgs(const uint8_t *msgs, const UidPriorityKey *indices,
+    uint32_t N)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    uint16_t tot_weight;
+    tot_weight = g_teems_config.tot_weight;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+
+    uint32_t full_rows;
+    uint32_t last_row;
+    full_rows = N / uint32_t(tot_weight);
+    last_row = N % uint32_t(tot_weight);
+
+    for (auto &routing_node: g_teems_config.routing_nodes) {
+        uint8_t weight = g_teems_config.weights[routing_node].weight;
+        
+        if (weight == 0) {
+            // This shouldn't happen, but just in case
+            continue;
+        }
+        uint16_t start_weight =
+            g_teems_config.weights[routing_node].startweight;
+
+        // The number of messages headed for this routing node from the
+        // full rows
+        uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
+        // The number of messages headed for this routing node from the
+        // incomplete last row is:
+        // 0 if last_row < start_weight
+        // last_row-start_weight if start_weight <= last_row < start_weight + weight
+        // weight if start_weight + weight <= last_row
+        uint32_t num_msgs_last_row = 0;
+        if (start_weight <= last_row && last_row < start_weight + weight) {
+            num_msgs_last_row = last_row-start_weight;
+        } else if (start_weight + weight <= last_row) {
+            num_msgs_last_row = weight;
+        }
+        // The total number of messages headed for this routing node
+        uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
+
+        if (routing_node == my_node_num) {
+            // Special case: we're sending to ourselves; just put the
+            // messages in our own round1 buffer
+            MsgBuffer &round1c = route_state.round1c;
+
+            pthread_mutex_lock(&round1c.mutex);
+            uint32_t start = round1c.reserved;
+            if (start + num_msgs > round1c.bufsize) {
+                pthread_mutex_unlock(&round1c.mutex);
+                printf("Max %u messages exceeded\n", round1c.bufsize);
+                return;
+            }
+            round1c.reserved += num_msgs;
+            pthread_mutex_unlock(&round1c.mutex);
+            uint8_t *buf = round1c.buf + start * msg_size;
+
+            for (uint32_t i=0; i<full_rows; ++i) {
+                const UidPriorityKey *idxp = indices + i*tot_weight + start_weight;
+                for (uint32_t j=0; j<weight; ++j) {
+                    memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
+                    buf += msg_size;
+                }
+            }
+            const UidPriorityKey *idxp = indices + full_rows*tot_weight + start_weight;
+            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
+                buf += msg_size;
+            }
+
+            pthread_mutex_lock(&round1c.mutex);
+            round1c.inserted += num_msgs;
+            round1c.nodes_received += 1;
+            pthread_mutex_unlock(&round1c.mutex);
+
+        } else {
+            NodeCommState &nodecom = g_commstates[routing_node];
+            nodecom.message_start(num_msgs * msg_size);
+            for (uint32_t i=0; i<full_rows; ++i) {
+                const UidPriorityKey *idxp = indices + i*tot_weight + start_weight;
+                for (uint32_t j=0; j<weight; ++j) {
+                    nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
+                }
+            }
+            const UidPriorityKey *idxp = indices + full_rows*tot_weight + start_weight;
             for (uint32_t j=0; j<num_msgs_last_row; ++j) {
                 nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
             }
@@ -526,6 +901,283 @@ static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
     }
 }
 
+
+static void round1a_processing(void *cbpointer) {
+    uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    MsgBuffer &round1 = route_state.round1;
+
+    if (my_roles & ROLE_ROUTING) {
+        route_state.cbpointer = cbpointer;
+        pthread_mutex_lock(&round1.mutex);
+        // Ensure there are no pending messages currently being inserted
+        // into the buffer
+        while (round1.reserved != round1.inserted) {
+            pthread_mutex_unlock(&round1.mutex);
+            pthread_mutex_lock(&round1.mutex);
+        }
+#ifdef PROFILE_ROUTING
+        uint32_t inserted = round1.inserted;
+        unsigned long start_round1a = printf_with_rtclock("begin round1a processing (%u)\n", inserted);
+        // Sort the messages we've received
+        unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs);
+#endif
+        // Sort received messages by increasing user ID and
+        // priority. Smaller priority number indicates higher priority.
+        sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1.buf,
+            g_teems_config.msg_size, round1.inserted, route_state.max_round1_msgs,
+            send_round1a_msgs);
+
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs);
+        printf_with_rtclock_diff(start_round1a, "end round1a processing (%u)\n", inserted);
+#endif
+        round1.reset();
+        pthread_mutex_unlock(&round1.mutex);
+
+        MsgBuffer &round1a = route_state.round1a;
+        pthread_mutex_lock(&round1a.mutex);
+        round1a.completed_prev_round = true;
+        nodenum_t nodes_received = round1a.nodes_received;
+        pthread_mutex_unlock(&round1a.mutex);
+        if (nodes_received == g_teems_config.num_routing_nodes) {
+            route_state.step = ROUTE_ROUND_1A;
+            route_state.cbpointer = NULL;
+            ocall_routing_round_complete(cbpointer, ROUND_1A);
+        }
+    } else {
+        route_state.step = ROUTE_ROUND_1A;
+        route_state.round1a.completed_prev_round = true;
+        ocall_routing_round_complete(cbpointer, ROUND_1A);
+    }
+}
+
+static void round1b_processing(void *cbpointer) {
+    uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
+    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
+    MsgBuffer &round1a = route_state.round1a;
+
+    if (my_roles & ROLE_ROUTING) {
+        route_state.cbpointer = cbpointer;
+        pthread_mutex_lock(&round1a.mutex);
+        // Ensure there are no pending messages currently being inserted
+        // into the buffer
+        while (round1a.reserved != round1a.inserted) {
+            pthread_mutex_unlock(&round1a.mutex);
+            pthread_mutex_lock(&round1a.mutex);
+        }
+
+#ifdef PROFILE_ROUTING
+        uint32_t inserted = round1a.inserted;
+        unsigned long start_round1b = printf_with_rtclock("begin round1b processing (%u)\n", inserted);
+        // Sort the messages we've received
+        unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
+#endif
+        // Sort received messages by increasing user ID and
+        // priority. Smaller priority number indicates higher priority.
+        sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
+            g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
+            send_round1b_msgs);
+
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
+        printf_with_rtclock_diff(start_round1b, "end round1b processing (%u)\n", inserted);
+#endif
+
+        //round1a.reset(); // Don't reset until end of round 1c
+        pthread_mutex_unlock(&round1a.mutex);
+        MsgBuffer &round1b = route_state.round1b;
+        pthread_mutex_lock(&round1b.mutex);
+        round1b.completed_prev_round = true;
+        nodenum_t nodes_received = round1b.nodes_received;
+        nodenum_t adjacent_nodes =
+            (((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) ? 1 : 2);        
+        pthread_mutex_unlock(&round1b.mutex);
+        if (nodes_received == adjacent_nodes) {
+            route_state.step = ROUTE_ROUND_1B;
+            route_state.cbpointer = NULL;
+            ocall_routing_round_complete(cbpointer, ROUND_1B);
+        }
+    } else {
+        route_state.step = ROUTE_ROUND_1B;
+        route_state.round1b.completed_prev_round = true;
+        ocall_routing_round_complete(cbpointer, ROUND_1B);
+    }
+}
+
+//FIX: adjust the message totals based on the sorts
+static void round1c_processing(void *cbpointer) {
+    uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
+    MsgBuffer &round1a = route_state.round1a;
+    MsgBuffer &round1b = route_state.round1b;
+
+    if (my_roles & ROLE_ROUTING) {
+        route_state.cbpointer = cbpointer;
+        pthread_mutex_lock(&round1b.mutex);
+        // Ensure there are no pending messages currently being inserted
+        // into the buffer
+        while (round1b.reserved != round1b.inserted) {
+            pthread_mutex_unlock(&round1b.mutex);
+            pthread_mutex_lock(&round1b.mutex);
+        }
+
+#ifdef PROFILE_ROUTING
+        uint32_t inserted = round1a.inserted;
+        unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", inserted);
+        // Sort the messages we've received
+        unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
+#endif
+
+        // Sort received messages by increasing user ID and
+        // priority. Smaller priority number indicates higher priority.
+        sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
+            g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
+            send_round1c_msgs);
+
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
+        printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", inserted);
+#endif
+
+        round1a.reset();
+        pthread_mutex_unlock(&round1a.mutex);
+        pthread_mutex_lock(&round1b.mutex);
+        round1b.reset();
+        pthread_mutex_unlock(&round1b.mutex);
+
+        MsgBuffer &round1c = route_state.round1c;
+        pthread_mutex_lock(&round1c.mutex);
+        round1c.completed_prev_round = true;
+        nodenum_t nodes_received = round1c.nodes_received;
+        pthread_mutex_unlock(&round1c.mutex);
+        if (nodes_received == num_routing_nodes) {
+            route_state.step = ROUTE_ROUND_1C;
+            route_state.cbpointer = NULL;
+            ocall_routing_round_complete(cbpointer, ROUND_1C);
+        }
+    } else {
+        route_state.step = ROUTE_ROUND_1C;
+        route_state.round1b.completed_prev_round = true;
+        ocall_routing_round_complete(cbpointer, ROUND_1C);
+    }
+}
+
+// Process messages in round 2
+static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prevround) {
+    if (my_roles & ROLE_ROUTING) {
+        route_state.cbpointer = cbpointer;
+
+        pthread_mutex_lock(&prevround.mutex);
+        // Ensure there are no pending messages currently being inserted
+        // into the buffer
+        while (prevround.reserved != prevround.inserted) {
+            pthread_mutex_unlock(&prevround.mutex);
+            pthread_mutex_lock(&prevround.mutex);
+        }
+
+        // If the _total_ number of messages we received in round 1
+        // is less than the max number of messages we could send to
+        // _each_ storage node, then cap the number of messages we
+        // will send to each storage node to that number.
+        uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
+        if (prevround.inserted < msgs_per_stg) {
+            msgs_per_stg = prevround.inserted;
+        }
+
+        // Note: at this point, it is required that each message in
+        // the prevround buffer have a _valid_ storage node id field.
+
+        // Obliviously tally the number of messages we received in
+        // the previous round destined for each storage node
+#ifdef PROFILE_ROUTING
+        uint32_t inserted = prevround.inserted;
+        unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, prevround.bufsize);
+        unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
+#endif
+        uint16_t msg_size = g_teems_config.msg_size;
+        nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
+        std::vector<uint32_t> tally = obliv_tally_stg(
+            prevround.buf, msg_size, prevround.inserted, num_storage_nodes);
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
+#endif
+        // Note: tally contains private values!  It's OK to
+        // non-obliviously check for an error condition, though.
+        // While we're at it, obliviously change the tally of
+        // messages received to a tally of padding messages
+        // required.
+        uint32_t tot_padding = 0;
+        for (nodenum_t i=0; i<num_storage_nodes; ++i) {
+            if (tally[i] > msgs_per_stg) {
+                printf("Received too many messages for storage node %u\n", i);
+                assert(tally[i] <= msgs_per_stg);
+            }
+            tally[i] = msgs_per_stg - tally[i];
+            tot_padding += tally[i];
+        }
+
+        prevround.reserved += tot_padding;
+        assert(prevround.reserved <= prevround.bufsize);
+
+        // Obliviously add padding for each storage node according
+        // to the (private) padding tally.
+#ifdef PROFILE_ROUTING
+        unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
+#endif
+        obliv_pad_stg(prevround.buf + prevround.inserted * msg_size,
+            msg_size, tally, tot_padding);
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
+#endif
+
+        prevround.inserted += tot_padding;
+
+        // Obliviously shuffle the messages
+#ifdef PROFILE_ROUTING
+        unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
+#endif
+        uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
+            prevround.buf, msg_size, prevround.inserted, prevround.bufsize);
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
+        printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, prevround.bufsize);
+#endif
+
+        // Now we can handle the messages non-obliviously, since we
+        // know there will be exactly msgs_per_stg messages to each
+        // storage node, and the oblivious shuffle broke the
+        // connection between where each message came from and where
+        // it's going.
+        send_round2_msgs(num_shuffled, msgs_per_stg);
+
+        prevround.reset();
+        pthread_mutex_unlock(&prevround.mutex);
+    }
+
+    if (my_roles & ROLE_STORAGE) {
+        route_state.cbpointer = cbpointer;
+        MsgBuffer &round2 = route_state.round2;
+
+        pthread_mutex_lock(&round2.mutex);
+        round2.completed_prev_round = true;
+        nodenum_t nodes_received = round2.nodes_received;
+        pthread_mutex_unlock(&round2.mutex);
+
+        if (nodes_received == g_teems_config.num_routing_nodes) {
+            route_state.step = ROUTE_ROUND_2;
+            route_state.cbpointer = NULL;
+            ocall_routing_round_complete(cbpointer, 2);
+        }
+    } else {
+        route_state.step = ROUTE_ROUND_2;
+        route_state.round2.completed_prev_round = true;
+        ocall_routing_round_complete(cbpointer, 2);
+    }
+}
+
 // Perform the next round of routing.  The callback pointer will be
 // passed to ocall_routing_round_complete when the round is complete.
 void ecall_routing_proceed(void *cbpointer)
@@ -551,9 +1203,17 @@ void ecall_routing_proceed(void *cbpointer)
             unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
             unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
-            sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
-                g_teems_config.msg_size, ingbuf.inserted,
-                route_state.tot_msg_per_ing, send_round1_msgs);
+            if (g_teems_config.private_routing) {
+                sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
+                    g_teems_config.msg_size, ingbuf.inserted,
+                    route_state.tot_msg_per_ing, send_round1_msgs<UidKey>);
+            } else {
+                // Sort received messages by increasing user ID and
+                // priority. Smaller priority number indicates higher priority.
+                sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, ingbuf.buf,
+                    g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing,
+                    send_round1_msgs<UidPriorityKey>);
+            }
 #ifdef PROFILE_ROUTING
             printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
             printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
@@ -580,137 +1240,15 @@ void ecall_routing_proceed(void *cbpointer)
             ocall_routing_round_complete(cbpointer, 1);
         }
     } else if (route_state.step == ROUTE_ROUND_1) {
-        if (my_roles & ROLE_ROUTING) {
-            route_state.cbpointer = cbpointer;
-            MsgBuffer &round1 = route_state.round1;
-
-            pthread_mutex_lock(&round1.mutex);
-            // Ensure there are no pending messages currently being inserted
-            // into the buffer
-            while (round1.reserved != round1.inserted) {
-                pthread_mutex_unlock(&round1.mutex);
-                pthread_mutex_lock(&round1.mutex);
-            }
-
-            // If the _total_ number of messages we received in round 1
-            // is less than the max number of messages we could send to
-            // _each_ storage node, then cap the number of messages we
-            // will send to each storage node to that number.
-            uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
-            if (round1.inserted < msgs_per_stg) {
-                msgs_per_stg = round1.inserted;
-            }
-
-            // Note: at this point, it is required that each message in
-            // the round1 buffer have a _valid_ storage node id field.
-
-            // Obliviously tally the number of messages we received in
-            // round1 destined for each storage node
-#ifdef PROFILE_ROUTING
-            uint32_t inserted = round1.inserted;
-            unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
-            unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
-#endif
-            uint16_t msg_size = g_teems_config.msg_size;
-            nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
-            std::vector<uint32_t> tally = obliv_tally_stg(
-                round1.buf, msg_size, round1.inserted, num_storage_nodes);
-#ifdef PROFILE_ROUTING
-            printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
-#endif
-
-            // For public routing, convert excess messages to padding destined
-            // for other storage nodes with fewer messages than the maximum.
-            if (!g_teems_config.private_routing) {
-#ifdef PROFILE_ROUTING
-                unsigned long start_convert_excess = printf_with_rtclock("begin converting excess messages (%u)\n", round1.inserted);
-#endif
-                // Sort received messages by increasing storage node and
-                // priority. Smaller priority number indicates higher priority.
-                // Sorted messages are put back into source buffer.
-                sort_mtobliv<NidPriorityKey>(g_teems_config.nthreads,
-                    round1.buf, g_teems_config.msg_size, round1.inserted,
-                    round1.bufsize);
-                // Convert excess messages into padding
-                obliv_excess_to_padding(round1.buf, msg_size, round1.inserted,
-                    tally, msgs_per_stg);
-#ifdef PROFILE_ROUTING
-                printf_with_rtclock_diff(start_convert_excess, "end converting excess messages (%u)\n", round1.inserted);
-#endif
-            }
-
-            // Note: tally contains private values!  It's OK to
-            // non-obliviously check for an error condition, though.
-            // While we're at it, obliviously change the tally of
-            // messages received to a tally of padding messages
-            // required.
-            uint32_t tot_padding = 0;
-            for (nodenum_t i=0; i<num_storage_nodes; ++i) {
-                if (tally[i] > msgs_per_stg) {
-                    printf("Received too many messages for storage node %u\n", i);
-                    assert(tally[i] <= msgs_per_stg);
-                }
-                tally[i] = msgs_per_stg - tally[i];
-                tot_padding += tally[i];
-            }
-
-            round1.reserved += tot_padding;
-            assert(round1.reserved <= round1.bufsize);
-
-            // Obliviously add padding for each storage node according
-            // to the (private) padding tally.
-#ifdef PROFILE_ROUTING
-            unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
-#endif
-            obliv_pad_stg(round1.buf + round1.inserted * msg_size,
-                msg_size, tally, tot_padding);
-#ifdef PROFILE_ROUTING
-            printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
-#endif
-
-            round1.inserted += tot_padding;
-
-            // Obliviously shuffle the messages
-#ifdef PROFILE_ROUTING
-            unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
-#endif
-            uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
-                round1.buf, msg_size, round1.inserted, round1.bufsize);
-#ifdef PROFILE_ROUTING
-            printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
-            printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
-#endif
-
-            // Now we can handle the messages non-obliviously, since we
-            // know there will be exactly msgs_per_stg messages to each
-            // storage node, and the oblivious shuffle broke the
-            // connection between where each message came from and where
-            // it's going.
-            send_round2_msgs(num_shuffled, msgs_per_stg);
-
-            round1.reset();
-            pthread_mutex_unlock(&round1.mutex);
-
-        }
-        if (my_roles & ROLE_STORAGE) {
-            route_state.cbpointer = cbpointer;
-            MsgBuffer &round2 = route_state.round2;
-
-            pthread_mutex_lock(&round2.mutex);
-            round2.completed_prev_round = true;
-            nodenum_t nodes_received = round2.nodes_received;
-            pthread_mutex_unlock(&round2.mutex);
-
-            if (nodes_received == g_teems_config.num_routing_nodes) {
-                route_state.step = ROUTE_ROUND_2;
-                route_state.cbpointer = NULL;
-                ocall_routing_round_complete(cbpointer, 2);
-            }
-        } else {
-            route_state.step = ROUTE_ROUND_2;
-            route_state.round2.completed_prev_round = true;
-            ocall_routing_round_complete(cbpointer, 2);
+        if (g_teems_config.private_routing) { // private routing next round
+            round2_processing(my_roles, cbpointer, route_state.round1);
+        } else { // public routing next round
+            round1a_processing(cbpointer);
         }
+    } else if (route_state.step == ROUTE_ROUND_1A) {
+        round1b_processing(cbpointer);
+    } else if (route_state.step == ROUTE_ROUND_1B) {
+        round1c_processing(cbpointer);
     } else if (route_state.step == ROUTE_ROUND_2) {
         if (my_roles & ROLE_STORAGE) {
             MsgBuffer &round2 = route_state.round2;

+ 15 - 0
Enclave/route.hpp

@@ -3,6 +3,10 @@
 
 #include <pthread.h>
 
+#define ROUND_1A 11
+#define ROUND_1B 12
+#define ROUND_1C 13
+
 struct MsgBuffer {
     pthread_mutex_t mutex;
     uint8_t *buf;
@@ -59,6 +63,9 @@ struct MsgBuffer {
 enum RouteStep {
     ROUTE_NOT_STARTED,
     ROUTE_ROUND_1,
+    ROUTE_ROUND_1A,
+    ROUTE_ROUND_1B,
+    ROUTE_ROUND_1C,
     ROUTE_ROUND_2
 };
 
@@ -72,10 +79,18 @@ enum RouteStep {
 struct RouteState {
     MsgBuffer ingbuf;
     MsgBuffer round1;
+    MsgBuffer round1a;
+    MsgBuffer round1b;
+    MsgBuffer round1c;
     MsgBuffer round2;
     RouteStep step;
     uint32_t tot_msg_per_ing;
     uint32_t max_msg_to_each_stg;
+    uint32_t max_round1_msgs;
+    uint32_t max_round1a_msgs;
+    uint32_t max_round1b_msgs_to_adj_rtr;
+    uint32_t max_round1b_msgs;
+    uint32_t max_round1c_msgs;
     uint32_t max_round2_msgs;
     uint32_t max_stg_msgs;
     void *cbpointer;

+ 1 - 1
Enclave/storage.cpp

@@ -20,7 +20,7 @@ static struct {
     // The destination vector for ORExpand
     std::vector<uint32_t> dest;
     // The selected array for compaction during public routing
-    // Need an bool array for compaction, and std:vector<bool> lacks .data()
+    // Need a bool array for compaction, and std:vector<bool> lacks .data()
     bool *pub_selected;
 } storage_state;