avl.hpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #ifndef __AVL_HPP__
  2. #define __AVL_HPP__
  3. #include <math.h>
  4. #include <stdio.h>
  5. #include <string>
  6. #include "types.hpp"
  7. #include "duoram.hpp"
  8. #include "cdpf.hpp"
  9. #include "mpcio.hpp"
  10. #include "options.hpp"
  11. #include "bst.hpp"
  12. #define KNRM "\x1B[0m"
  13. #define KRED "\x1B[31m"
  14. #define KGRN "\x1B[32m"
  15. #define KYEL "\x1B[33m"
  16. #define KBLU "\x1B[34m"
  17. #define KMAG "\x1B[35m"
  18. #define KCYN "\x1B[36m"
  19. #define KWHT "\x1B[37m"
  20. /*
  21. For AVL tree we'll treat the pointers fields as:
  22. < L_ptr (31 bits), R_ptr (31 bits), bal_L (1 bit), bal_R (1 bit)>
  23. Where L_ptr and R_ptr are pointers to the left and right child respectively,
  24. and bal_L and bal_R are the balance bits.
  25. Consequently AVL has its own versions of extract and set pointers for its children.
  26. */
  27. #define AVL_PTR_SIZE 31
  28. inline int AVL_TTL(size_t n) {
  29. double logn = log2(n);
  30. double TTL = 1.44 * logn;
  31. return (int(ceil(TTL)));
  32. }
  33. inline RegXS getAVLLeftPtr(RegXS pointer){
  34. return ((pointer&(0xFFFFFFFF00000000))>>33);
  35. }
  36. inline RegXS getAVLRightPtr(RegXS pointer){
  37. return ((pointer&(0x00000001FFFFFFFF))>>2);
  38. }
  39. inline void setAVLLeftPtr(RegXS &pointer, RegXS new_ptr){
  40. pointer&=(0x00000001FFFFFFFF);
  41. pointer+=(new_ptr<<33);
  42. }
  43. inline void setAVLRightPtr(RegXS &pointer, RegXS new_ptr){
  44. pointer&=(0xFFFFFFFE00000003);
  45. pointer+=(new_ptr<<2);
  46. }
  47. inline RegBS getLeftBal(RegXS pointer){
  48. RegBS bal_l;
  49. bool bal_l_bit = ((pointer.share() & (0x0000000000000002))>>1) & 1;
  50. bal_l.set(bal_l_bit);
  51. return bal_l;
  52. }
  53. inline RegBS getRightBal(RegXS pointer){
  54. RegBS bal_r;
  55. bool bal_r_bit = (pointer.share() & (0x0000000000000001)) & 1;
  56. bal_r.set(bal_r_bit);
  57. return bal_r;
  58. }
  59. inline void setLeftBal(RegXS &pointer, RegBS bal_l){
  60. value_t temp_ptr = pointer.share();
  61. temp_ptr&=(0xFFFFFFFFFFFFFFFD);
  62. temp_ptr^=((value_t)(bal_l.share()<<1));
  63. pointer.set(temp_ptr);
  64. }
  65. inline void setRightBal(RegXS &pointer, RegBS bal_r){
  66. value_t temp_ptr = pointer.share();
  67. temp_ptr&=(0xFFFFFFFFFFFFFFFE);
  68. temp_ptr^=((value_t)(bal_r.share()));
  69. pointer.set(temp_ptr);
  70. }
  71. inline void dumpAVL(Node n) {
  72. RegBS left_bal, right_bal;
  73. left_bal = getLeftBal(n.pointers);
  74. right_bal = getRightBal(n.pointers);
  75. printf("[%016lx %016lx %d %d %016lx]", n.key.share(), n.pointers.share(),
  76. left_bal.share(), right_bal.share(), n.value.share());
  77. }
  78. struct avl_del_return {
  79. // Flag to indicate if the key this deletion targets requires a successor swap
  80. RegBS F_ss;
  81. // Pointers to node to be deleted that would be replaced by successor node
  82. RegXS N_d;
  83. // Pointers to successor node that would replace deleted node
  84. RegXS N_s;
  85. // F_rs: Flag for updating child pointer with returned pointer
  86. RegBS F_r;
  87. RegXS ret_ptr;
  88. };
  89. struct avl_insert_return {
  90. RegXS gp_node; // grandparent node
  91. RegXS p_node; // parent node
  92. RegXS c_node; // child node
  93. RegXS i_node; // insertion node
  94. // Direction bits: 0 = Left, 1 = Right
  95. RegBS dir_gpp; // Direction bit from grandparent to parent node
  96. RegBS dir_pc; // Direction bit from p_node to c_node
  97. RegBS dir_cn; // Direction bit from c_node to new_node
  98. RegBS dir_i;
  99. RegBS imbalance;
  100. };
  101. class AVL {
  102. private:
  103. Duoram<Node> oram;
  104. RegXS root;
  105. size_t num_items = 0;
  106. size_t MAX_SIZE;
  107. std::vector<RegXS> empty_locations;
  108. std::tuple<RegBS, RegBS, RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
  109. RegAS ins_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy, avl_insert_return *ret);
  110. void rotate(MPCTIO &tio, yield_t &yield, RegXS &gp_pointers, RegXS p_ptr,
  111. RegXS &p_pointers, RegXS c_ptr, RegXS &c_pointers, RegBS dir_gpp,
  112. RegBS dir_pc, RegBS isNotDummy, RegBS F_gp);
  113. std::tuple<RegBS, RegBS, RegBS, RegBS> updateBalanceIns(MPCTIO &tio, yield_t &yield,
  114. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir);
  115. std::tuple<bool, RegBS> del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  116. Duoram<Node>::Flat &A, RegBS F_af, RegBS F_fs, int TTL,
  117. avl_del_return &ret_struct);
  118. std::tuple<RegBS, RegBS, RegBS, RegBS> updateBalanceDel(MPCTIO &tio, yield_t &yield,
  119. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir);
  120. bool lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key,
  121. Duoram<Node>::Flat &A, int TTL, RegBS isDummy, Node *ret_node);
  122. public:
  123. AVL(int num_players, size_t size) : oram(num_players, size) {
  124. this->MAX_SIZE = size;
  125. };
  126. void init(){
  127. num_items=0;
  128. empty_locations.clear();
  129. }
  130. size_t numEmptyLocations(){
  131. return(empty_locations.size());
  132. };
  133. void insert(MPCTIO &tio, yield_t &yield, const Node &node);
  134. // Deletes the first node that matches del_key
  135. bool del(MPCTIO &tio, yield_t &yield, RegAS del_key);
  136. // Returns the first node that matches key
  137. bool lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node);
  138. // Display and correctness check functions
  139. void pretty_print(MPCTIO &tio, yield_t &yield);
  140. void pretty_print(const std::vector<Node> &R, value_t node,
  141. const std::string &prefix, bool is_left_child, bool is_right_child);
  142. void check_avl(MPCTIO &tio, yield_t &yield);
  143. std::tuple<bool, bool, address_t> check_avl(const std::vector<Node> &R,
  144. value_t node, value_t min_key, value_t max_key);
  145. void print_oram(MPCTIO &tio, yield_t &yield);
  146. // For test functions ONLY:
  147. Duoram<Node>* get_oram() {
  148. return &oram;
  149. };
  150. RegXS get_root() {
  151. return root;
  152. };
  153. };
  154. void avl(MPCIO &mpcio, const PRACOptions &opts, char **args);
  155. void avl_tests(MPCIO &mpcio, const PRACOptions &opts, char **args);
  156. #endif