bst.hpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #ifndef __BST_HPP__
  2. #define __BST_HPP__
  3. #include "types.hpp"
  4. #include "duoram.hpp"
  5. #include "cdpf.hpp"
  6. #include "mpcio.hpp"
  7. #include "options.hpp"
  8. // #define BST_DEBUG
  9. // Some simple utility functions:
  10. bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag);
  11. struct Node {
  12. RegAS key;
  13. RegXS pointers;
  14. RegXS value;
  15. // Field-access macros so we can write A[i].NODE_KEY instead of
  16. // A[i].field(&Node::key)
  17. #define NODE_KEY field(&Node::key)
  18. #define NODE_POINTERS field(&Node::pointers)
  19. #define NODE_VALUE field(&Node::value)
  20. // For debugging and checking answers
  21. void dump() const {
  22. printf("[%016lx %016lx %016lx]", key.share(), pointers.share(),
  23. value.share());
  24. }
  25. // You'll need to be able to create a random element, and do the
  26. // operations +=, +, -=, - (binary and unary). Note that for
  27. // XOR-shared fields, + and - are both really XOR.
  28. inline void randomize() {
  29. key.randomize();
  30. pointers.randomize();
  31. value.randomize();
  32. }
  33. inline Node &operator+=(const Node &rhs) {
  34. this->key += rhs.key;
  35. this->pointers += rhs.pointers;
  36. this->value += rhs.value;
  37. return *this;
  38. }
  39. inline Node operator+(const Node &rhs) const {
  40. Node res = *this;
  41. res += rhs;
  42. return res;
  43. }
  44. inline Node &operator-=(const Node &rhs) {
  45. this->key -= rhs.key;
  46. this->pointers -= rhs.pointers;
  47. this->value -= rhs.value;
  48. return *this;
  49. }
  50. inline Node operator-(const Node &rhs) const {
  51. Node res = *this;
  52. res -= rhs;
  53. return res;
  54. }
  55. inline Node operator-() const {
  56. Node res;
  57. res.key = -this->key;
  58. res.pointers = -this->pointers;
  59. res.value = -this->value;
  60. return res;
  61. }
  62. // Multiply each field by the local share of the corresponding field
  63. // in the argument
  64. inline Node mulshare(const Node &rhs) const {
  65. Node res = *this;
  66. res.key.mulshareeq(rhs.key);
  67. res.pointers.mulshareeq(rhs.pointers);
  68. res.value.mulshareeq(rhs.value);
  69. return res;
  70. }
  71. // You need a method to turn a leaf node of a DPF into a share of a
  72. // unit element of your type. Typically set each RegAS to
  73. // dpf.unit_as(leaf) and each RegXS or RegBS to dpf.unit_bs(leaf).
  74. // Note that RegXS will extend a RegBS of 1 to the all-1s word, not
  75. // the word with value 1. This is used for ORAM reads, where the
  76. // same DPF is used for all the fields.
  77. template <nbits_t WIDTH>
  78. inline void unit(const RDPF<WIDTH> &dpf,
  79. typename RDPF<WIDTH>::LeafNode leaf) {
  80. key = dpf.unit_as(leaf);
  81. pointers = dpf.unit_bs(leaf);
  82. value = dpf.unit_bs(leaf);
  83. }
  84. // Perform an update on each of the fields, using field-specific
  85. // MemRefs constructed from the Shape shape and the index idx
  86. template <typename Sh, typename U>
  87. inline static void update(Sh &shape, yield_t &shyield, U idx,
  88. const Node &M) {
  89. run_coroutines(shyield,
  90. [&shape, &idx, &M] (yield_t &yield) {
  91. Sh Sh_coro = shape.context(yield);
  92. Sh_coro[idx].NODE_KEY += M.key;
  93. },
  94. [&shape, &idx, &M] (yield_t &yield) {
  95. Sh Sh_coro = shape.context(yield);
  96. Sh_coro[idx].NODE_POINTERS += M.pointers;
  97. },
  98. [&shape, &idx, &M] (yield_t &yield) {
  99. Sh Sh_coro = shape.context(yield);
  100. Sh_coro[idx].NODE_VALUE += M.value;
  101. });
  102. }
  103. };
  104. /*
  105. A function to perform key comparsions for BST traversal.
  106. Inputs: k1 = key of node in the tree, k2 = insertion/deletion/lookup key.
  107. Evaluates (k2-k1), and combines the lt and eq flag into one (flag to go
  108. left), and keeps the gt flag as is (flag to go right) during traversal.
  109. Returns the shared bit flags lteq (go left) and gt (go right).
  110. */
  111. std::tuple<RegBS, RegBS> compare_keys(MPCTIO &tio, yield_t &yield, RegAS k1, RegAS k2);
  112. // I/O operations (for sending over the network)
  113. template <typename T>
  114. T& operator>>(T& is, Node &x)
  115. {
  116. is >> x.key >> x.pointers >> x.value;
  117. return is;
  118. }
  119. template <typename T>
  120. T& operator<<(T& os, const Node &x)
  121. {
  122. os << x.key << x.pointers << x.value;
  123. return os;
  124. }
  125. // This macro will define I/O on tuples of two or three of the node type
  126. DEFAULT_TUPLE_IO(Node)
  127. struct del_return {
  128. // Flag to indicate if the key this deletion targets requires a successor swap
  129. RegBS F_ss;
  130. // Pointers to node to delete and successor node that would replace
  131. // deleted node
  132. RegXS N_d;
  133. RegXS N_s;
  134. // Flag for updating child pointer with returned pointer
  135. RegBS F_r;
  136. RegXS ret_ptr;
  137. };
  138. class BST {
  139. private:
  140. Duoram<Node> oram;
  141. RegXS root;
  142. size_t num_items = 0;
  143. size_t MAX_SIZE;
  144. std::vector<RegXS> empty_locations;
  145. std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
  146. RegAS insertion_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy);
  147. void insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A);
  148. bool del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  149. Duoram<Node>::Flat &A, RegBS F_af, RegBS F_fs, int TTL,
  150. del_return &ret_struct);
  151. RegBS lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key,
  152. Duoram<Node>::Flat &A, int TTL, RegBS isDummy, Node *ret_node);
  153. void pretty_print(const std::vector<Node> &R, value_t node,
  154. const std::string &prefix, bool is_left_child, bool is_right_child);
  155. std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
  156. value_t node, value_t min_key, value_t max_key);
  157. public:
  158. BST(int num_players, size_t size) : oram(num_players, size) {
  159. this->MAX_SIZE = size;
  160. };
  161. // Inserts the provided node into the BST
  162. void insert(MPCTIO &tio, yield_t &yield, Node &node);
  163. // Deletes the first node that matches del_key from the BST.
  164. // If an item with del_key does not exist in the tree, it results in an
  165. // explicit (non-oblivious) failure.
  166. bool del(MPCTIO &tio, yield_t &yield, RegAS del_key);
  167. // Returns the first node that matches key in the BST
  168. RegBS lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node);
  169. // Display and correctness check functions
  170. // Print the BST
  171. void pretty_print(MPCTIO &tio, yield_t &yield);
  172. // Check BST correctness
  173. void check_bst(MPCTIO &tio, yield_t &yield);
  174. // Debugging Functions
  175. #ifdef BST_DEBUG
  176. // Print the underlying ORAM state
  177. void print_oram(MPCTIO &tio, yield_t &yield);
  178. // Check the number of empty locations in ORAM
  179. // (Locations freed up after a delete operation, reusable for next insert.)
  180. size_t numEmptyLocations(){
  181. return(empty_locations.size());
  182. };
  183. #endif
  184. };
  185. void bst(MPCIO &mpcio,
  186. const PRACOptions &opts, char **args);
  187. #endif