Procházet zdrojové kódy

fixing errors in public routing

Aaron Johnson před 1 rokem
rodič
revize
f8337931f1
3 změnil soubory, kde provedl 57 přidání a 32 odebrání
  1. 54 32
      Enclave/route.cpp
  2. 1 0
      Enclave/route.hpp
  3. 2 0
      gen_enclave_config.py

+ 54 - 32
Enclave/route.cpp

@@ -106,9 +106,12 @@ bool route_init()
         }
         if (my_roles & ROLE_ROUTING) {
             route_state.round1.alloc(max_round2_msgs);
-            route_state.round1a.alloc(max_round1a_msgs);
-            route_state.round1b.alloc(2*max_round1b_msgs); // double space for sorting with 1a msgs
-            route_state.round1c.alloc(max_round1c_msgs);
+            if (!g_teems_config.private_routing) {
+                route_state.round1a.alloc(max_round1a_msgs);
+                route_state.round1a_sorted.alloc(max_round1a_msgs);
+                route_state.round1b.alloc(2*max_round1b_msgs); // double to sort with 1a msgs
+                route_state.round1c.alloc(max_round1c_msgs);
+            }
         }
         if (my_roles & ROLE_STORAGE) {
             route_state.round2.alloc(max_stg_msgs);
@@ -125,6 +128,7 @@ bool route_init()
     route_state.max_round1_msgs = max_round1_msgs;
     route_state.max_round1a_msgs = max_round1a_msgs;
     route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
+    route_state.max_round1b_msgs = max_round1b_msgs;
     route_state.max_round1c_msgs = max_round1c_msgs;
     route_state.max_msg_to_each_stg = max_msg_to_each_stg;
     route_state.max_round2_msgs = max_round2_msgs;
@@ -433,6 +437,13 @@ static void round1c_received(NodeCommState &nodest, uint8_t *data,
                 tot_enc_chunk_size);
         };
         nodest.in_msg_received = round1_received;
+    } else {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1a, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1a_received;
     }
     if (nodes_received == g_teems_config.num_routing_nodes &&
             completed_prev_round) {
@@ -707,9 +718,8 @@ static void send_round1a_msgs(const uint8_t *msgs, const UidPriorityKey *indices
 }
 
 // Send the round 1b messages from the round 1a buffer, which only occurs in public-channel routing.
-// msgs points to the message buffer, indices points to the the sorted indices, and N is the number
-// of non-padding items.
-static void send_round1b_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
+// msgs points to the message buffer, and N is the number of non-padding items.
+static void send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
     uint16_t msg_size = g_teems_config.msg_size;
     nodenum_t my_node_num = g_teems_config.my_node_num;
     nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
@@ -719,12 +729,10 @@ static void send_round1b_msgs(const uint8_t *msgs, const UidPriorityKey *indices
     if (prev_nodes > 0) {
         nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1];
         NodeCommState &nodecom = g_commstates[prev_node];
-        uint32_t num_msgs = min(route_state.max_round1a_msgs,
-            route_state.max_round1b_msgs_to_adj_rtr);
+        uint32_t num_msgs = min(N, route_state.max_round1b_msgs_to_adj_rtr);
         nodecom.message_start(num_msgs * msg_size);
         for (uint32_t i=0; i<num_msgs; ++i) {
-            const UidPriorityKey *idxp = indices + i;
-            nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
+            nodecom.message_data(msgs + i*msg_size, msg_size);
         }
     }
     // send to next node
@@ -741,8 +749,7 @@ static void send_round1b_msgs(const uint8_t *msgs, const UidPriorityKey *indices
             uint32_t num_msgs = N - start_msg;
             nodecom.message_start(num_msgs * msg_size);
             for (uint32_t i=0; i<num_msgs; ++i) {
-                const UidPriorityKey *idxp = indices + start_msg + i;
-                nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
+                nodecom.message_data(msgs + i*msg_size, msg_size);
             }
         }
     }
@@ -841,17 +848,16 @@ static void send_round1c_msgs(const uint8_t *msgs, const UidPriorityKey *indices
     }
 }
 
-// Send the round 2 messages from the round 1 buffer, which are already
+// Send the round 2 messages from the previous-round buffer, which are already
 // padded and shuffled, so this can be done non-obliviously.  tot_msgs
 // is the total number of messages in the input buffer, which may
 // include padding messages added by the shuffle.  Those messages are
 // not sent anywhere.  There are num_msgs_per_stg messages for each
 // storage node labelled for that node.
-static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
+static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg, MsgBuffer &prevround)
 {
     uint16_t msg_size = g_teems_config.msg_size;
-    MsgBuffer &round1 = route_state.round1;
-    const uint8_t* buf = round1.buf;
+    const uint8_t* buf = prevround.buf;
     nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
     nodenum_t my_node_num = g_teems_config.my_node_num;
     uint8_t *myself_buf = NULL;
@@ -957,6 +963,7 @@ static void round1b_processing(void *cbpointer) {
     uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
     nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
     MsgBuffer &round1a = route_state.round1a;
+    MsgBuffer &round1a_sorted = route_state.round1a_sorted;
 
     if (my_roles & ROLE_ROUTING) {
         route_state.cbpointer = cbpointer;
@@ -967,6 +974,7 @@ static void round1b_processing(void *cbpointer) {
             pthread_mutex_unlock(&round1a.mutex);
             pthread_mutex_lock(&round1a.mutex);
         }
+        pthread_mutex_lock(&round1a_sorted.mutex);
 
 #ifdef PROFILE_ROUTING
         uint32_t inserted = round1a.inserted;
@@ -976,16 +984,22 @@ static void round1b_processing(void *cbpointer) {
 #endif
         // Sort received messages by increasing user ID and
         // priority. Smaller priority number indicates higher priority.
-        sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
-            g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
-            send_round1b_msgs);
+        if (inserted > 0) {
+            // copy items in sorted order into round1a_sorted
+            sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
+                g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
+                route_state.round1a_sorted.buf);
+            send_round1b_msgs(round1a.buf, round1a.inserted);
+        } else {
+            send_round1b_msgs(NULL, 0);
+        }
 
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
         printf_with_rtclock_diff(start_round1b, "end round1b processing (%u)\n", inserted);
 #endif
 
-        //round1a.reset(); // Don't reset until end of round 1c
+        pthread_mutex_unlock(&round1a_sorted.mutex);
         pthread_mutex_unlock(&round1a.mutex);
         MsgBuffer &round1b = route_state.round1b;
         pthread_mutex_lock(&round1b.mutex);
@@ -1012,6 +1026,7 @@ static void round1c_processing(void *cbpointer) {
     nodenum_t my_node_num = g_teems_config.my_node_num;
     nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
     MsgBuffer &round1a = route_state.round1a;
+    MsgBuffer &round1a_sorted = route_state.round1a_sorted;
     MsgBuffer &round1b = route_state.round1b;
 
     if (my_roles & ROLE_ROUTING) {
@@ -1023,7 +1038,8 @@ static void round1c_processing(void *cbpointer) {
             pthread_mutex_unlock(&round1b.mutex);
             pthread_mutex_lock(&round1b.mutex);
         }
-
+        pthread_mutex_lock(&round1a.mutex);
+        pthread_mutex_lock(&round1a_sorted.mutex);
 #ifdef PROFILE_ROUTING
         uint32_t inserted = round1a.inserted;
         unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", inserted);
@@ -1033,9 +1049,13 @@ static void round1c_processing(void *cbpointer) {
 
         // Sort received messages by increasing user ID and
         // priority. Smaller priority number indicates higher priority.
-        sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
-            g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
-            send_round1c_msgs);
+        if (inserted > 0) {
+            sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a_sorted.buf,
+                g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
+                send_round1c_msgs);
+        } else {
+            send_round1c_msgs(NULL, NULL, 0);
+        }
 
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
@@ -1043,9 +1063,10 @@ static void round1c_processing(void *cbpointer) {
 #endif
 
         round1a.reset();
-        pthread_mutex_unlock(&round1a.mutex);
-        pthread_mutex_lock(&round1b.mutex);
+        round1a_sorted.reset();
         round1b.reset();
+        pthread_mutex_unlock(&round1a_sorted.mutex);
+        pthread_mutex_unlock(&round1a.mutex);
         pthread_mutex_unlock(&round1b.mutex);
 
         MsgBuffer &round1c = route_state.round1c;
@@ -1093,16 +1114,15 @@ static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prev
         // Obliviously tally the number of messages we received in
         // the previous round destined for each storage node
 #ifdef PROFILE_ROUTING
-        uint32_t inserted = prevround.inserted;
-        unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, prevround.bufsize);
-        unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
+        unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize);
+        unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", prevround.inserted);
 #endif
         uint16_t msg_size = g_teems_config.msg_size;
         nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
         std::vector<uint32_t> tally = obliv_tally_stg(
             prevround.buf, msg_size, prevround.inserted, num_storage_nodes);
 #ifdef PROFILE_ROUTING
-        printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
+        printf_with_rtclock_diff(start_tally, "end tally (%u)\n", prevround.inserted);
 #endif
         // Note: tally contains private values!  It's OK to
         // non-obliviously check for an error condition, though.
@@ -1143,7 +1163,7 @@ static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prev
             prevround.buf, msg_size, prevround.inserted, prevround.bufsize);
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
-        printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, prevround.bufsize);
+        printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize);
 #endif
 
         // Now we can handle the messages non-obliviously, since we
@@ -1151,7 +1171,7 @@ static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prev
         // storage node, and the oblivious shuffle broke the
         // connection between where each message came from and where
         // it's going.
-        send_round2_msgs(num_shuffled, msgs_per_stg);
+        send_round2_msgs(num_shuffled, msgs_per_stg, prevround);
 
         prevround.reset();
         pthread_mutex_unlock(&prevround.mutex);
@@ -1249,6 +1269,8 @@ void ecall_routing_proceed(void *cbpointer)
         round1b_processing(cbpointer);
     } else if (route_state.step == ROUTE_ROUND_1B) {
         round1c_processing(cbpointer);
+    } else if (route_state.step == ROUTE_ROUND_1C) {
+        round2_processing(my_roles, cbpointer, route_state.round1c);
     } else if (route_state.step == ROUTE_ROUND_2) {
         if (my_roles & ROLE_STORAGE) {
             MsgBuffer &round2 = route_state.round2;

+ 1 - 0
Enclave/route.hpp

@@ -80,6 +80,7 @@ struct RouteState {
     MsgBuffer ingbuf;
     MsgBuffer round1;
     MsgBuffer round1a;
+    MsgBuffer round1a_sorted;
     MsgBuffer round1b;
     MsgBuffer round1c;
     MsgBuffer round2;

+ 2 - 0
gen_enclave_config.py

@@ -30,6 +30,8 @@ def generate_config(N, M, T, B):
     heap_size += clients_per_server * (B + 60)
     # 4 Buffers of clients_per_server items of B size each
     heap_size += (clients_per_server * B * 5)
+    # Public routing uses 4 additional Buffers
+    heap_size += (clients_per_server * B * 3) + ((M-1)**2 * B * 4)
     # 3 x WN
     heap_size += (5 * T * (clients_per_server * math.ceil(math.log(clients_per_server,2)) * 8))
     # 60 MB