Browse Source

Parallelizing ORAM operations for OPT_ON

sshsshy 10 months ago
parent
commit
1670482517
1 changed files with 66 additions and 10 deletions
  1. 66 10
      avl.cpp

+ 66 - 10
avl.cpp

@@ -638,14 +638,33 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
 
         // Perform balance procedure
         RegXS gp_pointers, parent_pointers, child_pointers;
+        std::vector<coro_t> coroutines;
         #ifdef OPT_ON
             nbits_t width = ceil(log2(cur_max_index+1));
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, width);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, width);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, width);
+
+            coroutines.emplace_back( 
+                [&tio, &A, &oidx_gp, &gp_pointers](yield_t &yield) { 
+                  auto acont = A.context(yield);
+                  gp_pointers = acont[oidx_gp].NODE_POINTERS;});
+            coroutines.emplace_back(
+                [&tio, &A, &oidx_p, &parent_pointers](yield_t &yield) {
+                  auto acont = A.context(yield);
+                  parent_pointers = acont[oidx_p].NODE_POINTERS;});
+            coroutines.emplace_back(
+                [&tio, &A, &oidx_c, &child_pointers](yield_t &yield) {
+                  auto acont = A.context(yield);
+                  child_pointers = acont[oidx_c].NODE_POINTERS;});
+            run_coroutines(tio, coroutines);
+            coroutines.clear();
+
+            /*
             gp_pointers = A[oidx_gp].NODE_POINTERS;
             parent_pointers = A[oidx_p].NODE_POINTERS;
             child_pointers = A[oidx_c].NODE_POINTERS;
+            */
             /*
             size_t rec_gp_key = reconstruct_RegAS(tio, yield, A[oidx_gp].NODE_KEY);
             size_t rec_p_key = reconstruct_RegAS(tio, yield, A[oidx_p].NODE_KEY);
@@ -791,7 +810,6 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
             [&tio, &c_bal_dpc, p_c_update, n_bal_dpc] (yield_t &yield)
             {mpc_select(tio, yield, c_bal_dpc, p_c_update, c_bal_dpc, n_bal_dpc);});
 
-        std::vector<coro_t> coroutines;
         coroutines.emplace_back([&tio, &p_bal_r, ret, p_bal_ndpc] (yield_t &yield)
             {mpc_select(tio, yield, p_bal_r, ret.dir_pc, p_bal_ndpc, p_bal_r);});
         coroutines.emplace_back([&tio, &p_bal_l, ret, p_bal_ndpc] (yield_t &yield)
@@ -1277,9 +1295,20 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
 
     // Write back updated pointers correctly accounting for rotations
     #ifdef OPT_ON
-      A[oidx_cs].NODE_POINTERS+= (cs_node.pointers - old_cs_ptr);
-      A[oidx_gcs].NODE_POINTERS+= (gcs_node.pointers - old_gcs_ptr);
-      A[oidx].NODE_POINTERS +=(nodeptrs - oidx_oldptrs);
+      coroutines.emplace_back(
+          [&tio, &A, &oidx_cs, &cs_node, old_cs_ptr] (yield_t &yield) {
+              auto acont = A.context(yield);
+              (acont[oidx_cs].NODE_POINTERS)+= (cs_node.pointers - old_cs_ptr);});
+      coroutines.emplace_back(
+          [&tio, &A, &oidx_gcs, &gcs_node, old_gcs_ptr] (yield_t &yield) {
+              auto acont = A.context(yield);
+              (acont[oidx_gcs].NODE_POINTERS)+= (gcs_node.pointers - old_gcs_ptr);});
+      coroutines.emplace_back(
+          [&tio, &A, &oidx, nodeptrs, oidx_oldptrs] (yield_t &yield) {
+              auto acont = A.context(yield);
+              (acont[oidx].NODE_POINTERS)+=(nodeptrs - oidx_oldptrs);});
+      run_coroutines(tio, coroutines);
+      coroutines.clear();
     #else
       A[cs_ptr].NODE_POINTERS = cs_node.pointers;
       A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
@@ -1585,9 +1614,18 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             nbits_t width = ceil(log2(cur_max_index+1));
             typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
             typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
+            std::vector<coro_t> coroutines;
             #ifdef OPT_ON
-                del_node = A[oidx_nd];
-                suc_node = A[oidx_ns];
+                coroutines.emplace_back( 
+                    [&tio, &A, &oidx_nd, &del_node](yield_t &yield) { 
+                      auto acont = A.context(yield);
+                      del_node = acont[oidx_nd];});
+                coroutines.emplace_back(
+                    [&tio, &A, &oidx_ns, &suc_node](yield_t &yield) {
+                      auto acont = A.context(yield);
+                      suc_node = acont[oidx_ns];});
+                run_coroutines(tio, coroutines);
+                coroutines.clear();
             #else
                 del_node = A[ret_struct.N_d];
                 suc_node = A[ret_struct.N_s];
@@ -1619,10 +1657,28 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
 
             #ifdef OPT_ON
-                A[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
-                A[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
-                A[oidx_ns].NODE_KEY+=(-suc_node.key);
-                A[oidx_ns].NODE_VALUE+=(suc_node.value);
+                coroutines.emplace_back(
+                    [&tio, &A, &oidx_nd, &del_node, old_del_key] (yield_t &yield) {
+                        auto acont = A.context(yield);
+                        acont[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
+                    });
+                coroutines.emplace_back(
+                    [&tio, &A, &oidx_nd, &del_node, old_del_value] (yield_t &yield) {
+                        auto acont = A.context(yield);
+                        acont[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
+                    });
+                coroutines.emplace_back(
+                    [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
+                        auto acont = A.context(yield);
+                        acont[oidx_ns].NODE_KEY+=(-suc_node.key);
+                    });
+                coroutines.emplace_back(
+                    [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
+                        auto acont = A.context(yield);
+                        acont[oidx_ns].NODE_VALUE+=(-suc_node.value);
+                    });
+                run_coroutines(tio, coroutines);
+                coroutines.clear();   
             #else
                 A[ret_struct.N_d].NODE_KEY = del_node.key;
                 A[ret_struct.N_d].NODE_VALUE = del_node.value;