Browse Source

Optimization 1: Related (Read + Write) operations updated to use OblivIndex to perform (Read + Update) with the same DPF

sshsshy 1 year ago
parent
commit
bf65e91701
2 changed files with 64 additions and 42 deletions
  1. 56 41
      avl.cpp
  2. 8 1
      avl.hpp

+ 56 - 41
avl.cpp

@@ -360,7 +360,11 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     }
 
     RegBS isReal = isDummy ^ (tio.player());
-    Node cnode = A[ptr];
+    
+    typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
+    Node cnode = A[oidx];
+    RegXS old_pointers = cnode.pointers;
+
     // Compare key
     auto [lteq, gt] = compare_keys(tio, yield, cnode.key, insert_key);
 
@@ -443,7 +447,7 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     mpc_select(tio, yield, right, F_ir, right, ins_addr); 
     setAVLLeftPtr(cnode.pointers, left);
     setAVLRightPtr(cnode.pointers, right);
-    A[ptr].NODE_POINTERS = cnode.pointers;
+    A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
 
     // s0 = shares of 0
     RegBS s0;
@@ -504,27 +508,14 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
               ret_dir_pc, ret_dir_cn);
         */
 
-        // Perform actual insertion
-        /*
-        RegXS ins_pointers = A[ret.i_node].NODE_POINTERS;
-        RegXS left_ptr = getAVLLeftPtr(ins_pointers);
-        RegXS right_ptr = getAVLRightPtr(ins_pointers);
-        mpc_select(tio, yield, right_ptr, ret.dir_i, right_ptr, insert_address, AVL_PTR_SIZE);
-        // ret.dir_i -> !(ret.dir_i)
-        if(player0) {
-            ret.dir_i^=1;
-        }
-        mpc_select(tio, yield, left_ptr, ret.dir_i, left_ptr, insert_address, AVL_PTR_SIZE);
-        // We never use ret.dir_i again, so don't bother reverting the negation above.
-        setAVLLeftPtr(ins_pointers, left_ptr);
-        setAVLRightPtr(ins_pointers, right_ptr);
-        A[ret.i_node].NODE_POINTERS = ins_pointers;
-        */
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, TTL+1);
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, TTL+1);
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, TTL+1);
 
         // Perform balance procedure
-        RegXS gp_pointers = A[ret.gp_node].NODE_POINTERS;
-        RegXS parent_pointers = A[ret.p_node].NODE_POINTERS;
-        RegXS child_pointers = A[ret.c_node].NODE_POINTERS;
+        RegXS gp_pointers = A[oidx_gp].NODE_POINTERS;
+        RegXS parent_pointers = A[oidx_p].NODE_POINTERS;
+        RegXS child_pointers = A[oidx_c].NODE_POINTERS;
         // n_node (child's next node)
         RegXS child_left = getAVLLeftPtr(child_pointers);
         RegXS child_right = getAVLRightPtr(child_pointers);
@@ -539,7 +530,15 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         if(player0) {
             ret.dir_cn^=1;
         }
-        RegXS n_pointers = A[n_node].NODE_POINTERS;
+        //RegXS n_pointers = A[n_node].NODE_POINTERS; 
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, TTL+1);
+        RegXS n_pointers = A[oidx_n].NODE_POINTERS;
+
+        RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
+        old_gp_pointers = gp_pointers;
+        old_parent_pointers = parent_pointers;
+        old_child_pointers = child_pointers;
+        old_n_pointers = n_pointers;
 
         // F_dr = (dir_pc != dir_cn) : i.e., double rotation case if
         // (parent->child) and (child->new_node) are not in the same direction
@@ -646,10 +645,10 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         setRightBal(n_pointers, n_bal_r);
 
         // Write back update pointers and balances into gp, p, c, and n
-        A[ret.c_node].NODE_POINTERS = child_pointers;
-        A[ret.p_node].NODE_POINTERS = parent_pointers;
-        A[ret.gp_node].NODE_POINTERS = gp_pointers;
-        A[n_node].NODE_POINTERS = n_pointers;
+        A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
+        A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers); 
+        A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers); 
+        A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers); 
 
         // Handle root pointer switch (if F_gp is true in the return from insert())
         // If F_gp and we did a double rotation: root <-- new node
@@ -764,14 +763,18 @@ void AVL::updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &r
     since the the first rotation swaps their positions)
 */
 
-void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, RegXS ptr, 
+void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, 
+        Duoram<Node>::OblivIndex<RegXS, 1> oidx, RegXS oidx_oldptrs, RegXS ptr, 
         RegXS nodeptrs, RegBS new_p_bal_l, RegBS new_p_bal_r, RegBS &bal_upd, 
         RegBS c_prime, RegXS cs_ptr, RegBS imb, RegBS &F_ri, 
         avl_del_return &ret_struct) {
     bool player0 = tio.player()==0;
     RegBS s0, s1;
     s1.set(tio.player()==1);
-    Node cs_node = A[cs_ptr];
+
+    typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, MAX_DEPTH);
+    Node cs_node = A[oidx_cs];
+    RegXS old_cs_ptr = cs_node.pointers;
     //dirpc = dir_pc = dpc = c_prime
     RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, F_dr, not_c_prime;
     RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
@@ -788,7 +791,9 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, RegXS
     // We need to double rotate (LR or RL case) if cs_bal_dpc is 1
     mpc_and(tio, yield, F_dr, imb, cs_bal_dpc);
     mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);
-    Node gcs_node = A[gcs_ptr];
+    typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, MAX_DEPTH);
+    Node gcs_node = A[oidx_gcs];
+    RegXS old_gcs_ptr = gcs_node.pointers;
 
     not_c_prime = c_prime;
     if(player0) {
@@ -927,13 +932,16 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, RegXS
     setLeftBal(cs_node.pointers, cs_bal_l);
     setRightBal(cs_node.pointers, cs_bal_r);
 
-    A[cs_ptr].NODE_POINTERS = cs_node.pointers;
-    A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
+    //A[cs_ptr].NODE_POINTERS = cs_node.pointers;
+    //A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
+    A[oidx_cs].NODE_POINTERS+= (cs_node.pointers - old_cs_ptr);
+    A[oidx_gcs].NODE_POINTERS+= (gcs_node.pointers - old_gcs_ptr);
 
     // Write back updated pointers correctly accounting for rotations
     setLeftBal(nodeptrs, new_p_bal_l);
     setRightBal(nodeptrs, new_p_bal_r);
-    A[ptr].NODE_POINTERS = nodeptrs;
+    //A[ptr].NODE_POINTERS = nodeptrs;
+    A[oidx].NODE_POINTERS +=(nodeptrs - oidx_oldptrs);
 }
 
 /* Update the return structure
@@ -1024,7 +1032,9 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         RegBS zero;
         return {success, zero};
     } else {
-        Node node = A[ptr];
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
+        Node node = A[oidx];
+        RegXS oldptrs = node.pointers;
         // Compare key
 
         CDPF cdpf = tio.cdpf(yield);
@@ -1156,7 +1166,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
 
         // F_ri: subflag for F_r. F_ri = returned flag set to 1 from imbalance fix.
         RegBS F_ri;
-        fixImbalance(tio, yield, A, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd, 
+        fixImbalance(tio, yield, A, oidx, oldptrs, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd, 
               c_prime, cs_ptr, imb, F_ri, ret_struct);
 
         updateRetStruct(tio, yield, ptr, F_2, F_c2, F_c4, lf, F_ri, found, bal_upd, ret_struct); 
@@ -1208,8 +1218,10 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 rec_del_node.key.ashare, rec_suc_node.key.ashare);
             printf("flag_s = %d\n", ret_struct.F_ss.bshare);
             */
-            Node del_node = A[ret_struct.N_d];
-            Node suc_node = A[ret_struct.N_s];
+            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, MAX_DEPTH);
+            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, MAX_DEPTH);
+            Node del_node = A[oidx_nd];
+            Node suc_node = A[oidx_ns];
             RegAS zero_as; RegXS zero_xs;
             // Update root if needed
             mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
@@ -1222,13 +1234,15 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 rec_F_ss, rec_del_key, rec_suc_key);
             */            
 
+            RegXS old_del_value = del_node.value;
+            RegAS old_del_key = del_node.key;
             mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);
             mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);
-            A[ret_struct.N_d].NODE_KEY = del_node.key;
-            A[ret_struct.N_d].NODE_VALUE = del_node.value;
-            A[ret_struct.N_s].NODE_KEY = zero_as;
-            A[ret_struct.N_s].NODE_VALUE = zero_xs;
-
+            A[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
+            A[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
+            A[oidx_ns].NODE_KEY+=(-suc_node.key);
+            A[oidx_ns].NODE_VALUE+=(suc_node.value);
+          
             RegXS empty_loc;
             mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);
             //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
@@ -2764,5 +2778,6 @@ void avl_tests(MPCIO &mpcio,
             A.init();
             tree.init();
         }
+
     });
 }

+ 8 - 1
avl.hpp

@@ -123,6 +123,7 @@ class AVL {
 
     size_t num_items = 0;
     size_t MAX_SIZE;
+    int MAX_DEPTH;
 
     std::vector<RegXS> empty_locations;
 
@@ -140,7 +141,8 @@ class AVL {
     void updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &right,
           RegBS c_prime, avl_del_return ret_struct);
 
-    void fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, RegXS ptr, 
+    void fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A, 
+        Duoram<Node>::OblivIndex<RegXS,1> oidx, RegXS oidx_oldptrs, RegXS ptr, 
         RegXS nodeptrs, RegBS p_bal_l, RegBS p_bal_r, RegBS &bal_upd, RegBS c_prime, 
         RegXS cs_ptr, RegBS imb, RegBS &F_ri, avl_del_return &ret_struct);
 
@@ -160,6 +162,11 @@ class AVL {
   public:
     AVL(int num_players, size_t size) : oram(num_players, size) {
         this->MAX_SIZE = size;
+        MAX_DEPTH = 0;
+        while(size>0) {
+          MAX_DEPTH+=1;
+          size=size>>1;
+        }
     };
 
     void init(){