Просмотр исходного кода

Sort using precomputable shuffles

Ian Goldberg 1 год назад
Родитель
Сommit
ae3646cf41

+ 1 - 253
Enclave/OblivAlgs/WaksmanNetwork.cpp

@@ -1158,256 +1158,6 @@ void DecryptAndOWSS(unsigned char *encrypted_buffer, uint32_t N,
 }
 #endif
 
-// #define PROFILE_MTMERGESORT
-
-template<typename T> static int compare(const void *a, const void *b);
-
-template<>
-int compare<uint64_t>(const void *a, const void *b)
-{
-    uint32_t *a32 = (uint32_t*)a;
-    uint32_t *b32 = (uint32_t*)b;
-    int hi = a32[1]-b32[1];
-    int lo = a32[0]-b32[0];
-    return oselect_uint32_t(hi, lo, !hi);
-}
-
-template<typename T>
-struct MergeArgs {
-    T* dst;
-    const T* leftsrc;
-    size_t Nleft;
-    const T* rightsrc;
-    size_t Nright;
-};
-
-// Merge two sorted arrays into one.  The (sorted) source arrays are
-// leftsrc and rightsrc of lengths Nleft and Nright respectively.  Put
-// the merged sorted array into dst[0..Nleft+Nright-1].  Use up to the
-// given number of threads.
-template<typename T>
-static void* merge(void *voidargs)
-{
-    const MergeArgs<T>* args = (const MergeArgs<T>*)voidargs;
-#ifdef PROFILE_MTMERGESORT
-unsigned long start = printf_with_rtclock("begin merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
-#endif
-    T* dst = args->dst;
-    const T* left = args->leftsrc;
-    const T* right = args->rightsrc;
-    const T* leftend = args->leftsrc + args->Nleft;
-    const T* rightend = args->rightsrc + args->Nright;
-
-    while (left != leftend && right != rightend) {
-        if (compare<T>(left, right) < 0) {
-            *dst = *left;
-            ++dst;
-            ++left;
-        } else {
-            *dst = *right;
-            ++dst;
-            ++right;
-        }
-    }
-
-    if (left != leftend) {
-        memmove(dst, left, (leftend-left)*sizeof(T));
-    }
-    if (right != rightend) {
-        memmove(dst, right, (rightend-right)*sizeof(T));
-    }
-#ifdef PROFILE_MTMERGESORT
-printf_with_rtclock_diff(start, "end merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
-#endif
-
-    return NULL;
-}
-
-// In the sorted subarray src[0 .. len-1], binary search for the first
-// element that's larger than the target.  The return value is the index
-// into that subarray, so it's 0 if src[0] > target, and it's len if all
-// the elements are less than the target.  Remember that all elements
-// have to be different, so no comparison will ever return that the
-// elements are equal.
-template<typename T>
-static size_t binsearch(const T* src, size_t len, const T* target)
-{
-    size_t left = 0;
-    size_t right = len;
-
-    if (len == 0) {
-        return 0;
-    }
-    if (compare<T>(src + left, target) > 0) {
-        return 0;
-    }
-    if (len > 0 && compare<T>(src + right - 1, target) < 0) {
-        return len;
-    }
-
-    // Invariant: src[left] < target and src[right] > target (where
-    // src[len] is considered to be greater than all targets)
-    while (right - left > 1) {
-        size_t mid = left + (right - left)/2;
-        if (compare<T>(src + mid, target) > 0) {
-            right = mid;
-        } else {
-            left = mid;
-        }
-    }
-
-    return right;
-}
-
-// Merge two sorted arrays into one.  The (sorted) source arrays are
-// leftsrc and rightsrc of lengths Nleft and Nright respectively.  Put
-// the merged sorted array into dst[0..Nleft+Nright-1].  Use up to the
-// given number of threads.
-template<typename T>
-static void mtmerge(T* dst, const T* leftsrc, size_t Nleft,
-    const T* rightsrc, size_t Nright, threadid_t nthreads)
-{
-#ifdef PROFILE_MTMERGESORT
-unsigned long start = printf_with_rtclock("begin mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
-#endif
-
-    threadid_t threads_to_use = nthreads;
-    if (Nleft < 500 || Nright < 500) {
-        threads_to_use = 1;
-    }
-
-    // Break the left array into threads_to_use approximately
-    // equal-sized pieces
-
-    MergeArgs<T> margs[threads_to_use];
-    size_t leftinc = Nleft / threads_to_use;
-    size_t leftextra = Nleft % threads_to_use;
-    size_t leftlast = 0;
-    size_t rightlast = 0;
-
-    for (threadid_t t=0; t<threads_to_use; ++t) {
-        size_t leftlen = leftinc + (t < leftextra);
-        // Find the segment in the right array corresponding to this
-        // segment in the lest array.  If this is the last segment of
-        // the left array, that's just the whole remaining right array.
-        size_t rightlen;
-        if (t == threads_to_use - 1) {
-            rightlen = Nright - rightlast;
-        } else {
-            // The first element of the next left segment
-            const T* target = leftsrc + leftlast + leftlen;
-            // In the sorted subarray rightsrc[rightlast .. Nright-1],
-            // binary search for the first element that's larger than
-            // the target.  The return value is the index into that
-            // subarray, so it's 0 if rightsrc[rightlast] > target, and
-            // it's Nright-rightlast if all the elements are less than
-            // the target.
-            rightlen = binsearch<T>(rightsrc + rightlast,
-                Nright-rightlast, target);
-        }
-        margs[t] = { dst + leftlast + rightlast,
-            leftsrc + leftlast, leftlen,
-            rightsrc + rightlast, rightlen };
-        leftlast += leftlen;
-        rightlast += rightlen;
-        if (t > 0) {
-            threadpool_dispatch(g_thread_id+t, merge<T>, &margs[t]);
-        }
-    }
-    // Do the first block ourselves
-    merge<T>(&margs[0]);
-    for (size_t t=1; t<threads_to_use; ++t) {
-        threadpool_join(g_thread_id+t, NULL);
-    }
-
-#ifdef PROFILE_MTMERGESORT
-printf_with_rtclock_diff(start, "end mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
-#endif
-}
-
-template<typename T>
-struct MTMergesortArgs {
-    T* buf;
-    size_t N;
-    T* backing;
-    threadid_t nthreads;
-    bool ret;
-};
-
-template<typename T>
-static bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads);
-
-template<typename T>
-static void *mtmergesort_launch(void *voidargs)
-{
-    MTMergesortArgs<T>* args = (MTMergesortArgs<T>*)voidargs;
-    args->ret = mtmergesort<T>(args->buf, args->N, args->backing,
-        args->nthreads);
-    return NULL;
-}
-
-// Multithreaded mergesort.  Pass the data of type T to sort, as a
-// pointer and number of elements.  Also pass a backing store of the
-// same size.  The sorted data will end up in either the original data
-// array or the backing store; this function will return false if it's
-// in the original data and true if it's in the backing store.  Use up
-// to the given number of threads.
-template<typename T>
-static bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads)
-{
-    if (nthreads == 1 || N < 1000) {
-        // Just sort naively
-#ifdef PROFILE_MTMERGESORT
-unsigned long start = printf_with_rtclock("begin qsort(buf=%p, N=%lu)\n", buf, N);
-#endif
-        qsort(buf, N, sizeof(T), compare<T>);
-#ifdef PROFILE_MTMERGESORT
-printf_with_rtclock_diff(start, "end qsort(buf=%p, N=%lu)\n", buf, N);
-#endif
-        return false;
-    }
-#ifdef PROFILE_MTMERGESORT
-unsigned long start = printf_with_rtclock("begin mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
-#endif
-    size_t Nleft = (N+1)/2;
-    size_t Nright = N/2;
-    threadid_t threads_left = (nthreads+1)/2;
-    threadid_t threads_right = nthreads/2;
-
-    MTMergesortArgs<T> ms_right_args =
-        { buf + Nleft, Nright, backing + Nleft, threads_right, false };
-    threadpool_dispatch(g_thread_id+threads_left, mtmergesort_launch<T>,
-        &ms_right_args);
-    bool leftret = mtmergesort<T>(buf, Nleft, backing, threads_left);
-    threadpool_join(g_thread_id+threads_left, NULL);
-    bool rightret = ms_right_args.ret;
-
-    // If the left and right sorts put their answers in different
-    // places, move the right answer to match the left
-    if (leftret != rightret) {
-        if (leftret) {
-            // The left is in backing, and the right is in buf
-            memmove(backing + Nleft, buf + Nleft, Nright * sizeof(T));
-        } else {
-            // The left is in buf, and the right is in backing
-            memmove(buf + Nleft, backing + Nleft, Nright * sizeof(T));
-        }
-    }
-
-    // Merge the two halves
-    if (leftret) {
-        // The recursive outputs are in backing; merge them into buf
-        mtmerge<T>(buf, backing, Nleft, backing+Nleft, Nright, nthreads);
-    } else {
-        // The recursive outputs are in buf; merge them into backing
-        mtmerge<T>(backing, buf, Nleft, buf+Nleft, Nright, nthreads);
-    }
-#ifdef PROFILE_MTMERGESORT
-printf_with_rtclock_diff(start, "end mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
-#endif
-    return !leftret;
-}
-
 struct datacopy_args {
     const unsigned char *inbuf;
     const uint64_t *idx;
@@ -1426,13 +1176,12 @@ static void* datacopy_range(void *voidargs)
     return NULL;
 }
 
+#if 0
 // Sort the given array of N elements, each of size sz, using up to
 // nthreads threads. The output is put into the same memory as the input
 // array.  The first 4 bytes of each element is its key.
 static void mtsort(void *buffer, size_t N, size_t sz, threadid_t nthreads)
 {
-    // No multithreading yet
-
     uint64_t *idx = new uint64_t[N];
     unsigned char *inbuf = (unsigned char *)buffer;
     unsigned char *outbuf = new unsigned char[N*sz];
@@ -1474,7 +1223,6 @@ static void mtsort(void *buffer, size_t N, size_t sz, threadid_t nthreads)
     delete[] outbuf;
 }
 
-#if 0
 void DecryptAndMTSS(unsigned char *encrypted_buffer, uint32_t N,
   size_t encrypted_block_size, threadid_t nthreads,
   unsigned char *result_buffer, enc_ret *ret) {

+ 16 - 0
Enclave/OblivAlgs/WaksmanNetwork.hpp

@@ -43,6 +43,13 @@ struct WNEvalPlan {
   // or 2 (otherwise) items.
   std::vector<WNEvalPlan> subplans;
 
+  // Make WNEvalPlan objects non-copyable for efficiency
+  WNEvalPlan(const WNEvalPlan&) = delete;
+  WNEvalPlan& operator=(const WNEvalPlan&) = delete;
+  // But moves are OK
+  WNEvalPlan(WNEvalPlan &&wn) = default;
+  WNEvalPlan& operator=(WNEvalPlan&&) = default;
+
   WNEvalPlan(uint32_t N, uint32_t nthreads) : N(N), nthreads(nthreads) {
       if (N<2) {
           subtree_num_inswitches = 0;
@@ -202,6 +209,13 @@ class WaksmanNetwork {
 
 public:
 
+  // Make WaksmanNetwork objects non-copyable for efficiency
+  WaksmanNetwork(const WaksmanNetwork&) = delete;
+  WaksmanNetwork& operator=(const WaksmanNetwork&) = delete;
+  // But moves are OK
+  WaksmanNetwork(WaksmanNetwork &&wn) = default;
+  WaksmanNetwork& operator=(WaksmanNetwork&&) = default;
+
   // Set up the WaksmanNetwork for N items.  N need not be a power of 2.
   // N <= 2^31
   WaksmanNetwork(uint32_t N);
@@ -684,4 +698,6 @@ void DecryptAndMTSS(unsigned char *encrypted_buffer, uint32_t N,
   unsigned char *result_buffer, enc_ret *ret);
 #endif
 
+#include "WaksmanNetwork.tcc"
+
 #endif

+ 251 - 0
Enclave/OblivAlgs/WaksmanNetwork.tcc

@@ -0,0 +1,251 @@
+
+// #define PROFILE_MTMERGESORT
+
+template<typename T> static int compare(const void *a, const void *b);
+
+template<>
+int compare<uint64_t>(const void *a, const void *b)
+{
+    uint32_t *a32 = (uint32_t*)a;
+    uint32_t *b32 = (uint32_t*)b;
+    int hi = a32[1]-b32[1];
+    int lo = a32[0]-b32[0];
+    return oselect_uint32_t(hi, lo, !hi);
+}
+
+template<typename T>
+struct MergeArgs {
+    T* dst;
+    const T* leftsrc;
+    size_t Nleft;
+    const T* rightsrc;
+    size_t Nright;
+};
+
+// Merge two sorted arrays into one.  The (sorted) source arrays are
+// leftsrc and rightsrc of lengths Nleft and Nright respectively.  Put
+// the merged sorted array into dst[0..Nleft+Nright-1].  Use up to the
+// given number of threads.
+template<typename T>
+static void* merge(void *voidargs)
+{
+    const MergeArgs<T>* args = (const MergeArgs<T>*)voidargs;
+#ifdef PROFILE_MTMERGESORT
+unsigned long start = printf_with_rtclock("begin merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
+#endif
+    T* dst = args->dst;
+    const T* left = args->leftsrc;
+    const T* right = args->rightsrc;
+    const T* leftend = args->leftsrc + args->Nleft;
+    const T* rightend = args->rightsrc + args->Nright;
+
+    while (left != leftend && right != rightend) {
+        if (compare<T>(left, right) < 0) {
+            *dst = *left;
+            ++dst;
+            ++left;
+        } else {
+            *dst = *right;
+            ++dst;
+            ++right;
+        }
+    }
+
+    if (left != leftend) {
+        memmove(dst, left, (leftend-left)*sizeof(T));
+    }
+    if (right != rightend) {
+        memmove(dst, right, (rightend-right)*sizeof(T));
+    }
+#ifdef PROFILE_MTMERGESORT
+printf_with_rtclock_diff(start, "end merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
+#endif
+
+    return NULL;
+}
+
+// In the sorted subarray src[0 .. len-1], binary search for the first
+// element that's larger than the target.  The return value is the index
+// into that subarray, so it's 0 if src[0] > target, and it's len if all
+// the elements are less than the target.  Remember that all elements
+// have to be different, so no comparison will ever return that the
+// elements are equal.
+template<typename T>
+static size_t binsearch(const T* src, size_t len, const T* target)
+{
+    size_t left = 0;
+    size_t right = len;
+
+    if (len == 0) {
+        return 0;
+    }
+    if (compare<T>(src + left, target) > 0) {
+        return 0;
+    }
+    if (len > 0 && compare<T>(src + right - 1, target) < 0) {
+        return len;
+    }
+
+    // Invariant: src[left] < target and src[right] > target (where
+    // src[len] is considered to be greater than all targets)
+    while (right - left > 1) {
+        size_t mid = left + (right - left)/2;
+        if (compare<T>(src + mid, target) > 0) {
+            right = mid;
+        } else {
+            left = mid;
+        }
+    }
+
+    return right;
+}
+
+// Merge two sorted arrays into one.  The (sorted) source arrays are
+// leftsrc and rightsrc of lengths Nleft and Nright respectively.  Put
+// the merged sorted array into dst[0..Nleft+Nright-1].  Use up to the
+// given number of threads.
+template<typename T>
+static void mtmerge(T* dst, const T* leftsrc, size_t Nleft,
+    const T* rightsrc, size_t Nright, threadid_t nthreads)
+{
+#ifdef PROFILE_MTMERGESORT
+unsigned long start = printf_with_rtclock("begin mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
+#endif
+
+    threadid_t threads_to_use = nthreads;
+    if (Nleft < 500 || Nright < 500) {
+        threads_to_use = 1;
+    }
+
+    // Break the left array into threads_to_use approximately
+    // equal-sized pieces
+
+    MergeArgs<T> margs[threads_to_use];
+    size_t leftinc = Nleft / threads_to_use;
+    size_t leftextra = Nleft % threads_to_use;
+    size_t leftlast = 0;
+    size_t rightlast = 0;
+
+    for (threadid_t t=0; t<threads_to_use; ++t) {
+        size_t leftlen = leftinc + (t < leftextra);
+        // Find the segment in the right array corresponding to this
+        // segment in the lest array.  If this is the last segment of
+        // the left array, that's just the whole remaining right array.
+        size_t rightlen;
+        if (t == threads_to_use - 1) {
+            rightlen = Nright - rightlast;
+        } else {
+            // The first element of the next left segment
+            const T* target = leftsrc + leftlast + leftlen;
+            // In the sorted subarray rightsrc[rightlast .. Nright-1],
+            // binary search for the first element that's larger than
+            // the target.  The return value is the index into that
+            // subarray, so it's 0 if rightsrc[rightlast] > target, and
+            // it's Nright-rightlast if all the elements are less than
+            // the target.
+            rightlen = binsearch<T>(rightsrc + rightlast,
+                Nright-rightlast, target);
+        }
+        margs[t] = { dst + leftlast + rightlast,
+            leftsrc + leftlast, leftlen,
+            rightsrc + rightlast, rightlen };
+        leftlast += leftlen;
+        rightlast += rightlen;
+        if (t > 0) {
+            threadpool_dispatch(g_thread_id+t, merge<T>, &margs[t]);
+        }
+    }
+    // Do the first block ourselves
+    merge<T>(&margs[0]);
+    for (size_t t=1; t<threads_to_use; ++t) {
+        threadpool_join(g_thread_id+t, NULL);
+    }
+
+#ifdef PROFILE_MTMERGESORT
+printf_with_rtclock_diff(start, "end mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
+#endif
+}
+
+template<typename T>
+struct MTMergesortArgs {
+    T* buf;
+    size_t N;
+    T* backing;
+    threadid_t nthreads;
+    bool ret;
+};
+
+template<typename T>
+static bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads);
+
+template<typename T>
+static void *mtmergesort_launch(void *voidargs)
+{
+    MTMergesortArgs<T>* args = (MTMergesortArgs<T>*)voidargs;
+    args->ret = mtmergesort<T>(args->buf, args->N, args->backing,
+        args->nthreads);
+    return NULL;
+}
+
+// Multithreaded mergesort.  Pass the data of type T to sort, as a
+// pointer and number of elements.  Also pass a backing store of the
+// same size.  The sorted data will end up in either the original data
+// array or the backing store; this function will return false if it's
+// in the original data and true if it's in the backing store.  Use up
+// to the given number of threads.
+template<typename T>
+bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads)
+{
+    if (nthreads == 1 || N < 1000) {
+        // Just sort naively
+#ifdef PROFILE_MTMERGESORT
+unsigned long start = printf_with_rtclock("begin qsort(buf=%p, N=%lu)\n", buf, N);
+#endif
+        qsort(buf, N, sizeof(T), compare<T>);
+#ifdef PROFILE_MTMERGESORT
+printf_with_rtclock_diff(start, "end qsort(buf=%p, N=%lu)\n", buf, N);
+#endif
+        return false;
+    }
+#ifdef PROFILE_MTMERGESORT
+unsigned long start = printf_with_rtclock("begin mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
+#endif
+    size_t Nleft = (N+1)/2;
+    size_t Nright = N/2;
+    threadid_t threads_left = (nthreads+1)/2;
+    threadid_t threads_right = nthreads/2;
+
+    MTMergesortArgs<T> ms_right_args =
+        { buf + Nleft, Nright, backing + Nleft, threads_right, false };
+    threadpool_dispatch(g_thread_id+threads_left, mtmergesort_launch<T>,
+        &ms_right_args);
+    bool leftret = mtmergesort<T>(buf, Nleft, backing, threads_left);
+    threadpool_join(g_thread_id+threads_left, NULL);
+    bool rightret = ms_right_args.ret;
+
+    // If the left and right sorts put their answers in different
+    // places, move the right answer to match the left
+    if (leftret != rightret) {
+        if (leftret) {
+            // The left is in backing, and the right is in buf
+            memmove(backing + Nleft, buf + Nleft, Nright * sizeof(T));
+        } else {
+            // The left is in buf, and the right is in backing
+            memmove(buf + Nleft, backing + Nleft, Nright * sizeof(T));
+        }
+    }
+
+    // Merge the two halves
+    if (leftret) {
+        // The recursive outputs are in backing; merge them into buf
+        mtmerge<T>(buf, backing, Nleft, backing+Nleft, Nright, nthreads);
+    } else {
+        // The recursive outputs are in buf; merge them into backing
+        mtmerge<T>(backing, buf, Nleft, buf+Nleft, Nright, nthreads);
+    }
+#ifdef PROFILE_MTMERGESORT
+printf_with_rtclock_diff(start, "end mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
+#endif
+    return !leftret;
+}
+

+ 147 - 0
Enclave/sort.cpp

@@ -0,0 +1,147 @@
+#include <map>
+#include <deque>
+#include <pthread.h>
+#include "sort.hpp"
+
+// A set of precomputed WaksmanNetworks of a given size
+struct SizedWNs {
+    pthread_mutex_t mutex;
+    std::deque<WaksmanNetwork> wns;
+};
+
+// A (mutexed) map mapping sizes to SizedWNs
+struct PrecompWNs {
+    pthread_mutex_t mutex;
+    std::map<uint32_t,SizedWNs> sized_wns;
+};
+
+static PrecompWNs precomp_wns;
+
+// A (mutexed) map mapping (N, nthreads) pairs to WNEvalPlans
+struct EvalPlans {
+    pthread_mutex_t mutex;
+    std::map<std::pair<uint32_t,threadid_t>,WNEvalPlan> eval_plans;
+};
+
+static EvalPlans precomp_eps;
+
+void sort_precompute(uint32_t N)
+{
+    uint32_t *random_permutation = new uint32_t[N];
+    if (!random_permutation) {
+        printf("Allocating memory failed in sort_precompute\n");
+    }
+    assert(random_permutation);
+    for (uint32_t i=0;i<N;++i) {
+        random_permutation[i] = i;
+    }
+    RecursiveShuffle_M2((unsigned char *) random_permutation, N, sizeof(uint32_t));
+
+    WaksmanNetwork wnet(N);
+    wnet.setPermutation(random_permutation);
+
+    // Note that sized_wns[N] creates a map entry for N if it doesn't yet exist
+    pthread_mutex_lock(&precomp_wns.mutex);
+    SizedWNs& szwn = precomp_wns.sized_wns[N];
+    pthread_mutex_unlock(&precomp_wns.mutex);
+    pthread_mutex_lock(&szwn.mutex);
+    szwn.wns.push_back(std::move(wnet));
+    pthread_mutex_unlock(&szwn.mutex);
+}
+
+void sort_precompute_evalplan(uint32_t N, threadid_t nthreads)
+{
+    std::pair<uint32_t,threadid_t> idx = {N, nthreads};
+    pthread_mutex_lock(&precomp_eps.mutex);
+    if (!precomp_eps.eval_plans.count(idx)) {
+        precomp_eps.eval_plans.try_emplace(idx, N, nthreads);
+    }
+    pthread_mutex_unlock(&precomp_eps.mutex);
+}
+
+// Perform the sort using up to nthreads threads.  The items to sort are
+// byte arrays of size msg_size.  The key is the first 4 bytes of each
+// item.
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na,
+    // the arguments to the callback are nthreads, items, the sorted
+    // indices, and the number of non-padding items
+    std::function<void(threadid_t, const uint8_t*, const uint64_t*,
+        uint32_t Nr)> cb)
+{
+    // Find the smallest Nw for which we have a precomputed
+    // WaksmanNetwork with Nr <= Nw <= Na
+    pthread_mutex_lock(&precomp_wns.mutex);
+    std::optional<WaksmanNetwork> wn;
+    uint32_t Nw;
+    for (auto& N : precomp_wns.sized_wns) {
+        if (N.first > Na) {
+            printf("No precomputed WaksmanNetworks of size at most %u\n", Na);
+            assert(false);
+        }
+        if (N.first < Nr) {
+            continue;
+        }
+        // We're in the right range, but see if we have an actual
+        // precomputed WaksmanNetwork
+        pthread_mutex_lock(&N.second.mutex);
+        if (N.second.wns.size() == 0) {
+            pthread_mutex_unlock(&N.second.mutex);
+            continue;
+        }
+        wn = std::move(N.second.wns.front());
+        N.second.wns.pop_front();
+        Nw = N.first;
+        pthread_mutex_unlock(&N.second.mutex);
+        break;
+    }
+    pthread_mutex_unlock(&precomp_wns.mutex);
+    if (!wn) {
+        printf("No precomputed WaksmanNetwork of size range [%u,%u] found.\n",
+            Nr, Na);
+        assert(wn);
+    }
+    std::pair<uint32_t,threadid_t> epidx = {Nw, nthreads};
+    pthread_mutex_lock(&precomp_eps.mutex);
+    if (!precomp_eps.eval_plans.count(epidx)) {
+        printf("No precomputed WNEvalPlan with N=%u, nthreads=%hu\n",
+            Nw, nthreads);
+        assert(false);
+    }
+    WNEvalPlan &eval_plan = precomp_eps.eval_plans.at(epidx);
+    pthread_mutex_unlock(&precomp_eps.mutex);
+    // Mark Nw-Nr items as padding (Nr, Na, and Nw are _not_ private)
+    for (uint32_t i=Nr; i<Nw; ++i) {
+        (*(uint32_t*)(items+msg_size*i)) = uint32_t(-1);
+    }
+    // Shuffle Nw items
+    wn.value().applyInversePermutation<OSWAP_16X>(
+        items, msg_size, eval_plan);
+
+    // Create the indices
+    uint64_t *idx = new uint64_t[Nr];
+    uint64_t *nextidx = idx;
+    for (uint32_t i=0; i<Nw; ++i) {
+        uint64_t key = (*(uint32_t*)(items+msg_size*i));
+        if (key != uint32_t(-1)) {
+            *nextidx = (key<<32) + i;
+            ++nextidx;
+        }
+    }
+    if (nextidx != idx + Nr) {
+        printf("Found %u non-padding items, expected %u\n",
+            nextidx-idx, Nr);
+        assert(nextidx == idx + Nr);
+    }
+    // Sort the keys and indices
+    uint64_t *backingidx = new uint64_t[Nr];
+    bool whichbuf = mtmergesort<uint64_t>(idx, Nr, backingidx, nthreads);
+    uint64_t *sortedidx = whichbuf ? backingidx : idx;
+    for (uint32_t i=0; i<Nr; ++i) {
+        sortedidx[i] &= uint64_t(0xffffffff);
+    }
+    cb(nthreads, items, sortedidx, Nr);
+
+    delete[] idx;
+    delete[] backingidx;
+}

+ 45 - 0
Enclave/sort.hpp

@@ -0,0 +1,45 @@
+#ifndef __SORT_HPP__
+#define __SORT_HPP__
+
+#include <functional>
+
+#include "WaksmanNetwork.hpp"
+
+// We have Nr elements to sort, at the beginning of an array of
+// allocated size Na items.  Nr and Na are _not_ secret.  The strategy
+// is to get the smallest (precomputed) WaksmanNetwork (on a random
+// permutation) and WNEvalPlan, each of size Nw, where Nr <= Nw <= Na.
+// Nw is also not secret.  Then mark the Nw-Nr items following the given
+// elements as padding, use the WaksmanNetwork and WNEvalPlan to shuffle
+// the Nw items, then non-obliviously sort the non-padding items.  The
+// sort itself is done by making an index array holding just the sort
+// keys and original indices into the array, but only of the non-padding
+// items.  Sort the index array, and call back a std::function with the
+// sorted index array and the shuffled array so that it can read out the
+// elements in sorted order.  This function does not have to be
+// oblivious to what elements it is reading, because they've already
+// been obliviously shuffled.
+
+// Precompute a WaksmanNetwork of size N for a random permutation.  This
+// call does not itself use threads, but may be called from a background
+// thread.  These are consumed as they are used, so you need to keep
+// making more.
+void sort_precompute(uint32_t N);
+
+// Precompute a WNEvalPlan for a given size and number of threads.
+// These are not consumed as they are used, so you only need to call
+// this once for each size/nthreads you need.  The precomputation itself
+// only uses a single thread, but also may be called from a background
+// thread.
+void sort_precompute_evalplan(uint32_t N, threadid_t nthreads);
+
+// Perform the sort using up to nthreads threads.  The items to sort are
+// byte arrays of size msg_size.  The key is the 10-bit storage server
+// id contatenated with the 22-bit uid at the storage server.
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na,
+    // the arguments to the callback are nthreads, items, the sorted
+    // indices, and the number of non-padding items
+    std::function<void(threadid_t, const uint8_t*, const uint64_t*, uint32_t Nr)>);
+
+#endif

+ 2 - 0
Makefile

@@ -297,6 +297,7 @@ Enclave/comms.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
 Enclave/comms.o: Enclave/enclave_api.h
 Enclave/config.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/comms.hpp
 Enclave/config.o: Enclave/enclave_api.h Enclave/config.hpp
+Enclave/sort.o: Enclave/sort.hpp
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/CONFIG.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.tcc
@@ -344,3 +345,4 @@ Enclave/OblivAlgs/WaksmanNetwork.o: Enclave/OblivAlgs/TightCompaction_v2.hpp
 Enclave/OblivAlgs/WaksmanNetwork.o: Enclave/OblivAlgs/TightCompaction_v2.tcc
 Enclave/OblivAlgs/WaksmanNetwork.o: Enclave/OblivAlgs/RecursiveShuffle.tcc
 Enclave/OblivAlgs/WaksmanNetwork.o: Enclave/OblivAlgs/aes.hpp
+Enclave/OblivAlgs/WaksmanNetwork.o: Enclave/OblivAlgs/WaksmanNetwork.tcc