Explorar o código

Splitting del() into fixImbalance() and updateChildPointers()

sshsshy hai 1 ano
pai
achega
d016ffb9d6
Modificáronse 2 ficheiros con 262 adicións e 235 borrados
  1. 255 235
      avl.cpp
  2. 7 0
      avl.hpp

+ 255 - 235
avl.cpp

@@ -702,6 +702,237 @@ bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
     return found;
 }
 
+void AVL::updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &right,
+          RegBS c_prime, avl_del_return ret_struct) {
+    bool player0 = tio.player()==0;
+    RegBS F_rr; // Flag to resolve F_r by updating correct child ptr
+    mpc_and(tio, yield, F_rr, c_prime, ret_struct.F_r);
+    mpc_select(tio, yield, right, F_rr, right, ret_struct.ret_ptr);
+    if(player0)
+        c_prime^=1;
+    mpc_and(tio, yield, F_rr, c_prime, ret_struct.F_r);
+    mpc_select(tio, yield, left, F_rr, left, ret_struct.ret_ptr);
+    if(player0)
+        c_prime^=1;
+}
+
+
+// Perform rotations if imbalance (else dummy rotations)
+/*
+   For capturing both the symmetric L and R cases of rotations, we'll capture directions with
+   dpc  = dir_pc = direction from parent to child, and
+   ndpc = not(dir_pc)
+   When we travelled down the stack, we went from p->c. But in deletions to handle any imbalance
+   we look at c's sibling cs (child's sibling). And the rotation is between p and cs if there
+   was an imbalance at p, and perhaps even cs and it's child (the child in dir_pc, as that's the
+   only case that results in a double rotation when deleting).
+
+   In case of an imbalance we have to always rotate p->cs link. (L or R case)
+   If cs.bal_(dir_pc), then we have a double rotation (LR or RL) case.
+   In such cases, first rotate cs->gcs link, and then p->cs link. gcs = grandchild on cs path
+
+   Layout: In the R (or LR) case:
+
+         p
+       /   \
+      cs    c
+     /  \
+    a   gcs
+        /  \
+       x    y
+
+   - One of x or y must exist for it to be an LR case,
+     since then cs.bal_(dir_pc) = cs.bal_r = 1
+
+   Layout: In the L (or RL) case:
+
+         p
+       /   \
+      c     cs
+           /  \
+         gcs   a
+        /   \
+       x     y
+
+   - One of x or y must exist for it to be an RL case,
+     since then cs.bal_(dir_pc) = cs.bal_l = 1
+
+   (Note: if double rotation case, in the second rotation cs is actually gcs,
+    since the the first rotation swaps their positions)
+*/
+
+void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, RegXS ptr, 
+        RegXS nodeptrs, RegBS new_p_bal_l, RegBS new_p_bal_r, RegBS &bal_upd, 
+        RegBS c_prime, RegXS cs_ptr, RegBS imb, RegBS &F_ri, 
+        avl_del_return &ret_struct) {
+    bool player0 = tio.player()==0;
+    RegBS s0, s1;
+    s1.set(tio.player()==1);
+    Node cs_node = A[cs_ptr];
+    //dirpc = dir_pc = dpc = c_prime
+    RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, F_dr, not_c_prime;
+    RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
+    // child's sibling node's balances in dir_pc (dpc), and not_dir_pc (ndpc)
+    cs_bal_l = getLeftBal(cs_node.pointers);
+    cs_bal_r = getRightBal(cs_node.pointers);
+    cs_left = getAVLLeftPtr(cs_node.pointers);
+    cs_right = getAVLRightPtr(cs_node.pointers);
+    mpc_select(tio, yield, cs_bal_dpc, c_prime, cs_bal_l, cs_bal_r);
+    mpc_select(tio, yield, cs_bal_ndpc, c_prime, cs_bal_r, cs_bal_l);
+    mpc_select(tio, yield, cs_dpc, c_prime, cs_left, cs_right);
+    mpc_select(tio, yield, cs_ndpc, c_prime, cs_right, cs_left);
+
+    // We need to double rotate (LR or RL case) if cs_bal_dpc is 1
+    mpc_and(tio, yield, F_dr, imb, cs_bal_dpc);
+    mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);
+    Node gcs_node = A[gcs_ptr];
+
+    not_c_prime = c_prime;
+    if(player0) {
+      not_c_prime^=1;
+    }
+    // First rotation: cs->gcs link
+    rotate(tio, yield, nodeptrs, cs_ptr, cs_node.pointers, gcs_ptr,
+        gcs_node.pointers, not_c_prime, c_prime, F_dr, s0);
+
+    // If F_dr, we did first rotation. Then cs and gcs need to swap before the second rotate.
+    RegXS new_cs_pointers, new_cs, new_ptr;
+    mpc_select(tio, yield, new_cs_pointers, F_dr, cs_node.pointers, gcs_node.pointers);
+    mpc_select(tio, yield, new_cs, F_dr, cs_ptr, gcs_ptr, AVL_PTR_SIZE);
+
+    // Second rotation: p->cs link
+    // Since we don't have access to gp node here we just send a null and s0
+    // for gp_pointers and dir_gpp. Instead this pointer fix is handled by F_r
+    // and ret_struct.ret_ptr.
+    rotate(tio, yield, null, ptr, nodeptrs, new_cs,
+        new_cs_pointers, s0, not_c_prime, imb, s1);
+
+    /*
+    size_t rec_p_left_1, rec_p_right_1;
+    bool rec_flag_imb, rec_flag_dr;
+    rec_flag_imb = reconstruct_RegBS(tio, yield, imb);
+    rec_flag_dr = reconstruct_RegBS(tio, yield, F_dr);
+    rec_p_left_1 = reconstruct_RegXS(tio, yield, getAVLLeftPtr(node.pointers));
+    rec_p_right_1 = reconstruct_RegXS(tio, yield, getAVLRightPtr(node.pointers));
+    printf("flag_imb = %d, flag_dr = %d\n", rec_flag_imb, rec_flag_dr);
+    printf("parent_ptrs (foundter rotations): left = %lu, right = %lu\n", rec_p_left_1, rec_p_right_1);
+    */
+
+    // If imb (we do some rotation), then update F_r, and ret_ptr, to
+    // fix the gp->p link (The F_r clauses later, and this are mutually
+    // exclusive events. They will never trigger together.)
+    mpc_select(tio, yield, new_ptr, F_dr, cs_ptr, gcs_ptr);
+    mpc_select(tio, yield, F_ri, imb, s0, s1);
+    mpc_select(tio, yield, ret_struct.ret_ptr, imb, ret_struct.ret_ptr, new_ptr);
+
+    // Write back new_cs_pointers correctly to (cs_node/gcs_node).pointers
+    // and then balance the nodes
+    mpc_select(tio, yield, cs_node.pointers, F_dr, new_cs_pointers, cs_node.pointers);
+    mpc_select(tio, yield, gcs_node.pointers, F_dr, gcs_node.pointers, new_cs_pointers);
+
+    /*
+       Update balances based on imbalance and type of rotations that happen.
+       In the case of an imbalance, updateBalance() sets bal_l and bal_r of p to 0.
+
+    */
+    RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
+    // IC1 = Single rotation (L/R). L/R = dpc
+    mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
+    // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
+    mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
+    // IC2 = Single rotation (L/R).
+    IC2 = IC1 ^ IC3;
+    if(player0) {
+      IC2^=1;
+    }
+    mpc_and(tio, yield, IC2, imb, IC2);
+
+    /*
+    bool rec_IC1, rec_IC2, rec_IC3;
+    rec_IC1 = reconstruct_RegBS(tio, yield, IC1);
+    rec_IC2 = reconstruct_RegBS(tio, yield, IC2);
+    rec_IC3 = reconstruct_RegBS(tio, yield, IC3);
+    printf("rec_IC1 = %d, rec_IC2 = %d, rec_IC3 = %d\n", rec_IC1, rec_IC2, rec_IC3);
+    */
+
+    // IC1, IC2, IC3: CS.bal = 0 0
+    mpc_select(tio, yield, cs_bal_dpc, imb, cs_bal_dpc, s0);
+    mpc_select(tio, yield, cs_bal_ndpc, imb, cs_bal_ndpc, s0);
+    mpc_select(tio, yield, cs_bal_r, c_prime, cs_bal_ndpc, cs_bal_dpc);
+    mpc_select(tio, yield, cs_bal_l, c_prime, cs_bal_dpc, cs_bal_ndpc);
+
+    // IC2: p.bal_ndpc = 1, cs.bal_dpc = 1
+    // (IC2 & not_c_prime)
+    cs_bal_dpc^=IC2;
+    RegBS p_bal_dpc, p_bal_ndpc;
+    mpc_select(tio, yield, p_bal_ndpc, c_prime, new_p_bal_r, new_p_bal_l);
+    p_bal_ndpc^=IC2;
+    RegBS IC2_ndpc_l, IC2_ndpc_r, IC2_dpc_l, IC2_dpc_r;
+    mpc_and(tio, yield, IC2_ndpc_l, IC2, c_prime);
+    mpc_and(tio, yield, IC2_ndpc_r, IC2, not_c_prime);
+    mpc_and(tio, yield, IC2_dpc_l, IC2, not_c_prime);
+    mpc_and(tio, yield, IC2_dpc_r, IC2, c_prime);
+
+    mpc_select(tio, yield, new_p_bal_l, IC2_ndpc_l, new_p_bal_l, p_bal_ndpc);
+    mpc_select(tio, yield, new_p_bal_r, IC2_ndpc_r, new_p_bal_r, p_bal_ndpc);
+    mpc_select(tio, yield, cs_bal_l, IC2_dpc_l, cs_bal_l, cs_bal_dpc);
+    mpc_select(tio, yield, cs_bal_r, IC2_dpc_r, cs_bal_r, cs_bal_dpc);
+    // In the IC2 case bal_upd = 0 (The rotation doesn't end up
+    // decreasing height of this subtree.
+    mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);
+
+    // IC3:
+    // To set balance in this case we need to know if gcs.dpc child exists
+    // and similarly if gcs.ndpc child exitst.
+    // if(gcs.ndpc child exists): cs.bal_ndpc = 1
+    // if(gcs.dpc child exists): p.bal_dpc = 1
+    RegBS gcs_dpc_exists, gcs_ndpc_exists;
+    RegXS gcs_l = getAVLLeftPtr(gcs_node.pointers);
+    RegXS gcs_r = getAVLRightPtr(gcs_node.pointers);
+    RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
+    RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
+    RegXS gcs_dpc, gcs_ndpc;
+    mpc_select(tio, yield, gcs_dpc, c_prime, gcs_l, gcs_r);
+    mpc_select(tio, yield, gcs_ndpc, not_c_prime, gcs_l, gcs_r);
+
+    CDPF cdpf = tio.cdpf(yield);
+    gcs_dpc_exists = cdpf.is_zero(tio, yield, gcs_dpc, tio.aes_ops());
+    gcs_ndpc_exists = cdpf.is_zero(tio, yield, gcs_ndpc, tio.aes_ops());
+    cs_bal_ndpc^=IC3;
+    RegBS IC3_ndpc_l, IC3_ndpc_r, IC3_dpc_l, IC3_dpc_r;
+    mpc_and(tio, yield, IC3_ndpc_l, IC3, c_prime);
+    mpc_and(tio, yield, IC3_ndpc_r, IC3, not_c_prime);
+    mpc_and(tio, yield, IC3_dpc_l, IC3, not_c_prime);
+    mpc_and(tio, yield, IC3_dpc_r, IC3, c_prime);
+    RegBS f0, f1, f2, f3;
+    mpc_and(tio, yield, f0, IC3_dpc_l, gcs_dpc_exists);
+    mpc_and(tio, yield, f1, IC3_dpc_r, gcs_dpc_exists);
+    mpc_and(tio, yield, f2, IC3_ndpc_l, gcs_ndpc_exists);
+    mpc_and(tio, yield, f3, IC3_ndpc_r, gcs_ndpc_exists);
+
+    mpc_select(tio, yield, new_p_bal_l, f0, new_p_bal_l, IC3);
+    mpc_select(tio, yield, new_p_bal_r, f1, new_p_bal_r, IC3);
+    mpc_select(tio, yield, cs_bal_l, f2, cs_bal_l, IC3);
+    mpc_select(tio, yield, cs_bal_r, f3, cs_bal_r, IC3);
+    // In IC3 gcs.bal = 0 0
+    mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);
+    mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);
+
+    // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
+    setLeftBal(gcs_node.pointers, gcs_bal_l);
+    setRightBal(gcs_node.pointers, gcs_bal_r);
+    setLeftBal(cs_node.pointers, cs_bal_l);
+    setRightBal(cs_node.pointers, cs_bal_r);
+
+    A[cs_ptr].NODE_POINTERS = cs_node.pointers;
+    A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
+
+    // Write back updated pointers correctly accounting for rotations
+    setLeftBal(nodeptrs, new_p_bal_l);
+    setRightBal(nodeptrs, new_p_bal_r);
+    A[ptr].NODE_POINTERS = nodeptrs;
+}
+
 
 std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
       Duoram<Node>::Flat &A, RegBS found, RegBS find_successor, int TTL,
@@ -711,8 +942,6 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         //Reconstruct and return found
         bool success = reconstruct_RegBS(tio, yield, found);
         RegBS zero;
-        if(player0)
-          ret_struct.F_r^=1;
         return {success, zero};
     } else {
         Node node = A[ptr];
@@ -808,38 +1037,13 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         if(!key_found) {
           return {0, s0};
         }
-        /* F_rs: Flag for updating the correct child pointer of this node
-           This happens if F_r is set in ret_struct. F_r indicates if we need
-           to update a child pointer at this level by skipping the current
-           child in the direction of traversal. We do this in two cases:
-             i) F_d & (!F_2) : If we delete here, and this node does not have
-                2 children (;i.e., we are not in the finding successor case)
-            ii) F_ns: Found the successor (no more left children while
-                traversing to find successor)
-           In cases i and ii we skip the next node, and make the current node
-           point to the node foundter the next node.
-
-           iii) We did rotation(s) at the lower level, changing the child in
-                that position. So we update it to the correct node in that
-                position now.
-           Whether skip happens or just update happens is handled by how
-           ret_struct.ret_ptr is set.
-        */
-
-        RegBS F_rr; // Flag to resolve F_r by updating correct child ptr
-        mpc_and(tio, yield, F_rr, c_prime, ret_struct.F_r);
-        mpc_select(tio, yield, right, F_rr, right, ret_struct.ret_ptr);
-        if(player0)
-            c_prime^=1;
-        mpc_and(tio, yield, F_rr, c_prime, ret_struct.F_r);
-        mpc_select(tio, yield, left, F_rr, left, ret_struct.ret_ptr);
-        if(player0)
-            c_prime^=1;
 
+        updateChildPointers(tio, yield, left, right, c_prime, ret_struct);
         setAVLLeftPtr(node.pointers, left);
         setAVLRightPtr(node.pointers, right);
         // Delay storing pointers back until balance updates are done as well.
-        // Since we resolved the F_r flag returned, we set it back to 0.
+        // Since we resolved the F_r flag returned with updateChildPointers(),
+        // we set it back to 0.
         ret_struct.F_r = s0;
 
         RegBS p_bal_l, p_bal_r;
@@ -872,217 +1076,33 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
 
         // F_ri: subflag for F_r. F_ri = returned flag set to 1 from imbalance fix.
         RegBS F_ri;
-        // Perform rotations if imbalance (else dummy rotations)
-        /*
-           For capturing both the symmetric L and R cases of rotations, we'll capture directions with
-           dpc  = dir_pc = direction from parent to child, and
-           ndpc = not(dir_pc)
-           When we travelled down the stack, we went from p->c. But in deletions to handle any imbalance
-           we look at c's sibling cs (child's sibling). And the rotation is between p and cs if there
-           was an imbalance at p, and perhaps even cs and it's child (the child in dir_pc, as that's the
-           only case that results in a double rotation when deleting).
-
-           In case of an imbalance we have to always rotate p->cs link. (L or R case)
-           If cs.bal_(dir_pc), then we have a double rotation (LR or RL) case.
-           In such cases, first rotate cs->gcs link, and then p->cs link. gcs = grandchild on cs path
-
-           Layout: In the R (or LR) case:
-
-                 p
-               /   \
-              cs    c
-             /  \
-            a   gcs
-                /  \
-               x    y
-
-           - One of x or y must exist for it to be an LR case,
-             since then cs.bal_(dir_pc) = cs.bal_r = 1
-
-           Layout: In the L (or RL) case:
-
-                 p
-               /   \
-              c     cs
-                   /  \
-                 gcs   a
-                /   \
-               x     y
-
-           - One of x or y must exist for it to be an RL case,
-             since then cs.bal_(dir_pc) = cs.bal_l = 1
-
-           (Note: if double rotation case, in the second rotation cs is actually gcs,
-            since the the first rotation swaps their positions)
-        */
-
-        Node cs_node = A[cs_ptr];
-        //dirpc = dir_pc = dpc = c_prime
-        RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, F_dr, not_c_prime;
-        RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
-        // child's sibling node's balances in dir_pc (dpc), and not_dir_pc (ndpc)
-        cs_bal_l = getLeftBal(cs_node.pointers);
-        cs_bal_r = getRightBal(cs_node.pointers);
-        cs_left = getAVLLeftPtr(cs_node.pointers);
-        cs_right = getAVLRightPtr(cs_node.pointers);
-        mpc_select(tio, yield, cs_bal_dpc, c_prime, cs_bal_l, cs_bal_r);
-        mpc_select(tio, yield, cs_bal_ndpc, c_prime, cs_bal_r, cs_bal_l);
-        mpc_select(tio, yield, cs_dpc, c_prime, cs_left, cs_right);
-        mpc_select(tio, yield, cs_ndpc, c_prime, cs_right, cs_left);
-
-        // We need to double rotate (LR or RL case) if cs_bal_dpc is 1
-        mpc_and(tio, yield, F_dr, imb, cs_bal_dpc);
-        mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);
-        Node gcs_node = A[gcs_ptr];
-
-        not_c_prime = c_prime;
-        if(player0) {
-          not_c_prime^=1;
-        }
-        // First rotation: cs->gcs link
-        rotate(tio, yield, node.pointers, cs_ptr, cs_node.pointers, gcs_ptr,
-            gcs_node.pointers, not_c_prime, c_prime, F_dr, s0);
-
-        // If F_dr, we did first rotation. Then cs and gcs need to swap before the second rotate.
-        RegXS new_cs_pointers, new_cs, new_ptr;
-        mpc_select(tio, yield, new_cs_pointers, F_dr, cs_node.pointers, gcs_node.pointers);
-        mpc_select(tio, yield, new_cs, F_dr, cs_ptr, gcs_ptr, AVL_PTR_SIZE);
-
-        // Second rotation: p->cs link
-        // Since we don't have access to gp node here we just send a null and s0
-        // for gp_pointers and dir_gpp. Instead this pointer fix is handled by F_r
-        // and ret_struct.ret_ptr.
-        rotate(tio, yield, null, ptr, node.pointers, new_cs,
-            new_cs_pointers, s0, not_c_prime, imb, s1);
-
-        /*
-        size_t rec_p_left_1, rec_p_right_1;
-        bool rec_flag_imb, rec_flag_dr;
-        rec_flag_imb = reconstruct_RegBS(tio, yield, imb);
-        rec_flag_dr = reconstruct_RegBS(tio, yield, F_dr);
-        rec_p_left_1 = reconstruct_RegXS(tio, yield, getAVLLeftPtr(node.pointers));
-        rec_p_right_1 = reconstruct_RegXS(tio, yield, getAVLRightPtr(node.pointers));
-        printf("flag_imb = %d, flag_dr = %d\n", rec_flag_imb, rec_flag_dr);
-        printf("parent_ptrs (foundter rotations): left = %lu, right = %lu\n", rec_p_left_1, rec_p_right_1);
-        */
-
-        // If imb (we do some rotation), then update F_r, and ret_ptr, to
-        // fix the gp->p link (The F_r clauses later, and this are mutually
-        // exclusive events. They will never trigger together.)
-        mpc_select(tio, yield, new_ptr, F_dr, cs_ptr, gcs_ptr);
-        mpc_select(tio, yield, F_ri, imb, s0, s1);
-        mpc_select(tio, yield, ret_struct.ret_ptr, imb, ret_struct.ret_ptr, new_ptr);
-
-        // Write back new_cs_pointers correctly to (cs_node/gcs_node).pointers
-        // and then balance the nodes
-        mpc_select(tio, yield, cs_node.pointers, F_dr, new_cs_pointers, cs_node.pointers);
-        mpc_select(tio, yield, gcs_node.pointers, F_dr, gcs_node.pointers, new_cs_pointers);
-
-        /*
-           Update balances based on imbalance and type of rotations that happen.
-           In the case of an imbalance, updateBalance() sets bal_l and bal_r of p to 0.
-
-        */
-        RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
-        // IC1 = Single rotation (L/R). L/R = dpc
-        mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
-        // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
-        mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
-        // IC2 = Single rotation (L/R).
-        IC2 = IC1 ^ IC3;
-        if(player0) {
-          IC2^=1;
-        }
-        mpc_and(tio, yield, IC2, imb, IC2);
-
-        /*
-        bool rec_IC1, rec_IC2, rec_IC3;
-        rec_IC1 = reconstruct_RegBS(tio, yield, IC1);
-        rec_IC2 = reconstruct_RegBS(tio, yield, IC2);
-        rec_IC3 = reconstruct_RegBS(tio, yield, IC3);
-        printf("rec_IC1 = %d, rec_IC2 = %d, rec_IC3 = %d\n", rec_IC1, rec_IC2, rec_IC3);
-        */
-
-        // IC1, IC2, IC3: CS.bal = 0 0
-        mpc_select(tio, yield, cs_bal_dpc, imb, cs_bal_dpc, s0);
-        mpc_select(tio, yield, cs_bal_ndpc, imb, cs_bal_ndpc, s0);
-        mpc_select(tio, yield, cs_bal_r, c_prime, cs_bal_ndpc, cs_bal_dpc);
-        mpc_select(tio, yield, cs_bal_l, c_prime, cs_bal_dpc, cs_bal_ndpc);
-
-        // IC2: p.bal_ndpc = 1, cs.bal_dpc = 1
-        // (IC2 & not_c_prime)
-        cs_bal_dpc^=IC2;
-        RegBS p_bal_dpc, p_bal_ndpc;
-        mpc_select(tio, yield, p_bal_ndpc, c_prime, new_p_bal_r, new_p_bal_l);
-        p_bal_ndpc^=IC2;
-        RegBS IC2_ndpc_l, IC2_ndpc_r, IC2_dpc_l, IC2_dpc_r;
-        mpc_and(tio, yield, IC2_ndpc_l, IC2, c_prime);
-        mpc_and(tio, yield, IC2_ndpc_r, IC2, not_c_prime);
-        mpc_and(tio, yield, IC2_dpc_l, IC2, not_c_prime);
-        mpc_and(tio, yield, IC2_dpc_r, IC2, c_prime);
-
-        mpc_select(tio, yield, new_p_bal_l, IC2_ndpc_l, new_p_bal_l, p_bal_ndpc);
-        mpc_select(tio, yield, new_p_bal_r, IC2_ndpc_r, new_p_bal_r, p_bal_ndpc);
-        mpc_select(tio, yield, cs_bal_l, IC2_dpc_l, cs_bal_l, cs_bal_dpc);
-        mpc_select(tio, yield, cs_bal_r, IC2_dpc_r, cs_bal_r, cs_bal_dpc);
-        // In the IC2 case bal_upd = 0 (The rotation doesn't end up
-        // decreasing height of this subtree.
-        mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);
-
-        // IC3:
-        // To set balance in this case we need to know if gcs.dpc child exists
-        // and similarly if gcs.ndpc child exitst.
-        // if(gcs.ndpc child exists): cs.bal_ndpc = 1
-        // if(gcs.dpc child exists): p.bal_dpc = 1
-        RegBS gcs_dpc_exists, gcs_ndpc_exists;
-        RegXS gcs_l = getAVLLeftPtr(gcs_node.pointers);
-        RegXS gcs_r = getAVLRightPtr(gcs_node.pointers);
-        RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
-        RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
-        RegXS gcs_dpc, gcs_ndpc;
-        mpc_select(tio, yield, gcs_dpc, c_prime, gcs_l, gcs_r);
-        mpc_select(tio, yield, gcs_ndpc, not_c_prime, gcs_l, gcs_r);
-        gcs_dpc_exists = cdpf.is_zero(tio, yield, gcs_dpc, aes_ops);
-        gcs_ndpc_exists = cdpf.is_zero(tio, yield, gcs_ndpc, aes_ops);
-        cs_bal_ndpc^=IC3;
-        RegBS IC3_ndpc_l, IC3_ndpc_r, IC3_dpc_l, IC3_dpc_r;
-        mpc_and(tio, yield, IC3_ndpc_l, IC3, c_prime);
-        mpc_and(tio, yield, IC3_ndpc_r, IC3, not_c_prime);
-        mpc_and(tio, yield, IC3_dpc_l, IC3, not_c_prime);
-        mpc_and(tio, yield, IC3_dpc_r, IC3, c_prime);
-        RegBS f0, f1, f2, f3;
-        mpc_and(tio, yield, f0, IC3_dpc_l, gcs_dpc_exists);
-        mpc_and(tio, yield, f1, IC3_dpc_r, gcs_dpc_exists);
-        mpc_and(tio, yield, f2, IC3_ndpc_l, gcs_ndpc_exists);
-        mpc_and(tio, yield, f3, IC3_ndpc_r, gcs_ndpc_exists);
-
-        mpc_select(tio, yield, new_p_bal_l, f0, new_p_bal_l, IC3);
-        mpc_select(tio, yield, new_p_bal_r, f1, new_p_bal_r, IC3);
-        mpc_select(tio, yield, cs_bal_l, f2, cs_bal_l, IC3);
-        mpc_select(tio, yield, cs_bal_r, f3, cs_bal_r, IC3);
-        // In IC3 gcs.bal = 0 0
-        mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);
-        mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);
-
-        // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
-        setLeftBal(gcs_node.pointers, gcs_bal_l);
-        setRightBal(gcs_node.pointers, gcs_bal_r);
-        setLeftBal(cs_node.pointers, cs_bal_l);
-        setRightBal(cs_node.pointers, cs_bal_r);
-
-        A[cs_ptr].NODE_POINTERS = cs_node.pointers;
-        A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
-
-        // Write back updated pointers correctly accounting for rotations
-        setLeftBal(node.pointers, new_p_bal_l);
-        setRightBal(node.pointers, new_p_bal_r);
-        A[ptr].NODE_POINTERS = node.pointers;
+        fixImbalance(tio, yield, A, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd, 
+              c_prime, cs_ptr, imb, F_ri, ret_struct);
  
         // Update the return structure
         // F_dh = Delete Here flag,
         // F_sf = successor found (no more left children while trying to find successor)
         // F_rs = subflag for F_r. F_rs = flag for F_r set to 1 from handling a skip fix
         // (deleting a node with single child, or found successor cases)
+
+        /* F_rs: Flag for updating the correct child pointer of this node
+           This happens if F_r is set in ret_struct. F_r indicates if we need
+           to update a child pointer at this level by skipping the current
+           child in the direction of traversal. We do this in two cases:
+             i) F_d & (!F_2) : If we delete here, and this node does not have
+                2 children (;i.e., we are not in the finding successor case)
+            ii) F_ns: Found the successor (no more left children while
+                traversing to find successor)
+           In cases i and ii we skip the next node, and make the current node
+           point to the node after the next node.
+
+           iii) We did rotation(s) at the lower level, changing the child in
+                that position. So we update it to the correct node in that
+                position now.
+           Whether skip happens or just update happens is handled by how
+           ret_struct.ret_ptr is set.
+        */
+
         RegBS F_dh, F_sf, F_rs;
         mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);
         if(player0)

+ 7 - 0
avl.hpp

@@ -137,6 +137,13 @@ class AVL {
     std::tuple<RegBS, RegBS, RegBS, RegBS> updateBalanceIns(MPCTIO &tio, yield_t &yield,
         RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir);
 
+    void updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &right,
+          RegBS c_prime, avl_del_return ret_struct);
+
+    void fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, RegXS ptr, 
+        RegXS nodeptrs, RegBS p_bal_l, RegBS p_bal_r, RegBS &bal_upd, RegBS c_prime, 
+        RegXS cs_ptr, RegBS imb, RegBS &F_ri, avl_del_return &ret_struct);
+
     std::tuple<bool, RegBS> del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         Duoram<Node>::Flat &A, RegBS F_af, RegBS F_fs, int TTL,
         avl_del_return &ret_struct);