Browse Source

Prepare round 2 messages

Ian Goldberg 1 year ago
parent
commit
c3207004ce
6 changed files with 177 additions and 11 deletions
  1. 2 0
      Enclave/config.cpp
  2. 3 0
      Enclave/config.hpp
  3. 71 0
      Enclave/obliv.cpp
  4. 30 0
      Enclave/obliv.hpp
  5. 64 10
      Enclave/route.cpp
  6. 7 1
      Makefile

+ 2 - 0
Enclave/config.cpp

@@ -33,6 +33,7 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
     g_teems_config.ingestion_nodes.clear();
     g_teems_config.routing_nodes.clear();
     g_teems_config.storage_nodes.clear();
+    g_teems_config.storage_map.clear();
     for (nodenum_t i=0; i<num_nodes; ++i) {
         NodeWeight nw;
         nw.startweight = cumul_weight;
@@ -62,6 +63,7 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
             } else {
                 g_teems_config.storage_nodes.push_back(i);
             }
+            g_teems_config.storage_map.push_back(i);
         }
         cumul_weight += nw.weight;
         g_teems_config.weights.push_back(nw);

+ 3 - 0
Enclave/config.hpp

@@ -35,6 +35,9 @@ struct Config {
     std::vector<nodenum_t> ingestion_nodes;
     std::vector<nodenum_t> routing_nodes;
     std::vector<nodenum_t> storage_nodes;
+    // storage_map[i] is the node number of the storage node responsible
+    // for the destination adddresses with storage node field i.
+    std::vector<nodenum_t> storage_map;
 };
 
 extern Config g_teems_config;

+ 71 - 0
Enclave/obliv.cpp

@@ -0,0 +1,71 @@
+#include "oasm_lib.h"
+#include "enclave_api.h"
+#include "obliv.hpp"
+
+// Routines for processing private data obliviously
+
+// Obliviously tally the number of messages in the given buffer destined
+// for each storage node.  Each message is of size msg_size bytes.
+// There are num_msgs messages in the buffer.  There are
+// num_storage_nodes storage nodes in total.  The destination storage
+// node of each message is determined by looking at the top
+// DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
+// beginning of the message; this will be a number between 0 and
+// num_storage_nodes-1, which is not necessarily the node number of the
+// storage node, which may be larger if, for example, there are a bunch
+// of routing or ingestion nodes that are not also storage nodes.  The
+// return value is a vector of length num_storage_nodes containing the
+// tally.
+std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
+    uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes)
+{
+    // The _contents_ of buf are private, but everything else in the
+    // input is public.  The contents of the output tally (but not its
+    // length) are also private.
+    std::vector<uint32_t> tally(num_storage_nodes, 0);
+
+    // This part must all be oblivious except for the length checks on
+    // num_msgs and num_storage_nodes
+    while (num_msgs) {
+        uint32_t storage_node_id = (*(const uint32_t*)buf) >> DEST_UID_BITS;
+        for (uint32_t i=0; i<num_storage_nodes; ++i) {
+            tally[i] += (storage_node_id == i);
+        }
+        buf += msg_size;
+        --num_msgs;
+    }
+
+    return tally;
+}
+
+// Obliviously create padding messages destined for the various storage
+// nodes, using the (private) counts in the tally vector.  The tally
+// vector may be modified by this function.  tot_padding must be the sum
+// of the elements in tally, which need _not_ be private.
+void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
+    std::vector<uint32_t> &tally, uint32_t tot_padding)
+{
+    // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in
+    // the bottom DEST_UID_BITS.
+    uint32_t pad_user = (1<<DEST_UID_BITS)-1;
+
+    // This value is not oblivious
+    const uint32_t num_storage_nodes = uint32_t(tally.size());
+
+    // This part must all be oblivious except for the length checks on
+    // tot_padding and num_storage_nodes
+    while (tot_padding) {
+        bool found = false;
+        uint32_t found_node = 0;
+        for (uint32_t i=0; i<num_storage_nodes; ++i) {
+            bool found_here = (!found) & (!!tally[i]);
+            found_node = oselect_uint32_t(found_node, i, found_here);
+            found = found | found_here;
+            tally[i] -= found_here;
+        }
+        *(uint32_t*)buf = ((found_node<<DEST_UID_BITS) | pad_user);
+
+        buf += msg_size;
+        --tot_padding;
+    }
+}

+ 30 - 0
Enclave/obliv.hpp

@@ -0,0 +1,30 @@
+#ifndef __OBLIV_HPP__
+#define __OBLIV_HPP__
+
+#include <vector>
+
+// Routines for processing private data obliviously
+
+// Obliviously tally the number of messages in the given buffer destined
+// for each storage node.  Each message is of size msg_size bytes.
+// There are num_msgs messages in the buffer.  There are
+// num_storage_nodes storage nodes in total.  The destination storage
+// node of each message is determined by looking at the top
+// DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
+// beginning of the message; this will be a number between 0 and
+// num_storage_nodes-1, which is not necessarily the node number of the
+// storage node, which may be larger if, for example, there are a bunch
+// of routing or ingestion nodes that are not also storage nodes.  The
+// return value is a vector of length num_storage_nodes containing the
+// tally.
+std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
+    uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes);
+
+// Obliviously create padding messages destined for the various storage
+// nodes, using the (private) counts in the tally vector.  The tally
+// vector may be modified by this function.  tot_padding must be the sum
+// of the elements in tally, which need _not_ be private.
+void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
+    std::vector<uint32_t> &tally, uint32_t tot_padding);
+
+#endif

+ 64 - 10
Enclave/route.cpp

@@ -4,6 +4,7 @@
 #include "utils.hpp"
 #include "sort.hpp"
 #include "comms.hpp"
+#include "obliv.hpp"
 #include "route.hpp"
 
 #define PROFILE_ROUTING
@@ -75,7 +76,7 @@ static struct RouteState {
     MsgBuffer round2;
     RouteStep step;
     uint32_t tot_msg_per_ing;
-    uint32_t max_msg_to_each_str;
+    uint32_t max_msg_to_each_stg;
     uint32_t max_round2_msgs;
     void *cbpointer;
 } route_state;
@@ -112,25 +113,25 @@ bool route_init()
     // Compute the maximum number of messages we could send in round 2
 
     // Each storage node has at most this many users
-    uint32_t users_per_str = CEILDIV(g_teems_config.user_count,
+    uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
         g_teems_config.num_storage_nodes);
 
     // And so can receive at most this many messages
-    uint32_t tot_msg_per_str = users_per_str *
+    uint32_t tot_msg_per_stg = users_per_stg *
         g_teems_config.m_priv_in;
 
     // Which will be at most this many from us
-    uint32_t max_msg_to_each_str = CEILDIV(tot_msg_per_str,
+    uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
         g_teems_config.tot_weight) * g_teems_config.my_weight;
 
     // But we can't send more messages to each storage server than we
     // could receive in total
-    if (max_msg_to_each_str > max_round1_msgs) {
-        max_msg_to_each_str = max_round1_msgs;
+    if (max_msg_to_each_stg > max_round1_msgs) {
+        max_msg_to_each_stg = max_round1_msgs;
     }
 
     // And the max total number of outgoing messages in round 2 is then
-    uint32_t max_round2_msgs = max_msg_to_each_str *
+    uint32_t max_round2_msgs = max_msg_to_each_stg *
         g_teems_config.num_storage_nodes;
 
     // In case we have a weird configuration where users can send more
@@ -155,7 +156,7 @@ bool route_init()
             route_state.round1.alloc(max_round2_msgs);
         }
         if (my_roles & ROLE_STORAGE) {
-            route_state.round2.alloc(tot_msg_per_str +
+            route_state.round2.alloc(tot_msg_per_stg +
                 g_teems_config.tot_weight);
         }
     } catch (std::bad_alloc&) {
@@ -164,7 +165,7 @@ bool route_init()
     }
     route_state.step = ROUTE_NOT_STARTED;
     route_state.tot_msg_per_ing = tot_msg_per_ing;
-    route_state.max_msg_to_each_str = max_msg_to_each_str;
+    route_state.max_msg_to_each_stg = max_msg_to_each_stg;
     route_state.max_round2_msgs = max_round2_msgs;
     route_state.cbpointer = NULL;
 
@@ -489,8 +490,61 @@ void ecall_routing_proceed(void *cbpointer)
                 pthread_mutex_lock(&round1.mutex);
             }
 
+            // If the _total_ number of messages we received in round 1
+            // is less than the max number of messages we could send to
+            // _each_ storage node, then cap the number of messages we
+            // will send to each storage node to that number.
+            uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
+            if (round1.inserted < msgs_per_stg) {
+                msgs_per_stg = round1.inserted;
+            }
+
+            // Note: at this point, it is required that each message in
+            // the round1 buffer have a _valid_ storage node id field.
+
+            // Obliviously tally the number of messages we received in
+            // round1 destined for each storage node
             uint32_t msg_size = g_teems_config.msg_size;
-            for(uint32_t i=0;i<round1.inserted;++i) {
+            nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
+            std::vector<uint32_t> tally = obliv_tally_stg(
+                round1.buf, msg_size, round1.inserted, num_storage_nodes);
+
+            // Note: tally contains private values!  It's OK to
+            // non-obliviously check for an error condition, though.
+            // While we're at it, obliviously change the tally of
+            // messages received to a tally of padding messages
+            // required.
+            uint32_t tot_padding = 0;
+            for (nodenum_t i=0; i<num_storage_nodes; ++i) {
+                if (tally[i] > msgs_per_stg) {
+                    printf("Received too many messages for storage node %u\n", i);
+                    assert(tally[i] <= msgs_per_stg);
+                }
+                tally[i] = msgs_per_stg - tally[i];
+                tot_padding += tally[i];
+            }
+
+            round1.reserved += tot_padding;
+            assert(round1.reserved <= round1.bufsize);
+
+            // Obliviously add padding for each storage node according
+            // to the (private) padding tally.
+            obliv_pad_stg(round1.buf + round1.inserted * msg_size,
+                msg_size, tally, tot_padding);
+
+            round1.inserted += tot_padding;
+
+            // Obliviously shuffle the messages
+            uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
+                round1.buf, msg_size, round1.inserted, round1.bufsize);
+
+            // Now we can handle the messages non-obliviously, since we
+            // know there will be exactly msgs_per_stg messages to each
+            // storage node, and the oblivious shuffle broke the
+            // connection between where each message came from and where
+            // it's going.
+
+            for(uint32_t i=0;i<num_shuffled;++i) {
                 uint32_t destaddr = *(uint32_t*)(round1.buf+i*msg_size);
                 printf("%08x\n", destaddr);
             }

+ 7 - 1
Makefile

@@ -256,6 +256,11 @@ Enclave/%.o: Enclave/%.cpp
 	@echo "CXX  <=  $<"
 	@$(CXX) $(SGX_COMMON_CXXFLAGS) $(Enclave_Cpp_Flags) -c $< -o $@
 
+Enclave/asm/%.s: Enclave/%.cpp
+	@echo "CXXASM  <=  $<"
+	@mkdir -p $$(dirname $@)
+	@$(CXX) $(SGX_COMMON_CXXFLAGS) $(Enclave_Cpp_Flags) -S $< -o $@
+
 $(Enclave_Cpp_Objects): Enclave/Enclave_t.h
 
 $(Enclave_Name): Enclave/Enclave_t.o $(Enclave_Cpp_Objects)
@@ -297,9 +302,10 @@ Enclave/comms.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
 Enclave/comms.o: Enclave/enclave_api.h Enclave/route.hpp Enclave/comms.hpp
 Enclave/config.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/comms.hpp
 Enclave/config.o: Enclave/enclave_api.h Enclave/config.hpp Enclave/route.hpp
+Enclave/obliv.o: Enclave/enclave_api.h Enclave/obliv.hpp
 Enclave/route.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
 Enclave/route.o: Enclave/enclave_api.h Enclave/sort.hpp Enclave/comms.hpp
-Enclave/route.o: Enclave/route.hpp
+Enclave/route.o: Enclave/obliv.hpp Enclave/route.hpp
 Enclave/sort.o: Enclave/sort.hpp
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/CONFIG.h