Bläddra i källkod

Addressed all the feedback. Converted AVL_OPT_ON to a runtime flag

sshsshy 7 månader sedan
förälder
incheckning
759c96f05a
2 ändrade filer med 90 tillägg och 68 borttagningar
  1. 88 66
      avl.cpp
  2. 2 2
      avl.hpp

+ 88 - 66
avl.cpp

@@ -493,11 +493,14 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
 
     RegBS isReal = isDummy ^ (!tio.player());
     Node cnode;
+    std::optional<Duoram<Node>::OblivIndex<RegXS, 1>> oidx;
+    RegXS old_pointers;
+    nbits_t width = ceil(log2(cur_max_index+1));
+
     if(OPTIMIZED) {
-        nbits_t width = ceil(log2(cur_max_index+1));
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
-        cnode = A[oidx];
-        RegXS old_pointers = cnode.pointers;
+        oidx.emplace(tio, yield, ptr, width);
+        cnode = A[oidx.value()];
+        old_pointers = cnode.pointers;
     } else {
         cnode = A[ptr];
     }
@@ -616,7 +619,7 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     */
 
     if(OPTIMIZED) {
-        A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
+        A[oidx.value()].NODE_POINTERS+=(cnode.pointers - old_pointers);
     } else {
         A[ptr].NODE_POINTERS = cnode.pointers;
     }
@@ -691,24 +694,28 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         // Perform balance procedure
         RegXS gp_pointers, parent_pointers, child_pointers;
         std::vector<coro_t> coroutines;
+        std::optional<Duoram<Node>::template OblivIndex<RegXS, 1>> oidx_gp;
+        std::optional<Duoram<Node>::template OblivIndex<RegXS, 1>> oidx_p;
+        std::optional<Duoram<Node>::template OblivIndex<RegXS, 1>> oidx_c;
+        nbits_t width = ceil(log2(cur_max_index+1));
+
         if(OPTIMIZED) {
-            nbits_t width = ceil(log2(cur_max_index+1));
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, width);
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, width);
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, width);
+            oidx_gp.emplace(tio, yield, ret.gp_node, width);
+            oidx_p.emplace(tio, yield, ret.p_node, width);
+            oidx_c.emplace(tio, yield, ret.c_node, width);
 
             coroutines.emplace_back(
                 [&tio, &A, &oidx_gp, &gp_pointers](yield_t &yield) {
                   auto acont = A.context(yield);
-                  gp_pointers = acont[oidx_gp].NODE_POINTERS;});
+                  gp_pointers = acont[oidx_gp.value()].NODE_POINTERS;});
             coroutines.emplace_back(
                 [&tio, &A, &oidx_p, &parent_pointers](yield_t &yield) {
                   auto acont = A.context(yield);
-                  parent_pointers = acont[oidx_p].NODE_POINTERS;});
+                  parent_pointers = acont[oidx_p.value()].NODE_POINTERS;});
             coroutines.emplace_back(
                 [&tio, &A, &oidx_c, &child_pointers](yield_t &yield) {
                   auto acont = A.context(yield);
-                  child_pointers = acont[oidx_c].NODE_POINTERS;});
+                  child_pointers = acont[oidx_c.value()].NODE_POINTERS;});
             run_coroutines(tio, coroutines);
             coroutines.clear();
 
@@ -745,15 +752,16 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         RegXS n_node, n_pointers;
         mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
 
+        std::optional <Duoram<Node>::template OblivIndex<RegXS,1>> oidx_n;
         if(OPTIMIZED) {
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, width);
-            n_pointers = A[oidx_n].NODE_POINTERS;
+            oidx_n.emplace(tio, yield, n_node, width);
+            n_pointers = A[oidx_n.value()].NODE_POINTERS;
         } else {
             n_pointers = A[n_node].NODE_POINTERS;
         }
 
-        if(OPTIMIZED) { 
-            RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
+        RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
+        if(OPTIMIZED) {
             old_gp_pointers = gp_pointers;
             old_parent_pointers = parent_pointers;
             old_child_pointers = child_pointers;
@@ -888,10 +896,10 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
 
         // Write back update pointers and balances into gp, p, c, and n
         if(OPTIMIZED) {
-            A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
-            A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers);
-            A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers);
-            A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers);
+            A[oidx_n.value()].NODE_POINTERS+=(n_pointers - old_n_pointers);
+            A[oidx_c.value()].NODE_POINTERS+=(child_pointers - old_child_pointers);
+            A[oidx_p.value()].NODE_POINTERS+=(parent_pointers - old_parent_pointers);
+            A[oidx_gp.value()].NODE_POINTERS+=(gp_pointers - old_gp_pointers);
         } else {
             A[ret.c_node].NODE_POINTERS = child_pointers;
             A[ret.p_node].NODE_POINTERS = parent_pointers;
@@ -1045,11 +1053,13 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     s1.set(tio.player()==1);
 
     Node cs_node, gcs_node;
+    std::optional<Duoram<Node>::OblivIndex<RegXS,1>> oidx_cs;
+    RegXS old_cs_ptr, old_gcs_ptr;
+    nbits_t width = ceil(log2(cur_max_index+1));
+
     if(OPTIMIZED) {
-        RegXS old_cs_ptr, old_gcs_ptr;
-        nbits_t width = ceil(log2(cur_max_index+1));
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, width);
-        cs_node = A[oidx_cs];
+        oidx_cs.emplace(tio, yield, cs_ptr, width);
+        cs_node = A[oidx_cs.value()];
             old_cs_ptr = cs_node.pointers;
     } else {
         cs_node = A[cs_ptr];
@@ -1116,9 +1126,10 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
         { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
 
+    std::optional<Duoram<Node>::template OblivIndex<RegXS,1>> oidx_gcs;
     if(OPTIMIZED) {
-        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, width);
-        gcs_node = A[oidx_gcs];
+        oidx_gcs.emplace(tio, yield, gcs_ptr, width);
+        gcs_node = A[oidx_gcs.value()];
         old_gcs_ptr = gcs_node.pointers;
     } else {
         gcs_node = A[gcs_ptr];
@@ -1287,11 +1298,11 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
       coroutines.emplace_back(
           [&tio, &A, &oidx_cs, &cs_node, old_cs_ptr] (yield_t &yield) {
               auto acont = A.context(yield);
-              (acont[oidx_cs].NODE_POINTERS)+= (cs_node.pointers - old_cs_ptr);});
+              (acont[oidx_cs.value()].NODE_POINTERS)+= (cs_node.pointers - old_cs_ptr);});
       coroutines.emplace_back(
           [&tio, &A, &oidx_gcs, &gcs_node, old_gcs_ptr] (yield_t &yield) {
               auto acont = A.context(yield);
-              (acont[oidx_gcs].NODE_POINTERS)+= (gcs_node.pointers - old_gcs_ptr);});
+              (acont[oidx_gcs.value()].NODE_POINTERS)+= (gcs_node.pointers - old_gcs_ptr);});
       coroutines.emplace_back(
           [&tio, &A, &oidx, nodeptrs, oidx_oldptrs] (yield_t &yield) {
               auto acont = A.context(yield);
@@ -1372,8 +1383,9 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
     } else {
         Node node;
         RegXS oldptrs;
-        // This OblivIndex creation is not required if !AVL_OPT_ON, but for convenience we leave it in
-        // so that fixImbalance has an oidx to be supplied when in the !AVL_OPT_ON setting.
+        // This OblivIndex creation is not required if we are not running optimized version,
+        // but for convenience we leave it in, so that fixImbalance has an oidx to be supplied
+        // when we are in the non-optimized setting.
         nbits_t width = ceil(log2(cur_max_index+1));
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
         if(OPTIMIZED) {
@@ -1622,18 +1634,20 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
 
             Node del_node, suc_node;
             nbits_t width = ceil(log2(cur_max_index+1));
+            std::optional<Duoram<Node>::template OblivIndex<RegXS,2>> oidx_nd;
+            std::optional<Duoram<Node>::template OblivIndex<RegXS,2>> oidx_ns;
+            std::vector<coro_t> coroutines;
             if(OPTIMIZED) {
-                typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
-                typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
-                std::vector<coro_t> coroutines;
+                oidx_nd.emplace(tio, yield, ret_struct.N_d, width);
+                oidx_ns.emplace(tio, yield, ret_struct.N_s, width);
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_nd, &del_node](yield_t &yield) {
                       auto acont = A.context(yield);
-                      del_node = acont[oidx_nd];});
+                      del_node = acont[oidx_nd.value()];});
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_ns, &suc_node](yield_t &yield) {
                       auto acont = A.context(yield);
-                      suc_node = acont[oidx_ns];});
+                      suc_node = acont[oidx_ns.value()];});
                 run_coroutines(tio, coroutines);
                 coroutines.clear();
             } else{
@@ -1651,13 +1665,13 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             printf("rec_F_ss = %d, del_node.key = %lu, suc_nod.key = %lu\n",
                 rec_F_ss, rec_del_key, rec_suc_key);
             */
-            #ifdef AVL_OPT_ON
-                RegXS old_del_value;
-                RegAS old_del_key;
+            RegXS old_del_value;
+            RegAS old_del_key;
+            RegXS empty_loc;
+            if(OPTIMIZED) {
                 old_del_value = del_node.value;
                 old_del_key = del_node.key;
-            #endif
-            RegXS empty_loc;
+            }
 
             run_coroutines(tio, [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
                 { mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);},
@@ -1666,35 +1680,35 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                 [&tio, &empty_loc, ret_struct](yield_t &yield)
                 { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
 
-            #ifdef AVL_OPT_ON
+            if(OPTIMIZED) {
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_nd, &del_node, old_del_key] (yield_t &yield) {
                         auto acont = A.context(yield);
-                        acont[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
+                        acont[oidx_nd.value()].NODE_KEY+=(del_node.key - old_del_key);
                     });
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_nd, &del_node, old_del_value] (yield_t &yield) {
                         auto acont = A.context(yield);
-                        acont[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
+                        acont[oidx_nd.value()].NODE_VALUE+=(del_node.value - old_del_value);
                     });
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
                         auto acont = A.context(yield);
-                        acont[oidx_ns].NODE_KEY+=(-suc_node.key);
+                        acont[oidx_ns.value()].NODE_KEY+=(-suc_node.key);
                     });
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
                         auto acont = A.context(yield);
-                        acont[oidx_ns].NODE_VALUE+=(-suc_node.value);
+                        acont[oidx_ns.value()].NODE_VALUE+=(-suc_node.value);
                     });
                 run_coroutines(tio, coroutines);
                 coroutines.clear();
-            #else
+            } else {
                 A[ret_struct.N_d].NODE_KEY = del_node.key;
                 A[ret_struct.N_d].NODE_VALUE = del_node.value;
                 A[ret_struct.N_s].NODE_KEY = zero_as;
                 A[ret_struct.N_s].NODE_VALUE = zero_xs;
-            #endif
+            }
 
             //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
             empty_locations.emplace_back(empty_loc);
@@ -1783,7 +1797,7 @@ void avl(MPCIO &mpcio,
     size_t oram_size = init_size + n_inserts;
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
-    run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes, run_sanity] (yield_t &yield) {
+    run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes, run_sanity, optimized] (yield_t &yield) {
         //printf("ORAM init_size = %ld, oram_size = %ld\n", init_size, oram_size);
         std::cout << "\n===== SETUP =====\n";
         AVL tree(tio.player(), oram_size, optimized);
@@ -1810,7 +1824,11 @@ void avl(MPCIO &mpcio,
             tree.insert(tio, yield, node);
             if(run_sanity) {
                 tree.pretty_print(tio, yield);
-                assert(tree.check_avl(tio, yield));
+                if(tio.player()==0) {
+                    assert(tree.check_avl(tio, yield));
+                } else {
+                  tree.check_avl(tio, yield);
+                }
             }
             //tree.print_oram(tio, yield);
         }
@@ -1833,7 +1851,11 @@ void avl(MPCIO &mpcio,
             tree.del(tio, yield, del_key);
             if(run_sanity) {
                 tree.pretty_print(tio, yield);
-                assert(tree.check_avl(tio, yield));
+                if(tio.player()==0) {
+                    assert(tree.check_avl(tio, yield));
+                } else {
+                  tree.check_avl(tio, yield);
+                }
             }
         }
     });
@@ -1876,7 +1898,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 7, 9};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
 
             for(size_t i = 0; i<insert_array_size; i++) {
@@ -2028,7 +2050,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 7, 5};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2098,7 +2120,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 12, 7, 5, 3};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2182,7 +2204,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 5, 7};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2262,7 +2284,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 12, 7, 3, 5};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2346,7 +2368,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 9, 7};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2424,7 +2446,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 3, 12, 7, 9};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2508,7 +2530,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 3, 7, 9};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2586,7 +2608,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2680,7 +2702,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 7, 12, 5};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2762,7 +2784,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2857,7 +2879,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 5, 12, 7};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -2942,7 +2964,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 12, 7, 3, 5};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -3033,7 +3055,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 9, 3, 7};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -3119,7 +3141,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
@@ -3228,7 +3250,7 @@ void avl_tests(MPCIO &mpcio,
         {
             bool success = true;
             int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
-            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);

+ 2 - 2
avl.hpp

@@ -1,6 +1,7 @@
 #ifndef __AVL_HPP__
 #define __AVL_HPP__
 
+#include <optional>
 #include <math.h>
 #include <stdio.h>
 #include <string>
@@ -27,7 +28,6 @@
     DEBUG_BB: Debug flag for balance bit computations
 */
 
-#define AVL_OPT_ON
 // #define AVL_RANDOMIZE_INSERTS
 // #define AVL_DEBUG
 // #define AVL_DEBUG_BB
@@ -194,7 +194,7 @@ class AVL {
           MAX_DEPTH+=1;
           size=size>>1;
         }
-        OPTIMIZED = opt_flag; 
+        OPTIMIZED = opt_flag;
     };
 
     void init(){