浏览代码

Deletion correctness fix for incorrect balance bits in IC3. Expanded IC3 to IC3_S1 to IC3_S3

sshsshy 2 年之前
父节点
当前提交
5993e613dc
共有 2 个文件被更改,包括 112 次插入31 次删除
  1. 110 30
      avl.cpp
  2. 2 1
      avl.hpp

+ 110 - 30
avl.cpp

@@ -125,6 +125,11 @@ std::tuple<bool, bool, bool, address_t> AVL::check_avl(const std::vector<Node> &
             bb_ok = true;
             bb_ok = true;
         }
         }
     }
     }
+    #ifdef DEBUG_BB
+      if(bb_ok == false){
+          printf("BB check failed at node with key = %ld\n", key);
+      }
+    #endif
     //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
     //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
     return { leftok && rightok && key >= min_key && key <= max_key,
     return { leftok && rightok && key >= min_key && key <= max_key,
         avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
         avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
@@ -1000,6 +1005,9 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         gcs_node = A[gcs_ptr];
         gcs_node = A[gcs_ptr];
     #endif
     #endif
 
 
+    RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
+    RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
+
     not_c_prime = c_prime;
     not_c_prime = c_prime;
     if(player0) {
     if(player0) {
       not_c_prime^=1;
       not_c_prime^=1;
@@ -1025,7 +1033,7 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         new_cs_pointers, s0, not_c_prime, imb, s1);
         new_cs_pointers, s0, not_c_prime, imb, s1);
 
 
     // If imb (we do some rotation), then update F_r, and ret_ptr, to
     // If imb (we do some rotation), then update F_r, and ret_ptr, to
-    // fix the gp->p link (The F_r clauses later, and this are mutually
+    // fix the gp->p link (There are F_r clauses later, but they are mutually
     // exclusive events. They will never trigger together.)
     // exclusive events. They will never trigger together.)
 
 
     std::vector<coro_t> coroutines;
     std::vector<coro_t> coroutines;
@@ -1049,7 +1057,6 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     /*
     /*
        Update balances based on imbalance and type of rotations that happen.
        Update balances based on imbalance and type of rotations that happen.
        In the case of an imbalance, updateBalance() sets bal_l and bal_r of p to 0.
        In the case of an imbalance, updateBalance() sets bal_l and bal_r of p to 0.
-
     */
     */
     RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
     RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
     RegBS cs_zero_bal = cs_bal_dpc ^ cs_bal_ndpc;
     RegBS cs_zero_bal = cs_bal_dpc ^ cs_bal_ndpc;
@@ -1060,25 +1067,37 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         // IC1 = Single rotation (L/R). L/R = dpc
         // IC1 = Single rotation (L/R). L/R = dpc
             mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
             mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
         },
         },
+        // IC2 = Single rotation (L/R). L/R = dpc
+        [&tio, &IC2, imb, cs_zero_bal](yield_t &yield) {
+            mpc_and(tio, yield, IC2, imb, cs_zero_bal);  
+        },
         [&tio, &IC3, imb, cs_bal_dpc](yield_t &yield) {
         [&tio, &IC3, imb, cs_bal_dpc](yield_t &yield) {
         // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
         // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
             mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
             mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
-        },
-        [&tio, &IC2, imb, cs_zero_bal](yield_t &yield) {
-            mpc_and(tio, yield, IC2, imb, cs_zero_bal);  
         });
         });
 
 
-    // IC2 = Single rotation (L/R).
-    /*
-    IC2 = IC1 ^ IC3;
-    if(player0) {
-      IC2^=1;
-    }
-    */
+    RegBS p_bal_ndpc, p_bal_dpc;
+    RegBS ndpc_is_l, ndpc_is_r, dpc_is_l, dpc_is_r;
+    RegBS IC2_n_cprime;
+
+    // First flags to check dpc = L/R, and similarly ndpc = L/R
+    coroutines.emplace_back([&tio, &ndpc_is_l, c_prime, imb] (yield_t &yield)
+        { mpc_and(tio, yield, ndpc_is_l, imb, c_prime);});
+    coroutines.emplace_back([&tio, &ndpc_is_r, imb, not_c_prime](yield_t &yield)
+        { mpc_and(tio, yield, ndpc_is_r, imb, not_c_prime);});
+    coroutines.emplace_back([&tio, &dpc_is_l, imb, not_c_prime](yield_t &yield)
+        { mpc_and(tio, yield, dpc_is_l, imb, not_c_prime);});
+    coroutines.emplace_back([&tio, &dpc_is_r, imb, c_prime](yield_t &yield)
+        { mpc_and(tio, yield, dpc_is_r, imb, c_prime);});
+    run_coroutines(tio, coroutines); 
+    coroutines.clear();
+
 
 
-    RegBS p_bal_dpc, p_bal_ndpc;
-    RegBS IC2_ndpc_l, IC2_ndpc_r, IC2_dpc_l, IC2_dpc_r;
+    mpc_and(tio, yield, IC2_n_cprime, IC2, c_prime);
+    mpc_select(tio, yield, p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r);
+    mpc_select(tio, yield, p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r);
 
 
+    /*
     run_coroutines(tio, [&tio, &IC2, imb] (yield_t &yield)
     run_coroutines(tio, [&tio, &IC2, imb] (yield_t &yield)
         { mpc_and(tio, yield, IC2, imb, IC2);},
         { mpc_and(tio, yield, IC2, imb, IC2);},
         [&tio, &cs_bal_dpc, imb, s0](yield_t &yield) 
         [&tio, &cs_bal_dpc, imb, s0](yield_t &yield) 
@@ -1086,27 +1105,80 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
           mpc_select(tio, yield, cs_bal_dpc, imb, cs_bal_dpc, s0);},
           mpc_select(tio, yield, cs_bal_dpc, imb, cs_bal_dpc, s0);},
         [&tio, &cs_bal_ndpc, c_prime, imb, s0](yield_t &yield) {
         [&tio, &cs_bal_ndpc, c_prime, imb, s0](yield_t &yield) {
         mpc_select(tio, yield, cs_bal_ndpc, imb, cs_bal_ndpc, s0);});
         mpc_select(tio, yield, cs_bal_ndpc, imb, cs_bal_ndpc, s0);});
+    */
+    mpc_select(tio, yield, cs_bal_ndpc, IC1, cs_bal_ndpc, s0);
+    mpc_select(tio, yield, cs_bal_dpc, IC2, cs_bal_dpc, s1);
+    mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);
+    mpc_select(tio, yield, p_bal_ndpc, IC2, p_bal_ndpc, s1);
+
+    /* IC3 has 3 subcases:
+          IC3_S1: gcs_bal_dpc = 0, gcs_bal_ndpc = 1
+          IC3_S2: gcs_bal_dpc = 1, gc_bal_ndpc = 0
+          IC3_S3: gcs_bal_dpc = 0, gcs_bal_ndpc = 0
+    
+          IC3_S1: p_dpc <- 1
+                  cs_dpc <- 0 
+                  (gcs_bal stays same)
+          IC3_S2: Swap cs_dpc and cs_ndpc (1 0 -> - 1).
+                  cs_dpc <- 0, cs_ndpc <- 1
+                  gcs_bal_dpc <- 0
+          IC3_S3: cs_dpc <- 0
+                  gcs_bal stays same
+    */
+    RegBS IC3_S1, IC3_S2, IC3_S3, gcs_balanced, gcs_bal_dpc, gcs_bal_ndpc;
+    mpc_select(tio, yield, gcs_bal_dpc, c_prime, gcs_bal_l, gcs_bal_r);
+    mpc_select(tio, yield, gcs_bal_ndpc, c_prime, gcs_bal_r, gcs_bal_l);
+    mpc_and(tio, yield, IC3_S1, IC3, gcs_bal_ndpc);
+    mpc_and(tio, yield, IC3_S2, IC3, gcs_bal_dpc);
+    gcs_balanced = gcs_bal_dpc ^ gcs_bal_ndpc;
+    if(player0) {
+        gcs_balanced^=1;
+    }
+    mpc_and(tio, yield, IC3_S3, IC3, gcs_balanced);
+
+    //mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);
+    //mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);
+    RegBS imb_n_cprime;
+    mpc_and(tio, yield, imb_n_cprime, imb, c_prime); 
+
+    
+    mpc_select(tio, yield, p_bal_dpc, IC3_S1, p_bal_dpc, s1);
+    mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);
+
+    mpc_select(tio, yield, cs_bal_ndpc, IC3_S2, cs_bal_ndpc, s1);
+    mpc_select(tio, yield, gcs_bal_dpc, IC3_S2, gcs_bal_dpc, s0);
+
+    // Updating balance bits of p and cs.
+
+    // Write back updated balance bits
+    // Updating gcs_bal_l/r
+    run_coroutines(tio, [&tio, &gcs_bal_r, dpc_is_r, gcs_bal_dpc] (yield_t &yield) 
+        { mpc_select(tio, yield, gcs_bal_r, dpc_is_r, gcs_bal_r, gcs_bal_dpc);},
+        [&tio, &gcs_bal_l, dpc_is_l, gcs_bal_dpc](yield_t &yield)
+        { mpc_select(tio, yield, gcs_bal_l, dpc_is_l, gcs_bal_l, gcs_bal_dpc);});
+
+    // Updating cs_bal_l/r (cs_bal_dpc effected by IC3, cs_bal_ndpc effected by IC1,2)
+    run_coroutines(tio, [&tio, &cs_bal_r, dpc_is_r, cs_bal_dpc] (yield_t &yield) 
+        { mpc_select(tio, yield, cs_bal_r, dpc_is_r, cs_bal_r, cs_bal_dpc);},
+        [&tio, &cs_bal_l, dpc_is_l, cs_bal_dpc](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_l, dpc_is_l, cs_bal_l, cs_bal_dpc);});
+
+    run_coroutines(tio, [&tio, &cs_bal_r, ndpc_is_r, cs_bal_ndpc] (yield_t &yield) 
+        { mpc_select(tio, yield, cs_bal_r, ndpc_is_r, cs_bal_r, cs_bal_ndpc);},
+        [&tio, &cs_bal_l, ndpc_is_l, cs_bal_ndpc](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_l, ndpc_is_l, cs_bal_l, cs_bal_ndpc);});
+
+    // Updating new_p_bal_l/r (p_bal_ndpc effected by IC2)
+    run_coroutines(tio, [&tio, &new_p_bal_r, ndpc_is_r, p_bal_ndpc] (yield_t &yield) 
+        { mpc_select(tio, yield, new_p_bal_r, ndpc_is_r, new_p_bal_r, p_bal_ndpc);},
+        [&tio, &new_p_bal_l, ndpc_is_l, p_bal_ndpc](yield_t &yield)
+        { mpc_select(tio, yield, new_p_bal_l, ndpc_is_l, new_p_bal_l, p_bal_ndpc);});
 
 
-    run_coroutines(tio, [&tio, &cs_bal_r, c_prime, cs_bal_ndpc, cs_bal_dpc] (yield_t &yield) 
-        { mpc_select(tio, yield, cs_bal_r, c_prime, cs_bal_ndpc, cs_bal_dpc);},
-        [&tio, &cs_bal_l, c_prime, cs_bal_dpc, cs_bal_ndpc](yield_t &yield)
-        { mpc_select(tio, yield, cs_bal_l, c_prime, cs_bal_dpc, cs_bal_ndpc);});
 
 
     // IC2: p.bal_ndpc = 1, cs.bal_dpc = 1
     // IC2: p.bal_ndpc = 1, cs.bal_dpc = 1
     // (IC2 & not_c_prime)
     // (IC2 & not_c_prime)
-    coroutines.emplace_back([&tio, &p_bal_ndpc, c_prime, new_p_bal_r, new_p_bal_l](yield_t &yield)
-        { mpc_select(tio, yield, p_bal_ndpc, c_prime, new_p_bal_r, new_p_bal_l);});
-    coroutines.emplace_back([&tio, &IC2_ndpc_l, c_prime, IC2] (yield_t &yield)
-        { mpc_and(tio, yield, IC2_ndpc_l, IC2, c_prime);});
-    coroutines.emplace_back([&tio, &IC2_ndpc_r, IC2, not_c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC2_ndpc_r, IC2, not_c_prime);});
-    coroutines.emplace_back([&tio, &IC2_dpc_l, IC2, not_c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC2_dpc_l, IC2, not_c_prime);});
-    coroutines.emplace_back([&tio, &IC2_dpc_r, IC2, c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC2_dpc_r, IC2, c_prime);});
-    run_coroutines(tio, coroutines); 
-    coroutines.clear();
 
 
+    /*
     cs_bal_dpc^=IC2;
     cs_bal_dpc^=IC2;
     p_bal_ndpc^=IC2;
     p_bal_ndpc^=IC2;
   
   
@@ -1125,7 +1197,13 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);});
         mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);});
     run_coroutines(tio, coroutines);
     run_coroutines(tio, coroutines);
     coroutines.clear();
     coroutines.clear();
+    */
+
+    // In the IC2 case bal_upd = 0 (The rotation doesn't end up
+    // decreasing height of this subtree.
+    mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);
 
 
+    /*
     // IC3:
     // IC3:
     // To set balance in this case we need to know if gcs.dpc child exists
     // To set balance in this case we need to know if gcs.dpc child exists
     // and similarly if gcs.ndpc child exitst.
     // and similarly if gcs.ndpc child exitst.
@@ -1186,6 +1264,8 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     coroutines.emplace_back([&tio, &gcs_bal_r, IC3, s0](yield_t &yield) {
     coroutines.emplace_back([&tio, &gcs_bal_r, IC3, s0](yield_t &yield) {
         mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);});
         mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);});
     run_coroutines(tio, coroutines); 
     run_coroutines(tio, coroutines); 
+    */
+
 
 
     // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
     // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
     setLeftBal(gcs_node.pointers, gcs_bal_l);
     setLeftBal(gcs_node.pointers, gcs_bal_l);

+ 2 - 1
avl.hpp

@@ -23,7 +23,8 @@
 #define OPT_ON 0
 #define OPT_ON 0
 #define RANDOMIZE 0
 #define RANDOMIZE 0
 #define SANITY_TEST 0
 #define SANITY_TEST 0
-//#define DEBUG 0
+// #define DEBUG 0
+// #define DEBUG_BB 0
 
 
 /*
 /*
   For AVL tree we'll treat the pointers fields as:
   For AVL tree we'll treat the pointers fields as: