ORExpand.tcc 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. // See ORExpand.hpp for explanations of notation and inputs.
  2. // Particularly note that all subarrays [lo..hi] are _inclusive_ of lo
  3. // but _exclusive_ of hi.
  4. // buf is an array of block_size-byte blocks. dest is an array of 32-bit
  5. // words. We are given two (contiguous) subarrays [lo..mid] and
  6. // [mid..hi], and indices a and b with a in [lo..mid] and b in
  7. // [mid..hi]. If (mid <= dest[a] < hi) or (lo <= dest[b] <= mid), then
  8. // swap dest[a] with dest[b] and buf[a] with buf[b]; otherwise, do not.
  9. // However, all tests and swaps must be done obliviously to the values
  10. // of dest[a] and dest[b] (and the contents of buf). It's OK to not be
  11. // oblivious to the values of lo, mid, hi, a, and b themselves, however.
  12. template <OSwap_Style oswap_style>
  13. static inline void mid_oswap(unsigned char *buf, uint32_t *dest,
  14. size_t block_size, uint32_t lo, uint32_t mid, uint32_t hi,
  15. uint32_t a, uint32_t b)
  16. {
  17. uint32_t desta = dest[a];
  18. uint32_t destb = dest[b];
  19. uint8_t swap_flag = ((mid <= desta) & (desta < hi))
  20. | ((lo <= destb) & (destb < mid));
  21. // The next line could be optimized with some inline assembly, since
  22. // we've already loaded desta and destb, so we could obliviously
  23. // swap those registers, and then non-obliviously write them back
  24. // out to dest[a] and dest[b].
  25. oswap_buffer<OSWAP_4>((unsigned char *)(dest+a),
  26. (unsigned char *)(dest+b), 4, swap_flag);
  27. oswap_buffer<oswap_style>(buf+a*block_size, buf+b*block_size,
  28. (uint32_t)block_size, swap_flag);
  29. }
  30. template <OSwap_Style oswap_style>
  31. void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
  32. uint32_t lo, uint32_t hi)
  33. {
  34. // Passing hi < lo is an illegal input
  35. assert(hi >= lo);
  36. // The length of the subarray
  37. const uint32_t N = hi-lo;
  38. // Nothing to do on inputs where [lo..hi] has length 0 or 1
  39. if (N < 2) {
  40. return;
  41. }
  42. // The largest power of 2 strictly less than N
  43. const uint32_t N2 = uint32_t(pow2_lt(N));
  44. // We divide [lo..hi] (of length N) into two pieces:
  45. // [lo..mid] and [mid..hi], where [mid..hi] has length N2 (the
  46. // largest power of 2 strictly less than N). Note that mid is just
  47. // _somewhere_ in the middle of [lo..hi]; it will not be the exact
  48. // midpoint if N is not itself a power of 2. (It will be the exact
  49. // midpoint if N is a power of 2, however.)
  50. const uint32_t mid = hi-N2;
  51. // N1 is the length of [lo..mid]. Note that N1 <= N2, with equality
  52. // if and only if N is a power of 2.
  53. const uint32_t N1 = N-N2;
  54. // We're going to do N1 oblivious swaps on the buf and dest arrays,
  55. // between items lo+i and hi-N1+i for 0 <= i < N1. If dest[lo+i]
  56. // lies in [mid..hi] (and is not 0xffffffff to indicate padding)
  57. // _or_ if dest[hi-N1+i] lies in [lo..mid] (again and is not
  58. // 0xffffffff), then we swap them and their corresponding buf
  59. // blocks. The cool part is that it cannot be the case that both
  60. // dest[lo+i] and dest[hi-N1+i] are not padding and they both have
  61. // values on the same side of mid. Why is that?
  62. // Case 1: If dest[lo+i] < dest[hi-N1+i], then all of the blocks
  63. // from lo+i to hi-N1+1 inclusive must be non-padding blocks, and
  64. // since this contiguous block has strictly increasing values, it
  65. // must be that dest[hi-N1+i] - dest[lo+i] >= (hi-N1+i)-(lo+i) =
  66. // N-N1 = N2. Since the lengths of [lo..mid] and [mid..hi] are each
  67. // at most N2, dest[lo+i] and dest[hi-N1+i] cannot be both in the
  68. // same one of those subarrays.
  69. // Case 2: If dest[hi-N1+i] < dest[lo+i], then the contiguous range
  70. // of non-padding blocks wraps around hi back to lo, so we must have
  71. // that dest[lo+i] - dest[hi-N1+i] >= (hi+i) - (hi-N1+i) = N1, and
  72. // also since the range wraps around, it must start at a non-zero
  73. // offset z, which means that N had to be a power of 2, and so
  74. // N1=N2. Therefore dest[lo+i] - dest[hi-N1+i] >= N2, and as above,
  75. // dest[lo+i] and dest[hi-N1+i] cannot both be in [lo..mid] or both
  76. // be in [mid..hi], each of which has length N1=N2.
  77. // So these oblivious swaps will ensure that all the blocks with
  78. // dest in [lo..mid] end up in [lo..mid] and all the blocks with
  79. // dest in [mid..hi] end up in [mid..hi]. In addition, the property
  80. // that all the non-padding blocks are contiguous (possibly wrapping
  81. // around for the [mid..hi] subarray which has length a power of 2)
  82. // and monotonicly increasing are preserved for both the [lo..mid]
  83. // and [mid..hi] subarrays.
  84. for (uint32_t i=0; i<N1; ++i) {
  85. mid_oswap<oswap_style>(buf, dest, block_size, lo, mid, hi,
  86. lo+i, hi-N1+i);
  87. }
  88. // And now we just recurse on the two subarrays.
  89. ORExpand<oswap_style>(buf, dest, block_size, lo, mid);
  90. ORExpand<oswap_style>(buf, dest, block_size, mid, hi);
  91. }
  92. // Multithreaded version of ORExpand
  93. struct mid_oswap_range_args {
  94. unsigned char *buf;
  95. uint32_t *dest;
  96. size_t block_size;
  97. uint32_t lo, mid, hi, a, b, num;
  98. };
  99. template <OSwap_Style oswap_style>
  100. static void *mid_oswap_range_launch(void *voidargs)
  101. {
  102. const mid_oswap_range_args *args =
  103. (mid_oswap_range_args*)voidargs;
  104. for (uint32_t i=0; i<args->num; ++i) {
  105. mid_oswap<oswap_style>(args->buf, args->dest, args->block_size,
  106. args->lo, args->mid, args->hi,
  107. args->a + i, args->b + i);
  108. }
  109. return NULL;
  110. }
  111. struct ORExpand_parallel_args {
  112. unsigned char *buf;
  113. uint32_t *dest;
  114. size_t block_size;
  115. uint32_t lo, hi;
  116. threadid_t nthreads;
  117. };
  118. template <OSwap_Style oswap_style>
  119. static void* ORExpand_parallel_launch(void *voidargs)
  120. {
  121. const ORExpand_parallel_args* args =
  122. (ORExpand_parallel_args*)voidargs;
  123. ORExpand_parallel<oswap_style>(args->buf, args->dest,
  124. args->block_size, args->lo, args->hi, args->nthreads);
  125. return NULL;
  126. }
  127. // See ORExpand, above, for detailed comments as to how this algorithm
  128. // works.
  129. template <OSwap_Style oswap_style>
  130. void ORExpand_parallel(unsigned char *buf, uint32_t *dest,
  131. size_t block_size, uint32_t lo, uint32_t hi, threadid_t nthreads)
  132. {
  133. // Passing hi < lo is an illegal input
  134. assert(hi >= lo);
  135. // The length of the subarray
  136. const uint32_t N = hi-lo;
  137. // Nothing to do on inputs where [lo..hi] has length 0 or 1
  138. if (N < 2) {
  139. return;
  140. }
  141. // Use the single-threaded version if nthreads <= 1
  142. if (nthreads <= 1) {
  143. #ifdef PROFILE_OREXPAND
  144. unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
  145. #endif
  146. ORExpand<oswap_style>(buf, dest, block_size, lo, hi);
  147. #ifdef PROFILE_OREXPAND
  148. printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
  149. #endif
  150. return;
  151. }
  152. #ifdef PROFILE_OREXPAND
  153. unsigned long start = printf_with_rtclock("Thread %u starting ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
  154. #endif
  155. // The largest power of 2 strictly less than N
  156. const uint32_t N2 = uint32_t(pow2_lt(N));
  157. const uint32_t mid = hi-N2;
  158. const uint32_t N1 = N-N2;
  159. mid_oswap_range_args args[nthreads];
  160. uint32_t inc = N1 / nthreads;
  161. uint32_t extra = N1 % nthreads;
  162. uint32_t last = 0;
  163. for (threadid_t i=0; i<nthreads; ++i) {
  164. uint32_t num = inc + (i < extra);
  165. args[i] = { buf, dest, block_size, lo, mid, hi, lo+last,
  166. hi-N1+last, num };
  167. last += num;
  168. }
  169. // Launch all but the first section into other threads
  170. for (threadid_t i=1; i<nthreads; ++i) {
  171. threadpool_dispatch(g_thread_id+i,
  172. mid_oswap_range_launch<oswap_style>, args+i);
  173. }
  174. // Do the first section ourselves
  175. mid_oswap_range_launch<oswap_style>(args);
  176. // Join the threads
  177. for (threadid_t i=1; i<nthreads; ++i) {
  178. threadpool_join(g_thread_id+i, NULL);
  179. }
  180. // Use half the threads for the left subarray and half for the right
  181. // subarray (this choice could be improved if N1 << N2, perhaps).
  182. threadid_t lthreads = nthreads / 2;
  183. threadid_t rthreads = nthreads - lthreads;
  184. threadid_t rightthreadid = g_thread_id + lthreads;
  185. ORExpand_parallel_args rightargs = {
  186. buf, dest, block_size, mid, hi, rthreads
  187. };
  188. threadpool_dispatch(rightthreadid,
  189. ORExpand_parallel_launch<oswap_style>, &rightargs);
  190. // Do the left subarray ourselves (with lthreads threads)
  191. ORExpand_parallel<oswap_style>(buf, dest, block_size, lo, mid,
  192. lthreads);
  193. // Join the thread
  194. threadpool_join(rightthreadid, NULL);
  195. #ifdef PROFILE_OREXPAND
  196. printf_with_rtclock_diff(start, "Thread %u ending ORExpand(N=%u, lo=%u, hi=%u, nthreads=%hu)\n", g_thread_id, N, lo, hi, nthreads);
  197. #endif
  198. }