Переглянути джерело

adding public routing to storage servers

Aaron Johnson 1 рік тому
батько
коміт
573d434674
4 змінених файлів з 63 додано та 58 видалено
  1. 22 38
      Enclave/obliv.cpp
  2. 4 12
      Enclave/obliv.hpp
  3. 9 8
      Enclave/route.cpp
  4. 28 0
      Enclave/storage.cpp

+ 22 - 38
Enclave/obliv.cpp

@@ -70,60 +70,44 @@ void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
     }
 }
 
-// Determine the number of messages exceeding the maximum that can be sent to a 
-// storage server. Oblivious to the contents of tally vector.
-std::vector<uint32_t> obliv_excess_stg(std::vector<uint32_t> &tally,
-    nodenum_t num_storage_nodes, uint32_t msgs_per_stg)
+// For each excess message, convert into padding for nodes that will need some.
+// Oblivious to contents of message buffer and tally vector. May modify message
+// buffer and tally vector.
+void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
+    std::vector<uint32_t> &tally, uint32_t msgs_per_stg)
 {
+    const nodenum_t num_storage_nodes = nodenum_t(tally.size());
+
+    // Determine the number of messages exceeding and under the maximum that
+    // can be sent to a storage server. Oblivious to the contents of tally
+    // vector.
     std::vector<uint32_t> excess(num_storage_nodes, 0);
+    std::vector<uint32_t> padding(num_storage_nodes, 0);
     for (nodenum_t i=0; i<num_storage_nodes; ++i) {
         bool exceeds = tally[i] > msgs_per_stg;
-        uint32_t diff = tally[i] - msgs_per_stg; // nonsensical if !exceeds
+        uint32_t diff = tally[i] - msgs_per_stg;
         excess[i] = oselect_uint32_t(0, diff, exceeds);
+        diff = msgs_per_stg - tally[i];
+        padding[i] = oselect_uint32_t(0, diff, !exceeds);
     }
 
-    return excess;
-}
-
-// Determine the number of messages under the maximum that can be sent to a 
-// storage server. Oblivious to the contents of tally vector.
-std::vector<uint32_t> obliv_padding_stg(std::vector<uint32_t> &tally,
-    nodenum_t num_storage_nodes, uint32_t msgs_per_stg)
-{
-    std::vector<uint32_t> padding(num_storage_nodes, 0);
-    for (nodenum_t i=0; i<num_storage_nodes; ++i) {
-        bool under = tally[i] < msgs_per_stg;
-        uint32_t diff = msgs_per_stg - tally[i]; // nonsensical if !under
-        padding[i] = oselect_uint32_t(0, diff, under);
-    }
-
-    return padding;
-}
-
-// For each excess messages, convert into padding for nodes that will need some.
-// Oblivious to contents of excess, padding, and tally vectors. May modify
-// excess, padding, and tally vectors.
-void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
-    std::vector<uint32_t> &excess, std::vector<uint32_t> &padding,
-    std::vector<uint32_t> &tally, nodenum_t num_storage_nodes)
-{
     uint8_t *cur_msg = buf + ((num_msgs-1)*msg_size);
     uint32_t pad_user = (1<<DEST_UID_BITS)-1;
     for (uint32_t i=0; i<num_msgs; ++i) {
         // Determine if storage node for current node has excess messages.
         // Also, decrement excess count and tally if so.
         uint32_t storage_node_id = (*(const uint32_t*)cur_msg) >> DEST_UID_BITS;
-        bool stg_node_excess = false;
+        bool node_excess = false;
         for (uint32_t j=0; j<num_storage_nodes; ++j) {
             bool at_msg_node = (storage_node_id == j);
             bool cur_node_excess = (excess[j] > 0);
-            stg_node_excess = oselect_uint32_t(stg_node_excess,
-                cur_node_excess, at_msg_node);
+            node_excess = oselect_uint32_t(node_excess, cur_node_excess,
+                at_msg_node);
             excess[j] -= (at_msg_node & cur_node_excess);
             tally[j] -= (at_msg_node & cur_node_excess);
         }
-        // Find first node that needs padding. Also, decrement padding count
-        // and increment tally if current node has excess messages.
+        // Find first node that needs padding. Decrement padding count and
+        // increment tally for that if current-message node has excess messages.
         bool found_padding = false;
         nodenum_t found_padding_node = 0;
         for (uint32_t j=0; j<num_storage_nodes; ++j) {
@@ -131,13 +115,13 @@ void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
             found_padding_node = oselect_uint32_t(found_padding_node, j,
                 found_padding_here);
             found_padding = found_padding | found_padding_here;
-            padding[j] -= (found_padding_here & stg_node_excess);
-            tally[j] += (found_padding_here & stg_node_excess);
+            padding[j] -= (found_padding_here & node_excess);
+            tally[j] += (found_padding_here & node_excess);
         }
         // Convert to padding if excess
         uint32_t pad = ((found_padding_node<<DEST_UID_BITS) | pad_user);
         *(uint32_t*)cur_msg = oselect_uint32_t(*(uint32_t*)cur_msg, pad,
-            stg_node_excess);
+            node_excess);
         // Go to previous message for backwards iteration through messages
         cur_msg -= msg_size;
     }

+ 4 - 12
Enclave/obliv.hpp

@@ -27,17 +27,9 @@ std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
 void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
     std::vector<uint32_t> &tally, uint32_t tot_padding);
 
-// Obliviously determine the number of messages exceeding the maximum that can
-// be sent to a storage server.
-std::vector<uint32_t> obliv_excess_stg(std::vector<uint32_t> &tally,
-    nodenum_t num_storage_nodes, uint32_t msgs_per_stg);
-
-// Obliviously determine the number of messages under the maximum that can
-// be sent to a storage server.
-std::vector<uint32_t> obliv_padding_stg(std::vector<uint32_t> &tally,
-    nodenum_t num_storage_nodes, uint32_t msgs_per_stg);
-
+// For each excess message, convert into padding for nodes that will need some.
+// Oblivious to contents of message buffer and tally vector. May modify message
+// buffer and tally vector.
 void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
-    std::vector<uint32_t> &excess, std::vector<uint32_t> &padding,
-    std::vector<uint32_t> &tally, nodenum_t num_storage_nodes);
+    std::vector<uint32_t> &tally, uint32_t msgs_per_stg);
 #endif

+ 9 - 8
Enclave/route.cpp

@@ -621,14 +621,12 @@ void ecall_routing_proceed(void *cbpointer)
             printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
 #endif
 
-            // For public routing, remove excess messages, making them padding
+            // For public routing, convert excess messages to padding destined
+            // for other storage nodes with fewer messages than the maximum.
             if (!g_teems_config.private_routing) {
-                // How many excess messages to remove per storage server
-                std::vector<uint32_t> excess = obliv_excess_stg(tally,
-                    num_storage_nodes, msgs_per_stg);
-                // How many padding messages to add per storage server
-                std::vector<uint32_t> padding = obliv_padding_stg(tally,
-                    num_storage_nodes, msgs_per_stg);
+#ifdef PROFILE_ROUTING
+                unsigned long start_convert_excess = printf_with_rtclock("begin converting excess messages (%u)\n", round1.inserted);
+#endif
                 // Sort received messages by increasing storage node and
                 // priority. Smaller priority number indicates higher priority.
                 // Sorted messages are put back into source buffer.
@@ -637,7 +635,10 @@ void ecall_routing_proceed(void *cbpointer)
                     round1.bufsize);
                 // Convert excess messages into padding
                 obliv_excess_to_padding(round1.buf, msg_size, round1.inserted,
-                    excess, padding, num_storage_nodes);
+                    tally, msgs_per_stg);
+#ifdef PROFILE_ROUTING
+                printf_with_rtclock_diff(start_convert_excess, "end converting excess messages (%u)\n", round1.inserted);
+#endif
             }
 
             // Note: tally contains private values!  It's OK to

+ 28 - 0
Enclave/storage.cpp

@@ -109,6 +109,34 @@ void storage_received(MsgBuffer &storage_buf)
     printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
 #endif
 
+    // For public routing, remove excess per-user messages and compact
+    if (!g_teems_config.private_routing) {
+        bool *selected = new bool[num_msgs];
+        uint8_t *msg = storage_state.stg_buf.buf;
+        uint32_t uid;
+        uint32_t prev_uid = uid;
+        uint32_t num_user_msgs = 0; // number of messages seen to the user
+        uint8_t sel;
+        for (uint32_t i=0; i<num_msgs; ++i) {
+            uid = *(uint32_t*) msg;
+            uid &= uid_mask;
+            sel = ((uint8_t) ((num_user_msgs <= g_teems_config.m_pub_in))) &   
+                ((uint8_t) uid != uid_mask);
+            // Make padding if too many messages for user
+            *(uint32_t *) msg = (*(uint32_t *) msg) & nid_mask;
+            *(uint32_t *) msg += oselect_uint32_t(uid_mask, uid, sel);
+            // Mark as selected only if messages per user not exceeded
+            selected[i] = (bool) oselect_uint32_t(0, 1, sel);
+            prev_uid = uid;
+            num_user_msgs = oselect_uint32_t(1, num_user_msgs+1,
+                uid == prev_uid);
+            msg += msg_size;
+        }
+        TightCompact_parallel((unsigned char *) storage_state.stg_buf.buf,
+            num_msgs, msg_size, selected, g_teems_config.nthreads);
+        delete[] selected;
+    }
+
     /*
     for (uint32_t i=0;i<num_msgs; ++i) {
         printf("%3d: %08x %08x\n", i,