Browse Source

Optimizing (Parallelizing) MPC operations in the updated fixImbalanceDel

sshsshy 11 months ago
parent
commit
9d98318520
1 changed files with 111 additions and 192 deletions
  1. 111 192
      avl.cpp

+ 111 - 192
avl.cpp

@@ -329,22 +329,14 @@ std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_
         [&tio, &bal_l, F_rs, s0](yield_t &yield) 
         { mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);});
 
-    // if(bal_upd) and not imbalance bal_upd<-0
     /*
-    RegBS bu0;
-    not_imbalance = imbalance;
-    if(player0){
-        not_imbalance^=1;
-    } 
-    mpc_and(tio, yield, bu0, bal_upd, not_imbalance);
-    mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
+       if(bal_upd) and this node:
+        (i) becomes balanced: the height has decreased, so continue propogating bal_upd.
+        (ii) becomes imbalanced: fixImbalance will update bal_upd correctly.
+        (iii) updates from balanced to left/right heavy: the height of this subtree has not changed,
+              so don't propogate bal_upd.
+        We handle (iii) below.
     */
-
-    // if(bal_upd) and this node turns balanced, the height has decreased, so continue propogating bal_upd.
-    // if(bal_upd) and node turns imbalanced, fixImbalance will update bal_upd correctly.
-    // if(bal_upd) and node moves out of balanced to left/right heavy, the height of this subtree has not changed,
-    //   so don't propogate bal_upd.
-    // if(bal_upd && bal_l ^ bal_r
     RegBS LR_heavy, bu0;
     LR_heavy = bal_l ^ bal_r;
     mpc_and(tio, yield, bu0, bal_upd, LR_heavy);
@@ -975,12 +967,13 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         RegXS nodeptrs, RegBS new_p_bal_l, RegBS new_p_bal_r, RegBS &bal_upd, 
         RegBS c_prime, RegXS cs_ptr, RegBS imb, RegBS &F_ri, 
         avl_del_return &ret_struct) {
+
     bool player0 = tio.player()==0;
     RegBS s0, s1;
     s1.set(tio.player()==1);
 
-    RegXS old_cs_ptr;
-    Node cs_node;
+    RegXS old_cs_ptr, old_gcs_ptr;
+    Node cs_node, gcs_node;
     #ifdef OPT_ON
         nbits_t width = ceil(log2(cur_max_index+1));
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, width);
@@ -990,22 +983,60 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         cs_node = A[cs_ptr];
     #endif
     //dirpc = dir_pc = dpc = c_prime
-    RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, F_dr, not_c_prime;
+    RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, p_bal_ndpc, p_bal_dpc;
+    RegBS F_dr, not_c_prime;
     RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
+
+    not_c_prime = c_prime;
+    if(player0) {
+      not_c_prime^=1;
+    }
+
     // child's sibling node's balances in dir_pc (dpc), and not_dir_pc (ndpc)
     cs_bal_l = getLeftBal(cs_node.pointers);
     cs_bal_r = getRightBal(cs_node.pointers);
     cs_left = getAVLLeftPtr(cs_node.pointers);
     cs_right = getAVLRightPtr(cs_node.pointers);
 
-    run_coroutines(tio, [&tio, &cs_bal_dpc, c_prime, cs_bal_l, cs_bal_r](yield_t &yield)
-        { mpc_select(tio, yield, cs_bal_dpc, c_prime, cs_bal_l, cs_bal_r);},
-        [&tio, &cs_bal_ndpc, c_prime, cs_bal_r, cs_bal_l](yield_t &yield)
-        { mpc_select(tio, yield, cs_bal_ndpc, c_prime, cs_bal_r, cs_bal_l);},
-        [&tio, &cs_dpc, c_prime, cs_left, cs_right](yield_t &yield)
-        { mpc_select(tio, yield, cs_dpc, c_prime, cs_left, cs_right);},
-        [&tio, &cs_ndpc, c_prime, cs_right, cs_left](yield_t &yield)
-        { mpc_select(tio, yield, cs_ndpc, c_prime, cs_right, cs_left);});
+    std::vector<coro_t> coroutines;
+    RegBS gcs_balanced, gcs_bal_dpc, gcs_bal_ndpc;
+    RegBS ndpc_is_l, ndpc_is_r, dpc_is_l, dpc_is_r;
+
+    // First flags to check dpc = L/R, and similarly ndpc = L/R
+    // If it's not an imbalance all of these are zeroes, resulting in no updates
+    // to the pointers and balances in the end when we write back post imbalance
+    // fix pointers and balances.
+    coroutines.emplace_back([&tio, &ndpc_is_l, c_prime, imb] (yield_t &yield)
+        { mpc_and(tio, yield, ndpc_is_l, imb, c_prime);});
+    coroutines.emplace_back([&tio, &ndpc_is_r, imb, not_c_prime](yield_t &yield)
+        { mpc_and(tio, yield, ndpc_is_r, imb, not_c_prime);});
+    coroutines.emplace_back([&tio, &dpc_is_l, imb, not_c_prime](yield_t &yield)
+        { mpc_and(tio, yield, dpc_is_l, imb, not_c_prime);});
+    coroutines.emplace_back([&tio, &dpc_is_r, imb, c_prime](yield_t &yield)
+        { mpc_and(tio, yield, dpc_is_r, imb, c_prime);});
+    run_coroutines(tio, coroutines); 
+    coroutines.clear();
+
+    coroutines.emplace_back(
+        [&tio, &cs_bal_dpc, dpc_is_r, cs_bal_l, cs_bal_r] (yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_dpc, dpc_is_r, cs_bal_l, cs_bal_r);});
+    coroutines.emplace_back(
+        [&tio, &cs_bal_ndpc, ndpc_is_l, cs_bal_r, cs_bal_l](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_ndpc, ndpc_is_l, cs_bal_r, cs_bal_l);});
+    coroutines.emplace_back(
+        [&tio, &cs_dpc, dpc_is_r, cs_left, cs_right](yield_t &yield)
+        { mpc_select(tio, yield, cs_dpc, dpc_is_r, cs_left, cs_right);});
+    coroutines.emplace_back(
+        [&tio, &cs_ndpc, ndpc_is_l, cs_right, cs_left](yield_t &yield)
+        { mpc_select(tio, yield, cs_ndpc, ndpc_is_l, cs_right, cs_left);});
+    coroutines.emplace_back(
+        [&tio, &p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r](yield_t &yield)
+        { mpc_select(tio, yield, p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r);});
+    coroutines.emplace_back(
+        [&tio, &p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r] (yield_t &yield)
+        { mpc_select(tio, yield, p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r);});
+    run_coroutines(tio, coroutines);
+    coroutines.clear();
 
     // We need to double rotate (LR or RL case) if cs_bal_dpc is 1
     run_coroutines(tio, [&tio, &F_dr, imb, cs_bal_dpc] (yield_t &yield)
@@ -1013,8 +1044,6 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
         { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
 
-    Node gcs_node;
-    RegXS old_gcs_ptr;
     #ifdef OPT_ON
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, width);
         gcs_node = A[oidx_gcs];
@@ -1025,11 +1054,12 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
 
     RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
     RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
+    
+    run_coroutines(tio, [&tio, &gcs_bal_dpc, dpc_is_r, gcs_bal_l, gcs_bal_r](yield_t &yield)
+        { mpc_select(tio, yield, gcs_bal_dpc, dpc_is_r, gcs_bal_l, gcs_bal_r);},
+        [&tio, &gcs_bal_ndpc, ndpc_is_r, gcs_bal_l, gcs_bal_r](yield_t &yield)
+        { mpc_select(tio, yield, gcs_bal_ndpc, ndpc_is_r, gcs_bal_l, gcs_bal_r);});
 
-    not_c_prime = c_prime;
-    if(player0) {
-      not_c_prime^=1;
-    }
     // First rotation: cs->gcs link
     rotate(tio, yield, nodeptrs, cs_ptr, cs_node.pointers, gcs_ptr,
         gcs_node.pointers, not_c_prime, c_prime, F_dr, s0);
@@ -1054,7 +1084,6 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     // fix the gp->p link (There are F_r clauses later, but they are mutually
     // exclusive events. They will never trigger together.)
 
-    std::vector<coro_t> coroutines;
     coroutines.emplace_back([&tio, &F_ri, imb, s0, s1](yield_t &yield) {
         mpc_select(tio, yield, F_ri, imb, s0, s1);
     });
@@ -1094,40 +1123,6 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
             mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
         });
 
-    RegBS p_bal_ndpc, p_bal_dpc;
-    RegBS ndpc_is_l, ndpc_is_r, dpc_is_l, dpc_is_r;
-    RegBS IC2_n_cprime;
-
-    // First flags to check dpc = L/R, and similarly ndpc = L/R
-    coroutines.emplace_back([&tio, &ndpc_is_l, c_prime, imb] (yield_t &yield)
-        { mpc_and(tio, yield, ndpc_is_l, imb, c_prime);});
-    coroutines.emplace_back([&tio, &ndpc_is_r, imb, not_c_prime](yield_t &yield)
-        { mpc_and(tio, yield, ndpc_is_r, imb, not_c_prime);});
-    coroutines.emplace_back([&tio, &dpc_is_l, imb, not_c_prime](yield_t &yield)
-        { mpc_and(tio, yield, dpc_is_l, imb, not_c_prime);});
-    coroutines.emplace_back([&tio, &dpc_is_r, imb, c_prime](yield_t &yield)
-        { mpc_and(tio, yield, dpc_is_r, imb, c_prime);});
-    run_coroutines(tio, coroutines); 
-    coroutines.clear();
-
-
-    mpc_and(tio, yield, IC2_n_cprime, IC2, c_prime);
-    mpc_select(tio, yield, p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r);
-    mpc_select(tio, yield, p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r);
-
-    /*
-    run_coroutines(tio, [&tio, &IC2, imb] (yield_t &yield)
-        { mpc_and(tio, yield, IC2, imb, IC2);},
-        [&tio, &cs_bal_dpc, imb, s0](yield_t &yield) 
-        { // IC1, IC2, IC3: CS.bal = 0 0
-          mpc_select(tio, yield, cs_bal_dpc, imb, cs_bal_dpc, s0);},
-        [&tio, &cs_bal_ndpc, c_prime, imb, s0](yield_t &yield) {
-        mpc_select(tio, yield, cs_bal_ndpc, imb, cs_bal_ndpc, s0);});
-    */
-    mpc_select(tio, yield, cs_bal_ndpc, IC1, cs_bal_ndpc, s0);
-    mpc_select(tio, yield, cs_bal_dpc, IC2, cs_bal_dpc, s1);
-    mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);
-    mpc_select(tio, yield, p_bal_ndpc, IC2, p_bal_ndpc, s1);
 
     /* IC3 has 3 subcases:
           IC3_S1: gcs_bal_dpc = 0, gcs_bal_ndpc = 1
@@ -1143,147 +1138,71 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
           IC3_S3: cs_dpc <- 0
                   gcs_bal stays same
     */
-    RegBS IC3_S1, IC3_S2, IC3_S3, gcs_balanced, gcs_bal_dpc, gcs_bal_ndpc;
-    mpc_select(tio, yield, gcs_bal_dpc, c_prime, gcs_bal_l, gcs_bal_r);
-    mpc_select(tio, yield, gcs_bal_ndpc, c_prime, gcs_bal_r, gcs_bal_l);
-    mpc_and(tio, yield, IC3_S1, IC3, gcs_bal_ndpc);
-    mpc_and(tio, yield, IC3_S2, IC3, gcs_bal_dpc);
+    RegBS IC3_S1, IC3_S2, IC3_S3;
     gcs_balanced = gcs_bal_dpc ^ gcs_bal_ndpc;
     if(player0) {
         gcs_balanced^=1;
     }
-    mpc_and(tio, yield, IC3_S3, IC3, gcs_balanced);
-
-    //mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);
-    //mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);
-    RegBS imb_n_cprime;
-    mpc_and(tio, yield, imb_n_cprime, imb, c_prime); 
-
-    
-    mpc_select(tio, yield, p_bal_dpc, IC3_S1, p_bal_dpc, s1);
-    mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);
 
-    mpc_select(tio, yield, cs_bal_ndpc, IC3_S2, cs_bal_ndpc, s1);
-    mpc_select(tio, yield, gcs_bal_dpc, IC3_S2, gcs_bal_dpc, s0);
+    // Updating balance bits of p, cs, and gcs.
+    // Parallel Ops 1
+    coroutines.emplace_back([&tio, &cs_bal_ndpc, IC1, s0](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_ndpc, IC1, cs_bal_ndpc, s0);});
+    coroutines.emplace_back([&tio, &cs_bal_dpc, IC2, s1](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_dpc, IC2, cs_bal_dpc, s1);});
+    coroutines.emplace_back([&tio, &p_bal_ndpc, IC2, s1](yield_t &yield)
+        { mpc_select(tio, yield, p_bal_ndpc, IC2, p_bal_ndpc, s1);});
+    coroutines.emplace_back([&tio, &IC3_S1, IC3, gcs_bal_ndpc](yield_t &yield)
+        { mpc_and(tio, yield, IC3_S1, IC3, gcs_bal_ndpc);});
+    coroutines.emplace_back([&tio, &IC3_S2, IC3, gcs_bal_dpc](yield_t &yield)
+        { mpc_and(tio, yield, IC3_S2, IC3, gcs_bal_dpc);});
+    coroutines.emplace_back([&tio, &IC3_S3, IC3, gcs_balanced](yield_t &yield)
+        { mpc_and(tio, yield, IC3_S3, IC3, gcs_balanced);});
+    // In the IC2 case bal_upd = 0 (The rotation doesn't end up
+    // decreasing height of this subtree.
+    coroutines.emplace_back([&tio, &bal_upd, IC2, s0](yield_t &yield)
+        { mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);});
+    run_coroutines(tio, coroutines);
+    coroutines.clear();
 
-    // Updating balance bits of p and cs.
+    // Parallel Ops 2
+    coroutines.emplace_back([&tio, &cs_bal_dpc, IC3, s0](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);});
+    coroutines.emplace_back([&tio, &p_bal_dpc, IC3_S1, s1](yield_t &yield)
+        { mpc_select(tio, yield, p_bal_dpc, IC3_S1, p_bal_dpc, s1);});
+    coroutines.emplace_back([&tio, &cs_bal_ndpc, IC3_S2, s1](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_ndpc, IC3_S2, cs_bal_ndpc, s1);});
+    coroutines.emplace_back([&tio, &gcs_bal_dpc, IC3_S2, s0](yield_t &yield)
+        { mpc_select(tio, yield, gcs_bal_dpc, IC3_S2, gcs_bal_dpc, s0);});
+    run_coroutines(tio, coroutines);
+    coroutines.clear();
 
-    // Write back updated balance bits
+    // Write back updated balance bits (Parallel batch 1)
     // Updating gcs_bal_l/r
-    run_coroutines(tio, [&tio, &gcs_bal_r, dpc_is_r, gcs_bal_dpc] (yield_t &yield) 
-        { mpc_select(tio, yield, gcs_bal_r, dpc_is_r, gcs_bal_r, gcs_bal_dpc);},
-        [&tio, &gcs_bal_l, dpc_is_l, gcs_bal_dpc](yield_t &yield)
+    coroutines.emplace_back([&tio, &gcs_bal_r, dpc_is_r, gcs_bal_dpc](yield_t &yield)
+        { mpc_select(tio, yield, gcs_bal_r, dpc_is_r, gcs_bal_r, gcs_bal_dpc);});
+    coroutines.emplace_back([&tio, &gcs_bal_l, dpc_is_l, gcs_bal_dpc](yield_t &yield)
         { mpc_select(tio, yield, gcs_bal_l, dpc_is_l, gcs_bal_l, gcs_bal_dpc);});
-
     // Updating cs_bal_l/r (cs_bal_dpc effected by IC3, cs_bal_ndpc effected by IC1,2)
-    run_coroutines(tio, [&tio, &cs_bal_r, dpc_is_r, cs_bal_dpc] (yield_t &yield) 
-        { mpc_select(tio, yield, cs_bal_r, dpc_is_r, cs_bal_r, cs_bal_dpc);},
-        [&tio, &cs_bal_l, dpc_is_l, cs_bal_dpc](yield_t &yield)
+    coroutines.emplace_back([&tio, &cs_bal_r, dpc_is_r, cs_bal_dpc](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_r, dpc_is_r, cs_bal_r, cs_bal_dpc);});
+    coroutines.emplace_back([&tio, &cs_bal_l, dpc_is_l, cs_bal_dpc](yield_t &yield)
         { mpc_select(tio, yield, cs_bal_l, dpc_is_l, cs_bal_l, cs_bal_dpc);});
-
-    run_coroutines(tio, [&tio, &cs_bal_r, ndpc_is_r, cs_bal_ndpc] (yield_t &yield) 
-        { mpc_select(tio, yield, cs_bal_r, ndpc_is_r, cs_bal_r, cs_bal_ndpc);},
-        [&tio, &cs_bal_l, ndpc_is_l, cs_bal_ndpc](yield_t &yield)
-        { mpc_select(tio, yield, cs_bal_l, ndpc_is_l, cs_bal_l, cs_bal_ndpc);});
-
     // Updating new_p_bal_l/r (p_bal_ndpc effected by IC2)
-    run_coroutines(tio, [&tio, &new_p_bal_r, ndpc_is_r, p_bal_ndpc] (yield_t &yield) 
-        { mpc_select(tio, yield, new_p_bal_r, ndpc_is_r, new_p_bal_r, p_bal_ndpc);},
-        [&tio, &new_p_bal_l, ndpc_is_l, p_bal_ndpc](yield_t &yield)
+    coroutines.emplace_back([&tio, &new_p_bal_r, ndpc_is_r, p_bal_ndpc] (yield_t &yield)
+        { mpc_select(tio, yield, new_p_bal_r, ndpc_is_r, new_p_bal_r, p_bal_ndpc);});
+    coroutines.emplace_back([&tio, &new_p_bal_l, ndpc_is_l, p_bal_ndpc](yield_t &yield)
         { mpc_select(tio, yield, new_p_bal_l, ndpc_is_l, new_p_bal_l, p_bal_ndpc);});
-
-
-    // IC2: p.bal_ndpc = 1, cs.bal_dpc = 1
-    // (IC2 & not_c_prime)
-
-    /*
-    cs_bal_dpc^=IC2;
-    p_bal_ndpc^=IC2;
-  
-    coroutines.emplace_back([&tio, &new_p_bal_l, IC2_ndpc_l, p_bal_ndpc](yield_t &yield)
-        { mpc_select(tio, yield, new_p_bal_l, IC2_ndpc_l, new_p_bal_l, p_bal_ndpc);});
-    coroutines.emplace_back([&tio, &new_p_bal_r, IC2_ndpc_r, p_bal_ndpc](yield_t &yield)
-        { mpc_select(tio, yield, new_p_bal_r, IC2_ndpc_r, new_p_bal_r, p_bal_ndpc);});
-    coroutines.emplace_back([&tio, &cs_bal_l, IC2_dpc_l, cs_bal_dpc](yield_t &yield)
-        { mpc_select(tio, yield, cs_bal_l, IC2_dpc_l, cs_bal_l, cs_bal_dpc);});
-    coroutines.emplace_back([&tio, &cs_bal_r, IC2_dpc_r, cs_bal_dpc](yield_t &yield)
-        { mpc_select(tio, yield, cs_bal_r, IC2_dpc_r, cs_bal_r, cs_bal_dpc);});
-    coroutines.emplace_back([&tio, &bal_upd, IC2, s0](yield_t &yield)
-        {
-        // In the IC2 case bal_upd = 0 (The rotation doesn't end up
-        // decreasing height of this subtree.
-        mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);});
     run_coroutines(tio, coroutines);
     coroutines.clear();
-    */
-
-    // In the IC2 case bal_upd = 0 (The rotation doesn't end up
-    // decreasing height of this subtree.
-    mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);
-
-    /*
-    // IC3:
-    // To set balance in this case we need to know if gcs.dpc child exists
-    // and similarly if gcs.ndpc child exitst.
-    // if(gcs.ndpc child exists): cs.bal_ndpc = 1
-    // if(gcs.dpc child exists): p.bal_dpc = 1
-    RegBS gcs_dpc_exists, gcs_ndpc_exists;
-    RegXS gcs_l = getAVLLeftPtr(gcs_node.pointers);
-    RegXS gcs_r = getAVLRightPtr(gcs_node.pointers);
-    RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
-    RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
-    RegXS gcs_dpc, gcs_ndpc;
-
-    run_coroutines(tio, [&tio, &gcs_dpc, c_prime, gcs_l, gcs_r] (yield_t &yield)
-        { mpc_select(tio, yield, gcs_dpc, c_prime, gcs_l, gcs_r);},
-        [&tio, &gcs_ndpc, not_c_prime, gcs_l, gcs_r] (yield_t &yield)
-        { mpc_select(tio, yield, gcs_ndpc, not_c_prime, gcs_l, gcs_r);});
-
-    CDPF cdpf = tio.cdpf(yield);
-    run_coroutines(tio, [&tio, &gcs_dpc_exists, gcs_dpc, &cdpf](yield_t &yield) 
-        { gcs_dpc_exists = cdpf.is_zero(tio, yield, gcs_dpc, tio.aes_ops());},
-        [&tio, &gcs_ndpc_exists, gcs_ndpc, &cdpf](yield_t &yield)
-        { gcs_ndpc_exists = cdpf.is_zero(tio, yield, gcs_ndpc, tio.aes_ops());});
-
-    cs_bal_ndpc^=IC3;
-    RegBS IC3_ndpc_l, IC3_ndpc_r, IC3_dpc_l, IC3_dpc_r;
-    
-    run_coroutines(tio, [&tio, &IC3_ndpc_l, IC3, c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC3_ndpc_l, IC3, c_prime);},
-        [&tio, &IC3_ndpc_r, IC3, not_c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC3_ndpc_r, IC3, not_c_prime);},
-        [&tio, &IC3_dpc_l, IC3, not_c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC3_dpc_l, IC3, not_c_prime);},
-        [&tio, &IC3_dpc_r, IC3, c_prime](yield_t &yield)
-        { mpc_and(tio, yield, IC3_dpc_r, IC3, c_prime);});
-
-    RegBS f0, f1, f2, f3;
-    run_coroutines(tio, [&tio, &f0, IC3_dpc_l, gcs_dpc_exists] (yield_t &yield)
-        { mpc_and(tio, yield, f0, IC3_dpc_l, gcs_dpc_exists);},
-        [&tio, &f1, IC3_dpc_r, gcs_dpc_exists] (yield_t &yield)
-        { mpc_and(tio, yield, f1, IC3_dpc_r, gcs_dpc_exists);},
-        [&tio, &f2, IC3_ndpc_l, gcs_ndpc_exists] (yield_t &yield)
-        { mpc_and(tio, yield, f2, IC3_ndpc_l, gcs_ndpc_exists);},
-        [&tio, &f3, IC3_ndpc_r, gcs_ndpc_exists] (yield_t &yield) 
-        { mpc_and(tio, yield, f3, IC3_ndpc_r, gcs_ndpc_exists);});
-
-    
-    coroutines.emplace_back([&tio, &new_p_bal_l, f0, IC3](yield_t &yield) {
-        mpc_select(tio, yield, new_p_bal_l, f0, new_p_bal_l, IC3);});
-    coroutines.emplace_back([&tio, &new_p_bal_r, f1, IC3](yield_t &yield) {
-        mpc_select(tio, yield, new_p_bal_r, f1, new_p_bal_r, IC3);});
-    coroutines.emplace_back([&tio, &cs_bal_l, f2, IC3](yield_t &yield) {
-        mpc_select(tio, yield, cs_bal_l, f2, cs_bal_l, IC3);});
-    coroutines.emplace_back([&tio, &cs_bal_r, f3, IC3](yield_t &yield) {
-        mpc_select(tio, yield, cs_bal_r, f3, cs_bal_r, IC3);});
-    // In IC3 gcs.bal = 0 0
-    coroutines.emplace_back([&tio, &gcs_bal_l, IC3, s0](yield_t &yield) {
-        mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);});
-    coroutines.emplace_back([&tio, &gcs_bal_r, IC3, s0](yield_t &yield) {
-        mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);});
-    run_coroutines(tio, coroutines); 
-    */
-
+ 
+    // Write back updated balance bits (Parallel batch 2)
+    coroutines.emplace_back([&tio, &cs_bal_r, ndpc_is_r, cs_bal_ndpc] (yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_r, ndpc_is_r, cs_bal_r, cs_bal_ndpc);});
+    coroutines.emplace_back([&tio, &cs_bal_l, ndpc_is_l, cs_bal_ndpc](yield_t &yield)
+        { mpc_select(tio, yield, cs_bal_l, ndpc_is_l, cs_bal_l, cs_bal_ndpc);});
+    run_coroutines(tio, coroutines);
+    coroutines.clear();
 
     // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
     setLeftBal(gcs_node.pointers, gcs_bal_l);