Browse Source

Only send messages in one direction in round1b

This saves one small sort and one large sort
Ian Goldberg 11 months ago
parent
commit
e92d046286
4 changed files with 140 additions and 206 deletions
  1. 136 202
      Enclave/route.cpp
  2. 0 1
      Enclave/route.hpp
  3. 3 2
      gen_enclave_config.py
  4. 1 1
      run_experiments.py

+ 136 - 202
Enclave/route.cpp

@@ -91,6 +91,9 @@ bool route_init()
     uint32_t max_msg_to_each_stg;
     max_msg_to_each_stg = CEILDIV(tot_msg_per_stg, g_teems_config.tot_weight) *
         g_teems_config.my_weight;
+    if (!g_teems_config.private_routing) {
+        max_msg_to_each_stg += 1;
+    }
 
     // But we can't send more messages to each storage server than we
     // could receive in total
@@ -123,7 +126,8 @@ bool route_init()
         std::max(max_round1_msgs, 2*max_round1b_msgs_to_adj_rtr),
         g_teems_config.num_routing_nodes) *
         g_teems_config.num_routing_nodes;
-    uint32_t max_round1c_msgs = max_round1a_msgs;
+    uint32_t max_round1c_msgs = std::max(max_round1a_msgs,
+        max_round2_msgs);
 
     /*
     printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
@@ -150,15 +154,15 @@ bool route_init()
 #endif
             if (!g_teems_config.private_routing) {
                 route_state.round1a.alloc(max_round1a_msgs);
-                route_state.round1a_sorted.alloc(max_round1a_msgs);
+                route_state.round1a_sorted.alloc(max_round1a_msgs +
+                    max_round1b_msgs_to_adj_rtr);
                 // double round 1b buffers to sort with some round 1a messages
-                route_state.round1b_prev.alloc(2*max_round1b_msgs_to_adj_rtr);
                 route_state.round1b_next.alloc(2*max_round1b_msgs_to_adj_rtr);
                 route_state.round1c.alloc(max_round1c_msgs);
 #ifdef TRACK_HEAP_USAGE
                 printf("route_init alloc %u msgs\n", max_round1a_msgs);
-                printf("route_init alloc %u msgs\n", max_round1a_msgs);
-                printf("route_init alloc %u msgs\n", 2*max_round1b_msgs_to_adj_rtr);
+                printf("route_init alloc %u msgs\n", max_round1a_msgs +
+                    max_round1b_msgs_to_adj_rtr);
                 printf("route_init alloc %u msgs\n", 2*max_round1b_msgs_to_adj_rtr);
                 printf("route_init alloc %u msgs\n", max_round1c_msgs);
                 printf("route_init H4 heap %u\n", g_peak_heap_used);
@@ -261,8 +265,6 @@ size_t ecall_precompute_sort(int sizeidx)
             if(!g_teems_config.private_routing) {
                 used_sizes.push_back(route_state.max_round1a_msgs);
                 used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
-                used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
-                used_sizes.push_back(route_state.max_round1c_msgs);
             }
         }
         if (my_roles & ROLE_STORAGE) {
@@ -299,7 +301,8 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
     uint32_t start = msgbuf.reserved;
     if (start + num_msgs > msgbuf.bufsize) {
         pthread_mutex_unlock(&msgbuf.mutex);
-        printf("Max %u messages exceeded\n", msgbuf.bufsize);
+        printf("Max %u messages exceeded (msgbuffer_get_buf)\n",
+            msgbuf.bufsize);
         return NULL;
     }
     msgbuf.reserved += num_msgs;
@@ -311,9 +314,6 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
 static void round1a_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t);
 
-static void round1b_prev_received(NodeCommState &nodest,
-    uint8_t *data, uint32_t plaintext_len, uint32_t);
-
 static void round1b_next_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t);
 
@@ -390,20 +390,11 @@ static void round1a_received(NodeCommState &nodest,
     pthread_mutex_unlock(&route_state.round1a.mutex);
 
     // Both are routing nodes
-    // We only expect a message from the previous and next nodes (if they exist)
+    // We only expect a message from the next node (if it exists)
     nodenum_t my_node_num = g_teems_config.my_node_num;
     nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
     uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
-    if ((prev_nodes > 0) &&
-        (nodest.node_num == g_teems_config.routing_nodes[num_routing_nodes-1])) {
-        // Node is previous routing node
-        nodest.in_msg_get_buf = [&](NodeCommState &commst,
-                uint32_t tot_enc_chunk_size) {
-            return msgbuffer_get_buf(route_state.round1b_prev, commst,
-                tot_enc_chunk_size);
-        };
-        nodest.in_msg_received = round1b_prev_received;
-    } else if ((prev_nodes < num_routing_nodes-1) &&
+    if ((prev_nodes < num_routing_nodes-1) &&
         (nodest.node_num == g_teems_config.routing_nodes[1])) {
         // Node is next routing node
         nodest.in_msg_get_buf = [&](NodeCommState &commst,
@@ -431,49 +422,6 @@ static void round1a_received(NodeCommState &nodest,
     }
 }
 
-// A round 1b message was received from the previous routing node
-static void round1b_prev_received(NodeCommState &nodest,
-    uint8_t *data, uint32_t plaintext_len, uint32_t)
-{
-    uint16_t msg_size = g_teems_config.msg_size;
-    assert((plaintext_len % uint32_t(msg_size)) == 0);
-    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
-    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
-    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
-    pthread_mutex_lock(&route_state.round1b_prev.mutex);
-    route_state.round1b_prev.inserted += num_msgs;
-    route_state.round1b_prev.nodes_received += 1;
-    nodenum_t nodes_received = route_state.round1b_prev.nodes_received;
-    bool completed_prev_round = route_state.round1b_prev.completed_prev_round;
-    pthread_mutex_unlock(&route_state.round1b_prev.mutex);
-    pthread_mutex_lock(&route_state.round1b_next.mutex);
-    nodes_received += route_state.round1b_next.nodes_received;
-    pthread_mutex_unlock(&route_state.round1b_next.mutex);
-    nodest.in_msg_get_buf = [&](NodeCommState &commst,
-            uint32_t tot_enc_chunk_size) {
-        return msgbuffer_get_buf(route_state.round1c, commst,
-            tot_enc_chunk_size);
-    };
-    nodest.in_msg_received = round1c_received;
-    nodenum_t my_node_num = g_teems_config.my_node_num;
-    uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
-    nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
-    nodenum_t adjacent_nodes;
-    if (num_routing_nodes == 1) {
-        adjacent_nodes = 0;
-    } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
-        adjacent_nodes = 1;
-    } else {
-        adjacent_nodes = 2;
-    }
-    if (nodes_received == adjacent_nodes && completed_prev_round) {
-        route_state.step = ROUTE_ROUND_1B;
-        void *cbpointer = route_state.cbpointer;
-        route_state.cbpointer = NULL;
-        ocall_routing_round_complete(cbpointer, ROUND_1B);
-    }
-}
-
 // A round 1b message was received from the next routing node
 static void round1b_next_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
@@ -489,9 +437,6 @@ static void round1b_next_received(NodeCommState &nodest,
     nodenum_t nodes_received = route_state.round1b_next.nodes_received;
     bool completed_prev_round = route_state.round1b_next.completed_prev_round;
     pthread_mutex_unlock(&route_state.round1b_next.mutex);
-    pthread_mutex_lock(&route_state.round1b_prev.mutex);
-    nodes_received += route_state.round1b_prev.nodes_received;
-    pthread_mutex_unlock(&route_state.round1b_prev.mutex);
     nodest.in_msg_get_buf = [&](NodeCommState &commst,
             uint32_t tot_enc_chunk_size) {
         return msgbuffer_get_buf(route_state.round1c, commst,
@@ -504,10 +449,10 @@ static void round1b_next_received(NodeCommState &nodest,
     nodenum_t adjacent_nodes;
     if (num_routing_nodes == 1) {
         adjacent_nodes = 0;
-    } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
-        adjacent_nodes = 1;
+    } else if (prev_nodes == num_routing_nodes-1) {
+        adjacent_nodes = 0;
     } else {
-        adjacent_nodes = 2;
+        adjacent_nodes = 1;
     }
     if (nodes_received == adjacent_nodes && completed_prev_round) {
         route_state.step = ROUTE_ROUND_1B;
@@ -656,7 +601,7 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
     uint32_t start = ingbuf.reserved;
     if (start + num_msgs > route_state.tot_msg_per_ing) {
         pthread_mutex_unlock(&ingbuf.mutex);
-        printf("Max %u messages exceeded\n",
+        printf("Max %u messages exceeded (ecall_ingest_raw)\n",
             route_state.tot_msg_per_ing);
         return false;
     }
@@ -673,10 +618,12 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
     return true;
 }
 
-// Send messages round-robin, used in rounds 1 and 1c.  Note that N here is not private.
+// Send messages round-robin, used in rounds 1 and 1c.  Note that N here
+// is not private.  Pass indices = nullptr to just send the messages in
+// the order they appear in the msgs buffer.
 template<typename T>
-static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T *indices,
-    uint32_t N)
+static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs,
+    const T *indices, uint32_t N)
 {
     uint16_t msg_size = g_teems_config.msg_size;
     uint16_t tot_weight;
@@ -723,7 +670,8 @@ static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T
             uint32_t start = round.reserved;
             if (start + num_msgs > round.bufsize) {
                 pthread_mutex_unlock(&round.mutex);
-                printf("Max %u messages exceeded\n", round.bufsize);
+                printf("Max %u messages exceeded (send_round_robin_msgs)\n",
+                    round.bufsize);
                 return;
             }
             round.reserved += num_msgs;
@@ -731,16 +679,28 @@ static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T
             uint8_t *buf = round.buf + start * msg_size;
 
             for (uint32_t i=0; i<full_rows; ++i) {
-                const T *idxp = indices + i*tot_weight + start_weight;
-                for (uint32_t j=0; j<weight; ++j) {
+                if (indices) {
+                    const T *idxp = indices + i*tot_weight + start_weight;
+                    for (uint32_t j=0; j<weight; ++j) {
+                        memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
+                        buf += msg_size;
+                    }
+                } else {
+                    size_t idx = i*tot_weight + start_weight;
+                    memmove(buf, msgs + idx*msg_size, weight*msg_size);
+                    buf += weight*msg_size;
+                }
+            }
+            if (indices) {
+                const T *idxp = indices + full_rows*tot_weight + start_weight;
+                for (uint32_t j=0; j<num_msgs_last_row; ++j) {
                     memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                     buf += msg_size;
                 }
-            }
-            const T *idxp = indices + full_rows*tot_weight + start_weight;
-            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
-                memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
-                buf += msg_size;
+            } else {
+                size_t idx = full_rows*tot_weight + start_weight;
+                memmove(buf, msgs + idx*msg_size, num_msgs_last_row*msg_size);
+                buf += num_msgs_last_row*msg_size;
             }
 
             pthread_mutex_lock(&round.mutex);
@@ -752,14 +712,27 @@ static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T
             NodeCommState &nodecom = g_commstates[routing_node];
             nodecom.message_start(num_msgs * msg_size);
             for (uint32_t i=0; i<full_rows; ++i) {
-                const T *idxp = indices + i*tot_weight + start_weight;
-                for (uint32_t j=0; j<weight; ++j) {
-                    nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
+                if (indices) {
+                    const T *idxp = indices + i*tot_weight + start_weight;
+                    for (uint32_t j=0; j<weight; ++j) {
+                        nodecom.message_data(msgs + idxp[j].index()*msg_size,
+                            msg_size);
+                    }
+                } else {
+                    size_t idx = i*tot_weight + start_weight;
+                    nodecom.message_data(msgs + idx*msg_size, weight*msg_size);
                 }
             }
-            const T *idxp = indices + full_rows*tot_weight + start_weight;
-            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
-                nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
+            if (indices) {
+                const T *idxp = indices + full_rows*tot_weight + start_weight;
+                for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                    nodecom.message_data(msgs + idxp[j].index()*msg_size,
+                        msg_size);
+                }
+            } else {
+                size_t idx = full_rows*tot_weight + start_weight;
+                nodecom.message_data(msgs + idx*msg_size,
+                    num_msgs_last_row*msg_size);
             }
         }
     }
@@ -829,7 +802,8 @@ static void send_round1a_msgs(const uint8_t *msgs, const UidPriorityKey *indices
 
 // Send the round 1b messages from the round 1a buffer, which only occurs in public-channel routing.
 // msgs points to the message buffer, and N is the number of non-padding items.
-static void send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
+// Return the number of messages sent
+static uint32_t send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
     uint16_t msg_size = g_teems_config.msg_size;
     nodenum_t my_node_num = g_teems_config.my_node_num;
     nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
@@ -841,28 +815,10 @@ static void send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
         NodeCommState &nodecom = g_commstates[prev_node];
         uint32_t num_msgs = min(N, route_state.max_round1b_msgs_to_adj_rtr);
         nodecom.message_start(num_msgs * msg_size);
-        for (uint32_t i=0; i<num_msgs; ++i) {
-            nodecom.message_data(msgs + i*msg_size, msg_size);
-        }
-    }
-    // send to next node
-    if (prev_nodes < num_routing_nodes-1) {
-        nodenum_t next_node = g_teems_config.routing_nodes[1];
-        NodeCommState &nodecom = g_commstates[next_node];
-        if (N <= route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr) {
-            // No messages to exchange with next node
-            nodecom.message_start(0);
-            // No need to call message_data()
-        } else {
-            uint32_t start_msg =
-                route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr;
-            uint32_t num_msgs = N - start_msg;
-            nodecom.message_start(num_msgs * msg_size);
-            for (uint32_t i=0; i<num_msgs; ++i) {
-                nodecom.message_data(msgs + i*msg_size, msg_size);
-            }
-        }
+        nodecom.message_data(msgs, num_msgs * msg_size);
+        return num_msgs;
     }
+    return 0;
 }
 
 // Send the round 2 messages from the previous-round buffer, which are already
@@ -889,7 +845,8 @@ static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg, MsgBu
             uint32_t start = round2.reserved;
             if (start + num_msgs_per_stg > round2.bufsize) {
                 pthread_mutex_unlock(&round2.mutex);
-                printf("Max %u messages exceeded\n", round2.bufsize);
+                printf("Max %u messages exceeded (send_round2_msgs)\n",
+                    round2.bufsize);
                 return;
             }
             round2.reserved += num_msgs_per_stg;
@@ -1008,26 +965,43 @@ static void round1b_processing(void *cbpointer) {
         uint32_t inserted = round1a.inserted;
         unsigned long start_round1b =
             printf_with_rtclock("begin round1b processing (%u)\n", inserted);
-        // Sort the messages we've received
 #endif
+        // Sort the messages we've received
         // Sort received messages by increasing user ID and
         // priority. Larger priority number indicates higher priority.
         if (inserted > 0) {
             // copy items in sorted order into round1a_sorted
 #ifdef PROFILE_ROUTING
-        unsigned long start_sort =
-            printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted,
-                route_state.max_round1a_msgs);
+            unsigned long start_sort =
+                printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted,
+                    route_state.max_round1a_msgs);
 #endif
+            route_state.round1a_sorted.reserved = round1a.inserted;
             sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
                 g_teems_config.msg_size, round1a.inserted,
                 route_state.max_round1a_msgs,
                 route_state.round1a_sorted.buf);
+            route_state.round1a_sorted.inserted = round1a.inserted;
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n",
             inserted, route_state.max_round1a_msgs);
 #endif
-            send_round1b_msgs(round1a_sorted.buf, round1a.inserted);
+#ifdef TRACE_ROUTING
+        show_messages("In round 1b (sorted)", round1a_sorted.buf,
+            round1a_sorted.inserted);
+#endif
+            uint32_t num_sent = send_round1b_msgs(round1a_sorted.buf,
+                round1a_sorted.inserted);
+            // Remove the sent messages from the buffer
+            memmove(round1a_sorted.buf,
+                round1a_sorted.buf + num_sent * g_teems_config.msg_size,
+                (round1a_sorted.inserted - num_sent) * g_teems_config.msg_size);
+            round1a_sorted.inserted -= num_sent;
+            round1a_sorted.reserved -= num_sent;
+#ifdef TRACE_ROUTING
+        show_messages("In round 1b (after sending initial block)",
+            round1a_sorted.buf, round1a_sorted.inserted);
+#endif
         } else {
             send_round1b_msgs(NULL, 0);
         }
@@ -1039,23 +1013,18 @@ static void round1b_processing(void *cbpointer) {
 
         pthread_mutex_unlock(&round1a_sorted.mutex);
         pthread_mutex_unlock(&round1a.mutex);
-        MsgBuffer &round1b_prev = route_state.round1b_prev;
-        pthread_mutex_lock(&round1b_prev.mutex);
-        round1b_prev.completed_prev_round = true;
-        nodenum_t nodes_received = round1b_prev.nodes_received;
-        pthread_mutex_unlock(&round1b_prev.mutex);
         MsgBuffer &round1b_next = route_state.round1b_next;
         pthread_mutex_lock(&round1b_next.mutex);
         round1b_next.completed_prev_round = true;
-        nodes_received += round1b_next.nodes_received;
+        nodenum_t nodes_received = round1b_next.nodes_received;
         pthread_mutex_unlock(&round1b_next.mutex);
         nodenum_t adjacent_nodes;
         if (num_routing_nodes == 1) {
             adjacent_nodes = 0;
-        } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
-            adjacent_nodes = 1;
+        } else if (prev_nodes == num_routing_nodes-1) {
+            adjacent_nodes = 0;
         } else {
-            adjacent_nodes = 2;
+            adjacent_nodes = 1;
         }
         if (nodes_received == adjacent_nodes) {
             route_state.step = ROUTE_ROUND_1B;
@@ -1064,7 +1033,6 @@ static void round1b_processing(void *cbpointer) {
         }
     } else {
         route_state.step = ROUTE_ROUND_1B;
-        route_state.round1b_prev.completed_prev_round = true;
         route_state.round1b_next.completed_prev_round = true;
         ocall_routing_round_complete(cbpointer, ROUND_1B);
     }
@@ -1092,19 +1060,13 @@ static void round1c_processing(void *cbpointer) {
     uint32_t max_round1a_msgs = route_state.max_round1a_msgs;
     MsgBuffer &round1a = route_state.round1a;
     MsgBuffer &round1a_sorted = route_state.round1a_sorted;
-    MsgBuffer &round1b_prev = route_state.round1b_prev;
     MsgBuffer &round1b_next = route_state.round1b_next;
 
     if (my_roles & ROLE_ROUTING) {
         route_state.cbpointer = cbpointer;
-        pthread_mutex_lock(&round1b_prev.mutex);
         pthread_mutex_lock(&round1b_next.mutex);
         // Ensure there are no pending messages currently being inserted
-        // into the round 1b buffers
-        while (round1b_prev.reserved != round1b_prev.inserted) {
-            pthread_mutex_unlock(&round1b_prev.mutex);
-            pthread_mutex_lock(&round1b_prev.mutex);
-        }
+        // into the round 1b buffer
         while (round1b_next.reserved != round1b_next.inserted) {
             pthread_mutex_unlock(&round1b_next.mutex);
             pthread_mutex_lock(&round1b_next.mutex);
@@ -1113,108 +1075,80 @@ static void round1c_processing(void *cbpointer) {
         pthread_mutex_lock(&round1a_sorted.mutex);
 
 #ifdef TRACE_ROUTING
-        show_messages("Start of round 1c", round1a.buf, round1a.inserted);
+        show_messages("Start of round 1c", round1a_sorted.buf,
+            round1a_sorted.inserted);
 #endif
 #ifdef PROFILE_ROUTING
         unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", round1a.inserted);
 #endif
 
-        // sort round1b_prev msgs with initial msgs in round1a_sorted
-        if (prev_nodes > 0) {
-            // Copy initial msgs in round1a_sorted to round1b_prev buffer for sorting
-            // Note that all inserted values and buffer sizes are non-secret
-            uint32_t num_init_round1a = min(round1a.inserted,
-                max_round1b_msgs_to_adj_rtr);
-            uint32_t num_round1b_prev = round1b_prev.inserted;
-            if (num_round1b_prev + num_init_round1a  <= max_round1b_msgs_to_adj_rtr) {
-                // all our round 1a messages "belong" to previous router and can be removed here
-                round1a.inserted = 0;
-            } else {
-                // copy initial round1a msgs after round1b_prev msgs
-                memmove(round1b_prev.buf+num_round1b_prev*msg_size, round1a_sorted.buf,
-                    num_init_round1a*msg_size);
-                // sort and take final msgs as initial round1a msgs
-#ifdef PROFILE_ROUTING
-                unsigned long start_sort = printf_with_rtclock("begin round1b_prev oblivious sort (%u,%u)\n", num_round1b_prev + num_init_round1a, 2*max_round1b_msgs_to_adj_rtr);
-#endif
-                uint32_t num_copy = num_round1b_prev+num_init_round1a-max_round1b_msgs_to_adj_rtr;
-                sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1b_prev.buf,
-                    msg_size, num_round1b_prev + num_init_round1a,
-                    2*max_round1b_msgs_to_adj_rtr,
-                    [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
-                        return copy_msgs(round1a_sorted.buf, max_round1b_msgs_to_adj_rtr,
-                            num_copy, src, indices);
-                    }
-                );
-                round1a.inserted -= (max_round1b_msgs_to_adj_rtr-num_round1b_prev);
-#ifdef PROFILE_ROUTING
-                printf_with_rtclock_diff(start_sort, "end round1b_prev oblivious sort (%u,%u)\n", num_round1b_prev + num_init_round1a, 2*max_round1b_msgs_to_adj_rtr);
-#endif
-            }
-        }
         // sort round1b_next msgs with final msgs in round1a_sorted
         if ((prev_nodes < num_routing_nodes-1) && (round1b_next.inserted > 0)) {
             // Copy final msgs in round1a_sorted to round1b_next buffer for sorting
             // Note that all inserted values and buffer sizes are non-secret
-            // round1b_next.inserted>0, so round1a >= max_round1a_msgs-max_round1b_msgs_to_adj_rtr
-            uint32_t round1a_msg_start = max_round1a_msgs-max_round1b_msgs_to_adj_rtr;
-            uint32_t num_final_round1a = round1a.inserted - round1a_msg_start;
+            uint32_t num_final_round1a = std::min(
+                round1a_sorted.inserted, max_round1b_msgs_to_adj_rtr);
+            uint32_t round1a_msg_start =
+                round1a_sorted.inserted-num_final_round1a;
             uint32_t num_round1b_next = round1b_next.inserted;
+            round1b_next.reserved += num_final_round1a;
             memmove(round1b_next.buf+num_round1b_next*msg_size,
                 round1a_sorted.buf + round1a_msg_start*msg_size,
                 num_final_round1a*msg_size);
-            // sort and take initial msgs as final round1a msgs
+            round1b_next.inserted += num_final_round1a;
+            round1a_sorted.reserved -= num_final_round1a;
+            round1a_sorted.inserted -= num_final_round1a;
+#ifdef TRACE_ROUTING
+            show_messages("In round 1c (after setting aside)",
+                round1a_sorted.buf, round1a_sorted.inserted);
+            show_messages("In round 1c (the aside)", round1b_next.buf,
+                round1b_next.inserted);
+#endif
+            // sort and append to round1a msgs
+            uint32_t num_copy = round1b_next.inserted;
 #ifdef PROFILE_ROUTING
-            unsigned long start_sort = printf_with_rtclock("begin round1b_next oblivious sort (%u,%u)\n", num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr);
+            unsigned long start_sort =
+                printf_with_rtclock("begin round1b_next oblivious sort (%u,%u)\n",
+                    num_copy, 2*max_round1b_msgs_to_adj_rtr);
 #endif
-            uint32_t num_copy = min(num_final_round1a+num_round1b_next,
-                max_round1b_msgs_to_adj_rtr);
-            sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1b_next.buf,
-                msg_size, num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr,
+            round1a_sorted.reserved += num_copy;
+            sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads,
+                round1b_next.buf, msg_size, num_copy,
+                2*max_round1b_msgs_to_adj_rtr,
                 [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
-                    return copy_msgs(round1a_sorted.buf + round1a_msg_start*msg_size, 0,
-                        num_copy, src, indices);
+                    return copy_msgs(round1a_sorted.buf +
+                        round1a_sorted.inserted * msg_size, 0, num_copy,
+                        src, indices);
                 }
             );
-            round1a.inserted += (num_copy - num_final_round1a);
+            round1a_sorted.inserted += num_copy;
 #ifdef PROFILE_ROUTING
-            printf_with_rtclock_diff(start_sort, "end round1b_next oblivious sort (%u,%u)\n", num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr);
+            printf_with_rtclock_diff(start_sort,
+                "end round1b_next oblivious sort (%u,%u)\n",
+                num_copy, 2*max_round1b_msgs_to_adj_rtr);
 #endif
         }
 #ifdef TRACE_ROUTING
-        printf("round1a_sorted.inserted = %lu. round1a.inserted = %lu\n",
-            round1a_sorted.inserted, round1a.inserted);
+        printf("round1a_sorted.inserted = %lu, reserved = %lu, bufsize = %lu\n",
+            round1a_sorted.inserted, round1a_sorted.reserved,
+            round1a_sorted.bufsize);
         show_messages("In round 1c", round1a_sorted.buf,
-            round1a.inserted);
-#endif
-#ifdef PROFILE_ROUTING
-        unsigned long start_sort = printf_with_rtclock("begin full oblivious sort (%u,%u)\n", round1a.inserted, route_state.max_round1a_msgs);
+            round1a_sorted.inserted);
 #endif
-        // Sort received messages by increasing user ID and
-        // priority. Larger priority number indicates higher priority.
-        if (round1a.inserted > 0) {
-            sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a_sorted.buf,
-                msg_size, round1a.inserted, route_state.max_round1a_msgs,
-                [&](const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
-                    send_round_robin_msgs<UidPriorityKey>(route_state.round1c, msgs, indices, N);
-                });
-        } else {
-            send_round_robin_msgs<UidPriorityKey>(route_state.round1c, NULL, NULL, 0);
-        }
+        send_round_robin_msgs<UidPriorityKey>(route_state.round1c,
+            round1a_sorted.buf, NULL, round1a_sorted.inserted);
 
 #ifdef PROFILE_ROUTING
-        printf_with_rtclock_diff(start_sort, "end full oblivious sort (%u,%u)\n", round1a.inserted, route_state.max_round1a_msgs);
-        printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", round1a.inserted);
+        printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n",
+            round1a.inserted);
 #endif
 
         round1a.reset();
         round1a_sorted.reset();
-        round1b_prev.reset();
         round1b_next.reset();
         pthread_mutex_unlock(&round1a_sorted.mutex);
         pthread_mutex_unlock(&round1a.mutex);
         pthread_mutex_unlock(&round1b_next.mutex);
-        pthread_mutex_unlock(&round1b_prev.mutex);
 
         MsgBuffer &round1c = route_state.round1c;
         pthread_mutex_lock(&round1c.mutex);

+ 0 - 1
Enclave/route.hpp

@@ -81,7 +81,6 @@ struct RouteState {
     MsgBuffer round1;
     MsgBuffer round1a;
     MsgBuffer round1a_sorted;
-    MsgBuffer round1b_prev;
     MsgBuffer round1b_next;
     MsgBuffer round1c;
     MsgBuffer round2;

+ 3 - 2
gen_enclave_config.py

@@ -43,7 +43,8 @@ def get_heap_size(N, M, T, B, PRIVATE_ROUTE=True, PRO=1, PRI=1, PUO=1, PUI=1, nu
 
     # 2 Buffers of clients_per_server items of B size each, plus 1 of
     # size (clients_per_server + M) items, for private routing
-    heap_size += (clients_per_server * B * 2) * num_out_mult
+    heap_size += (clients_per_server * B ) * num_out_mult
+    heap_size += ((clients_per_server + (M-1)**2) * B ) * num_out_mult
     heap_size += ((clients_per_server + M) * B) * num_in_mult
     # Additional buffers for public routing
     wn_size = max(clients_per_server, 2*(M-1)**2)
@@ -51,7 +52,7 @@ def get_heap_size(N, M, T, B, PRIVATE_ROUTE=True, PRO=1, PRI=1, PUO=1, PUI=1, nu
     colsort_size = int((wn_size + M - 1) / M) * M
 
     if not PRIVATE_ROUTE:
-        heap_size += (colsort_size * B * 3) + (2 * (M-1)**2 * B * 2)
+        heap_size += (colsort_size * B * 3) + (2 * (M-1)**2 * B)
 
     # num_WN_to_precompute times size of each WN
     heap_size += (num_WN_to_precompute * num_out_mult * \

+ 1 - 1
run_experiments.py

@@ -81,7 +81,7 @@ def run_exp(LOG_FOLDER, PRIVATE_ROUTE, NUM_EPOCHS, N, M, T, B, PRIV_OUT, PRIV_IN
                     if(PRIVATE_ROUTE):
                         num_WN_to_precompute = 2 * 3
                     else:
-                        num_WN_to_precompute = 2 * 8
+                        num_WN_to_precompute = 2 * 6
 
                     # Make the correct output folder for diagnostic/experiment
                     experiment_name = str(n) + "_" + str(m) + "_" + str(t) + "_" + str(b) + "/"