Browse Source

Fixed OblivIndex to use the correct width DPFS. Support randomized inserts and deletes (cur_max_index). Sanity test option for inserts/deletes

sshsshy 11 months ago
parent
commit
11c4290539
2 changed files with 109 additions and 35 deletions
  1. 103 34
      avl.cpp
  2. 6 1
      avl.hpp

+ 103 - 34
avl.cpp

@@ -408,7 +408,8 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     RegBS isReal = isDummy ^ (tio.player());
     Node cnode;
     #ifdef OPT_ON
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
+        nbits_t width = ceil(log2(cur_max_index+1));
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
         cnode = A[oidx];
     #else
         cnode = A[ptr];
@@ -424,11 +425,12 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     RegXS right = getAVLRightPtr(cnode.pointers);
     RegBS bal_l = getLeftBal(cnode.pointers);
     RegBS bal_r = getRightBal(cnode.pointers);
+   
     /*
     size_t rec_left = reconstruct_RegXS(tio, yield, left);
     size_t rec_right = reconstruct_RegXS(tio, yield, right);
     size_t rec_key = reconstruct_RegAS(tio, yield, cnode.key);
-    printf("\n\nKey = %ld\n", rec_key);
+    printf("\n\n(Before recursing) Key = %ld\n", rec_key);
     printf("rec_left = %ld, rec_right = %ld\n", rec_left, rec_right);
     */
 
@@ -459,13 +461,6 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     printf("\nrec_ptr = %ld\n", rec_ptr);
     */
 
-    // Save insertion pointer and direction
-    /*
-    mpc_select(tio, yield, ret->i_node, F_i, ret->i_node, ptr, AVL_PTR_SIZE);
-    mpc_select(tio, yield, ret->dir_i, F_i, ret->dir_i, gt);
-    */
-    
-
     // Update balance
     // If we inserted at this level (F_i), bal_upd is set to 1
     mpc_or(tio, yield, bal_upd, bal_upd, F_i);
@@ -508,12 +503,13 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     // Store new_bal_l and new_bal_r for this node
     setLeftBal(cnode.pointers, new_bal_l);
     setRightBal(cnode.pointers, new_bal_r);
-    // We have to write the node pointers anyway to resolve balance updates
+    // We have to write the node pointers anyway to handle balance updates,
+    // so we perform insertion along with it by modifying pointers appropriately.
     RegBS F_ir, F_il;
     run_coroutines(tio, [&tio, &F_ir, F_i, gt](yield_t &yield) 
         { mpc_and(tio, yield, F_ir, F_i, gt); },
         [&tio, &F_il, F_i, lteq](yield_t &yield)
-        { mpc_and(tio, yield, F_il, F_i, lteq); });
+        { mpc_and(tio, yield, F_il, F_i, lteq); }); 
 
     run_coroutines(tio, [&tio, &left, F_il, ins_addr](yield_t &yield) 
         { mpc_select(tio, yield, left, F_il, left, ins_addr);},
@@ -522,6 +518,17 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
 
     setAVLLeftPtr(cnode.pointers, left);
     setAVLRightPtr(cnode.pointers, right);
+
+    /*
+    bool rec_F_ir, rec_F_il;
+    rec_F_ir = reconstruct_RegBS(tio, yield, F_ir);
+    rec_F_il = reconstruct_RegBS(tio, yield, F_il);
+    rec_left = reconstruct_RegXS(tio, yield, left);
+    rec_right = reconstruct_RegXS(tio, yield, right);
+    printf("(After recursing) F_il = %d, left = %ld, F_ir = %d, right = %ld\n",
+        rec_F_il, rec_left, rec_F_ir, rec_right);
+    */
+
     #ifdef OPT_ON
         A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
     #else
@@ -541,14 +548,15 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
 // Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
 void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
     bool player0 = tio.player()==0;
-    auto A = oram.flat(tio, yield);
     // If there are no items in tree. Make this new item the root.
     if(num_items==0) {
+        auto A = oram.flat(tio, yield);
         Node zero;
         A[0] = zero;
         A[1] = node;
         (root).set(1*tio.player());
         num_items++;
+        cur_max_index++;
         return;
     } else {
         // Insert node into next free slot in the ORAM
@@ -557,6 +565,10 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         num_items++;
         int TTL = AVL_TTL(num_items);
         bool insertAtEmptyLocation = (numEmptyLocations() > 0);
+        if(!insertAtEmptyLocation) {
+            cur_max_index++;
+        }
+        auto A = oram.flat(tio, yield, 0, cur_max_index+1);
         if(insertAtEmptyLocation) {
             insert_address = empty_locations.back();
             empty_locations.pop_back();
@@ -590,14 +602,30 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         // Perform balance procedure
         RegXS gp_pointers, parent_pointers, child_pointers;
         #ifdef OPT_ON
-            int logn = int(ceil(AVL_TTL(num_items)));
-            printf("n = %ld, logn = %d\n", num_items, logn);
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, logn);
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, logn);
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, logn); 
+            nbits_t width = ceil(log2(cur_max_index+1));
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, width);
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, width);
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, width);
             gp_pointers = A[oidx_gp].NODE_POINTERS;
             parent_pointers = A[oidx_p].NODE_POINTERS;
             child_pointers = A[oidx_c].NODE_POINTERS;
+            /*
+            size_t rec_gp_key = reconstruct_RegAS(tio, yield, A[oidx_gp].NODE_KEY);
+            size_t rec_p_key = reconstruct_RegAS(tio, yield, A[oidx_p].NODE_KEY);
+            size_t rec_c_key = reconstruct_RegAS(tio, yield, A[oidx_c].NODE_KEY);
+            size_t rec_gp_lptr = reconstruct_RegXS(tio, yield, getAVLLeftPtr(A[oidx_gp].NODE_POINTERS));
+            size_t rec_gp_rptr = reconstruct_RegXS(tio, yield, getAVLRightPtr(A[oidx_gp].NODE_POINTERS));
+            size_t rec_p_lptr = reconstruct_RegXS(tio, yield, getAVLLeftPtr(A[oidx_p].NODE_POINTERS));
+            size_t rec_p_rptr = reconstruct_RegXS(tio, yield, getAVLRightPtr(A[oidx_p].NODE_POINTERS));
+            size_t rec_c_lptr = reconstruct_RegXS(tio, yield, getAVLLeftPtr(A[oidx_c].NODE_POINTERS));
+            size_t rec_c_rptr = reconstruct_RegXS(tio, yield, getAVLRightPtr(A[oidx_c].NODE_POINTERS));
+            printf("Reconstructed:\ngp_key = %ld, gp_left_ptr = %ld, gp_right_ptr = %ld\n", 
+                rec_gp_key, rec_gp_lptr, rec_gp_rptr);
+            printf("p_key = %ld, p_left_ptr = %ld, p_right_ptr = %ld\n", 
+                rec_p_key, rec_p_lptr, rec_p_rptr);
+            printf("c_key = %ld, c_left_ptr = %ld, c_right_ptr = %ld\n", 
+                rec_c_key, rec_c_lptr, rec_c_rptr);
+            */
         #else
             gp_pointers = A[ret.gp_node].NODE_POINTERS;
             parent_pointers = A[ret.p_node].NODE_POINTERS;
@@ -610,7 +638,7 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
 
         #ifdef OPT_ON
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, logn);
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, width);
             n_pointers = A[oidx_n].NODE_POINTERS;
         #else
             n_pointers = A[n_node].NODE_POINTERS;  
@@ -899,7 +927,8 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     RegXS old_cs_ptr;
     Node cs_node;
     #ifdef OPT_ON
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, MAX_DEPTH);
+        nbits_t width = ceil(log2(cur_max_index+1));
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, width);
         cs_node = A[oidx_cs];
         old_cs_ptr = cs_node.pointers;
     #else
@@ -932,7 +961,7 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     Node gcs_node;
     RegXS old_gcs_ptr;
     #ifdef OPT_ON
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, MAX_DEPTH);
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, width);
         gcs_node = A[oidx_gcs];
         old_gcs_ptr = gcs_node.pointers;
     #else
@@ -1175,9 +1204,10 @@ void AVL::updateRetStruct(MPCTIO &tio, yield_t &yield, RegXS ptr, RegBS F_2, Reg
     run_coroutines(tio, [&tio, &ret_struct, F_c2](yield_t &yield)
         { mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);},
         [&tio, &F_dh, lf, not_found] (yield_t &yield)
-        { mpc_and(tio, yield, F_dh, lf, not_found);},
-        [&tio, &ret_struct, F_dh, ptr] (yield_t &yield)
-        { mpc_select(tio, yield, ret_struct.N_d, F_dh, ret_struct.N_d, ptr);});
+        { mpc_and(tio, yield, F_dh, lf, not_found);});
+
+    //[&tio, &ret_struct, F_dh, ptr] (yield_t &yield)
+    mpc_select(tio, yield, ret_struct.N_d, F_dh, ret_struct.N_d, ptr);
 
     // F_sf = Successor found = F_c4 = Finding successor & no more left child
     F_sf = F_c4;
@@ -1225,7 +1255,8 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         RegXS oldptrs;
         // This OblivIndex creation is not required if !OPT_ON, but for convenience we leave it in
         // so that fixImbalance has an oidx to be supplied when in the !OPT_ON setting.
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
+        nbits_t width = ceil(log2(cur_max_index+1));
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
         #ifdef OPT_ON
             node = A[oidx];
             oldptrs = node.pointers;
@@ -1363,11 +1394,12 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
     if(num_items==0)
         return 0;
 
-    auto A = oram.flat(tio, yield);
+    auto A = oram.flat(tio, yield, 0, cur_max_index+1);
     if(num_items==1) {
         //Delete root if root's key = del_key
         Node zero;
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, root, MAX_DEPTH);
+        nbits_t width = ceil(log2(cur_max_index+1));
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, root, width);
         Node node = A[oidx];
         // Compare key
         CDPF cdpf = tio.cdpf(yield);
@@ -1394,6 +1426,7 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         }
         else{
             num_items--;
+            
             /*
             printf("In delete's swap portion\n");
             Node rec_del_node = A.reconstruct(A[ret_struct.N_d]);
@@ -1402,9 +1435,11 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 rec_del_node.key.ashare, rec_suc_node.key.ashare);
             printf("flag_s = %d\n", ret_struct.F_ss.bshare);
             */
+
             Node del_node, suc_node;
-            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, MAX_DEPTH);
-            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, MAX_DEPTH);
+            nbits_t width = ceil(log2(cur_max_index+1));
+            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
+            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
             #ifdef OPT_ON
                 del_node = A[oidx_nd];
                 suc_node = A[oidx_ns];
@@ -1492,6 +1527,7 @@ void AVL::initialize(MPCTIO &tio, yield_t &yield, size_t depth) {
 
     // Set num_items to init_size after they have been initialized;
     num_items = init_size;
+    cur_max_index = num_items;
     // Set root correctly
     root.set(tio.player() * size_t(1)<<(depth-1));
 } 
@@ -1513,11 +1549,12 @@ void avl(MPCIO &mpcio,
        The ORAM size is set to 2^depth-1 + n_insert.
     */
     size_t init_size = (size_t(1)<<(depth));
-    size_t oram_size = init_size + 1 + n_inserts; // +1 because init_size does not account for slot at 0.
+    size_t oram_size = init_size + n_inserts; 
+    // +1 because init_size does not account for slot at 0.
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
     run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes] (yield_t &yield) {
-
+        //printf("ORAM init_size = %ld, oram_size = %ld\n", init_size, oram_size);
         std::cout << "\n===== SETUP =====\n";
         AVL tree(tio.player(), oram_size);
         tree.initialize(tio, yield, depth);
@@ -1532,8 +1569,20 @@ void avl(MPCIO &mpcio,
 
         for(size_t i = 1; i<=n_inserts; i++) {
             newnode(node);
-            node.key.set((i+init_size) * tio.player());
+            size_t ikey;
+            #ifdef RANDOMIZE
+                ikey = (1+(rand()%oram_size));
+            #else
+                ikey = (i+init_size);
+            #endif
+            printf("Insert key = %ld\n", ikey);
+            node.key.set(ikey * tio.player());
             tree.insert(tio, yield, node);
+            #ifdef SANITY_TEST
+                tree.pretty_print(tio, yield);
+                tree.check_avl(tio, yield);
+            #endif
+            //tree.print_oram(tio, yield);
         }
 
         tio.sync_lamport();
@@ -1543,8 +1592,19 @@ void avl(MPCIO &mpcio,
         tio.reset_lamport();
         for(size_t i = 1; i<=n_deletes; i++) {
             RegAS del_key;
-            del_key.set((i+init_size) * tio.player());
+            size_t dkey;
+            #ifdef RANDOMIZE
+                dkey = 1 + (rand()%init_size);
+            #else
+                dkey = i + init_size;
+            #endif
+            del_key.set(dkey * tio.player());
+            printf("Deletion key = %ld\n", dkey);
             tree.del(tio, yield, del_key);
+            #ifdef SANITY_TEST
+                tree.pretty_print(tio, yield);
+                tree.check_avl(tio, yield);
+            #endif
         }
     });
 }
@@ -1589,7 +1649,6 @@ void avl_tests(MPCIO &mpcio,
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
-              tree.pretty_print(tio, yield);
             }
             Duoram<Node>* oram = tree.get_oram();
             RegXS root_xs = tree.get_root();
@@ -1624,7 +1683,16 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T1 : FAIL\n");
                 }
             }
-            tree.pretty_print(tio, yield);
+            /*
+            //MY_tests:
+            // OblivIndex read on the ORAM:
+            RegXS mptr;
+            mptr.set(tio.player() * 1);
+            nbits_t width_bits = 2;
+            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, mptr, width_bits);
+            size_t rec_key = reconstruct_RegAS(tio, yield, A[oidx].NODE_KEY);
+            printf("RI: Retrieved key (3) = %ld\n", rec_key);
+            */
             A.init();
             tree.init();
         }
@@ -1903,6 +1971,7 @@ void avl_tests(MPCIO &mpcio,
             size_t root = reconstruct_RegXS(tio, yield, root_xs);
             auto A = oram->flat(tio, yield);
             auto R = A.reconstruct();
+            
             Node root_node, n9, n5;
             size_t n9_index, n5_index;
             root_node = R[root];

+ 6 - 1
avl.hpp

@@ -21,6 +21,8 @@
 #define KWHT  "\x1B[37m"
 
 #define OPT_ON 0
+#define RANDOMIZE 0
+#define SANITY_TEST 0
 
 /*
   For AVL tree we'll treat the pointers fields as:
@@ -95,7 +97,8 @@ inline void dumpAVL(Node n) {
     RegBS left_bal, right_bal;
     left_bal = getLeftBal(n.pointers);
     right_bal = getRightBal(n.pointers);
-    printf("[%016lx %016lx %d %d %016lx]", n.key.share(), n.pointers.share(),
+    printf("[%016lx %016lx(L:%ld, R:%ld) %d %d %016lx]", n.key.share(), n.pointers.share(),
+          getAVLLeftPtr(n.pointers).xshare, getAVLRightPtr(n.pointers).xshare,
           left_bal.share(), right_bal.share(), n.value.share());
 }
 
@@ -130,6 +133,7 @@ class AVL {
     RegXS root;
 
     size_t num_items = 0;
+    size_t cur_max_index = 0;
     size_t MAX_SIZE;
     int MAX_DEPTH;
 
@@ -179,6 +183,7 @@ class AVL {
 
     void init(){
         num_items=0;
+        cur_max_index=0;
         empty_locations.clear();
     }