avl.hpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #ifndef __AVL_HPP__
  2. #define __AVL_HPP__
  3. #include <math.h>
  4. #include <stdio.h>
  5. #include <string>
  6. #include "types.hpp"
  7. #include "duoram.hpp"
  8. #include "cdpf.hpp"
  9. #include "mpcio.hpp"
  10. #include "options.hpp"
  11. #include "bst.hpp"
  12. /*
  13. Macro definitions:
  14. AVL_OPT_ON: Turn AVL optimizations on
  15. Optimizations:
  16. - Use incremental DPFs for traversing the tree
  17. - Use updates instead of writes when possible
  18. RANDOMIZE: Randomize keys of items inserted. When turned off, items
  19. with incremental keys are inserted
  20. DEBUG: General debug flag
  21. DEBUG_BB: Debug flag for balance bit computations
  22. */
  23. #define AVL_OPT_ON
  24. // #define AVL_RANDOMIZE_INSERTS
  25. // #define AVL_DEBUG
  26. // #define AVL_DEBUG_BB
  27. /*
  28. For AVL tree we'll treat the pointers fields as:
  29. < L_ptr (31 bits), R_ptr (31 bits), bal_L (1 bit), bal_R (1 bit)>
  30. Where L_ptr and R_ptr are pointers to the left and right child respectively,
  31. and bal_L and bal_R are the balance bits.
  32. Consequently AVL has its own versions of extract and set pointers for its children.
  33. */
  34. #define AVL_PTR_SIZE 31
  35. inline int AVL_TTL(size_t n) {
  36. if(n==0) {
  37. return 0;
  38. } else if (n==1) {
  39. return 1;
  40. } else {
  41. double logn = log2(n);
  42. double TTL = 1.44 * logn;
  43. return (int(ceil(TTL)));
  44. }
  45. }
  46. inline RegXS getAVLLeftPtr(RegXS pointer){
  47. return (pointer>>33);
  48. }
  49. inline RegXS getAVLRightPtr(RegXS pointer){
  50. return ((pointer&(0x00000001FFFFFFFC))>>2);
  51. }
  52. inline void setAVLLeftPtr(RegXS &pointer, RegXS new_ptr){
  53. pointer&=(0x00000001FFFFFFFF);
  54. pointer+=(new_ptr<<33);
  55. }
  56. inline void setAVLRightPtr(RegXS &pointer, RegXS new_ptr){
  57. pointer&=(0xFFFFFFFE00000003);
  58. pointer+=(new_ptr<<2);
  59. }
  60. inline RegBS getLeftBal(RegXS pointer){
  61. RegBS bal_l;
  62. bool bal_l_bit = ((pointer.share() & (0x0000000000000002))>>1);
  63. bal_l.set(bal_l_bit);
  64. return bal_l;
  65. }
  66. inline RegBS getRightBal(RegXS pointer){
  67. RegBS bal_r;
  68. bool bal_r_bit = (pointer.share() & (0x0000000000000001));
  69. bal_r.set(bal_r_bit);
  70. return bal_r;
  71. }
  72. inline void setLeftBal(RegXS &pointer, RegBS bal_l){
  73. value_t temp_ptr = pointer.share();
  74. temp_ptr&=(0xFFFFFFFFFFFFFFFD);
  75. temp_ptr^=((value_t)(bal_l.share()<<1));
  76. pointer.set(temp_ptr);
  77. }
  78. inline void setRightBal(RegXS &pointer, RegBS bal_r){
  79. value_t temp_ptr = pointer.share();
  80. temp_ptr&=(0xFFFFFFFFFFFFFFFE);
  81. temp_ptr^=((value_t)(bal_r.share()));
  82. pointer.set(temp_ptr);
  83. }
  84. inline void dumpAVL(Node n) {
  85. RegBS left_bal, right_bal;
  86. left_bal = getLeftBal(n.pointers);
  87. right_bal = getRightBal(n.pointers);
  88. printf("[%016lx %016lx(L:%ld, R:%ld) %d %d %016lx]", n.key.share(), n.pointers.share(),
  89. getAVLLeftPtr(n.pointers).xshare, getAVLRightPtr(n.pointers).xshare,
  90. left_bal.share(), right_bal.share(), n.value.share());
  91. }
  92. struct avl_del_return {
  93. // Flag to indicate if the key this deletion targets requires a successor swap
  94. RegBS F_ss;
  95. // Pointers to node to be deleted that would be replaced by successor node
  96. RegXS N_d;
  97. // Pointers to successor node that would replace deleted node
  98. RegXS N_s;
  99. // F_r: Flag for updating child pointer with returned pointer
  100. RegBS F_r;
  101. RegXS ret_ptr;
  102. };
  103. struct avl_insert_return {
  104. RegXS gp_node; // grandparent node
  105. RegXS p_node; // parent node
  106. RegXS c_node; // child node
  107. // Direction bits: 0 = Left, 1 = Right
  108. RegBS dir_gpp; // Direction bit from grandparent to parent node
  109. RegBS dir_pc; // Direction bit from p_node to c_node
  110. RegBS dir_cn; // Direction bit from c_node to new_node
  111. RegBS imbalance;
  112. };
  113. class AVL {
  114. private:
  115. Duoram<Node> oram;
  116. RegXS root;
  117. size_t num_items = 0;
  118. size_t cur_max_index = 0;
  119. size_t MAX_SIZE;
  120. int MAX_DEPTH;
  121. std::vector<RegXS> empty_locations;
  122. std::tuple<RegBS, RegBS, RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
  123. RegXS ins_addr, RegAS ins_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy,
  124. avl_insert_return &ret);
  125. void rotate(MPCTIO &tio, yield_t &yield, RegXS &gp_pointers, RegXS p_ptr,
  126. RegXS &p_pointers, RegXS c_ptr, RegXS &c_pointers, RegBS dir_gpp,
  127. RegBS dir_pc, RegBS isNotDummy, RegBS F_gp);
  128. std::tuple<RegBS, RegBS, RegBS, RegBS> updateBalanceIns(MPCTIO &tio, yield_t &yield,
  129. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir);
  130. void updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &right,
  131. RegBS c_prime, const avl_del_return &ret_struct);
  132. void fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
  133. Duoram<Node>::OblivIndex<RegXS,1> oidx, RegXS oidx_oldptrs, RegXS ptr,
  134. RegXS nodeptrs, RegBS p_bal_l, RegBS p_bal_r, RegBS &bal_upd, RegBS c_prime,
  135. RegXS cs_ptr, RegBS imb, RegBS &F_ri, avl_del_return &ret_struct);
  136. void updateRetStruct(MPCTIO &tio, yield_t &yield, RegXS ptr, RegBS F_rs,
  137. RegBS F_dh, RegBS F_ri, RegBS &bal_upd, avl_del_return &ret_struct);
  138. std::tuple<bool, RegBS> del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  139. Duoram<Node>::Flat &A, RegBS F_af, RegBS F_fs, int TTL,
  140. avl_del_return &ret_struct);
  141. std::tuple<RegBS, RegBS, RegBS, RegBS> updateBalanceDel(MPCTIO &tio, yield_t &yield,
  142. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir);
  143. bool lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key,
  144. Duoram<Node>::Flat &A, int TTL, RegBS isDummy, Node *ret_node);
  145. void pretty_print(const std::vector<Node> &R, value_t node,
  146. const std::string &prefix, bool is_left_child, bool is_right_child);
  147. std::tuple<bool, bool, bool, address_t> check_avl(const std::vector<Node> &R,
  148. value_t node, value_t min_key, value_t max_key);
  149. public:
  150. AVL(int num_players, size_t size) : oram(num_players, size) {
  151. this->MAX_SIZE = size;
  152. MAX_DEPTH = 0;
  153. while(size>0) {
  154. MAX_DEPTH+=1;
  155. size=size>>1;
  156. }
  157. };
  158. void init(){
  159. num_items=0;
  160. cur_max_index=0;
  161. empty_locations.clear();
  162. }
  163. void insert(MPCTIO &tio, yield_t &yield, const Node &node);
  164. // Deletes the first node that matches del_key. If an item with del_key
  165. // does not exist in the tree, it results in an explicit (non-oblivious)
  166. // failure.
  167. bool del(MPCTIO &tio, yield_t &yield, RegAS del_key);
  168. // Returns the first node that matches key
  169. bool lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node);
  170. // Non-obliviously initialize an AVL tree of a particular size
  171. void initialize(MPCTIO &tio, yield_t &yield, size_t depth);
  172. // Display and correctness check functions
  173. void pretty_print(MPCTIO &tio, yield_t &yield);
  174. void check_avl(MPCTIO &tio, yield_t &yield);
  175. void print_oram(MPCTIO &tio, yield_t &yield);
  176. // For test functions ONLY:
  177. Duoram<Node>* get_oram() {
  178. return &oram;
  179. };
  180. RegXS get_root() {
  181. return root;
  182. };
  183. };
  184. void avl(MPCIO &mpcio, const PRACOptions &opts, char **args);
  185. void avl_tests(MPCIO &mpcio, const PRACOptions &opts, char **args);
  186. #endif