ソースを参照

Tweaks to incorporate AVL balance bit checks. Correcting updateBalanceDel's bal_upd propogation clause

sshsshy 11 ヶ月 前
コミット
f3b38e8a79
2 ファイル変更106 行追加37 行削除
  1. 104 36
      avl.cpp
  2. 2 1
      avl.hpp

+ 104 - 36
avl.cpp

@@ -89,28 +89,45 @@ void AVL::pretty_print(MPCTIO &tio, yield_t &yield) {
 // tuple<bool,address_t>, where the bool says whether the BST invariant
 // holds, and the address_t is the height of the tree (which will be
 // useful later when we check AVL trees).
-std::tuple<bool, bool, address_t> AVL::check_avl(const std::vector<Node> &R,
+std::tuple<bool, bool, bool, address_t> AVL::check_avl(const std::vector<Node> &R,
     value_t node, value_t min_key = 0, value_t max_key = ~0)
 {
     if (node == 0) {
-        return { true, true, 0 };
+        return { true, true, true, 0};
     }
     const Node &n = R[node];
     value_t key = n.key.ashare;
     value_t left_ptr = getAVLLeftPtr(n.pointers).xshare;
     value_t right_ptr = getAVLRightPtr(n.pointers).xshare;
-    auto [leftok, leftavlok, leftheight ] = check_avl(R, left_ptr, min_key, key);
-    auto [rightok, rightavlok, rightheight ] = check_avl(R, right_ptr, key, max_key);
+    auto [leftok, leftavlok, leftbbok, leftheight ] = check_avl(R, left_ptr, min_key, key);
+    auto [rightok, rightavlok, rightbbok, rightheight ] = check_avl(R, right_ptr, key, max_key);
     address_t height = leftheight;
     if (rightheight > height) {
         height = rightheight;
     }
     height += 1;
     int heightgap = leftheight - rightheight;
+    bool leftbal = (getLeftBal(n.pointers)).bshare;
+    bool rightbal = (getRightBal(n.pointers)).bshare;
     bool avlok = (abs(heightgap)<2);
+    bool bb_ok = false;
+
+    if(heightgap==-1) {
+        if(rightbal==1 && leftbal==0){
+            bb_ok = true;
+        }
+    } else if(heightgap==1){
+        if(leftbal==1 && rightbal==0){
+            bb_ok = true;
+        }
+    } else if(heightgap==0){
+        if(rightbal==0 && leftbal==0) {
+            bb_ok = true;
+        }
+    }
     //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
     return { leftok && rightok && key >= min_key && key <= max_key,
-        avlok && leftavlok && rightavlok, height};
+        avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
 }
 
 void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
@@ -126,9 +143,9 @@ void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
         rec_root+= peer_root;
     }
     if (tio.player() == 0) {
-      auto [ bst_ok, avl_ok, height ] = check_avl(R, rec_root.xshare);
-      printf("BST structure %s\nAVL structure %s\nTree height = %u\n",
-          bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", height);
+      auto [ bst_ok, avl_ok, bb_ok, height ] = check_avl(R, rec_root.xshare);
+      printf("BST structure %s\nAVL structure %s\nBalance Bits %s\nTree height = %u\n",
+          bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", bb_ok? "ok" : "NOT OK", height);
     }
 }
 
@@ -268,7 +285,7 @@ std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_
         RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir) {
     bool player0 = tio.player()==0;
     RegBS s0;
-    RegBS F_rs, F_ls, balanced, imbalance;
+    RegBS F_rs, F_ls, balanced, imbalance, not_imbalance;
     RegBS nt_child_dir = child_dir;
     if(player0) {
         nt_child_dir^=1;
@@ -308,11 +325,26 @@ std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_
         { mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);});
 
     // if(bal_upd) and not imbalance bal_upd<-0
+    /*
     RegBS bu0;
-    mpc_and(tio, yield, bu0, bal_upd, balanced);
+    not_imbalance = imbalance;
+    if(player0){
+        not_imbalance^=1;
+    } 
+    mpc_and(tio, yield, bu0, bal_upd, not_imbalance);
     mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
+    */
 
-    // Any bal_upd, propogates all the way up to root
+    // if(bal_upd) and this node turns balanced, the height has decreased, so continue propogating bal_upd.
+    // if(bal_upd) and node turns imbalanced, fixImbalance will update bal_upd correctly.
+    // if(bal_upd) and node moves out of balanced to left/right heavy, the height of this subtree has not changed,
+    //   so don't propogate bal_upd.
+    // if(bal_upd && bal_l ^ bal_r
+    RegBS LR_heavy, bu0;
+    LR_heavy = bal_l ^ bal_r;
+    mpc_and(tio, yield, bu0, bal_upd, LR_heavy);
+    mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
+    
     return {bal_l, bal_r, bal_upd, imbalance};
 }
 
@@ -1020,20 +1052,29 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
 
     */
     RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
+    RegBS cs_zero_bal = cs_bal_dpc ^ cs_bal_ndpc;
+    if(player0) {
+        cs_zero_bal^=1;
+    }
     run_coroutines(tio, [&tio, &IC1, imb, cs_bal_ndpc] (yield_t &yield) {
         // IC1 = Single rotation (L/R). L/R = dpc
-        mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
+            mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
         },
         [&tio, &IC3, imb, cs_bal_dpc](yield_t &yield) {
         // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
-        mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
+            mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
+        },
+        [&tio, &IC2, imb, cs_zero_bal](yield_t &yield) {
+            mpc_and(tio, yield, IC2, imb, cs_zero_bal);  
         });
 
     // IC2 = Single rotation (L/R).
+    /*
     IC2 = IC1 ^ IC3;
     if(player0) {
       IC2^=1;
     }
+    */
 
     RegBS p_bal_dpc, p_bal_ndpc;
     RegBS IC2_ndpc_l, IC2_ndpc_r, IC2_dpc_l, IC2_dpc_r;
@@ -1375,16 +1416,40 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         RegBS p_bal_l, p_bal_r;
         p_bal_l = getLeftBal(node.pointers);
         p_bal_r = getRightBal(node.pointers);
+
+        #ifdef DEBUG
+          size_t rec_key = reconstruct_RegAS(tio, yield, node.key);
+          bool rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
+          printf("current_key = %ld, bal_upd (before updateBalanceDel) = %d\n", rec_key, rec_bal_upd);
+        #endif
+
         auto [new_p_bal_l, new_p_bal_r, new_bal_upd, imb] =
             updateBalanceDel(tio, yield, p_bal_l, p_bal_r, bal_upd, c_prime);
+        bal_upd = new_bal_upd;
+
+        #ifdef DEBUG
+          bool rec_imb = reconstruct_RegBS(tio, yield, imb);
+          bool rec_new_bal_upd = reconstruct_RegBS(tio, yield, new_bal_upd);
+          printf("new_bal_upd (after updateBalanceDel) = %d, imb = %d\n", rec_new_bal_upd, rec_imb);
+        #endif
         
         // F_ri: subflag for F_r. F_ri = returned flag set to 1 from imbalance fix.
         RegBS F_ri;
         fixImbalance(tio, yield, A, oidx, oldptrs, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd, 
               c_prime, cs_ptr, imb, F_ri, ret_struct);
 
+        #ifdef DEBUG
+          rec_imb = reconstruct_RegBS(tio, yield, imb);
+          rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd); 
+          printf("imb (after fixImbalance) = %d, bal_upd = %d\n", rec_imb, rec_bal_upd);
+        #endif
         updateRetStruct(tio, yield, ptr, F_2, F_c2, F_c4, lf, F_ri, found, bal_upd, ret_struct); 
 
+        #ifdef DEBUG
+          rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
+          printf("bal_upd (after updateRetStruct) = %d\n", rec_bal_upd);
+        #endif
+
         return {key_found, bal_upd};
     }
 }
@@ -2856,25 +2921,25 @@ void avl_tests(MPCIO &mpcio,
 
                  5                     5                   5
                /   \                 /   \               /   \
-              3     12    Del 1     3     12            3     9
-             /      /    ------>          /     --->         / \
-            1      7                     9                  7   12
-                    \                   /
-                     9                 7
+              3     8     Del 7     3     8             3     9
+             /     /  \   ------>  /       \     --->  /     /  \
+            1     7    12         1         9         1     8    12
+                      /                      \
+                     9                        12
 
 
             T8 checks:
             - root is 5
-            - 3,9,7,12 are in correct positions
-            - Nodes 3,7,12 have 0 balance
-            - Nodes 3,7,12 have no children
-            - 5's bal = 0 1
+            - 3,9,8,12 are in correct positions
+            - Nodes 1,5,8,9,12 have 0 balance
+            - Nodes 1,5,8,9,12 have no children
+            - Node 3 has 1 0 balance 
 
         */
         {
             bool success = 1;
-            int insert_array[] = {5, 3, 12, 7, 1, 9};
-            size_t insert_array_size = 5;
+            int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
+            size_t insert_array_size = 6;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
               newnode(node);
@@ -2884,7 +2949,7 @@ void avl_tests(MPCIO &mpcio,
             }
 
             RegAS del_key;
-            del_key.set(1 * tio.player());
+            del_key.set(7 * tio.player());
             tree.del(tio, yield, del_key);
             tree.check_avl(tio, yield);
 
@@ -2893,8 +2958,8 @@ void avl_tests(MPCIO &mpcio,
             size_t root = reconstruct_RegXS(tio, yield, root_xs);
             auto A = oram->flat(tio, yield);
             auto R = A.reconstruct();
-            Node root_node, n3, n7, n9, n12;
-            size_t n3_index, n7_index, n9_index, n12_index;
+            Node root_node, n1, n3, n8, n9, n12;
+            size_t n1_index, n3_index, n8_index, n9_index, n12_index;
             root_node = R[root];
             if((root_node.key).share()!=5) {
                 success = false;
@@ -2903,34 +2968,37 @@ void avl_tests(MPCIO &mpcio,
             n9_index = (getAVLRightPtr(root_node.pointers)).share();
             n3 = R[n3_index];
             n9 = R[n9_index];
-            n7_index = getAVLLeftPtr(n9.pointers).share();
+            n1_index = getAVLLeftPtr(n3.pointers).share();
+            n8_index = getAVLLeftPtr(n9.pointers).share();
             n12_index = getAVLRightPtr(n9.pointers).share();
-            n7 = R[n7_index];
+            n1 = R[n1_index];
+            n8 = R[n8_index];
             n12 = R[n12_index];
 
             // Node value checks
+            if(n1.key.share()!=1) {
+                success = false;
+            }
             if(n3.key.share()!=3 || n9.key.share()!=9) {
                 success = false;
             }
-            if(n7.key.share()!=7 || n12.key.share()!=12) {
+            if(n8.key.share()!=8 || n12.key.share()!=12) {
                 success = false;
             }
 
             // Node balance checks
             size_t zero = 0;
-            zero+=(n3.pointers.share());
-            zero+=(n7.pointers.share());
+            zero+=(n1.pointers.share());
+            zero+=(getRightBal(n3.pointers).share());
+            zero+=(n8.pointers.share());
             zero+=(n12.pointers.share());
             zero+=(getLeftBal(root_node.pointers).share());
+            zero+=(getRightBal(root_node.pointers).share());
             zero+=(getLeftBal(n9.pointers).share());
             zero+=(getRightBal(n9.pointers).share());
             if(zero!=0) {
                 success = false;
             }
-            int one = (getRightBal(root_node.pointers).share());
-            if(one!=1) {
-                success = false;
-            }
             if(player0) {
                 if(success) {
                     print_green("T16 : SUCCESS\n");

+ 2 - 1
avl.hpp

@@ -23,6 +23,7 @@
 #define OPT_ON 0
 #define RANDOMIZE 0
 #define SANITY_TEST 0
+//#define DEBUG 0
 
 /*
   For AVL tree we'll treat the pointers fields as:
@@ -207,7 +208,7 @@ class AVL {
     void pretty_print(const std::vector<Node> &R, value_t node,
         const std::string &prefix, bool is_left_child, bool is_right_child);
     void check_avl(MPCTIO &tio, yield_t &yield);
-    std::tuple<bool, bool, address_t> check_avl(const std::vector<Node> &R,
+    std::tuple<bool, bool, bool, address_t> check_avl(const std::vector<Node> &R,
         value_t node, value_t min_key, value_t max_key);
     void print_oram(MPCTIO &tio, yield_t &yield);