ORExpand.tcc 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include "utils.hpp"
  2. // See ORExpand.hpp for explanations of notation and inputs.
  3. // Particularly note that all subarrays [lo..hi] are _inclusive_ of lo
  4. // but _exclusive_ of hi.
  5. // buf is an array of block_size-byte blocks. dest is an array of 32-bit
  6. // words. We are given two (contiguous) subarrays [lo..mid] and
  7. // [mid..hi], and indices a and b with a in [lo..mid] and b in
  8. // [mid..hi]. If (mid <= dest[a] < hi) or (lo <= dest[b] <= mid), then
  9. // swap dest[a] with dest[b] and buf[a] with buf[b]; otherwise, do not.
  10. // However, all tests and swaps must be done obliviously to the values
  11. // of dest[a] and dest[b] (and the contents of buf). It's OK to not be
  12. // oblivious to the values of lo, mid, hi, a, and b themselves, however.
  13. template <OSwap_Style oswap_style>
  14. static inline void mid_oswap(unsigned char *buf, uint32_t *dest,
  15. size_t block_size, uint32_t lo, uint32_t mid, uint32_t hi,
  16. uint32_t a, uint32_t b)
  17. {
  18. uint32_t desta = dest[a];
  19. uint32_t destb = dest[b];
  20. uint8_t swap_flag = ((mid <= desta) & (desta < hi))
  21. | ((lo <= destb) & (destb < mid));
  22. oswap_buffer<OSWAP_4>((unsigned char *)(dest+a),
  23. (unsigned char *)(dest+b), 4, swap_flag);
  24. oswap_buffer<oswap_style>(buf+a*block_size, buf+b*block_size,
  25. (uint32_t)block_size, swap_flag);
  26. }
  27. template <OSwap_Style oswap_style>
  28. void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
  29. uint32_t lo, uint32_t hi)
  30. {
  31. // Passing hi < lo is an illegal input
  32. assert(hi >= lo);
  33. // The length of the subarray
  34. const uint32_t N = hi-lo;
  35. // Nothing to do on inputs where [lo..hi] has length 0 or 1
  36. if (N < 2) {
  37. return;
  38. }
  39. // The largest power of 2 strictly less than N
  40. const uint32_t N2 = uint32_t(pow2_lt(N));
  41. // We divide [lo..hi] (of length N) into two pieces:
  42. // [lo..mid] and [mid..hi], where [mid..hi] has length N2 (the
  43. // largest power of 2 strictly less than N). Note that mid is just
  44. // _somewhere_ in the middle of [lo..hi]; it will not be the exact
  45. // midpoint if N is not itself a power of 2. (It will be the exact
  46. // midpoint if N is a power of 2, however.)
  47. const uint32_t mid = hi-N2;
  48. // N1 is the length of [lo..mid]. Note that N1 <= N2, with equality
  49. // if and only if N is a power of 2.
  50. const uint32_t N1 = N-N2;
  51. // We're going to do N1 oblivious swaps on the buf and dest arrays,
  52. // between items lo+i and hi-N1+i for 0 <= i < N1. If dest[lo+i]
  53. // lies in [mid..hi] (and is not 0xffffffff to indicate padding)
  54. // _or_ if dest[hi-N1+i] lies in [lo..mid] (again and is not
  55. // 0xffffffff), then we swap them and their corresponding buf
  56. // blocks. The cool part is that it cannot be the case that both
  57. // dest[lo+i] and dest[hi-N1+i] are not padding and they both have
  58. // values on the same side of mid. Why is that?
  59. // Case 1: If dest[lo+i] < dest[hi-N1+i], then all of the blocks
  60. // from lo+i to hi-N1+1 inclusive must be non-padding blocks, and
  61. // since this contiguous block has strictly increasing values, it
  62. // must be that dest[hi-N1+i] - dest[lo+i] >= (hi-N1+i)-(lo+i) =
  63. // N-N1 = N2. Since the lengths of [lo..mid] and [mid..hi] are each
  64. // at most N2, dest[lo+i] and dest[hi-N1+i] cannot be both in the
  65. // same one of those subarrays.
  66. // Case 2: If dest[hi-N1+i] < dest[lo+i], then the contiguous range
  67. // of non-padding blocks wraps around hi back to lo, so we must have
  68. // that dest[lo+i] - dest[hi-N1+i] >= (hi+i) - (hi-N1+i) = N1, and
  69. // also since the range wraps around, it must start at a non-zero
  70. // offset z, which means that N had to be a power of 2, and so
  71. // N1=N2. Therefore dest[lo+i] - dest[hi-N1+i] >= N2, and as above,
  72. // dest[lo+i] and dest[hi-N1+i] cannot both be in [lo..mid] or both
  73. // be in [mid..hi], each of which has length N1=N2.
  74. // So these oblivious swaps will ensure that all the blocks with
  75. // dest in [lo..mid] end up in [lo..mid] and all the blocks with
  76. // dest in [mid..hi] end up in [mid..hi]. In addition, the property
  77. // that all the non-padding blocks are contiguous (possibly wrapping
  78. // around for the [mid..hi] subarray which has length a power of 2)
  79. // and monotonicly increasing are preserved for both the [lo..mid]
  80. // and [mid..hi] subarrays.
  81. for (uint32_t i=0; i<N1; ++i) {
  82. mid_oswap<oswap_style>(buf, dest, block_size, lo, mid, hi,
  83. lo+i, hi-N1+i);
  84. }
  85. // And now we just recurse on the two subarrays.
  86. ORExpand<oswap_style>(buf, dest, block_size, lo, mid);
  87. ORExpand<oswap_style>(buf, dest, block_size, mid, hi);
  88. }