|
@@ -0,0 +1,232 @@
|
|
|
|
+// See ORExpand.hpp for explanations of notation and inputs.
|
|
|
|
+// Particularly note that all subarrays [lo..hi] are _inclusive_ of lo
|
|
|
|
+// but _exclusive_ of hi.
|
|
|
|
+
|
|
|
|
+// buf is an array of block_size-byte blocks. dest is an array of 32-bit
|
|
|
|
+// words. We are given two (contiguous) subarrays [lo..mid] and
|
|
|
|
+// [mid..hi], and indices a and b with a in [lo..mid] and b in
|
|
|
|
+// [mid..hi]. If (mid <= dest[a] < hi) or (lo <= dest[b] <= mid), then
|
|
|
|
+// swap dest[a] with dest[b] and buf[a] with buf[b]; otherwise, do not.
|
|
|
|
+// However, all tests and swaps must be done obliviously to the values
|
|
|
|
+// of dest[a] and dest[b] (and the contents of buf). It's OK to not be
|
|
|
|
+// oblivious to the values of lo, mid, hi, a, and b themselves, however.
|
|
|
|
+template <OSwap_Style oswap_style>
|
|
|
|
+static inline void mid_oswap(unsigned char *buf, uint32_t *dest,
|
|
|
|
+ size_t block_size, uint32_t lo, uint32_t mid, uint32_t hi,
|
|
|
|
+ uint32_t a, uint32_t b)
|
|
|
|
+{
|
|
|
|
+ uint32_t desta = dest[a];
|
|
|
|
+ uint32_t destb = dest[b];
|
|
|
|
+ uint8_t swap_flag = ((mid <= desta) & (desta < hi))
|
|
|
|
+ | ((lo <= destb) & (destb < mid));
|
|
|
|
+ // The next line could be optimized with some inline assembly, since
|
|
|
|
+ // we've already loaded desta and destb, so we could obliviously
|
|
|
|
+ // swap those registers, and then non-obliviously write them back
|
|
|
|
+ // out to dest[a] and dest[b].
|
|
|
|
+ oswap_buffer<OSWAP_4>((unsigned char *)(dest+a),
|
|
|
|
+ (unsigned char *)(dest+b), 4, swap_flag);
|
|
|
|
+ oswap_buffer<oswap_style>(buf+a*block_size, buf+b*block_size,
|
|
|
|
+ (uint32_t)block_size, swap_flag);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+template <OSwap_Style oswap_style>
|
|
|
|
+void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
|
|
|
|
+ uint32_t lo, uint32_t hi)
|
|
|
|
+{
|
|
|
|
+ // Passing hi < lo is an illegal input
|
|
|
|
+ assert(hi >= lo);
|
|
|
|
+
|
|
|
|
+ // The length of the subarray
|
|
|
|
+ const uint32_t N = hi-lo;
|
|
|
|
+
|
|
|
|
+ // Nothing to do on inputs where [lo..hi] has length 0 or 1
|
|
|
|
+ if (N < 2) {
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // The largest power of 2 strictly less than N
|
|
|
|
+ const uint32_t N2 = uint32_t(pow2_lt(N));
|
|
|
|
+
|
|
|
|
+ // We divide [lo..hi] (of length N) into two pieces:
|
|
|
|
+ // [lo..mid] and [mid..hi], where [mid..hi] has length N2 (the
|
|
|
|
+ // largest power of 2 strictly less than N). Note that mid is just
|
|
|
|
+ // _somewhere_ in the middle of [lo..hi]; it will not be the exact
|
|
|
|
+ // midpoint if N is not itself a power of 2. (It will be the exact
|
|
|
|
+ // midpoint if N is a power of 2, however.)
|
|
|
|
+ const uint32_t mid = hi-N2;
|
|
|
|
+
|
|
|
|
+ // N1 is the length of [lo..mid]. Note that N1 <= N2, with equality
|
|
|
|
+ // if and only if N is a power of 2.
|
|
|
|
+ const uint32_t N1 = N-N2;
|
|
|
|
+
|
|
|
|
+ // We're going to do N1 oblivious swaps on the buf and dest arrays,
|
|
|
|
+ // between items lo+i and hi-N1+i for 0 <= i < N1. If dest[lo+i]
|
|
|
|
+ // lies in [mid..hi] (and is not 0xffffffff to indicate padding)
|
|
|
|
+ // _or_ if dest[hi-N1+i] lies in [lo..mid] (again and is not
|
|
|
|
+ // 0xffffffff), then we swap them and their corresponding buf
|
|
|
|
+ // blocks. The cool part is that it cannot be the case that both
|
|
|
|
+ // dest[lo+i] and dest[hi-N1+i] are not padding and they both have
|
|
|
|
+ // values on the same side of mid. Why is that?
|
|
|
|
+
|
|
|
|
+ // Case 1: If dest[lo+i] < dest[hi-N1+i], then all of the blocks
|
|
|
|
+ // from lo+i to hi-N1+1 inclusive must be non-padding blocks, and
|
|
|
|
+ // since this contiguous block has strictly increasing values, it
|
|
|
|
+ // must be that dest[hi-N1+i] - dest[lo+i] >= (hi-N1+i)-(lo+i) =
|
|
|
|
+ // N-N1 = N2. Since the lengths of [lo..mid] and [mid..hi] are each
|
|
|
|
+ // at most N2, dest[lo+i] and dest[hi-N1+i] cannot be both in the
|
|
|
|
+ // same one of those subarrays.
|
|
|
|
+
|
|
|
|
+ // Case 2: If dest[hi-N1+i] < dest[lo+i], then the contiguous range
|
|
|
|
+ // of non-padding blocks wraps around hi back to lo, so we must have
|
|
|
|
+ // that dest[lo+i] - dest[hi-N1+i] >= (hi+i) - (hi-N1+i) = N1, and
|
|
|
|
+ // also since the range wraps around, it must start at a non-zero
|
|
|
|
+ // offset z, which means that N had to be a power of 2, and so
|
|
|
|
+ // N1=N2. Therefore dest[lo+i] - dest[hi-N1+i] >= N2, and as above,
|
|
|
|
+ // dest[lo+i] and dest[hi-N1+i] cannot both be in [lo..mid] or both
|
|
|
|
+ // be in [mid..hi], each of which has length N1=N2.
|
|
|
|
+
|
|
|
|
+ // So these oblivious swaps will ensure that all the blocks with
|
|
|
|
+ // dest in [lo..mid] end up in [lo..mid] and all the blocks with
|
|
|
|
+ // dest in [mid..hi] end up in [mid..hi]. In addition, the property
|
|
|
|
+ // that all the non-padding blocks are contiguous (possibly wrapping
|
|
|
|
+ // around for the [mid..hi] subarray which has length a power of 2)
|
|
|
|
+ // and monotonicly increasing are preserved for both the [lo..mid]
|
|
|
|
+ // and [mid..hi] subarrays.
|
|
|
|
+
|
|
|
|
+ for (uint32_t i=0; i<N1; ++i) {
|
|
|
|
+ mid_oswap<oswap_style>(buf, dest, block_size, lo, mid, hi,
|
|
|
|
+ lo+i, hi-N1+i);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // And now we just recurse on the two subarrays.
|
|
|
|
+ ORExpand<oswap_style>(buf, dest, block_size, lo, mid);
|
|
|
|
+ ORExpand<oswap_style>(buf, dest, block_size, mid, hi);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// Multithreaded version of ORExpand
|
|
|
|
+
|
|
|
|
+struct mid_oswap_range_args {
|
|
|
|
+ unsigned char *buf;
|
|
|
|
+ uint32_t *dest;
|
|
|
|
+ size_t block_size;
|
|
|
|
+ uint32_t lo, mid, hi, a, b, num;
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+template <OSwap_Style oswap_style>
|
|
|
|
+static void *mid_oswap_range_launch(void *voidargs)
|
|
|
|
+{
|
|
|
|
+ const mid_oswap_range_args *args =
|
|
|
|
+ (mid_oswap_range_args*)voidargs;
|
|
|
|
+ for (uint32_t i=0; i<args->num; ++i) {
|
|
|
|
+ mid_oswap<oswap_style>(args->buf, args->dest, args->block_size,
|
|
|
|
+ args->lo, args->mid, args->hi,
|
|
|
|
+ args->a + i, args->b + i);
|
|
|
|
+ }
|
|
|
|
+ return NULL;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+struct ORExpand_parallel_args {
|
|
|
|
+ unsigned char *buf;
|
|
|
|
+ uint32_t *dest;
|
|
|
|
+ size_t block_size;
|
|
|
|
+ uint32_t lo, hi;
|
|
|
|
+ threadid_t nthreads;
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+template <OSwap_Style oswap_style>
|
|
|
|
+static void* ORExpand_parallel_launch(void *voidargs)
|
|
|
|
+{
|
|
|
|
+ const ORExpand_parallel_args* args =
|
|
|
|
+ (ORExpand_parallel_args*)voidargs;
|
|
|
|
+ ORExpand_parallel<oswap_style>(args->buf, args->dest,
|
|
|
|
+ args->block_size, args->lo, args->hi, args->nthreads);
|
|
|
|
+ return NULL;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// See ORExpand, above, for detailed comments as to how this algorithm
|
|
|
|
+// works.
|
|
|
|
+template <OSwap_Style oswap_style>
|
|
|
|
+void ORExpand_parallel(unsigned char *buf, uint32_t *dest,
|
|
|
|
+ size_t block_size, uint32_t lo, uint32_t hi, threadid_t nthreads)
|
|
|
|
+{
|
|
|
|
+ // Passing hi < lo is an illegal input
|
|
|
|
+ assert(hi >= lo);
|
|
|
|
+
|
|
|
|
+ // The length of the subarray
|
|
|
|
+ const uint32_t N = hi-lo;
|
|
|
|
+
|
|
|
|
+ // Nothing to do on inputs where [lo..hi] has length 0 or 1
|
|
|
|
+ if (N < 2) {
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Use the single-threaded version if nthreads <= 1
|
|
|
|
+ if (nthreads <= 1) {
|
|
|
|
+#ifdef PROFILE_OREXPAND
|
|
|
|
+ unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
|
|
|
|
+#endif
|
|
|
|
+ ORExpand<oswap_style>(buf, dest, block_size, lo, hi);
|
|
|
|
+#ifdef PROFILE_OREXPAND
|
|
|
|
+ printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
|
|
|
|
+#endif
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+#ifdef PROFILE_OREXPAND
|
|
|
|
+ unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
|
|
|
|
+#endif
|
|
|
|
+
|
|
|
|
+ // The largest power of 2 strictly less than N
|
|
|
|
+ const uint32_t N2 = uint32_t(pow2_lt(N));
|
|
|
|
+ const uint32_t mid = hi-N2;
|
|
|
|
+ const uint32_t N1 = N-N2;
|
|
|
|
+
|
|
|
|
+ mid_oswap_range_args args[nthreads];
|
|
|
|
+ uint32_t inc = N1 / nthreads;
|
|
|
|
+ uint32_t extra = N1 % nthreads;
|
|
|
|
+ uint32_t last = 0;
|
|
|
|
+ for (threadid_t i=0; i<nthreads; ++i) {
|
|
|
|
+ uint32_t num = inc + (i < extra);
|
|
|
|
+ args[i] = { buf, dest, block_size, lo, mid, hi, lo+last,
|
|
|
|
+ hi-N1+last, num };
|
|
|
|
+ last += num;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Launch all but the first section into other threads
|
|
|
|
+ for (threadid_t i=1; i<nthreads; ++i) {
|
|
|
|
+ threadpool_dispatch(g_thread_id+i,
|
|
|
|
+ mid_oswap_range_launch<oswap_style>, args+i);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Do the first section ourselves
|
|
|
|
+ mid_oswap_range_launch<oswap_style>(args);
|
|
|
|
+
|
|
|
|
+ // Join the threads
|
|
|
|
+ for (threadid_t i=1; i<nthreads; ++i) {
|
|
|
|
+ threadpool_join(g_thread_id+i, NULL);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Use half the threads for the left subarray and half for the right
|
|
|
|
+ // subarray (this choice could be improved if N1 << N2, perhaps).
|
|
|
|
+ threadid_t lthreads = nthreads / 2;
|
|
|
|
+ threadid_t rthreads = nthreads - lthreads;
|
|
|
|
+
|
|
|
|
+ threadid_t rightthreadid = g_thread_id + lthreads;
|
|
|
|
+
|
|
|
|
+ ORExpand_parallel_args rightargs = {
|
|
|
|
+ buf, dest, block_size, mid, hi, rthreads
|
|
|
|
+ };
|
|
|
|
+ threadpool_dispatch(rightthreadid,
|
|
|
|
+ ORExpand_parallel_launch<oswap_style>, &rightargs);
|
|
|
|
+
|
|
|
|
+ // Do the left subarray ourselves (with lthreads threads)
|
|
|
|
+ ORExpand_parallel<oswap_style>(buf, dest, block_size, lo, mid,
|
|
|
|
+ lthreads);
|
|
|
|
+
|
|
|
|
+ // Join the thread
|
|
|
|
+ threadpool_join(rightthreadid, NULL);
|
|
|
|
+
|
|
|
|
+#ifdef PROFILE_OREXPAND
|
|
|
|
+ printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
|
|
|
|
+#endif
|
|
|
|
+}
|