duoram.cpp 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #include "duoram.hpp"
  2. #include "shapes.hpp"
  3. // Assuming the memory is already sorted, do an oblivious binary
  4. // search for the smallest index containing the value at least the
  5. // given one. (The answer will be the length of the Flat if all
  6. // elements are smaller than the target.) Only available for additive
  7. // shared databases for now.
  8. template <>
  9. RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
  10. {
  11. if (this->shape_size == 0) {
  12. RegAS zero;
  13. return zero;
  14. }
  15. // Create a Pad of the smallest power of 2 size strictly greater
  16. // than the Flat size
  17. address_t padsize = 1;
  18. nbits_t depth = 0;
  19. while (padsize <= this->shape_size) {
  20. padsize *= 2;
  21. ++depth;
  22. }
  23. Duoram<RegAS>::Pad P(*this, tio, yield, padsize);
  24. // Start in the middle
  25. RegAS index;
  26. index.set(this->tio.player() ? 0 : (1<<(depth-1))-1);
  27. // Invariant: index points to the last element of the left half of
  28. // the remaining possible range, which is of width (1<<depth).
  29. while (depth > 0) {
  30. // Obliviously read the value there
  31. RegAS val = P[index];
  32. // Compare it to the target
  33. CDPF cdpf = tio.cdpf(this->yield);
  34. auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
  35. val-target, tio.aes_ops());
  36. if (depth > 1) {
  37. // If val >= target, the answer is here or to the left
  38. // and we should subtract 2^{depth-2} from index
  39. // If val < target, the answer is strictly to the right
  40. // and we should add 2^{depth-2} to index
  41. // So we unconditionally subtract 2^{depth-2} from index, and
  42. // add (lt)*2^{depth-1}.
  43. RegAS uncond;
  44. uncond.set(tio.player() ? 0 : address_t(1)<<(depth-2));
  45. RegAS cond;
  46. cond.set(tio.player() ? 0 : address_t(1)<<(depth-1));
  47. RegAS condprod;
  48. mpc_flagmult(this->tio, this->yield, condprod, lt, cond);
  49. index -= uncond;
  50. index += condprod;
  51. } else {
  52. // The possible range is of width 2, and we're pointing to
  53. // the first element of it.
  54. // If val >= target, the answer is here or to the left, so
  55. // it's here.
  56. // If val < target, the answer is strictly to the right
  57. // so add lt to index
  58. RegAS cond;
  59. cond.set(tio.player() ? 0 : 1);
  60. RegAS condprod;
  61. mpc_flagmult(this->tio, this->yield, condprod, lt, cond);
  62. index += condprod;
  63. }
  64. --depth;
  65. }
  66. return index;
  67. }