// See ORExpand.hpp for explanations of notation and inputs. // Particularly note that all subarrays [lo..hi] are _inclusive_ of lo // but _exclusive_ of hi. // buf is an array of block_size-byte blocks. dest is an array of 32-bit // words. We are given two (contiguous) subarrays [lo..mid] and // [mid..hi], and indices a and b with a in [lo..mid] and b in // [mid..hi]. If (mid <= dest[a] < hi) or (lo <= dest[b] <= mid), then // swap dest[a] with dest[b] and buf[a] with buf[b]; otherwise, do not. // However, all tests and swaps must be done obliviously to the values // of dest[a] and dest[b] (and the contents of buf). It's OK to not be // oblivious to the values of lo, mid, hi, a, and b themselves, however. template static inline void mid_oswap(unsigned char *buf, uint32_t *dest, size_t block_size, uint32_t lo, uint32_t mid, uint32_t hi, uint32_t a, uint32_t b) { uint32_t desta = dest[a]; uint32_t destb = dest[b]; uint8_t swap_flag = ((mid <= desta) & (desta < hi)) | ((lo <= destb) & (destb < mid)); // The next line could be optimized with some inline assembly, since // we've already loaded desta and destb, so we could obliviously // swap those registers, and then non-obliviously write them back // out to dest[a] and dest[b]. oswap_buffer((unsigned char *)(dest+a), (unsigned char *)(dest+b), 4, swap_flag); oswap_buffer(buf+a*block_size, buf+b*block_size, (uint32_t)block_size, swap_flag); } template void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size, uint32_t lo, uint32_t hi) { // Passing hi < lo is an illegal input assert(hi >= lo); // The length of the subarray const uint32_t N = hi-lo; // Nothing to do on inputs where [lo..hi] has length 0 or 1 if (N < 2) { return; } // The largest power of 2 strictly less than N const uint32_t N2 = uint32_t(pow2_lt(N)); // We divide [lo..hi] (of length N) into two pieces: // [lo..mid] and [mid..hi], where [mid..hi] has length N2 (the // largest power of 2 strictly less than N). Note that mid is just // _somewhere_ in the middle of [lo..hi]; it will not be the exact // midpoint if N is not itself a power of 2. (It will be the exact // midpoint if N is a power of 2, however.) const uint32_t mid = hi-N2; // N1 is the length of [lo..mid]. Note that N1 <= N2, with equality // if and only if N is a power of 2. const uint32_t N1 = N-N2; // We're going to do N1 oblivious swaps on the buf and dest arrays, // between items lo+i and hi-N1+i for 0 <= i < N1. If dest[lo+i] // lies in [mid..hi] (and is not 0xffffffff to indicate padding) // _or_ if dest[hi-N1+i] lies in [lo..mid] (again and is not // 0xffffffff), then we swap them and their corresponding buf // blocks. The cool part is that it cannot be the case that both // dest[lo+i] and dest[hi-N1+i] are not padding and they both have // values on the same side of mid. Why is that? // Case 1: If dest[lo+i] < dest[hi-N1+i], then all of the blocks // from lo+i to hi-N1+1 inclusive must be non-padding blocks, and // since this contiguous block has strictly increasing values, it // must be that dest[hi-N1+i] - dest[lo+i] >= (hi-N1+i)-(lo+i) = // N-N1 = N2. Since the lengths of [lo..mid] and [mid..hi] are each // at most N2, dest[lo+i] and dest[hi-N1+i] cannot be both in the // same one of those subarrays. // Case 2: If dest[hi-N1+i] < dest[lo+i], then the contiguous range // of non-padding blocks wraps around hi back to lo, so we must have // that dest[lo+i] - dest[hi-N1+i] >= (hi+i) - (hi-N1+i) = N1, and // also since the range wraps around, it must start at a non-zero // offset z, which means that N had to be a power of 2, and so // N1=N2. Therefore dest[lo+i] - dest[hi-N1+i] >= N2, and as above, // dest[lo+i] and dest[hi-N1+i] cannot both be in [lo..mid] or both // be in [mid..hi], each of which has length N1=N2. // So these oblivious swaps will ensure that all the blocks with // dest in [lo..mid] end up in [lo..mid] and all the blocks with // dest in [mid..hi] end up in [mid..hi]. In addition, the property // that all the non-padding blocks are contiguous (possibly wrapping // around for the [mid..hi] subarray which has length a power of 2) // and monotonicly increasing are preserved for both the [lo..mid] // and [mid..hi] subarrays. for (uint32_t i=0; i(buf, dest, block_size, lo, mid, hi, lo+i, hi-N1+i); } // And now we just recurse on the two subarrays. ORExpand(buf, dest, block_size, lo, mid); ORExpand(buf, dest, block_size, mid, hi); } // Multithreaded version of ORExpand struct mid_oswap_range_args { unsigned char *buf; uint32_t *dest; size_t block_size; uint32_t lo, mid, hi, a, b, num; }; template static void *mid_oswap_range_launch(void *voidargs) { const mid_oswap_range_args *args = (mid_oswap_range_args*)voidargs; for (uint32_t i=0; inum; ++i) { mid_oswap(args->buf, args->dest, args->block_size, args->lo, args->mid, args->hi, args->a + i, args->b + i); } return NULL; } struct ORExpand_parallel_args { unsigned char *buf; uint32_t *dest; size_t block_size; uint32_t lo, hi; threadid_t nthreads; }; template static void* ORExpand_parallel_launch(void *voidargs) { const ORExpand_parallel_args* args = (ORExpand_parallel_args*)voidargs; ORExpand_parallel(args->buf, args->dest, args->block_size, args->lo, args->hi, args->nthreads); return NULL; } // See ORExpand, above, for detailed comments as to how this algorithm // works. template void ORExpand_parallel(unsigned char *buf, uint32_t *dest, size_t block_size, uint32_t lo, uint32_t hi, threadid_t nthreads) { // Passing hi < lo is an illegal input assert(hi >= lo); // The length of the subarray const uint32_t N = hi-lo; // Nothing to do on inputs where [lo..hi] has length 0 or 1 if (N < 2) { return; } // Use the single-threaded version if nthreads <= 1 if (nthreads <= 1) { #ifdef PROFILE_OREXPAND unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads); #endif ORExpand(buf, dest, block_size, lo, hi); #ifdef PROFILE_OREXPAND printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads); #endif return; } #ifdef PROFILE_OREXPAND unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads); #endif // The largest power of 2 strictly less than N const uint32_t N2 = uint32_t(pow2_lt(N)); const uint32_t mid = hi-N2; const uint32_t N1 = N-N2; mid_oswap_range_args args[nthreads]; uint32_t inc = N1 / nthreads; uint32_t extra = N1 % nthreads; uint32_t last = 0; for (threadid_t i=0; i, args+i); } // Do the first section ourselves mid_oswap_range_launch(args); // Join the threads for (threadid_t i=1; i, &rightargs); // Do the left subarray ourselves (with lthreads threads) ORExpand_parallel(buf, dest, block_size, lo, mid, lthreads); // Join the thread threadpool_join(rightthreadid, NULL); #ifdef PROFILE_OREXPAND printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads); #endif }