Ver código fonte

Convert global padding to per-storage-server padding in round 2

Ian Goldberg 10 meses atrás
pai
commit
a70720716d
3 arquivos alterados com 56 adições e 95 exclusões
  1. 27 72
      Enclave/obliv.cpp
  2. 5 11
      Enclave/obliv.hpp
  3. 24 12
      Enclave/route.cpp

+ 27 - 72
Enclave/obliv.cpp

@@ -38,24 +38,31 @@ std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
     return tally;
 }
 
-// Obliviously create padding messages destined for the various storage
-// nodes, using the (private) counts in the tally vector.  The tally
-// vector may be modified by this function.  tot_padding must be the sum
-// of the elements in tally, which need _not_ be private.
-void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
-    std::vector<uint32_t> &tally, uint32_t tot_padding)
+// Obliviously convert global padding (receiver id 0xffffffff) into
+// padding for each storage node according to the (private) padding
+// tally.
+void obliv_stg_padding(uint8_t *buf, uint32_t msg_size,
+    std::vector<uint32_t> &tally, uint32_t num_msgs)
 {
     // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in
     // the bottom DEST_UID_BITS.
-    uint32_t pad_user = (1<<DEST_UID_BITS)-1;
+    const uint32_t pad_user = (1<<DEST_UID_BITS)-1;
 
     // This value is not oblivious
-    const uint32_t num_storage_nodes = uint32_t(tally.size());
+    const nodenum_t num_storage_nodes = nodenum_t(tally.size());
 
-    // This part must all be oblivious except for the length checks on
-    // tot_padding and num_storage_nodes
-    while (tot_padding) {
-        bool found = false;
+    uint8_t *cur_msg = buf;
+    // For each message, obliviously turn global padding into padding
+    // for some storage node whose tally shows it still needs more
+    // padding.
+    for (uint32_t m=0; m<num_msgs; ++m) {
+        uint32_t receiver_id = *(uint32_t*)cur_msg;
+        bool is_padding = (receiver_id == 0xffffffff);
+
+        // Obliviously find a storage node that still needs more
+        // padding, if is_padding is true.  If is_padding is false, this
+        // whole block is a no-op.
+        bool found = !is_padding;
         uint32_t found_node = 0;
         for (uint32_t i=0; i<num_storage_nodes; ++i) {
             bool found_here = (!found) & (!!tally[i]);
@@ -63,66 +70,14 @@ void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
             found = found | found_here;
             tally[i] -= found_here;
         }
-        *(uint32_t*)buf = ((found_node<<DEST_UID_BITS) | pad_user);
 
-        buf += msg_size;
-        --tot_padding;
-    }
-}
+        // If this was padding, overwrite the receiver id with the
+        // padding id specific to the found storage node; otherwise
+        // write the original receiver id back.
+        receiver_id = oselect_uint32_t(receiver_id,
+            ((found_node<<DEST_UID_BITS) | pad_user), is_padding);
+        *(uint32_t*)cur_msg = receiver_id;
 
-// For each excess message, convert into padding for nodes that will need some.
-// Oblivious to contents of message buffer and tally vector. May modify message
-// buffer and tally vector.
-void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
-    std::vector<uint32_t> &tally, uint32_t msgs_per_stg)
-{
-    const nodenum_t num_storage_nodes = nodenum_t(tally.size());
-
-    // Determine the number of messages exceeding and under the maximum that
-    // can be sent to a storage server. Oblivious to the contents of tally
-    // vector.
-    std::vector<uint32_t> excess(num_storage_nodes, 0);
-    std::vector<uint32_t> padding(num_storage_nodes, 0);
-    for (nodenum_t i=0; i<num_storage_nodes; ++i) {
-        bool exceeds = tally[i] > msgs_per_stg;
-        uint32_t diff = tally[i] - msgs_per_stg;
-        excess[i] = oselect_uint32_t(0, diff, exceeds);
-        diff = msgs_per_stg - tally[i];
-        padding[i] = oselect_uint32_t(0, diff, !exceeds);
+        cur_msg += msg_size;
     }
-
-    uint8_t *cur_msg = buf + ((num_msgs-1)*msg_size);
-    uint32_t pad_user = (1<<DEST_UID_BITS)-1;
-    for (uint32_t i=0; i<num_msgs; ++i) {
-        // Determine if storage node for current node has excess messages.
-        // Also, decrement excess count and tally if so.
-        uint32_t storage_node_id = (*(const uint32_t*)cur_msg) >> DEST_UID_BITS;
-        bool node_excess = false;
-        for (uint32_t j=0; j<num_storage_nodes; ++j) {
-            bool at_msg_node = (storage_node_id == j);
-            bool cur_node_excess = (excess[j] > 0);
-            node_excess = oselect_uint32_t(node_excess, cur_node_excess,
-                at_msg_node);
-            excess[j] -= (at_msg_node & cur_node_excess);
-            tally[j] -= (at_msg_node & cur_node_excess);
-        }
-        // Find first node that needs padding. Decrement padding count and
-        // increment tally for that if current-message node has excess messages.
-        bool found_padding = false;
-        nodenum_t found_padding_node = 0;
-        for (uint32_t j=0; j<num_storage_nodes; ++j) {
-            bool found_padding_here = (!found_padding) & (!!padding[j]);
-            found_padding_node = oselect_uint32_t(found_padding_node, j,
-                found_padding_here);
-            found_padding = found_padding | found_padding_here;
-            padding[j] -= (found_padding_here & node_excess);
-            tally[j] += (found_padding_here & node_excess);
-        }
-        // Convert to padding if excess
-        uint32_t pad = ((found_padding_node<<DEST_UID_BITS) | pad_user);
-        *(uint32_t*)cur_msg = oselect_uint32_t(*(uint32_t*)cur_msg, pad,
-            node_excess);
-        // Go to previous message for backwards iteration through messages
-        cur_msg -= msg_size;
-    }
-}
+}

+ 5 - 11
Enclave/obliv.hpp

@@ -20,16 +20,10 @@
 std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
     uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes);
 
-// Obliviously create padding messages destined for the various storage
-// nodes, using the (private) counts in the tally vector.  The tally
-// vector may be modified by this function.  tot_padding must be the sum
-// of the elements in tally, which need _not_ be private.
-void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
-    std::vector<uint32_t> &tally, uint32_t tot_padding);
+// Obliviously convert global padding (receiver id 0xffffffff) into
+// padding for each storage node according to the (private) padding
+// tally.
+void obliv_stg_padding(uint8_t *buf, uint32_t msg_size,
+    std::vector<uint32_t> &tally, uint32_t num_msgs);
 
-// For each excess message, convert into padding for nodes that will need some.
-// Oblivious to contents of message buffer and tally vector. May modify message
-// buffer and tally vector.
-void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
-    std::vector<uint32_t> &tally, uint32_t msgs_per_stg);
 #endif

+ 24 - 12
Enclave/route.cpp

@@ -1331,32 +1331,44 @@ static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prev
         // While we're at it, obliviously change the tally of
         // messages received to a tally of padding messages
         // required.
-        uint32_t tot_padding = 0;
+        uint32_t tot_messages = msgs_per_stg * num_storage_nodes;
         for (nodenum_t i=0; i<num_storage_nodes; ++i) {
             if (tally[i] > msgs_per_stg) {
-                printf("Received too many messages for storage node %u\n", i);
+                printf("Received too many messages for storage node "
+                    "%u (%u > %u)\n", i, tally[i], msgs_per_stg);
                 assert(tally[i] <= msgs_per_stg);
             }
             tally[i] = msgs_per_stg - tally[i];
-            tot_padding += tally[i];
         }
 
-        prevround.reserved += tot_padding;
-        assert(prevround.reserved <= prevround.bufsize);
+        // Allocate extra padding messages (not yet destined for a
+        // particular storage node)
+        assert(prevround.inserted <= tot_messages &&
+            tot_messages <= prevround.bufsize);
+        prevround.reserved = tot_messages;
+        for (uint32_t i=prevround.inserted; i<tot_messages; ++i) {
+            uint8_t *header = prevround.buf + i*msg_size;
+            *(uint32_t*)header = 0xffffffff;
+        }
 
-        // Obliviously add padding for each storage node according
-        // to the (private) padding tally.
+        // Obliviously convert global padding (receiver id 0xffffffff)
+        // into padding for each storage node according to the (private)
+        // padding tally.
 #ifdef PROFILE_ROUTING
-        unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
+        unsigned long start_pad =
+            printf_with_rtclock("begin pad (%u)\n", tot_messages);
 #endif
-        obliv_pad_stg(prevround.buf + prevround.inserted * msg_size,
-            msg_size, tally, tot_padding);
+        obliv_stg_padding(prevround.buf, msg_size, tally, tot_messages);
 #ifdef PROFILE_ROUTING
-        printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
+        printf_with_rtclock_diff(start_pad, "end pad (%u)\n",
+            tot_messages);
 #endif
 
-        prevround.inserted += tot_padding;
+        prevround.inserted = tot_messages;
 
+#ifdef TRACE_ROUTING
+        show_messages("In round 2 after padding", prevround.buf, prevround.inserted);
+#endif
         // Obliviously shuffle the messages
 #ifdef PROFILE_ROUTING
         unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);