Pārlūkot izejas kodu

Generalize sort_mtobliv to be able to sort on different kinds of keys

Ian Goldberg 1 gadu atpakaļ
vecāks
revīzija
5cb57bc4be
6 mainītis faili ar 132 papildinājumiem un 68 dzēšanām
  1. 0 10
      Enclave/OblivAlgs/WaksmanNetwork.tcc
  2. 10 10
      Enclave/route.cpp
  3. 0 39
      Enclave/sort.cpp
  4. 39 3
      Enclave/sort.hpp
  5. 76 0
      Enclave/sort.tcc
  6. 7 6
      Makefile

+ 0 - 10
Enclave/OblivAlgs/WaksmanNetwork.tcc

@@ -3,16 +3,6 @@
 
 template<typename T> static int compare_keys(const void *a, const void *b);
 
-template<>
-int compare_keys<uint64_t>(const void *a, const void *b)
-{
-    uint32_t *a32 = (uint32_t*)a;
-    uint32_t *b32 = (uint32_t*)b;
-    int hi = a32[1]-b32[1];
-    int lo = a32[0]-b32[0];
-    return oselect_uint32_t(hi, lo, !hi);
-}
-
 template<typename T>
 struct MergeArgs {
     T* dst;

+ 10 - 10
Enclave/route.cpp

@@ -352,7 +352,7 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 }
 
 // Send the round 1 messages.  Note that N here is not private.
-static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
+static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
     uint32_t N)
 {
     uint16_t msg_size = g_teems_config.msg_size;
@@ -406,15 +406,15 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
             uint8_t *buf = round1.buf + start * msg_size;
 
             for (uint32_t i=0; i<full_rows; ++i) {
-                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                const UidKey *idxp = indices + i*tot_weight + start_weight;
                 for (uint32_t j=0; j<weight; ++j) {
-                    memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                    memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                     buf += msg_size;
                 }
             }
-            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
             for (uint32_t j=0; j<num_msgs_last_row; ++j) {
-                memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
                 buf += msg_size;
             }
 
@@ -427,14 +427,14 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
             NodeCommState &nodecom = g_commstates[routing_node];
             nodecom.message_start(num_msgs * msg_size);
             for (uint32_t i=0; i<full_rows; ++i) {
-                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                const UidKey *idxp = indices + i*tot_weight + start_weight;
                 for (uint32_t j=0; j<weight; ++j) {
-                    nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+                    nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
                 }
             }
-            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
             for (uint32_t j=0; j<num_msgs_last_row; ++j) {
-                nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+                nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
             }
         }
     }
@@ -524,7 +524,7 @@ void ecall_routing_proceed(void *cbpointer)
             unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
             unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
-            sort_mtobliv(g_teems_config.nthreads, ingbuf.buf,
+            sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
                 g_teems_config.msg_size, ingbuf.inserted,
                 route_state.tot_msg_per_ing, send_round1_msgs);
 #ifdef PROFILE_ROUTING

+ 0 - 39
Enclave/sort.cpp

@@ -128,42 +128,3 @@ uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     return Nw;
 }
 
-// Perform the sort using up to nthreads threads.  The items to sort are
-// byte arrays of size msg_size.  The key is the 10-bit storage server
-// id concatenated with the 22-bit uid at the storage server.
-void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
-    uint32_t Nr, uint32_t Na,
-    // the arguments to the callback are items, the sorted indices, and
-    // the number of non-padding items
-    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)> cb)
-{
-    // Shuffle the items
-    uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
-
-    // Create the indices
-    uint64_t *idx = new uint64_t[Nr];
-    uint64_t *nextidx = idx;
-    for (uint32_t i=0; i<Nw; ++i) {
-        uint64_t key = (*(uint32_t*)(items+msg_size*i));
-        if (key != uint32_t(-1)) {
-            *nextidx = (key<<32) + i;
-            ++nextidx;
-        }
-    }
-    if (nextidx != idx + Nr) {
-        printf("Found %u non-padding items, expected %u\n",
-            nextidx-idx, Nr);
-        assert(nextidx == idx + Nr);
-    }
-    // Sort the keys and indices
-    uint64_t *backingidx = new uint64_t[Nr];
-    bool whichbuf = mtmergesort<uint64_t>(idx, Nr, backingidx, nthreads);
-    uint64_t *sortedidx = whichbuf ? backingidx : idx;
-    for (uint32_t i=0; i<Nr; ++i) {
-        sortedidx[i] &= uint64_t(0xffffffff);
-    }
-    cb(items, sortedidx, Nr);
-
-    delete[] idx;
-    delete[] backingidx;
-}

+ 39 - 3
Enclave/sort.hpp

@@ -42,12 +42,48 @@ uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na);
 
 // Perform the sort using up to nthreads threads.  The items to sort are
-// byte arrays of size msg_size.  The key is the 10-bit storage server
-// id concatenated with the 22-bit uid at the storage server.
+// byte arrays of size msg_size.  The keys are of type T.  T must have
+// set_key<T> and compare_keys<T> defined.
+template<typename T>
 void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na,
     // the arguments to the callback are items, the sorted indices, and
     // the number of non-padding items
-    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)>);
+    std::function<void(const uint8_t*, const T*, uint32_t Nr)>);
+
+template<typename T>
+inline void set_key(T* key, const uint8_t *item, uint32_t index);
+
+template<typename T>
+int compare_keys(const void* a, const void* b);
+
+// The different kinds of keys we sort on
+
+// The 32-bit userid (which is the 10-bit node id and the 22-bit
+// per-node userid)
+struct UidKey {
+    uint64_t uid_index;
+
+    inline uint32_t index() const { return (uint32_t) uid_index; }
+};
+
+// The above and also the priority (for public routing)
+struct UidPriorityKey {
+    uint64_t uid_priority;
+    uint32_t idx;
+
+    inline uint32_t index() const { return idx; }
+};
+
+// Just the nodeid (not the per-node userid) and the priority (for
+// public routing)
+struct NidPriorityKey {
+    uint64_t nid_priority;
+    uint32_t idx;
+
+    inline uint32_t index() const { return idx; }
+};
+
+#include "sort.tcc"
 
 #endif

+ 76 - 0
Enclave/sort.tcc

@@ -0,0 +1,76 @@
+// set_key for each kind of key we sort on
+
+template<>
+inline void set_key<UidKey>(UidKey *key, const uint8_t *item, uint32_t index)
+{
+    key->uid_index = (uint64_t(*(const uint32_t *)item) << 32) + index;
+}
+
+template<>
+inline void set_key<UidPriorityKey>(UidPriorityKey *key, const uint8_t *item, uint32_t index)
+{
+    key->uid_priority = (uint64_t(*(const uint32_t *)item) << 32) +
+            (*(const uint32_t *)(item+4));
+    key->idx = index;
+}
+
+template<>
+inline void set_key<NidPriorityKey>(NidPriorityKey *key, const uint8_t *item, uint32_t index)
+{
+    constexpr uint32_t nid_mask = (~((1<<DEST_UID_BITS)-1));
+    key->nid_priority = (uint64_t((*(const uint32_t *)item)&nid_mask) << 32) +
+            (*(const uint32_t *)(item+4));
+    key->idx = index;
+}
+
+// compare_keys for each kind of key we sort on.  Note that it must not
+// be possible for any of these functions to return 0.  These functions
+// must also be oblivious.  Return a positive (32-bit signed) int if *a
+// is larger than *b, or a negative (32-bit signed) int otherwise.
+
+template<>
+int compare_keys<UidKey>(const void* a, const void* b)
+{
+    bool alarge = (((const UidKey*)a)->uid_index >
+        ((const UidKey *)b)->uid_index);
+    return oselect_uint32_t(-1, 1, alarge);
+}
+
+// Perform the sort using up to nthreads threads.  The items to sort are
+// byte arrays of size msg_size.  The key is the 10-bit storage server
+// id concatenated with the 22-bit uid at the storage server.
+template<typename T>
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na,
+    // the arguments to the callback are items, the sorted indices, and
+    // the number of non-padding items
+    std::function<void(const uint8_t*, const T*, uint32_t Nr)> cb)
+{
+    // Shuffle the items
+    uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
+
+    // Create the indices
+    T *idx = new T[Nr];
+    T *nextidx = idx;
+    for (uint32_t i=0; i<Nw; ++i) {
+        uint64_t padding = (*(uint32_t*)(items+msg_size*i));
+        if (padding != uint32_t(-1)) {
+            set_key<T>(nextidx, items+msg_size*i, i);
+            ++nextidx;
+        }
+    }
+    if (nextidx != idx + Nr) {
+        printf("Found %u non-padding items, expected %u\n",
+            nextidx-idx, Nr);
+        assert(nextidx == idx + Nr);
+    }
+    // Sort the keys and indices
+    T *backingidx = new T[Nr];
+    bool whichbuf = mtmergesort<T>(idx, Nr, backingidx, nthreads);
+    T *sortedidx = whichbuf ? backingidx : idx;
+    cb(items, sortedidx, Nr);
+
+    delete[] idx;
+    delete[] backingidx;
+}
+

+ 7 - 6
Makefile

@@ -321,8 +321,9 @@ Enclave/route.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
 Enclave/route.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
 Enclave/route.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
 Enclave/route.o: Enclave/OblivAlgs/aes.hpp
-Enclave/route.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/comms.hpp
-Enclave/route.o: Enclave/obliv.hpp Enclave/storage.hpp Enclave/route.hpp
+Enclave/route.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/sort.tcc
+Enclave/route.o: Enclave/comms.hpp Enclave/obliv.hpp Enclave/storage.hpp
+Enclave/route.o: Enclave/route.hpp
 Enclave/sort.o: Enclave/sort.hpp Enclave/OblivAlgs/WaksmanNetwork.hpp
 Enclave/sort.o: Enclave/OblivAlgs/oasm_lib.h Enclave/OblivAlgs/CONFIG.h
 Enclave/sort.o: Enclave/OblivAlgs/oasm_lib.tcc Enclave/OblivAlgs/foav.h
@@ -332,7 +333,7 @@ Enclave/sort.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
 Enclave/sort.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
 Enclave/sort.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
 Enclave/sort.o: Enclave/OblivAlgs/aes.hpp
-Enclave/sort.o: Enclave/OblivAlgs/WaksmanNetwork.tcc
+Enclave/sort.o: Enclave/OblivAlgs/WaksmanNetwork.tcc Enclave/sort.tcc
 Enclave/storage.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
 Enclave/storage.o: Enclave/enclave_api.h Enclave/OblivAlgs/CONFIG.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.h
@@ -341,13 +342,13 @@ Enclave/storage.o: Enclave/config.hpp Enclave/enclave_api.h
 Enclave/storage.o: Enclave/storage.hpp Enclave/OblivAlgs/ORExpand.hpp
 Enclave/storage.o: Enclave/OblivAlgs/ORExpand.tcc
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.hpp
-Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/enclave_api.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/CONFIG.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.tcc
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/foav.h
 Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.tcc
-Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
-Enclave/OblivAlgs/ORExpand.o: Enclave/enclave_api.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/CONFIG.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.tcc