Просмотр исходного кода

Oblivious part of storage processing for private routing

Ian Goldberg 1 год назад
Родитель
Сommit
f6640a173e
6 измененных файлов с 150 добавлено и 15 удалено
  1. 1 1
      App/start.cpp
  2. 18 7
      Enclave/route.cpp
  3. 1 0
      Enclave/sort.tcc
  4. 119 4
      Enclave/storage.cpp
  5. 1 1
      Enclave/storage.hpp
  6. 10 2
      Makefile

+ 1 - 1
App/start.cpp

@@ -76,7 +76,7 @@ static void epoch(NetIO &netio, char **args) {
                     // Use a token from node j
                     *((uint32_t*)nextmsg) =
                         (j << DEST_UID_BITS) +
-                            (((r<<8)+(my_node_num&0xff)) & dest_uid_mask);
+                            ((rem_tokens-1) & dest_uid_mask);
                     // Put a bunch of copies of r as the message body
                     for (uint16_t i=1;i<msg_size/4;++i) {
                         ((uint32_t*)nextmsg)[i] = r;

+ 18 - 7
Enclave/route.cpp

@@ -30,6 +30,7 @@ static struct RouteState {
     uint32_t tot_msg_per_ing;
     uint32_t max_msg_to_each_stg;
     uint32_t max_round2_msgs;
+    uint32_t max_stg_msgs;
     void *cbpointer;
 } route_state;
 
@@ -93,9 +94,11 @@ bool route_init()
         max_round2_msgs = max_round1_msgs;
     }
 
+    // The max number of messages that can arrive at a storage server
+    uint32_t max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
+
     /*
-    printf("round1_msgs = %u, round2_msgs = %u\n",
-        max_round1_msgs, max_round2_msgs);
+    printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
     */
 
     // Create the route state
@@ -108,10 +111,8 @@ bool route_init()
             route_state.round1.alloc(max_round2_msgs);
         }
         if (my_roles & ROLE_STORAGE) {
-            route_state.round2.alloc(tot_msg_per_stg +
-                g_teems_config.tot_weight);
-            if (!storage_init(tot_msg_per_stg +
-                g_teems_config.tot_weight)) {
+            route_state.round2.alloc(max_stg_msgs);
+            if (!storage_init(users_per_stg, max_stg_msgs)) {
                 return false;
             }
         }
@@ -123,6 +124,7 @@ bool route_init()
     route_state.tot_msg_per_ing = tot_msg_per_ing;
     route_state.max_msg_to_each_stg = max_msg_to_each_stg;
     route_state.max_round2_msgs = max_round2_msgs;
+    route_state.max_stg_msgs = max_stg_msgs;
     route_state.cbpointer = NULL;
 
     threadid_t nthreads = g_teems_config.nthreads;
@@ -163,10 +165,19 @@ size_t ecall_precompute_sort(int sizeidx)
         ret = sort_precompute(route_state.max_round2_msgs);
 #ifdef PROFILE_ROUTING
     printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.max_round2_msgs);}
+#endif
+        break;
+    case 2:
+#ifdef PROFILE_ROUTING
+    {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.max_stg_msgs);
+#endif
+        ret = sort_precompute(route_state.max_stg_msgs);
+#ifdef PROFILE_ROUTING
+    printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.max_stg_msgs);}
 #endif
         break;
     default:
-        ret = 2;
+        ret = 3;
         break;
     }
 

+ 1 - 0
Enclave/sort.tcc

@@ -130,6 +130,7 @@ static void *move_sorted(void *voidargs)
             items + (sorted_keys[i].index()) * msg_size,
             msg_size);
     }
+    return NULL;
 }
 
 // As above, but also pass an Nr*msg_size-byte buffer outbuf to put

+ 119 - 4
Enclave/storage.cpp

@@ -1,21 +1,40 @@
 #include "utils.hpp"
 #include "config.hpp"
-#include "storage.hpp"
 #include "ORExpand.hpp"
+#include "sort.hpp"
+#include "storage.hpp"
+
+#define PROFILE_STORAGE
 
 static struct {
+    uint32_t max_users;
+    uint32_t my_storage_node_id;
     // A local storage buffer, used when we need to do non-in-place
     // sorts of the messages that have arrived
     MsgBuffer stg_buf;
+    // The destination vector for ORExpand
+    std::vector<uint32_t> dest;
 } storage_state;
 
 // route_init will call this function; no one else should call it
 // explicitly.  The parameter is the number of messages that can fit in
 // the storage-side MsgBuffer.  Returns true on success, false on
 // failure.
-bool storage_init(uint32_t msg_buf_size)
+bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
 {
+    storage_state.max_users = max_users;
     storage_state.stg_buf.alloc(msg_buf_size);
+    storage_state.dest.resize(msg_buf_size);
+    uint32_t my_storage_node_id = 0;
+    for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
+        if (g_teems_config.roles[i] & ROLE_STORAGE) {
+            if (i == g_teems_config.my_node_num) {
+                storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
+            } else {
+                ++my_storage_node_id;
+            }
+        }
+    }
     return true;
 }
 
@@ -24,15 +43,33 @@ bool storage_init(uint32_t msg_buf_size)
 // done with it.
 void storage_received(MsgBuffer &storage_buf)
 {
-    // A dummy function for now that just counts how many real and
-    // padding messages arrived
     uint16_t msg_size = g_teems_config.msg_size;
     nodenum_t my_node_num = g_teems_config.my_node_num;
     const uint8_t *msgs = storage_buf.buf;
     uint32_t num_msgs = storage_buf.inserted;
     uint32_t real = 0, padding = 0;
     uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
+    uint32_t nid_mask = ~uid_mask;
 
+#ifdef PROFILE_STORAGE
+    unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
+#endif
+
+    // It's OK to test for errors in a way that's non-oblivous if
+    // there's an error (but it should be oblivious if there are no
+    // errors)
+    for (uint32_t i=0; i<num_msgs; ++i) {
+        uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
+        bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
+            & ((uid & uid_mask) < storage_state.max_users))
+            | ((uid & uid_mask) == uid_mask));
+        if (!ok) {
+            printf("Received bad uid: %08x\n", uid);
+            assert(ok);
+        }
+    }
+
+    // Testing: report how many real and dummy messages arrived
     printf("Storage server received %u messages:\n", num_msgs);
     for (uint32_t i=0; i<num_msgs; ++i) {
         uint32_t dest_addr = *(const uint32_t*)msgs;
@@ -53,6 +90,84 @@ void storage_received(MsgBuffer &storage_buf)
     }
     printf("%u real, %u padding\n", real, padding);
 
+    for (uint32_t i=0;i<num_msgs; ++i) {
+        printf("%3d: %08x %08x\n", i,
+        *(uint32_t*)(storage_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
+    }
+    // Sort the received messages by userid into the
+    // storage_state.stg_buf MsgBuffer.
+#ifdef PROFILE_STORAGE
+    unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
+#endif
+    sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
+        msg_size, storage_buf.inserted, storage_buf.bufsize,
+        storage_state.stg_buf.buf);
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
+#endif
+
+    for (uint32_t i=0;i<num_msgs; ++i) {
+        printf("%3d: %08x %08x\n", i,
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
+    }
+
+#ifdef PROFILE_STORAGE
+    unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
+#endif
+    // Obliviously set the dest array
+    uint32_t *dests = storage_state.dest.data();
+    uint32_t stg_size = storage_state.stg_buf.bufsize;
+    const uint8_t *buf = storage_state.stg_buf.buf;
+    uint32_t m_priv_in = g_teems_config.m_priv_in;
+
+    uint32_t uid = *(uint32_t*)(buf);
+    // num_msgs is not a private value
+    if (num_msgs > 0) {
+        uid &= uid_mask;
+        dests[0] = oselect_uint32_t(uid * m_priv_in, 0xffffffff,
+            uid == uid_mask);
+    }
+    uint32_t prev_uid = uid;
+    for (uint32_t i=1; i<num_msgs; ++i) {
+        uid = *(uint32_t*)(buf + i*msg_size);
+        uid &= uid_mask;
+        dests[i] = oselect_uint32_t(
+            oselect_uint32_t(uid * m_priv_in, dests[i-1]+1, uid==prev_uid),
+            0xffffffff, uid == uid_mask);
+        prev_uid = uid;
+    }
+    for (uint32_t i=num_msgs; i<stg_size; ++i) {
+        dests[i] = 0xffffffff;
+        *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
+    }
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
+#endif
+#ifdef PROFILE_STORAGE
+    unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
+#endif
+    ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
+        msg_size, stg_size, g_teems_config.nthreads);
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
+#endif
+    for (uint32_t i=0;i<stg_size; ++i) {
+        printf("%3d: %08x %08x\n", i,
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
+    }
+
+    // You can do more processing after these lines, as long as they
+    // don't touch storage_buf.  They _can_ touch the backing buffer
+    // storage_state.stg_buf.
     storage_buf.reset();
     pthread_mutex_unlock(&storage_buf.mutex);
+
+    storage_state.stg_buf.reset();
+
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
+#endif
 }

+ 1 - 1
Enclave/storage.hpp

@@ -9,7 +9,7 @@
 // explicitly.  The parameter is the number of messages that can fit in
 // the storage-side MsgBuffer.  Returns true on success, false on
 // failure.
-bool storage_init(uint32_t msg_buf_size);
+bool storage_init(uint32_t max_users, uint32_t msg_buf_size);
 
 // Handle the messages received by a storage node.  Pass a _locked_
 // MsgBuffer.  This function will itself reset and unlock it when it's

+ 10 - 2
Makefile

@@ -339,8 +339,16 @@ Enclave/storage.o: Enclave/enclave_api.h Enclave/OblivAlgs/CONFIG.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.tcc Enclave/OblivAlgs/foav.h
 Enclave/storage.o: Enclave/config.hpp Enclave/enclave_api.h
-Enclave/storage.o: Enclave/storage.hpp Enclave/OblivAlgs/ORExpand.hpp
-Enclave/storage.o: Enclave/OblivAlgs/ORExpand.tcc
+Enclave/storage.o: Enclave/OblivAlgs/ORExpand.hpp
+Enclave/storage.o: Enclave/OblivAlgs/ORExpand.tcc Enclave/sort.hpp
+Enclave/storage.o: Enclave/OblivAlgs/WaksmanNetwork.hpp
+Enclave/storage.o: Enclave/OblivAlgs/RecursiveShuffle.hpp
+Enclave/storage.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
+Enclave/storage.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
+Enclave/storage.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
+Enclave/storage.o: Enclave/OblivAlgs/aes.hpp
+Enclave/storage.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/sort.tcc
+Enclave/storage.o: Enclave/storage.hpp Enclave/route.hpp
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.hpp
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/enclave_api.h