Browse Source

Implementation of ORExpand (multi-threaded version)

Ian Goldberg 1 year ago
parent
commit
d122d04a6d
3 changed files with 193 additions and 2 deletions
  1. 49 0
      Enclave/OblivAlgs/ORExpand.cpp
  2. 16 0
      Enclave/OblivAlgs/ORExpand.hpp
  3. 128 2
      Enclave/OblivAlgs/ORExpand.tcc

+ 49 - 0
Enclave/OblivAlgs/ORExpand.cpp

@@ -42,4 +42,53 @@ void test_ORExpand()
     delete[] buf;
     delete[] dest;
 }
+
+void test_ORExpand_parallel(threadid_t nthreads)
+{
+    size_t block_size = 48;
+    uint32_t N = 1000000;
+
+    unsigned char *buf = new unsigned char[N*block_size];
+    uint32_t *dest = new uint32_t[N];
+    uint32_t next_dest = 0;
+    for (uint32_t i=0;i<N;++i) {
+        // Randomly decide whether i is real or dummy
+        bool keep = getRandomBit();
+        if (keep) {
+            dest[next_dest] = i;
+            ++next_dest;
+        }
+    }
+    for (uint32_t i=next_dest;i<N;++i) {
+        dest[i] = 0xffffffff;
+    }
+
+    for (size_t i=0;i<N;++i) {
+        for (size_t j=0;j<block_size;j+=2) {
+            buf[i*block_size+j] = (unsigned char)i;
+            buf[i*block_size+j+1] = (unsigned char)j/2;
+        }
+    }
+
+    ORExpand_parallel<OSWAP_16X>(buf, dest, block_size, N, nthreads);
+
+    for(uint32_t i=0;i<N;++i) {
+        assert(dest[i] == 0xffffffff || dest[i] == i);
+        if (N < 200) {
+            printf("%2d ", i);
+            if (dest[i] == 0xffffffff) {
+                printf("   ");
+            } else {
+                printf("%2d ", dest[i]);
+            }
+            for (size_t j=0;j<block_size;++j) {
+                printf("%02x", buf[i*block_size+j]);
+            }
+            printf("\n");
+        }
+    }
+
+    delete[] buf;
+    delete[] dest;
+}
 #endif

+ 16 - 0
Enclave/OblivAlgs/ORExpand.hpp

@@ -4,6 +4,7 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "utils.hpp"
 #include "oasm_lib.h"
 
 // Notation: for an array A, let A[lo..hi] denote the subarray of
@@ -88,11 +89,26 @@ void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
     ORExpand<oswap_style>(buf, dest, block_size, 0, N);
 }
 
+// Multithreaded versions of ORExpand and its convenience wrapper
+
+template <OSwap_Style oswap_style>
+void ORExpand_parallel(unsigned char *buf, uint32_t *dest,
+    size_t block_size, uint32_t lo, uint32_t hi, threadid_t nthreads);
+
+template <OSwap_Style oswap_style>
+void ORExpand_parallel(unsigned char *buf, uint32_t *dest,
+    size_t block_size, uint32_t N, threadid_t nthreads) {
+    ORExpand_parallel<oswap_style>(buf, dest, block_size, 0, N, nthreads);
+}
+
+// #define PROFILE_OREXPAND
+// #define TEST_OREXPAND
 
 #include "ORExpand.tcc"
 
 #ifdef TEST_OREXPAND
 void test_ORExpand();
+void test_ORExpand_parallel(threadid_t nthreads);
 #endif
 
 #endif

+ 128 - 2
Enclave/OblivAlgs/ORExpand.tcc

@@ -1,5 +1,3 @@
-#include "utils.hpp"
-
 // See ORExpand.hpp for explanations of notation and inputs.
 // Particularly note that all subarrays [lo..hi] are _inclusive_ of lo
 // but _exclusive_ of hi.
@@ -100,3 +98,131 @@ void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
     ORExpand<oswap_style>(buf, dest, block_size, lo, mid);
     ORExpand<oswap_style>(buf, dest, block_size, mid, hi);
 }
+
+// Multithreaded version of ORExpand
+
+struct mid_oswap_range_args {
+    unsigned char *buf;
+    uint32_t *dest;
+    size_t block_size;
+    uint32_t lo, mid, hi, a, b, num;
+};
+
+template <OSwap_Style oswap_style>
+static void *mid_oswap_range_launch(void *voidargs)
+{
+    const mid_oswap_range_args *args =
+        (mid_oswap_range_args*)voidargs;
+    for (uint32_t i=0; i<args->num; ++i) {
+        mid_oswap<oswap_style>(args->buf, args->dest, args->block_size,
+            args->lo, args->mid, args->hi,
+            args->a + i, args->b + i);
+    }
+    return NULL;
+}
+
+struct ORExpand_parallel_args {
+    unsigned char *buf;
+    uint32_t *dest;
+    size_t block_size;
+    uint32_t lo, hi;
+    threadid_t nthreads;
+};
+
+template <OSwap_Style oswap_style>
+static void* ORExpand_parallel_launch(void *voidargs)
+{
+    const ORExpand_parallel_args* args =
+        (ORExpand_parallel_args*)voidargs;
+    ORExpand_parallel<oswap_style>(args->buf, args->dest,
+        args->block_size, args->lo, args->hi, args->nthreads);
+    return NULL;
+}
+
+// See ORExpand, above, for detailed comments as to how this algorithm
+// works.
+template <OSwap_Style oswap_style>
+void ORExpand_parallel(unsigned char *buf, uint32_t *dest,
+    size_t block_size, uint32_t lo, uint32_t hi, threadid_t nthreads)
+{
+    // Passing hi < lo is an illegal input
+    assert(hi >= lo);
+
+    // The length of the subarray
+    const uint32_t N = hi-lo;
+
+    // Nothing to do on inputs where [lo..hi] has length 0 or 1
+    if (N < 2) {
+        return;
+    }
+
+    // Use the single-threaded version if nthreads <= 1
+    if (nthreads <= 1) {
+#ifdef PROFILE_OREXPAND
+        unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
+#endif
+        ORExpand<oswap_style>(buf, dest, block_size, lo, hi);
+#ifdef PROFILE_OREXPAND
+        printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
+#endif
+        return;
+    }
+
+#ifdef PROFILE_OREXPAND
+    unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
+#endif
+
+    // The largest power of 2 strictly less than N
+    const uint32_t N2 = uint32_t(pow2_lt(N));
+    const uint32_t mid = hi-N2;
+    const uint32_t N1 = N-N2;
+
+    mid_oswap_range_args args[nthreads];
+    uint32_t inc = N1 / nthreads;
+    uint32_t extra = N1 % nthreads;
+    uint32_t last = 0;
+    for (threadid_t i=0; i<nthreads; ++i) {
+        uint32_t num = inc + (i < extra);
+        args[i] = { buf, dest, block_size, lo, mid, hi, lo+last,
+            hi-N1+last, num };
+        last += num;
+    }
+
+    // Launch all but the first section into other threads
+    for (threadid_t i=1; i<nthreads; ++i) {
+        threadpool_dispatch(g_thread_id+i,
+            mid_oswap_range_launch<oswap_style>, args+i);
+    }
+
+    // Do the first section ourselves
+    mid_oswap_range_launch<oswap_style>(args);
+
+    // Join the threads
+    for (threadid_t i=1; i<nthreads; ++i) {
+        threadpool_join(g_thread_id+i, NULL);
+    }
+
+    // Use half the threads for the left subarray and half for the right
+    // subarray (this choice could be improved if N1 << N2, perhaps).
+    threadid_t lthreads = nthreads / 2;
+    threadid_t rthreads = nthreads - lthreads;
+
+    threadid_t rightthreadid = g_thread_id + lthreads;
+
+    ORExpand_parallel_args rightargs = {
+        buf, dest, block_size, mid, hi, rthreads
+    };
+    threadpool_dispatch(rightthreadid,
+        ORExpand_parallel_launch<oswap_style>, &rightargs);
+
+    // Do the left subarray ourselves (with lthreads threads)
+    ORExpand_parallel<oswap_style>(buf, dest, block_size, lo, mid,
+        lthreads);
+
+    // Join the thread
+    threadpool_join(rightthreadid, NULL);
+
+#ifdef PROFILE_OREXPAND
+    printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
+#endif
+}