Browse Source

Merging main

Sajin Sasy 1 year ago
parent
commit
76de7e1ba6
12 changed files with 541 additions and 119 deletions
  1. 2 2
      App/start.cpp
  2. 6 16
      Enclave/OblivAlgs/WaksmanNetwork.tcc
  3. 1 1
      Enclave/OblivAlgs/oasm_lib.h
  4. 2 2
      Enclave/comms.cpp
  5. 58 42
      Enclave/route.cpp
  6. 1 0
      Enclave/route.hpp
  7. 20 36
      Enclave/sort.cpp
  8. 66 5
      Enclave/sort.hpp
  9. 197 0
      Enclave/sort.tcc
  10. 159 5
      Enclave/storage.cpp
  11. 12 2
      Enclave/storage.hpp
  12. 17 8
      Makefile

+ 2 - 2
App/start.cpp

@@ -76,7 +76,7 @@ static void epoch(NetIO &netio, char **args) {
                     // Use a token from node j
                     *((uint32_t*)nextmsg) =
                         (j << DEST_UID_BITS) +
-                            (((r<<8)+(my_node_num&0xff)) & dest_uid_mask);
+                            ((rem_tokens-1) & dest_uid_mask);
                     // Put a bunch of copies of r as the message body
                     for (uint16_t i=1;i<msg_size/4;++i) {
                         ((uint32_t*)nextmsg)[i] = r;
@@ -206,7 +206,7 @@ static void route_test(NetIO &netio, char **args)
 
     // Precompute some WaksmanNetworks
     const Config &config = netio.config();
-    size_t num_sizes = ecall_precompute_sort(-1);
+    size_t num_sizes = ecall_precompute_sort(-2);
     for (int i=0;i<int(num_sizes);++i) {
         std::vector<boost::thread> ts;
         for (int j=0; j<config.nthreads; ++j) {

+ 6 - 16
Enclave/OblivAlgs/WaksmanNetwork.tcc

@@ -1,17 +1,7 @@
 
 // #define PROFILE_MTMERGESORT
 
-template<typename T> static int compare(const void *a, const void *b);
-
-template<>
-int compare<uint64_t>(const void *a, const void *b)
-{
-    uint32_t *a32 = (uint32_t*)a;
-    uint32_t *b32 = (uint32_t*)b;
-    int hi = a32[1]-b32[1];
-    int lo = a32[0]-b32[0];
-    return oselect_uint32_t(hi, lo, !hi);
-}
+template<typename T> static int compare_keys(const void *a, const void *b);
 
 template<typename T>
 struct MergeArgs {
@@ -40,7 +30,7 @@ unsigned long start = printf_with_rtclock("begin merge(dst=%p, leftsrc=%p, Nleft
     const T* rightend = args->rightsrc + args->Nright;
 
     while (left != leftend && right != rightend) {
-        if (compare<T>(left, right) < 0) {
+        if (compare_keys<T>(left, right) < 0) {
             *dst = *left;
             ++dst;
             ++left;
@@ -79,10 +69,10 @@ static size_t binsearch(const T* src, size_t len, const T* target)
     if (len == 0) {
         return 0;
     }
-    if (compare<T>(src + left, target) > 0) {
+    if (compare_keys<T>(src + left, target) > 0) {
         return 0;
     }
-    if (len > 0 && compare<T>(src + right - 1, target) < 0) {
+    if (len > 0 && compare_keys<T>(src + right - 1, target) < 0) {
         return len;
     }
 
@@ -90,7 +80,7 @@ static size_t binsearch(const T* src, size_t len, const T* target)
     // src[len] is considered to be greater than all targets)
     while (right - left > 1) {
         size_t mid = left + (right - left)/2;
-        if (compare<T>(src + mid, target) > 0) {
+        if (compare_keys<T>(src + mid, target) > 0) {
             right = mid;
         } else {
             left = mid;
@@ -201,7 +191,7 @@ bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads)
 #ifdef PROFILE_MTMERGESORT
 unsigned long start = printf_with_rtclock("begin qsort(buf=%p, N=%lu)\n", buf, N);
 #endif
-        qsort(buf, N, sizeof(T), compare<T>);
+        qsort(buf, N, sizeof(T), compare_keys<T>);
 #ifdef PROFILE_MTMERGESORT
 printf_with_rtclock_diff(start, "end qsort(buf=%p, N=%lu)\n", buf, N);
 #endif

+ 1 - 1
Enclave/OblivAlgs/oasm_lib.h

@@ -171,7 +171,7 @@
         "mov %[value0], %[out]\n"
         "test %[flag], %[flag]\n"
         "cmovnz %[value1], %[out]\n"
-        : [out] "=r" (out)
+        : [out] "=&r" (out)
         : [value0] "r" (value_0), [value1] "r" (value_1), [flag] "r" (flag)
         : "cc"
     );

+ 2 - 2
Enclave/comms.cpp

@@ -467,8 +467,8 @@ void NodeCommState::message_data(const uint8_t *data, uint32_t len, bool encrypt
         }
         if (encrypt) {
             // Encrypt the data
-            sgx_aes_gcm128_enc_update((uint8_t*)data, bytes_to_process,
-                frame+frame_offset, out_aes_gcm_state);
+            sgx_aes_gcm128_enc_update(const_cast<uint8_t*>(data),
+                bytes_to_process, frame+frame_offset, out_aes_gcm_state);
         } else {
             // Just copy the plaintext data during the handshake
             memmove(frame+frame_offset, data, bytes_to_process);

+ 58 - 42
Enclave/route.cpp

@@ -71,9 +71,11 @@ bool route_init()
         max_round2_msgs = max_round1_msgs;
     }
 
+    // The max number of messages that can arrive at a storage server
+    uint32_t max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
+
     /*
-    printf("round1_msgs = %u, round2_msgs = %u\n",
-        max_round1_msgs, max_round2_msgs);
+    printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
     */
 
     // Create the route state
@@ -86,8 +88,10 @@ bool route_init()
             route_state.round1.alloc(max_round2_msgs);
         }
         if (my_roles & ROLE_STORAGE) {
-            route_state.round2.alloc(tot_msg_per_stg +
-                g_teems_config.tot_weight);
+            route_state.round2.alloc(max_stg_msgs);
+            if (!storage_init(users_per_stg, max_stg_msgs)) {
+                return false;
+            }
         }
     } catch (std::bad_alloc&) {
         printf("Memory allocation failed in route_init\n");
@@ -97,14 +101,22 @@ bool route_init()
     route_state.tot_msg_per_ing = tot_msg_per_ing;
     route_state.max_msg_to_each_stg = max_msg_to_each_stg;
     route_state.max_round2_msgs = max_round2_msgs;
+    route_state.max_stg_msgs = max_stg_msgs;
     route_state.cbpointer = NULL;
 
     threadid_t nthreads = g_teems_config.nthreads;
 #ifdef PROFILE_ROUTING
     unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads);
 #endif
-    sort_precompute_evalplan(tot_msg_per_ing, nthreads);
-    sort_precompute_evalplan(max_round2_msgs, nthreads);
+    if (my_roles & ROLE_INGESTION) {
+        sort_precompute_evalplan(tot_msg_per_ing, nthreads);
+    }
+    if (my_roles & ROLE_ROUTING) {
+        sort_precompute_evalplan(max_round2_msgs, nthreads);
+    }
+    if (my_roles & ROLE_STORAGE) {
+        sort_precompute_evalplan(max_stg_msgs, nthreads);
+    }
 #ifdef PROFILE_ROUTING
     printf_with_rtclock_diff(start, "end precompute evalplans\n");
 #endif
@@ -112,38 +124,45 @@ bool route_init()
 }
 
 // Precompute the WaksmanNetworks needed for the sorts.  If you pass -1,
-// it will return the number of different sizes it needs.  If you pass
-// [0,sizes-1], it will compute one WaksmanNetwork with that size index
-// and return the number of available WaksmanNetworks of that size.
+// it will return the number of different sizes it needs to regenerate.
+// If you pass [0,sizes-1], it will compute one WaksmanNetwork with that
+// size index and return the number of available WaksmanNetworks of that
+// size.  If you pass anything else, it will return the number of
+// different sizes it needs at all.
+
+// The list of sizes that need refilling, updated when you pass -1
+static std::vector<uint32_t> used_sizes;
 
 size_t ecall_precompute_sort(int sizeidx)
 {
     size_t ret = 0;
 
-    switch(sizeidx) {
-    case 0:
-#ifdef PROFILE_ROUTING
-    {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.tot_msg_per_ing);
-#endif
-        ret = sort_precompute(route_state.tot_msg_per_ing);
-#ifdef PROFILE_ROUTING
-    printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.tot_msg_per_ing);}
-#endif
-        break;
-    case 1:
+    if (sizeidx == -1) {
+        used_sizes = sort_get_used();
+        ret = used_sizes.size();
+    } else if (sizeidx >= 0 && sizeidx < used_sizes.size()) {
+        uint32_t size = used_sizes[sizeidx];
 #ifdef PROFILE_ROUTING
-    {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.max_round2_msgs);
+        unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", size);
 #endif
-        ret = sort_precompute(route_state.max_round2_msgs);
+        ret = sort_precompute(size);
 #ifdef PROFILE_ROUTING
-    printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.max_round2_msgs);}
+        printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", size);
 #endif
-        break;
-    default:
-        ret = 2;
-        break;
-    }
+    } else {
+        uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
 
+        if (my_roles & ROLE_INGESTION) {
+            used_sizes.push_back(route_state.tot_msg_per_ing);
+        }
+        if (my_roles & ROLE_ROUTING) {
+            used_sizes.push_back(route_state.max_round2_msgs);
+        }
+        if (my_roles & ROLE_STORAGE) {
+            used_sizes.push_back(route_state.max_stg_msgs);
+        }
+        ret = used_sizes.size();
+    }
     return ret;
 }
 
@@ -330,7 +349,7 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 }
 
 // Send the round 1 messages.  Note that N here is not private.
-static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
+static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
     uint32_t N)
 {
     uint16_t msg_size = g_teems_config.msg_size;
@@ -384,15 +403,15 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
             uint8_t *buf = round1.buf + start * msg_size;
 
             for (uint32_t i=0; i<full_rows; ++i) {
-                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                const UidKey *idxp = indices + i*tot_weight + start_weight;
                 for (uint32_t j=0; j<weight; ++j) {
-                    memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                    memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                     buf += msg_size;
                 }
             }
-            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
             for (uint32_t j=0; j<num_msgs_last_row; ++j) {
-                memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                 buf += msg_size;
             }
 
@@ -405,14 +424,14 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
             NodeCommState &nodecom = g_commstates[routing_node];
             nodecom.message_start(num_msgs * msg_size);
             for (uint32_t i=0; i<full_rows; ++i) {
-                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                const UidKey *idxp = indices + i*tot_weight + start_weight;
                 for (uint32_t j=0; j<weight; ++j) {
-                    nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+                    nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
                 }
             }
-            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
             for (uint32_t j=0; j<num_msgs_last_row; ++j) {
-                nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+                nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
             }
         }
     }
@@ -502,7 +521,7 @@ void ecall_routing_proceed(void *cbpointer)
             unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
             unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
-            sort_mtobliv(g_teems_config.nthreads, ingbuf.buf,
+            sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
                 g_teems_config.msg_size, ingbuf.inserted,
                 route_state.tot_msg_per_ing, send_round1_msgs);
 #ifdef PROFILE_ROUTING
@@ -657,14 +676,11 @@ void ecall_routing_proceed(void *cbpointer)
 #ifdef PROFILE_ROUTING
             unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
 #endif
-            storage_received(round2.buf, round2.inserted);
+            storage_received(round2);
 #ifdef PROFILE_ROUTING
             printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
 #endif
 
-            round2.reset();
-            pthread_mutex_unlock(&round2.mutex);
-
             // We're done
             route_state.step = ROUTE_NOT_STARTED;
             ocall_routing_round_complete(cbpointer, 0);

+ 1 - 0
Enclave/route.hpp

@@ -77,6 +77,7 @@ struct RouteState {
     uint32_t tot_msg_per_ing;
     uint32_t max_msg_to_each_stg;
     uint32_t max_round2_msgs;
+    uint32_t max_stg_msgs;
     void *cbpointer;
 };
 

+ 20 - 36
Enclave/sort.cpp

@@ -21,6 +21,16 @@ struct PrecompWNs {
 
 static PrecompWNs precomp_wns;
 
+// A (mutexed) vector of sizes we've used since we were last asked
+struct UsedSizes {
+    pthread_mutex_t mutex;
+    std::vector<uint32_t> used;
+
+    UsedSizes() { pthread_mutex_init(&mutex, NULL); }
+};
+
+static UsedSizes used_sizes;
+
 // A (mutexed) map mapping (N, nthreads) pairs to WNEvalPlans
 struct EvalPlans {
     pthread_mutex_t mutex;
@@ -96,6 +106,9 @@ uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
             pthread_mutex_unlock(&N.second.mutex);
             continue;
         }
+        pthread_mutex_lock(&used_sizes.mutex);
+        used_sizes.used.push_back(N.first);
+        pthread_mutex_unlock(&used_sizes.mutex);
         wn = std::move(N.second.wns.front());
         N.second.wns.pop_front();
         Nw = N.first;
@@ -128,42 +141,13 @@ uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     return Nw;
 }
 
-// Perform the sort using up to nthreads threads.  The items to sort are
-// byte arrays of size msg_size.  The key is the 10-bit storage server
-// id concatenated with the 22-bit uid at the storage server.
-void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
-    uint32_t Nr, uint32_t Na,
-    // the arguments to the callback are items, the sorted indices, and
-    // the number of non-padding items
-    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)> cb)
+std::vector<uint32_t> sort_get_used()
 {
-    // Shuffle the items
-    uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
-
-    // Create the indices
-    uint64_t *idx = new uint64_t[Nr];
-    uint64_t *nextidx = idx;
-    for (uint32_t i=0; i<Nw; ++i) {
-        uint64_t key = (*(uint32_t*)(items+msg_size*i));
-        if (key != uint32_t(-1)) {
-            *nextidx = (key<<32) + i;
-            ++nextidx;
-        }
-    }
-    if (nextidx != idx + Nr) {
-        printf("Found %u non-padding items, expected %u\n",
-            nextidx-idx, Nr);
-        assert(nextidx == idx + Nr);
-    }
-    // Sort the keys and indices
-    uint64_t *backingidx = new uint64_t[Nr];
-    bool whichbuf = mtmergesort<uint64_t>(idx, Nr, backingidx, nthreads);
-    uint64_t *sortedidx = whichbuf ? backingidx : idx;
-    for (uint32_t i=0; i<Nr; ++i) {
-        sortedidx[i] &= uint64_t(0xffffffff);
-    }
-    cb(items, sortedidx, Nr);
+    std::vector<uint32_t> res;
+
+    pthread_mutex_lock(&used_sizes.mutex);
+    res = std::move(used_sizes.used);
+    pthread_mutex_unlock(&used_sizes.mutex);
 
-    delete[] idx;
-    delete[] backingidx;
+    return res;
 }

+ 66 - 5
Enclave/sort.hpp

@@ -27,6 +27,10 @@
 // size.
 size_t sort_precompute(uint32_t N);
 
+// Return a vector of the precomputed sizes we've used since we were
+// last asked
+std::vector<uint32_t> sort_get_used();
+
 // Precompute a WNEvalPlan for a given size and number of threads.
 // These are not consumed as they are used, so you only need to call
 // this once for each (size,nthreads) pair you need.  The precomputation
@@ -41,13 +45,70 @@ void sort_precompute_evalplan(uint32_t N, threadid_t nthreads);
 uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na);
 
-// Perform the sort using up to nthreads threads.  The items to sort are
-// byte arrays of size msg_size.  The key is the 10-bit storage server
-// id concatenated with the 22-bit uid at the storage server.
+// Sort Nr items at the beginning of an allocated array of Na items
+// using up to nthreads threads.  The items to sort are byte arrays of
+// size msg_size.  The keys are of type T.  T must have set_key<T> and
+// compare_keys<T> defined.  The items will be shuffled in-place, and a
+// sorted array of keys will be passed to the provided callback
+// function.
+template<typename T>
 void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na,
-    // the arguments to the callback are items, the sorted indices, and
+    // the arguments to the callback are items, the sorted keys, and
     // the number of non-padding items
-    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)>);
+    std::function<void(const uint8_t*, const T*, uint32_t Nr)>);
+
+// As above, but also pass an Nr*msg_size-byte buffer outbuf to put
+// the sorted values into, instead of passing a callback.  This calls
+// the above function, then copies the data in sorted order into outbuf.
+// Note: the outbuf buffer cannot overlap the items buffer.
+template<typename T>
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na, uint8_t *outbuf);
+
+// As above, but the first Nr msg_size-byte entries in the items array
+// will end up with the sorted values.  Note: if Nr < Na, entries beyond
+// Nr may also change, but you should not even look at those values!
+// This calls the above function with a temporary buffer, then copies
+// that buffer back into the items array, so it's less efficient, both
+// in memory and CPU, than the above functions.
+template<typename T>
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na);
+
+template<typename T>
+inline void set_key(T* key, const uint8_t *item, uint32_t index);
+
+template<typename T>
+int compare_keys(const void* a, const void* b);
+
+// The different kinds of keys we sort on
+
+// The 32-bit userid (which is the 10-bit node id and the 22-bit
+// per-node userid)
+struct UidKey {
+    uint64_t uid_index;
+
+    inline uint32_t index() const { return (uint32_t) uid_index; }
+};
+
+// The above and also the priority (for public routing)
+struct UidPriorityKey {
+    uint64_t uid_priority;
+    uint32_t idx;
+
+    inline uint32_t index() const { return idx; }
+};
+
+// Just the nodeid (not the per-node userid) and the priority (for
+// public routing)
+struct NidPriorityKey {
+    uint64_t nid_priority;
+    uint32_t idx;
+
+    inline uint32_t index() const { return idx; }
+};
+
+#include "sort.tcc"
 
 #endif

+ 197 - 0
Enclave/sort.tcc

@@ -0,0 +1,197 @@
+// set_key for each kind of key we sort on
+
+template<>
+inline void set_key<UidKey>(UidKey *key, const uint8_t *item, uint32_t index)
+{
+    key->uid_index = (uint64_t(*(const uint32_t *)item) << 32) + index;
+}
+
+template<>
+inline void set_key<UidPriorityKey>(UidPriorityKey *key, const uint8_t *item, uint32_t index)
+{
+    key->uid_priority = (uint64_t(*(const uint32_t *)item) << 32) +
+            (*(const uint32_t *)(item+4));
+    key->idx = index;
+}
+
+template<>
+inline void set_key<NidPriorityKey>(NidPriorityKey *key, const uint8_t *item, uint32_t index)
+{
+    constexpr uint32_t nid_mask = (~((1<<DEST_UID_BITS)-1));
+    key->nid_priority = (uint64_t((*(const uint32_t *)item)&nid_mask) << 32) +
+            (*(const uint32_t *)(item+4));
+    key->idx = index;
+}
+
+// compare_keys for each kind of key we sort on.  Note that it must not
+// be possible for any of these functions to return 0.  These functions
+// must also be oblivious.  Return a positive (32-bit signed) int if *a
+// is larger than *b, or a negative (32-bit signed) int otherwise.
+
+template<>
+int compare_keys<UidKey>(const void* a, const void* b)
+{
+    bool alarge = (((const UidKey*)a)->uid_index >
+        ((const UidKey *)b)->uid_index);
+    return oselect_uint32_t(-1, 1, alarge);
+}
+
+template<>
+int compare_keys<UidPriorityKey>(const void* a, const void* b)
+{
+    uint64_t aup = ((const UidPriorityKey*)a)->uid_priority;
+    uint64_t bup = ((const UidPriorityKey*)b)->uid_priority;
+    uint32_t aidx = ((const UidPriorityKey*)a)->idx;
+    uint32_t bidx = ((const UidPriorityKey*)b)->idx;
+    bool auplarge = (aup > bup);
+    bool aupeq = (aup == bup);
+    bool aidxlarge = (aidx > bidx);
+    bool alarge = auplarge | (aupeq & aidxlarge);
+    return oselect_uint32_t(-1, 1, alarge);
+}
+
+template<>
+int compare_keys<NidPriorityKey>(const void* a, const void* b)
+{
+    uint64_t anp = ((const NidPriorityKey*)a)->nid_priority;
+    uint64_t bnp = ((const NidPriorityKey*)b)->nid_priority;
+    uint32_t aidx = ((const NidPriorityKey*)a)->idx;
+    uint32_t bidx = ((const NidPriorityKey*)b)->idx;
+    bool anplarge = (anp > bnp);
+    bool anpeq = (anp == bnp);
+    bool aidxlarge = (aidx > bidx);
+    bool alarge = anplarge | (anpeq & aidxlarge);
+    return oselect_uint32_t(-1, 1, alarge);
+}
+
+// Sort Nr items at the beginning of an allocated array of Na items
+// using up to nthreads threads.  The items to sort are byte arrays of
+// size msg_size.  The keys are of type T.  T must have set_key<T> and
+// compare_keys<T> defined.  The items will be shuffled in-place, and a
+// sorted array of keys will be passed to the provided callback
+// function.
+template<typename T>
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na,
+    // the arguments to the callback are items, the sorted indices, and
+    // the number of non-padding items
+    std::function<void(const uint8_t*, const T*, uint32_t Nr)> cb)
+{
+    // Shuffle the items
+    uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
+
+    // Create the indices
+    T *idx = new T[Nr];
+    T *nextidx = idx;
+    for (uint32_t i=0; i<Nw; ++i) {
+        uint64_t padding = (*(uint32_t*)(items+msg_size*i));
+        if (padding != uint32_t(-1)) {
+            set_key<T>(nextidx, items+msg_size*i, i);
+            ++nextidx;
+        }
+    }
+    if (nextidx != idx + Nr) {
+        printf("Found %u non-padding items, expected %u\n",
+            nextidx-idx, Nr);
+        assert(nextidx == idx + Nr);
+    }
+    // Sort the keys and indices
+    T *backingidx = new T[Nr];
+    bool whichbuf = mtmergesort<T>(idx, Nr, backingidx, nthreads);
+    T *sortedidx = whichbuf ? backingidx : idx;
+    cb(items, sortedidx, Nr);
+
+    delete[] idx;
+    delete[] backingidx;
+}
+
+template <typename T>
+struct move_sorted_args {
+    const T* sorted_keys;
+    const uint8_t *items;
+    uint8_t *destbuf;
+    uint32_t start, num;
+    uint16_t msg_size;
+};
+
+template <typename T>
+static void *move_sorted(void *voidargs)
+{
+    const move_sorted_args<T> *args =
+        (move_sorted_args<T> *)voidargs;
+    uint16_t msg_size = args->msg_size;
+    uint32_t start = args->start;
+    uint32_t end = start + args->num;
+    const T *sorted_keys = args->sorted_keys;
+    const uint8_t *items = args->items;
+    uint8_t *destbuf = args->destbuf;
+    for (uint32_t i=start; i<end; ++i) {
+        memmove(destbuf + i * msg_size,
+            items + (sorted_keys[i].index()) * msg_size,
+            msg_size);
+    }
+    return NULL;
+}
+
+// As above, but also pass an Nr*msg_size-byte buffer outbuf to put
+// the sorted values into, instead of passing a callback.  This calls
+// the above function, then copies the data in sorted order into outbuf.
+// Note: the outbuf buffer cannot overlap the items buffer.
+template<typename T>
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na, uint8_t *outbuf)
+{
+    sort_mtobliv<T>(nthreads, items, msg_size, Nr, Na,
+        [nthreads, msg_size, outbuf]
+        (const uint8_t* origitems, const T* keys, uint32_t Nr) {
+            // Special-case nthreads=1 for efficiency
+            if (nthreads <= 1) {
+                move_sorted_args<T> args = {
+                    keys, origitems, outbuf, 0, Nr, msg_size
+                };
+                move_sorted<T>(&args);
+            } else {
+                move_sorted_args<T> args[nthreads];
+                uint32_t inc = Nr / nthreads;
+                uint32_t extra = Nr % nthreads;
+                uint32_t last = 0;
+                for (threadid_t i=0; i<nthreads; ++i) {
+                    uint32_t num = inc + (i < extra);
+                    args[i] = {
+                        keys, origitems, outbuf, last, num, msg_size
+                    };
+                    last += num;
+                }
+
+                // Launch all but the first section into other threads
+                for (threadid_t i=1; i<nthreads; ++i) {
+                    threadpool_dispatch(g_thread_id+i,
+                        move_sorted<T>, args+i);
+                }
+
+                // Do the first section ourselves
+                move_sorted<T>(args);
+
+                // Join the threads
+                for (threadid_t i=1; i<nthreads; ++i) {
+                    threadpool_join(g_thread_id+i, NULL);
+                }
+            }
+        });
+}
+
+// As above, but the first Nr msg_size-byte entries in the items array
+// will end up with the sorted values.  Note: if Nr < Na, entries beyond
+// Nr may also change, but you should not even look at those values!
+// This calls the above function with a temporary buffer, then copies
+// that buffer back into the items array, so it's less efficient, both
+// in memory and CPU, than the above functions.
+template<typename T>
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na)
+{
+    uint8_t *tempbuf = new uint8_t[Nr * msg_size];
+    sort_mtobliv<T>(nthreads, items, msg_size, Nr, Na, tempbuf);
+    memmove(items, tempbuf, Nr * msg_size);
+    delete[] tempbuf;
+}

+ 159 - 5
Enclave/storage.cpp

@@ -1,18 +1,75 @@
 #include "utils.hpp"
 #include "config.hpp"
-#include "storage.hpp"
 #include "ORExpand.hpp"
+#include "sort.hpp"
+#include "storage.hpp"
+
+#define PROFILE_STORAGE
+
+static struct {
+    uint32_t max_users;
+    uint32_t my_storage_node_id;
+    // A local storage buffer, used when we need to do non-in-place
+    // sorts of the messages that have arrived
+    MsgBuffer stg_buf;
+    // The destination vector for ORExpand
+    std::vector<uint32_t> dest;
+} storage_state;
 
-// Handle the messages received by a storage node
-void storage_received(const uint8_t *msgs, uint32_t num_msgs)
+// route_init will call this function; no one else should call it
+// explicitly.  The parameter is the number of messages that can fit in
+// the storage-side MsgBuffer.  Returns true on success, false on
+// failure.
+bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
+{
+    storage_state.max_users = max_users;
+    storage_state.stg_buf.alloc(msg_buf_size);
+    storage_state.dest.resize(msg_buf_size);
+    uint32_t my_storage_node_id = 0;
+    for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
+        if (g_teems_config.roles[i] & ROLE_STORAGE) {
+            if (i == g_teems_config.my_node_num) {
+                storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
+            } else {
+                ++my_storage_node_id;
+            }
+        }
+    }
+    return true;
+}
+
+// Handle the messages received by a storage node.  Pass a _locked_
+// MsgBuffer.  This function will itself reset and unlock it when it's
+// done with it.
+void storage_received(MsgBuffer &storage_buf)
 {
-    // A dummy function for now that just counts how many real and
-    // padding messages arrived
     uint16_t msg_size = g_teems_config.msg_size;
     nodenum_t my_node_num = g_teems_config.my_node_num;
+    const uint8_t *msgs = storage_buf.buf;
+    uint32_t num_msgs = storage_buf.inserted;
     uint32_t real = 0, padding = 0;
     uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
+    uint32_t nid_mask = ~uid_mask;
+
+#ifdef PROFILE_STORAGE
+    unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
+#endif
+
+    // It's OK to test for errors in a way that's non-oblivous if
+    // there's an error (but it should be oblivious if there are no
+    // errors)
+    for (uint32_t i=0; i<num_msgs; ++i) {
+        uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
+        bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
+            & ((uid & uid_mask) < storage_state.max_users))
+            | ((uid & uid_mask) == uid_mask));
+        if (!ok) {
+            printf("Received bad uid: %08x\n", uid);
+            assert(ok);
+        }
+    }
 
+    // Testing: report how many real and dummy messages arrived
     printf("Storage server received %u messages:\n", num_msgs);
     for (uint32_t i=0; i<num_msgs; ++i) {
         uint32_t dest_addr = *(const uint32_t*)msgs;
@@ -32,4 +89,101 @@ void storage_received(const uint8_t *msgs, uint32_t num_msgs)
         msgs += msg_size;
     }
     printf("%u real, %u padding\n", real, padding);
+
+    /*
+    for (uint32_t i=0;i<num_msgs; ++i) {
+        printf("%3d: %08x %08x\n", i,
+        *(uint32_t*)(storage_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
+    }
+    */
+    // Sort the received messages by userid into the
+    // storage_state.stg_buf MsgBuffer.
+#ifdef PROFILE_STORAGE
+    unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
+#endif
+    sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
+        msg_size, storage_buf.inserted, storage_buf.bufsize,
+        storage_state.stg_buf.buf);
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
+#endif
+
+    /*
+    for (uint32_t i=0;i<num_msgs; ++i) {
+        printf("%3d: %08x %08x\n", i,
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
+    }
+    */
+
+#ifdef PROFILE_STORAGE
+    unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
+#endif
+    // Obliviously set the dest array
+    uint32_t *dests = storage_state.dest.data();
+    uint32_t stg_size = storage_state.stg_buf.bufsize;
+    const uint8_t *buf = storage_state.stg_buf.buf;
+    uint32_t m_priv_in = g_teems_config.m_priv_in;
+
+    uint32_t uid = *(uint32_t*)(buf);
+    uid &= uid_mask;
+    // num_msgs is not a private value
+    if (num_msgs > 0) {
+        dests[0] = oselect_uint32_t(uid * m_priv_in, 0xffffffff,
+            uid == uid_mask);
+    }
+    uint32_t prev_uid = uid;
+    for (uint32_t i=1; i<num_msgs; ++i) {
+        uid = *(uint32_t*)(buf + i*msg_size);
+        uid &= uid_mask;
+        uint32_t next = oselect_uint32_t(uid * m_priv_in, dests[i-1]+1,
+            uid == prev_uid);
+        dests[i] = oselect_uint32_t(next, 0xffffffff, uid == uid_mask);
+        prev_uid = uid;
+    }
+    for (uint32_t i=num_msgs; i<stg_size; ++i) {
+        dests[i] = 0xffffffff;
+        *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
+    }
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
+#endif
+    /*
+    for (uint32_t i=0;i<stg_size; ++i) {
+        printf("%3d: %08x %08x %u\n", i,
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
+        dests[i]);
+    }
+    */
+
+#ifdef PROFILE_STORAGE
+    unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
+#endif
+    ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
+        msg_size, stg_size, g_teems_config.nthreads);
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
+#endif
+    /*
+    for (uint32_t i=0;i<stg_size; ++i) {
+        printf("%3d: %08x %08x %u\n", i,
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
+        *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
+        dests[i]);
+    }
+    */
+
+    // You can do more processing after these lines, as long as they
+    // don't touch storage_buf.  They _can_ touch the backing buffer
+    // storage_state.stg_buf.
+    storage_buf.reset();
+    pthread_mutex_unlock(&storage_buf.mutex);
+
+    storage_state.stg_buf.reset();
+
+#ifdef PROFILE_STORAGE
+    printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
+#endif
 }

+ 12 - 2
Enclave/storage.hpp

@@ -3,7 +3,17 @@
 
 #include <cstdint>
 
-// Handle the messages received by a storage node
-void storage_received(const uint8_t *msgs, uint32_t num_msgs);
+#include "route.hpp"
+
+// route_init will call this function; no one else should call it
+// explicitly.  The parameter is the number of messages that can fit in
+// the storage-side MsgBuffer.  Returns true on success, false on
+// failure.
+bool storage_init(uint32_t max_users, uint32_t msg_buf_size);
+
+// Handle the messages received by a storage node.  Pass a _locked_
+// MsgBuffer.  This function will itself reset and unlock it when it's
+// done with it.
+void storage_received(MsgBuffer &storage_buf);
 
 #endif

+ 17 - 8
Makefile

@@ -348,8 +348,9 @@ Enclave/route.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
 Enclave/route.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
 Enclave/route.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
 Enclave/route.o: Enclave/OblivAlgs/aes.hpp
-Enclave/route.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/comms.hpp
-Enclave/route.o: Enclave/obliv.hpp Enclave/storage.hpp Enclave/route.hpp
+Enclave/route.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/sort.tcc
+Enclave/route.o: Enclave/comms.hpp Enclave/obliv.hpp Enclave/storage.hpp
+Enclave/route.o: Enclave/route.hpp
 Enclave/ingest.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
 Enclave/ingest.o: Enclave/route.hpp
 Enclave/sort.o: Enclave/sort.hpp Enclave/OblivAlgs/WaksmanNetwork.hpp
@@ -361,22 +362,30 @@ Enclave/sort.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
 Enclave/sort.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
 Enclave/sort.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
 Enclave/sort.o: Enclave/OblivAlgs/aes.hpp
-Enclave/sort.o: Enclave/OblivAlgs/WaksmanNetwork.tcc
+Enclave/sort.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/sort.tcc
 Enclave/storage.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
 Enclave/storage.o: Enclave/enclave_api.h Enclave/OblivAlgs/CONFIG.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.tcc Enclave/OblivAlgs/foav.h
 Enclave/storage.o: Enclave/config.hpp Enclave/enclave_api.h
-Enclave/storage.o: Enclave/storage.hpp Enclave/OblivAlgs/ORExpand.hpp
-Enclave/storage.o: Enclave/OblivAlgs/ORExpand.tcc
+Enclave/storage.o: Enclave/OblivAlgs/ORExpand.hpp
+Enclave/storage.o: Enclave/OblivAlgs/ORExpand.tcc Enclave/sort.hpp
+Enclave/storage.o: Enclave/OblivAlgs/WaksmanNetwork.hpp
+Enclave/storage.o: Enclave/OblivAlgs/RecursiveShuffle.hpp
+Enclave/storage.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
+Enclave/storage.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
+Enclave/storage.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
+Enclave/storage.o: Enclave/OblivAlgs/aes.hpp
+Enclave/storage.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/sort.tcc
+Enclave/storage.o: Enclave/storage.hpp Enclave/route.hpp
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.hpp
-Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/enclave_api.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/CONFIG.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.tcc
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/foav.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.tcc
-Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
-Enclave/OblivAlgs/ORExpand.o: Enclave/enclave_api.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/CONFIG.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.tcc