Browse Source

Parallelizing BST functions and compressing rounds.

sshsshy 10 months ago
parent
commit
18fa9b8c04
1 changed files with 170 additions and 93 deletions
  1. 170 93
      bst.cpp

+ 170 - 93
bst.cpp

@@ -244,9 +244,12 @@ std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
     RegBS ret_direction;
     // If we insert here (F_i), return the ptr to this node as wptr
     // and update direction to the direction taken by compare_keys
-    mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
-    //ret_direction = direction + F_i (direction - gt)
-    mpc_and(tio, yield, ret_direction, F_i, direction^gt);
+    run_coroutines(tio, [&tio, &ret_ptr, F_i, wptr, ptr](yield_t &yield)
+        { mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);},
+        [&tio, &ret_direction, F_i, direction, gt](yield_t &yield)
+        //ret_direction = direction + F_i (direction - gt)
+        { mpc_and(tio, yield, ret_direction, F_i, direction^gt);});
+
     ret_direction^=direction;  
 
     return {ret_ptr, ret_direction};
@@ -295,16 +298,21 @@ void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Fl
         RegXS left_ptr = extractLeftPtr(pointers);
         RegXS right_ptr = extractRightPtr(pointers);
         RegXS new_right_ptr, new_left_ptr;
-      
-        mpc_select(tio, yield, new_right_ptr, direction, right_ptr, insert_address);
+    
+        RegBS not_direction = direction; 
         if (player0) {
-            direction^=1;
+            not_direction^=1;
         }
-        mpc_select(tio, yield, new_left_ptr, direction, left_ptr, insert_address);
+    
+        run_coroutines(tio, 
+            [&tio, &new_right_ptr, direction, right_ptr, insert_address](yield_t &yield) 
+            { mpc_select(tio, yield, new_right_ptr, direction, right_ptr, insert_address);},
+            [&tio, &new_left_ptr, not_direction, left_ptr, insert_address](yield_t &yield)
+            { mpc_select(tio, yield, new_left_ptr, not_direction, left_ptr, insert_address);});
+
         setLeftPtr(pointers, new_left_ptr);
         setRightPtr(pointers, new_right_ptr);
         A[wptr].NODE_POINTERS = pointers;
-        //printf("num_items == %ld!\n", num_items);
     } 
 }
 
@@ -350,16 +358,31 @@ RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node
     RegXS right = extractRightPtr(cnode.pointers);
     
     RegXS next_ptr;
-    mpc_select(tio, yield, next_ptr, gt, left, right, 32);
 
     RegBS F_found;
     // If we haven't found the key yet, and the lookup matches the current node key,
     // then we found the node to return
     RegBS isNotDummy = isDummy ^ (!tio.player());
-    mpc_and(tio, yield, F_found, isNotDummy, eq);
-    mpc_select(tio, yield, ret_node->key, eq, ret_node->key, cnode.key); 
-    mpc_select(tio, yield, ret_node->value, eq, ret_node->value, cnode.value); 
-    mpc_or(tio, yield, isDummy, isDummy, eq);
+
+    // Note: This logic returns the last matched key and value. 
+    // Returning the first one incurs an additional round.
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(
+        [&tio, &next_ptr, gt, left, right](yield_t &yield)
+        { mpc_select(tio, yield, next_ptr, gt, left, right, 32);});
+    coroutines.emplace_back(
+        [&tio, &F_found, isNotDummy, eq](yield_t &yield)
+        { mpc_and(tio, yield, F_found, isNotDummy, eq);});
+    coroutines.emplace_back(
+        [&tio, &ret_node, eq, cnode](yield_t &yield)
+        { mpc_select(tio, yield, ret_node->key, eq, ret_node->key, cnode.key);});
+    coroutines.emplace_back(
+        [&tio, &ret_node, eq, cnode](yield_t &yield)
+        { mpc_select(tio, yield, ret_node->value, eq, ret_node->value, cnode.value);});
+    coroutines.emplace_back(
+        [&tio, &isDummy, eq](yield_t &yield)
+        { mpc_or(tio, yield, isDummy, isDummy, eq);});
+    run_coroutines(tio, coroutines);
 
     #ifdef BST_DEBUG 
         size_t ckey = mpc_reconstruct(tio, yield, cnode.key, 64);
@@ -412,12 +435,28 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
           ret_struct.F_r^=1;
         return success;
     } else {
-        Node node = A[ptr];
-        // Compare key
+        // s1: shares of 1 bit, s0: shares of 0 bit
+        RegBS s1, s0;
+        s1.set(tio.player()==1);
 
+        Node node = A[ptr];
+        RegXS left = extractLeftPtr(node.pointers);
+        RegXS right = extractRightPtr(node.pointers);
+        
         CDPF cdpf = tio.cdpf(yield);
-        auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
-      
+        size_t &aes_ops = tio.aes_ops();
+        RegBS l0, r0, lt, eq, gt;
+        // Check if left and right children are 0, and compute F_0, F_1, F_2
+        run_coroutines(tio,
+            [&tio, &l0, left, &aes_ops, &cdpf](yield_t &yield)
+            { l0 = cdpf.is_zero(tio, yield, left, aes_ops);},
+            [&tio, &r0, right, &aes_ops, &cdpf](yield_t &yield)
+            { r0 = cdpf.is_zero(tio, yield, right, aes_ops);},
+            [&tio, &lt, &eq, &gt, del_key, node, &cdpf](yield_t &yield)
+            { auto [a, b, c] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
+              lt = a; eq = b; gt = c;});
+        // Compare Key
+
         /*
         // Reconstruct and Debug Block 0
         bool lt_rec, eq_rec, gt_rec;
@@ -436,26 +475,10 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         // lf = local found. We found the key to delete in this level.
         RegBS lf = eq;
 
-        // Depending on [lteq, gt] select the next ptr/index as
-        // upper 32 bits of cnode.pointers if lteq
-        // lower 32 bits of cnode.pointers if gt 
-        RegXS left = extractLeftPtr(node.pointers);
-        RegXS right = extractRightPtr(node.pointers);
-        
-        CDPF dpf = tio.cdpf(yield);
-        size_t &aes_ops = tio.aes_ops();
-        // Check if left and right children are 0, and compute F_0, F_1, F_2
-        RegBS l0 = dpf.is_zero(tio, yield, left, aes_ops);
-        RegBS r0 = dpf.is_zero(tio, yield, right, aes_ops);
+
         RegBS F_0, F_1, F_2;
-        // F_0 = l0 & r0
-        mpc_and(tio, yield, F_0, l0, r0);
         // F_1 = l0 \xor r0
         F_1 = l0 ^ r0;
-        // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
-        F_2 = F_0 ^ F_1;
-        if(player0)
-            F_2^=1;
 
         // We set next ptr based on c, but we need to handle three 
         // edge cases where we do not go by just the comparison result
@@ -464,10 +487,19 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         // Case 1: found the node here (lf): we traverse down the lone child path.
         // or we are finding successor (fs) and there is no left child. 
         RegBS F_c1, F_c2, F_c3, F_c4;
+
         // Case 1: lf & F_1
-        mpc_and(tio, yield, F_c1, lf, F_1);
-        // Set c_prime for Case 1
-        mpc_select(tio, yield, c_prime, F_c1, c, l0);
+        run_coroutines(tio,
+            [&tio, &F_c1, lf, F_1](yield_t &yield)
+            { mpc_and(tio, yield, F_c1, lf, F_1);},
+            [&tio, &F_0, l0, r0] (yield_t &yield)
+            // F_0 = l0 & r0
+            { mpc_and(tio, yield, F_0, l0, r0);});
+        
+        // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
+        F_2 = F_0 ^ F_1;
+        if(player0)
+            F_2^=1;
 
         /*
         // Reconstruct and Debug Block 1
@@ -479,15 +511,16 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         printf("F_0 = %d, F_1 = %d, F_2 = %d, c_prime = %d\n", F_0_rec, F_1_rec, F_2_rec, c_prime_rec);
         */
 
-        // s1: shares of 1 bit, s0: shares of 0 bit
-        RegBS s1, s0;
-        s1.set(tio.player()==1);
-        // Case 2: found the node here (lf) and node has both children (F_2)
-        // In find successor case, so find inorder successor
-        // (Go right and then find leftmost child.)
-        mpc_and(tio, yield, F_c2, lf, F_2);
-        mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);
-
+        run_coroutines(tio,
+            [&tio, &c_prime, F_c1, c, l0](yield_t &yield)
+            // Set c_prime for Case 1
+            { mpc_select(tio, yield, c_prime, F_c1, c, l0);},
+            [&tio, &F_c2, lf, F_2](yield_t &yield)
+            // Case 2: found the node here (lf) and node has both children (F_2)
+            // In find successor case, so find inorder successor
+            // (Go right and then find leftmost child.)
+            { mpc_and(tio, yield, F_c2, lf, F_2);});
+        
         /*
         // Reconstruct and Debug Block 2
         bool F_c2_rec, s1_rec;
@@ -497,28 +530,37 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         printf("c_prime = %d, F_c2 = %d, s1 = %d\n", c_prime_rec, F_c2_rec, s1_rec);
         */
 
+        run_coroutines(tio,
+            [&tio, &c_prime, F_c2, s1](yield_t &yield)
+            { mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);},
+            [&tio, &F_c3, fs, F_2](yield_t &yield)
+            // Case 3: finding successor (fs) and node has both children (F_2)
+            // Go left. 
+            { mpc_and(tio, yield, F_c3, fs, F_2);});
+
+        run_coroutines(tio,
+            [&tio, &c_prime, F_c3, s0](yield_t &yield)
+            { mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);},
+            // Case 4: finding successor (fs) and node has no more left children (l0)
+            // This is the successor node then.
+            // Go right (since no more left)
+            [&tio, &F_c4, fs, l0] (yield_t &yield)
+            { mpc_and(tio, yield, F_c4, fs, l0);});
 
-        // Case 3: finding successor (fs) and node has both children (F_2)
-        // Go left. 
-        mpc_and(tio, yield, F_c3, fs, F_2);
-        mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);
-
-        // Case 4: finding successor (fs) and node has no more left children (l0)
-        // This is the successor node then.
-        // Go right (since no more left) 
-        mpc_and(tio, yield, F_c4, fs, l0);
         mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
 
-        // Set next_ptr
-        mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);
-        
         RegBS af_prime, fs_prime;
-        mpc_or(tio, yield, af_prime, af, lf);
-
-        // If in Case 2, set fs. We are now finding successor
-        mpc_or(tio, yield, fs_prime, fs, F_c2);
-
-        // If in Case 3. Successor found here already. Toggle fs off
+        run_coroutines(tio, 
+            [&tio, &next_ptr, c_prime, left, right](yield_t &yield)
+            // Set next_ptr
+            { mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);},
+            [&tio, &af_prime, af, lf](yield_t &yield)
+            { mpc_or(tio, yield, af_prime, af, lf);},
+            [&tio, &fs_prime, fs, F_c2](yield_t &yield)
+            // If in Case 2, set fs. We are now finding successor
+            { mpc_or(tio, yield, fs_prime, fs, F_c2);});
+
+        // If in Case 4. Successor found here already. Toggle fs off
         fs_prime=fs_prime^F_c4;
 
         bool key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
@@ -528,15 +570,22 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
           return 0;
 
         //printf("TTL = %d\n", TTL); 
-        RegBS F_rs;
+        RegBS F_rs_right, F_rs_left, not_c_prime=c_prime;
+        if(player0)
+            not_c_prime^=1; 
         // Flag here should be direction (c_prime) and F_r i.e. we need to swap return ptr in,
         // F_r needs to be returned in ret_struct
-        mpc_and(tio, yield, F_rs, c_prime, ret_struct.F_r);
-        mpc_select(tio, yield, right, F_rs, right, ret_struct.ret_ptr);
-        if(player0)
-            c_prime^=1; 
-        mpc_and(tio, yield, F_rs, c_prime, ret_struct.F_r);
-        mpc_select(tio, yield, left, F_rs, left, ret_struct.ret_ptr); 
+        run_coroutines(tio,
+            [&tio, &F_rs_right, c_prime, ret_struct](yield_t &yield)
+            { mpc_and(tio, yield, F_rs_right, c_prime, ret_struct.F_r);},
+            [&tio, &F_rs_left, not_c_prime, left, ret_struct](yield_t &yield)
+            { mpc_and(tio, yield, F_rs_left, not_c_prime, ret_struct.F_r);});
+
+        run_coroutines(tio,
+            [&tio, &right, F_rs_right, ret_struct](yield_t &yield)
+            { mpc_select(tio, yield, right, F_rs_right, right, ret_struct.ret_ptr);},
+            [&tio, &left, F_rs_left, ret_struct](yield_t &yield)
+            { mpc_select(tio, yield, left, F_rs_left, left, ret_struct.ret_ptr);});
 
         /*
         // Reconstruct and Debug Block 3
@@ -553,26 +602,37 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         A[ptr].NODE_POINTERS = new_ptr;
 
         // Update the return structure 
-        RegBS F_nd, F_ns, F_r;
-        mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);
-        if(player0)
-            af^=1;
-        mpc_and(tio, yield, F_nd, lf, af);
+        RegBS F_nd, F_ns, F_r, not_af = af, not_F_2 = F_2;
+        if(player0) {
+            not_af^=1;
+            not_F_2^=1;
+        }
         // F_ns = fs & l0 
-        // Finding successor flag & no more left child
+        // Finding successor flag & no more left child = F_c4
         F_ns = F_c4;
+    
+        run_coroutines(tio, 
+            [&tio, &ret_struct, F_c2](yield_t &yield)
+            { mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);},
+            [&tio, &F_nd, lf, not_af](yield_t &yield)
+            { mpc_and(tio, yield, F_nd, lf, not_af);});
+            
+
         // F_r = F_d.(!F_2)
-        if(player0)
-            F_2^=1;
         // If we have to delete here, and it doesn't have two children we have to
         // update child pointer in parent with the returned pointer
-        mpc_and(tio, yield, F_r, F_nd, F_2);
+        mpc_and(tio, yield, F_r, F_nd, not_F_2);
         mpc_or(tio, yield, F_r, F_r, F_ns);
         ret_struct.F_r = F_r;
 
-        mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_d, ptr);
-        mpc_select(tio, yield, ret_struct.N_s, F_ns, ret_struct.N_s, ptr);
-        mpc_select(tio, yield, ret_struct.ret_ptr, F_r, ptr, ret_struct.ret_ptr);
+        run_coroutines(tio,
+            [&tio, &ret_struct, F_nd, ptr](yield_t &yield)
+            { mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_d, ptr);},
+            [&tio, &ret_struct, F_ns, ptr](yield_t &yield)
+            { mpc_select(tio, yield, ret_struct.N_s, F_ns, ret_struct.N_s, ptr);},
+            [&tio, &ret_struct, F_r, ptr](yield_t &yield)
+            { mpc_select(tio, yield, ret_struct.ret_ptr, F_r, ptr, ret_struct.ret_ptr);});
+
         //We don't empty the key and value of the node with del_key in the ORAM 
         return 1;
     }
@@ -616,16 +676,33 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             Node del_node = A[ret_struct.N_d];
             Node suc_node = A[ret_struct.N_s];
             RegAS zero_as; RegXS zero_xs;
-            mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
-            mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);
-            mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);           
-            A[ret_struct.N_d].NODE_KEY = del_node.key;
-            A[ret_struct.N_d].NODE_VALUE = del_node.value;
-            A[ret_struct.N_s].NODE_KEY = zero_as;
-            A[ret_struct.N_s].NODE_VALUE = zero_xs;
-
-            RegXS empty_loc;
-            mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);
+            RegXS empty_loc, temp_root = root;
+         
+            run_coroutines(tio, 
+                [&tio, &temp_root, ret_struct](yield_t &yield)
+                { mpc_select(tio, yield, temp_root, ret_struct.F_r, temp_root, ret_struct.ret_ptr);},
+                [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
+                { mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);},
+                [&tio, &del_node, ret_struct, suc_node](yield_t & yield)
+                { mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);},
+                [&tio, &empty_loc, ret_struct](yield_t &yield)
+                { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
+            root = temp_root;
+
+            run_coroutines(tio,
+                [&tio, &A, ret_struct, del_node](yield_t &yield)
+                {   auto acont = A.context(yield);
+                    acont[ret_struct.N_d].NODE_KEY = del_node.key;},
+                [&tio, &A, ret_struct, del_node](yield_t &yield)
+                {   auto acont = A.context(yield);
+                    acont[ret_struct.N_d].NODE_VALUE = del_node.value;},
+                [&tio, &A, ret_struct, zero_as](yield_t &yield)
+                {   auto acont = A.context(yield);
+                    acont[ret_struct.N_s].NODE_KEY = zero_as;},
+                [&tio, &A, ret_struct, zero_xs](yield_t &yield)
+                {   auto acont = A.context(yield);
+                    acont[ret_struct.N_s].NODE_VALUE = zero_xs;});
+
             //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
             empty_locations.emplace_back(empty_loc);
         }