duoram.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #include "duoram.hpp"
  2. #include "shapes.hpp"
  3. // Assuming the memory is already sorted, do an oblivious binary
  4. // search for the smallest index containing the value at least the
  5. // given one. (The answer will be the length of the Shape if all
  6. // elements are smaller than the target.) Only available for additive
  7. // shared databases for now.
  8. // The basic version uses log(N) ORAM reads of size N, where N is the
  9. // smallest power of 2 strictly larger than the Shape size
  10. template <>
  11. RegAS Duoram<RegAS>::Shape::basic_binary_search(RegAS &target)
  12. {
  13. if (this->shape_size == 0) {
  14. RegAS zero;
  15. return zero;
  16. }
  17. // Create a Pad of the smallest power of 2 size strictly greater
  18. // than the Shape size
  19. address_t padsize = 1;
  20. nbits_t depth = 0;
  21. while (padsize <= this->shape_size) {
  22. padsize *= 2;
  23. ++depth;
  24. }
  25. Duoram<RegAS>::Pad P(*this, tio, yield, padsize);
  26. // Start in the middle
  27. RegAS index;
  28. index.set(this->tio.player() ? 0 : (1<<(depth-1))-1);
  29. // Invariant: index points to the last element of the left half of
  30. // the remaining possible range, which is of width (1<<depth).
  31. while (depth > 0) {
  32. // Obliviously read the value there
  33. RegAS val = P[index];
  34. // Compare it to the target
  35. CDPF cdpf = tio.cdpf(this->yield);
  36. auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
  37. val-target, tio.aes_ops());
  38. if (depth > 1) {
  39. // If val >= target, the answer is here or to the left
  40. // and we should subtract 2^{depth-2} from index
  41. // If val < target, the answer is strictly to the right
  42. // and we should add 2^{depth-2} to index
  43. // So we unconditionally subtract 2^{depth-2} from index, and
  44. // add (lt)*2^{depth-1}.
  45. RegAS uncond;
  46. uncond.set(tio.player() ? 0 : address_t(1)<<(depth-2));
  47. RegAS cond;
  48. cond.set(tio.player() ? 0 : address_t(1)<<(depth-1));
  49. RegAS condprod;
  50. mpc_flagmult(this->tio, this->yield, condprod, lt, cond);
  51. index -= uncond;
  52. index += condprod;
  53. } else {
  54. // The possible range is of width 2, and we're pointing to
  55. // the first element of it.
  56. // If val >= target, the answer is here or to the left, so
  57. // it's here.
  58. // If val < target, the answer is strictly to the right
  59. // so add lt to index
  60. RegAS cond;
  61. cond.set(tio.player() ? 0 : 1);
  62. RegAS condprod;
  63. mpc_flagmult(this->tio, this->yield, condprod, lt, cond);
  64. index += condprod;
  65. }
  66. --depth;
  67. }
  68. return index;
  69. }
  70. // This version does 1 ORAM read of size 2, 1 of size 4, 1 of size
  71. // 8, ..., 1 of size N/2, where N is the smallest power of 2 strictly
  72. // larger than the Shape size
  73. template <>
  74. RegXS Duoram<RegAS>::Shape::binary_search(RegAS &target)
  75. {
  76. if (this->shape_size == 0) {
  77. RegXS zero;
  78. return zero;
  79. }
  80. // Create a Pad of the smallest power of 2 size strictly greater
  81. // than the Shape size
  82. address_t padsize = 1;
  83. nbits_t depth = 0;
  84. while (padsize <= this->shape_size) {
  85. padsize *= 2;
  86. ++depth;
  87. }
  88. Duoram<RegAS>::Pad P(*this, tio, yield, padsize);
  89. // Explicitly read the middle item
  90. address_t mid = (1<<(depth-1))-1;
  91. RegAS val = P[mid];
  92. // Compare it to the target
  93. CDPF cdpf = tio.cdpf(this->yield);
  94. auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
  95. val-target, tio.aes_ops());
  96. if (depth == 1) {
  97. // There was only one item in the Shape, and mid will equal 0, so
  98. // val is (a share of) that item, P[0]. If val >= target, the
  99. // answer is here or to the left, so it must be 0. If val <
  100. // target, the answer is strictly to the right, so it must be 1.
  101. // So just return lt.
  102. RegXS ret;
  103. ret.xshare = lt.bshare;
  104. return ret;
  105. }
  106. auto oidx = P.oblivindex(depth-1);
  107. oidx.incr(lt);
  108. --depth;
  109. while(depth > 0) {
  110. // Create the Stride shape; the ORAM will operate only over
  111. // elements of the Stride, which will consist of exactly those
  112. // elements of the Pad we could possibly be accessing at this
  113. // depth. Those will be elements start=(1<<(depth-1)-1,
  114. // start+(1<<depth), start+(2<<depth), start+(3<<depth), and so
  115. // on. The invariant is that the range of remaining possible
  116. // answers is of width (1<<depth), and we will look at the
  117. // rightmost element of the left half. If that value (val) has
  118. // val >= target, then the answer is at that position or to the
  119. // left, so we append a 0 to the index. If val < targer, then
  120. // the answer is strictly to the right, so we append a 1 to the
  121. // index. That is, always append lt to the index.
  122. Duoram<RegAS>::Stride S(P, tio, yield, (1<<(depth-1))-1, 1<<depth);
  123. RegAS val = S[oidx];
  124. CDPF cdpf = tio.cdpf(this->yield);
  125. auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
  126. val-target, tio.aes_ops());
  127. oidx.incr(lt);
  128. --depth;
  129. }
  130. return oidx.index();
  131. }