Browse Source

Non-oblivious initiailization for AVL trees. Toggle optimizations on and off with OPT_ON in avl.hpp

sshsshy 1 year ago
parent
commit
e9cde7de60
2 changed files with 152 additions and 59 deletions
  1. 147 59
      avl.cpp
  2. 5 0
      avl.hpp

+ 147 - 59
avl.cpp

@@ -406,9 +406,13 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     }
 
     RegBS isReal = isDummy ^ (tio.player());
-    
-    typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
-    Node cnode = A[oidx];
+    Node cnode;
+    #ifdef OPT_ON
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
+        cnode = A[oidx];
+    #else
+        cnode = A[ptr];
+    #endif
     RegXS old_pointers = cnode.pointers;
 
     // Compare key
@@ -518,8 +522,11 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
 
     setAVLLeftPtr(cnode.pointers, left);
     setAVLRightPtr(cnode.pointers, right);
-    A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
-
+    #ifdef OPT_ON
+        A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
+    #else
+        A[ptr].NODE_POINTERS = cnode.pointers;
+    #endif
     // s0 = shares of 0
     RegBS s0;
 
@@ -579,27 +586,40 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
               ret_dir_pc, ret_dir_cn);
         */
 
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, TTL+1);
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, TTL+1);
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, TTL+1);
-
         // Perform balance procedure
-        RegXS gp_pointers = A[oidx_gp].NODE_POINTERS;
-        RegXS parent_pointers = A[oidx_p].NODE_POINTERS;
-        RegXS child_pointers = A[oidx_c].NODE_POINTERS;
+        RegXS gp_pointers, parent_pointers, child_pointers;
+        #ifdef OPT_ON
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, TTL+1);
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, TTL+1);
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, TTL+1); 
+            gp_pointers = A[oidx_gp].NODE_POINTERS;
+            parent_pointers = A[oidx_p].NODE_POINTERS;
+            child_pointers = A[oidx_c].NODE_POINTERS;
+        #else
+            RegXS gp_pointers = A[ret.gp_node].NODE_POINTERS;
+            RegXS parent_pointers = A[ret.p_node].NODE_POINTERS;
+            RegXS child_pointers = A[ret.c_node].NODE_POINTERS;
+        #endif
         // n_node (child's next node)
         RegXS child_left = getAVLLeftPtr(child_pointers);
         RegXS child_right = getAVLRightPtr(child_pointers);
-        RegXS n_node;
+        RegXS n_node, n_pointers;
         mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, TTL+1);
-        RegXS n_pointers = A[oidx_n].NODE_POINTERS;
+
+        #ifdef OPT_ON
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, TTL+1);
+            n_pointers = A[oidx_n].NODE_POINTERS;
+        #else
+            n_pointers = A[n_node].NODE_POINTERS;  
+        #endif
 
         RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
-        old_gp_pointers = gp_pointers;
-        old_parent_pointers = parent_pointers;
-        old_child_pointers = child_pointers;
-        old_n_pointers = n_pointers;
+        #ifdef OPT_ON
+            old_gp_pointers = gp_pointers;
+            old_parent_pointers = parent_pointers;
+            old_child_pointers = child_pointers;
+            old_n_pointers = n_pointers;
+        #endif
 
         // F_dr = (dir_pc != dir_cn) : i.e., double rotation case if
         // (parent->child) and (child->new_node) are not in the same direction
@@ -728,10 +748,17 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         setRightBal(n_pointers, n_bal_r);
 
         // Write back update pointers and balances into gp, p, c, and n
-        A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
-        A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers); 
-        A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers); 
-        A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers); 
+        #ifdef OPT_ON
+            A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
+            A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers); 
+            A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers); 
+            A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers); 
+        #else
+            A[ret.c_node].NODE_POINTERS = child_pointers;
+            A[ret.p_node].NODE_POINTERS = parent_pointers;
+            A[ret.gp_node].NODE_POINTERS = gp_pointers;
+            A[n_node].NODE_POINTERS = n_pointers;
+        #endif
 
         // Handle root pointer update (if F_ur is true)
         // If F_ur and we did a double rotation: root <-- new node
@@ -866,9 +893,15 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     RegBS s0, s1;
     s1.set(tio.player()==1);
 
-    typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, MAX_DEPTH);
-    Node cs_node = A[oidx_cs];
-    RegXS old_cs_ptr = cs_node.pointers;
+    RegXS old_cs_ptr;
+    Node cs_node;
+    #ifdef OPT_ON
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, MAX_DEPTH);
+        cs_node = A[oidx_cs];
+        old_cs_ptr = cs_node.pointers;
+    #else
+        cs_node = A[cs_ptr];
+    #endif
     //dirpc = dir_pc = dpc = c_prime
     RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, F_dr, not_c_prime;
     RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
@@ -893,9 +926,15 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
         { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
 
-    typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, MAX_DEPTH);
-    Node gcs_node = A[oidx_gcs];
-    RegXS old_gcs_ptr = gcs_node.pointers;
+    Node gcs_node;
+    RegXS old_gcs_ptr;
+    #ifdef OPT_ON
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, MAX_DEPTH);
+        gcs_node = A[oidx_gcs];
+        old_gcs_ptr = gcs_node.pointers;
+    #else
+        gcs_node = A[gcs_ptr];
+    #endif
 
     not_c_prime = c_prime;
     if(player0) {
@@ -1087,7 +1126,11 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     // Write back updated pointers correctly accounting for rotations
     setLeftBal(nodeptrs, new_p_bal_l);
     setRightBal(nodeptrs, new_p_bal_r);
-    A[oidx].NODE_POINTERS +=(nodeptrs - oidx_oldptrs);
+    #ifdef OPT_ON
+        A[oidx].NODE_POINTERS +=(nodeptrs - oidx_oldptrs);
+    #else
+        A[ptr].NODE_POINTERS = nodeptrs;
+    #endif
 }
 
 /* Update the return structure
@@ -1174,11 +1217,17 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         RegBS zero;
         return {success, zero};
     } else {
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
-        Node node = A[oidx];
-        RegXS oldptrs = node.pointers;
-        // Compare key
+        Node node;
+        RegXS oldptrs;
+        #ifdef OPT_ON
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
+            node = A[oidx];
+            oldptrs = node.pointers;
+        #else
+            node = A[ptr];
+        #endif
 
+        // Compare key
         CDPF cdpf = tio.cdpf(yield);
         auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
 
@@ -1347,10 +1396,16 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 rec_del_node.key.ashare, rec_suc_node.key.ashare);
             printf("flag_s = %d\n", ret_struct.F_ss.bshare);
             */
+            Node del_node, suc_node;
             typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, MAX_DEPTH);
             typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, MAX_DEPTH);
-            Node del_node = A[oidx_nd];
-            Node suc_node = A[oidx_ns];
+            #ifdef OPT_ON
+                del_node = A[oidx_nd];
+                suc_node = A[oidx_ns];
+            #else
+                del_node = A[ret_struct.N_d];
+                suc_node = A[ret_struct.N_s];
+            #endif
             RegAS zero_as; RegXS zero_xs;
             // Update root if needed
             mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
@@ -1362,9 +1417,12 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             printf("rec_F_ss = %d, del_node.key = %lu, suc_nod.key = %lu\n",
                 rec_F_ss, rec_del_key, rec_suc_key);
             */            
-
-            RegXS old_del_value = del_node.value;
-            RegAS old_del_key = del_node.key;
+            RegXS old_del_value;
+            RegAS old_del_key;
+            #ifdef OPT_ON
+                old_del_value = del_node.value;
+                old_del_key = del_node.key;
+            #endif
             RegXS empty_loc;
 
             run_coroutines(tio, [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
@@ -1374,11 +1432,18 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 [&tio, &empty_loc, ret_struct](yield_t &yield)
                 { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
 
-            A[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
-            A[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
-            A[oidx_ns].NODE_KEY+=(-suc_node.key);
-            A[oidx_ns].NODE_VALUE+=(suc_node.value);
-          
+            #ifdef OPT_ON
+                A[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
+                A[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
+                A[oidx_ns].NODE_KEY+=(-suc_node.key);
+                A[oidx_ns].NODE_VALUE+=(suc_node.value);
+            #else
+                A[ret_struct.N_d].NODE_KEY = del_node.key;
+                A[ret_struct.N_d].NODE_VALUE = del_node.value;
+                A[ret_struct.N_s].NODE_KEY = zero_as;
+                A[ret_struct.N_s].NODE_VALUE = zero_xs;
+            #endif
+
             //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
             empty_locations.emplace_back(empty_loc);
         }
@@ -1387,14 +1452,45 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
     }
 }
 
+void AVL::initialize(MPCTIO &tio, yield_t &yield, size_t depth) {
+    size_t init_size = (size_t(1)<<depth) - 1;
+    auto A = oram.flat(tio, yield);
+
+    for(size_t i=1; i<=depth; i++) {
+        size_t start = size_t(1)<<(i-1);
+        size_t gap = size_t(1)<<i;
+        size_t current = start;
+        for(size_t j=1; j<=(size_t(1)<<(depth-i)); j++) {
+            //printf("current = %ld ", current);
+            Node node;
+            node.key.set(current * tio.player());
+            if(i!=1) {
+                //Set left and right child pointers and balance bits
+                size_t ptr_gap = start/2;
+                RegXS lptr, rptr;
+                lptr.set(tio.player() * (current-(ptr_gap)));
+                rptr.set(tio.player() * (current+(ptr_gap)));
+                setAVLLeftPtr(node.pointers, lptr);
+                setAVLRightPtr(node.pointers, rptr);
+            }
+            printf("\n");
+            A[current] = node;
+            current+=gap;
+        }
+    }
+
+    // Set num_items to init_size after they have been initialized;
+    num_items = init_size;
+    // Set root correctly
+    root.set(tio.player() * size_t(1)<<(depth-1));
+} 
 
 // Now we use the AVL class in various ways.  This function is called by
 // online.cpp.
 void avl(MPCIO &mpcio,
     const PRACOptions &opts, char **args)
 {
-    nbits_t depth=4;
-    size_t n_inserts=0, n_deletes=0;
+    size_t depth=4, n_inserts=0, n_deletes=0;
     if (*args) {
         depth = atoi(args[0]);
         n_inserts = atoi(args[1]);
@@ -1405,28 +1501,20 @@ void avl(MPCIO &mpcio,
        So we initialize (initial inserts) with 2^depth-2 items.
        The ORAM size is set to 2^depth-1 + n_insert.
     */
-    size_t init_size = (size_t(1)<<depth) - 2;
+    size_t init_size = (size_t(1)<<(depth));
     size_t oram_size = init_size + 1 + n_inserts; // +1 because init_size does not account for slot at 0.
+    printf("oram_size = %ld\n", oram_size);
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
     run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes] (yield_t &yield) {
 
         std::cout << "\n===== SETUP =====\n";
         AVL tree(tio.player(), oram_size);
+        tree.initialize(tio, yield, depth);
+        //tree.pretty_print(tio, yield);
+        tio.sync_lamport();
 
-        // Insert 2^depth-1 items
         Node node;
-        for(size_t i = 1; i<=init_size; i++) {
-            newnode(node);
-            node.key.set(i * tio.player());
-            //printf("Insert %d\n", insert_array[i]);
-            tree.insert(tio, yield, node);
-            //tree.print_oram(tio, yield);
-            //tree.pretty_print(tio, yield);
-            //tree.check_avl(tio, yield);
-        }
-
-        tio.sync_lamport();
         mpcio.dump_stats(std::cout);
         std::cout << "\n===== INSERTS =====\n";
         mpcio.reset_stats();

+ 5 - 0
avl.hpp

@@ -20,6 +20,8 @@
 #define KCYN  "\x1B[36m"
 #define KWHT  "\x1B[37m"
 
+#define OPT_ON 0
+
 /*
   For AVL tree we'll treat the pointers fields as:
   < L_ptr (31 bits), R_ptr (31 bits), bal_L (1 bit), bal_R (1 bit)>
@@ -186,6 +188,9 @@ class AVL {
     // Returns the first node that matches key
     bool lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node);
 
+    // Non-obliviously initialize an AVL tree of a particular size
+    void initialize(MPCTIO &tio, yield_t &yield, size_t depth);
+
     // Display and correctness check functions
     void pretty_print(MPCTIO &tio, yield_t &yield);
     void pretty_print(const std::vector<Node> &R, value_t node,