Browse Source

Addressed feedback. Commit midway through converting opt_flag into a command line argument to merge main first for optional oidx.

sshsshy 1 year ago
parent
commit
8cdbada147
3 changed files with 87 additions and 96 deletions
  1. 73 60
      avl.cpp
  2. 3 1
      avl.hpp
  3. 11 35
      bst.cpp

+ 73 - 60
avl.cpp

@@ -493,14 +493,14 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
 
 
     RegBS isReal = isDummy ^ (!tio.player());
     RegBS isReal = isDummy ^ (!tio.player());
     Node cnode;
     Node cnode;
-    #ifdef AVL_OPT_ON
+    if(OPTIMIZED) {
         nbits_t width = ceil(log2(cur_max_index+1));
         nbits_t width = ceil(log2(cur_max_index+1));
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
         cnode = A[oidx];
         cnode = A[oidx];
         RegXS old_pointers = cnode.pointers;
         RegXS old_pointers = cnode.pointers;
-    #else
+    } else {
         cnode = A[ptr];
         cnode = A[ptr];
-    #endif
+    }
 
 
     // Compare key
     // Compare key
     auto [lteq, gt] = compare_keys(tio, yield, cnode.key, insert_key);
     auto [lteq, gt] = compare_keys(tio, yield, cnode.key, insert_key);
@@ -615,11 +615,11 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
         rec_F_il, rec_left, rec_F_ir, rec_right);
         rec_F_il, rec_left, rec_F_ir, rec_right);
     */
     */
 
 
-    #ifdef AVL_OPT_ON
+    if(OPTIMIZED) {
         A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
         A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
-    #else
+    } else {
         A[ptr].NODE_POINTERS = cnode.pointers;
         A[ptr].NODE_POINTERS = cnode.pointers;
-    #endif
+    }
     // s0 = shares of 0
     // s0 = shares of 0
     RegBS s0;
     RegBS s0;
 
 
@@ -691,7 +691,7 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         // Perform balance procedure
         // Perform balance procedure
         RegXS gp_pointers, parent_pointers, child_pointers;
         RegXS gp_pointers, parent_pointers, child_pointers;
         std::vector<coro_t> coroutines;
         std::vector<coro_t> coroutines;
-        #ifdef AVL_OPT_ON
+        if(OPTIMIZED) {
             nbits_t width = ceil(log2(cur_max_index+1));
             nbits_t width = ceil(log2(cur_max_index+1));
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, width);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, width);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, width);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, width);
@@ -734,31 +734,31 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
             printf("c_key = %ld, c_left_ptr = %ld, c_right_ptr = %ld\n",
             printf("c_key = %ld, c_left_ptr = %ld, c_right_ptr = %ld\n",
                 rec_c_key, rec_c_lptr, rec_c_rptr);
                 rec_c_key, rec_c_lptr, rec_c_rptr);
             */
             */
-        #else
+        } else {
             gp_pointers = A[ret.gp_node].NODE_POINTERS;
             gp_pointers = A[ret.gp_node].NODE_POINTERS;
             parent_pointers = A[ret.p_node].NODE_POINTERS;
             parent_pointers = A[ret.p_node].NODE_POINTERS;
             child_pointers = A[ret.c_node].NODE_POINTERS;
             child_pointers = A[ret.c_node].NODE_POINTERS;
-        #endif
+        }
         // n_node (child's next node)
         // n_node (child's next node)
         RegXS child_left = getAVLLeftPtr(child_pointers);
         RegXS child_left = getAVLLeftPtr(child_pointers);
         RegXS child_right = getAVLRightPtr(child_pointers);
         RegXS child_right = getAVLRightPtr(child_pointers);
         RegXS n_node, n_pointers;
         RegXS n_node, n_pointers;
         mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
         mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
 
 
-        #ifdef AVL_OPT_ON
+        if(OPTIMIZED) {
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, width);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, width);
             n_pointers = A[oidx_n].NODE_POINTERS;
             n_pointers = A[oidx_n].NODE_POINTERS;
-        #else
+        } else {
             n_pointers = A[n_node].NODE_POINTERS;
             n_pointers = A[n_node].NODE_POINTERS;
-        #endif
+        }
 
 
-        #ifdef AVL_OPT_ON
+        if(OPTIMIZED) { 
             RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
             RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
             old_gp_pointers = gp_pointers;
             old_gp_pointers = gp_pointers;
             old_parent_pointers = parent_pointers;
             old_parent_pointers = parent_pointers;
             old_child_pointers = child_pointers;
             old_child_pointers = child_pointers;
             old_n_pointers = n_pointers;
             old_n_pointers = n_pointers;
-        #endif
+        }
 
 
         // F_dr = (dir_pc != dir_cn) : i.e., double rotation case if
         // F_dr = (dir_pc != dir_cn) : i.e., double rotation case if
         // (parent->child) and (child->new_node) are not in the same direction
         // (parent->child) and (child->new_node) are not in the same direction
@@ -887,17 +887,17 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         setRightBal(n_pointers, n_bal_r);
         setRightBal(n_pointers, n_bal_r);
 
 
         // Write back update pointers and balances into gp, p, c, and n
         // Write back update pointers and balances into gp, p, c, and n
-        #ifdef AVL_OPT_ON
+        if(OPTIMIZED) {
             A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
             A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
             A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers);
             A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers);
             A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers);
             A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers);
             A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers);
             A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers);
-        #else
+        } else {
             A[ret.c_node].NODE_POINTERS = child_pointers;
             A[ret.c_node].NODE_POINTERS = child_pointers;
             A[ret.p_node].NODE_POINTERS = parent_pointers;
             A[ret.p_node].NODE_POINTERS = parent_pointers;
             A[ret.gp_node].NODE_POINTERS = gp_pointers;
             A[ret.gp_node].NODE_POINTERS = gp_pointers;
             A[n_node].NODE_POINTERS = n_pointers;
             A[n_node].NODE_POINTERS = n_pointers;
-        #endif
+        }
 
 
         // Handle root pointer update (if F_ur is true)
         // Handle root pointer update (if F_ur is true)
         // If F_ur and we did a double rotation: root <-- new node
         // If F_ur and we did a double rotation: root <-- new node
@@ -1045,15 +1045,15 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     s1.set(tio.player()==1);
     s1.set(tio.player()==1);
 
 
     Node cs_node, gcs_node;
     Node cs_node, gcs_node;
-    #ifdef AVL_OPT_ON
+    if(OPTIMIZED) {
         RegXS old_cs_ptr, old_gcs_ptr;
         RegXS old_cs_ptr, old_gcs_ptr;
         nbits_t width = ceil(log2(cur_max_index+1));
         nbits_t width = ceil(log2(cur_max_index+1));
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, width);
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, width);
         cs_node = A[oidx_cs];
         cs_node = A[oidx_cs];
             old_cs_ptr = cs_node.pointers;
             old_cs_ptr = cs_node.pointers;
-    #else
+    } else {
         cs_node = A[cs_ptr];
         cs_node = A[cs_ptr];
-    #endif
+    }
     //dirpc = dir_pc = dpc = c_prime
     //dirpc = dir_pc = dpc = c_prime
     RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, p_bal_ndpc, p_bal_dpc;
     RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, p_bal_ndpc, p_bal_dpc;
     RegBS F_dr, not_c_prime;
     RegBS F_dr, not_c_prime;
@@ -1116,13 +1116,13 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
         [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
         [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
         { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
         { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
 
 
-    #ifdef AVL_OPT_ON
+    if(OPTIMIZED) {
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, width);
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, width);
         gcs_node = A[oidx_gcs];
         gcs_node = A[oidx_gcs];
         old_gcs_ptr = gcs_node.pointers;
         old_gcs_ptr = gcs_node.pointers;
-    #else
+    } else {
         gcs_node = A[gcs_ptr];
         gcs_node = A[gcs_ptr];
-    #endif
+    }
 
 
     RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
     RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
     RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
     RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
@@ -1283,7 +1283,7 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     setRightBal(nodeptrs, new_p_bal_r);
     setRightBal(nodeptrs, new_p_bal_r);
 
 
     // Write back updated pointers correctly accounting for rotations
     // Write back updated pointers correctly accounting for rotations
-    #ifdef AVL_OPT_ON
+    if(OPTIMIZED) {
       coroutines.emplace_back(
       coroutines.emplace_back(
           [&tio, &A, &oidx_cs, &cs_node, old_cs_ptr] (yield_t &yield) {
           [&tio, &A, &oidx_cs, &cs_node, old_cs_ptr] (yield_t &yield) {
               auto acont = A.context(yield);
               auto acont = A.context(yield);
@@ -1298,11 +1298,11 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
               (acont[oidx].NODE_POINTERS)+=(nodeptrs - oidx_oldptrs);});
               (acont[oidx].NODE_POINTERS)+=(nodeptrs - oidx_oldptrs);});
       run_coroutines(tio, coroutines);
       run_coroutines(tio, coroutines);
       coroutines.clear();
       coroutines.clear();
-    #else
+    } else {
       A[cs_ptr].NODE_POINTERS = cs_node.pointers;
       A[cs_ptr].NODE_POINTERS = cs_node.pointers;
       A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
       A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
       A[ptr].NODE_POINTERS = nodeptrs;
       A[ptr].NODE_POINTERS = nodeptrs;
-    #endif
+    }
 }
 }
 
 
 /*
 /*
@@ -1376,12 +1376,12 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         // so that fixImbalance has an oidx to be supplied when in the !AVL_OPT_ON setting.
         // so that fixImbalance has an oidx to be supplied when in the !AVL_OPT_ON setting.
         nbits_t width = ceil(log2(cur_max_index+1));
         nbits_t width = ceil(log2(cur_max_index+1));
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
         typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
-        #ifdef AVL_OPT_ON
+        if(OPTIMIZED) {
             node = A[oidx];
             node = A[oidx];
             oldptrs = node.pointers;
             oldptrs = node.pointers;
-        #else
+        } else {
             node = A[ptr];
             node = A[ptr];
-        #endif
+        }
 
 
         RegXS left = getAVLLeftPtr(node.pointers);
         RegXS left = getAVLLeftPtr(node.pointers);
         RegXS right = getAVLRightPtr(node.pointers);
         RegXS right = getAVLRightPtr(node.pointers);
@@ -1622,7 +1622,7 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
 
 
             Node del_node, suc_node;
             Node del_node, suc_node;
             nbits_t width = ceil(log2(cur_max_index+1));
             nbits_t width = ceil(log2(cur_max_index+1));
-            #ifdef AVL_OPT_ON
+            if(OPTIMIZED) {
                 typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
                 typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
                 typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
                 typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
                 std::vector<coro_t> coroutines;
                 std::vector<coro_t> coroutines;
@@ -1636,10 +1636,10 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
                       suc_node = acont[oidx_ns];});
                       suc_node = acont[oidx_ns];});
                 run_coroutines(tio, coroutines);
                 run_coroutines(tio, coroutines);
                 coroutines.clear();
                 coroutines.clear();
-            #else
+            } else{
                 del_node = A[ret_struct.N_d];
                 del_node = A[ret_struct.N_d];
                 suc_node = A[ret_struct.N_s];
                 suc_node = A[ret_struct.N_s];
-            #endif
+            }
             RegAS zero_as; RegXS zero_xs;
             RegAS zero_as; RegXS zero_xs;
             // Update root if needed
             // Update root if needed
             mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
             mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
@@ -1747,22 +1747,30 @@ void AVL::initialize(MPCTIO &tio, yield_t &yield, size_t depth) {
 void avl(MPCIO &mpcio,
 void avl(MPCIO &mpcio,
     const PRACOptions &opts, char **args)
     const PRACOptions &opts, char **args)
 {
 {
-    int argc = 8;
+
+    int nargs = 0;
+    while(args[nargs]!=nullptr) {
+        ++nargs;
+    }
+
     int depth = 0;          // Initialization depth
     int depth = 0;          // Initialization depth
     size_t n_inserts = 0;   // Max ORAM_SIZE = 2^depth + n_inserts
     size_t n_inserts = 0;   // Max ORAM_SIZE = 2^depth + n_inserts
     size_t n_deletes = 0;
     size_t n_deletes = 0;
     bool run_sanity = 0;
     bool run_sanity = 0;
+    bool optimized = false;
 
 
     // Process command line arguments
     // Process command line arguments
-    for (int i = 0; i < argc; i += 2) {
+    for (int i = 0; i < nargs; i += 2) {
         std::string option = args[i];
         std::string option = args[i];
-        if (option == "-m" && i + 1 < argc) {
+        if (option == "-m" && i + 1 < nargs) {
             depth = std::atoi(args[i + 1]);
             depth = std::atoi(args[i + 1]);
-        } else if (option == "-i" && i + 1 < argc) {
+        } else if (option == "-i" && i + 1 < nargs) {
             n_inserts = std::atoi(args[i + 1]);
             n_inserts = std::atoi(args[i + 1]);
-        } else if (option == "-e" && i + 1 < argc) {
+        } else if (option == "-e" && i + 1 < nargs) {
             n_deletes = std::atoi(args[i + 1]);
             n_deletes = std::atoi(args[i + 1]);
-        } else if (option == "-s" && i + 1 < argc) {
+        } else if (option == "-opt" && i + 1 < nargs) {
+            optimized = std::atoi(args[i + 1]);
+        } else if (option == "-s" && i + 1 < nargs) {
             run_sanity = std::atoi(args[i + 1]);
             run_sanity = std::atoi(args[i + 1]);
         }
         }
     }
     }
@@ -1778,7 +1786,7 @@ void avl(MPCIO &mpcio,
     run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes, run_sanity] (yield_t &yield) {
     run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes, run_sanity] (yield_t &yield) {
         //printf("ORAM init_size = %ld, oram_size = %ld\n", init_size, oram_size);
         //printf("ORAM init_size = %ld, oram_size = %ld\n", init_size, oram_size);
         std::cout << "\n===== SETUP =====\n";
         std::cout << "\n===== SETUP =====\n";
-        AVL tree(tio.player(), oram_size);
+        AVL tree(tio.player(), oram_size, optimized);
         tree.initialize(tio, yield, depth);
         tree.initialize(tio, yield, depth);
         //tree.pretty_print(tio, yield);
         //tree.pretty_print(tio, yield);
         tio.sync_lamport();
         tio.sync_lamport();
@@ -1802,7 +1810,7 @@ void avl(MPCIO &mpcio,
             tree.insert(tio, yield, node);
             tree.insert(tio, yield, node);
             if(run_sanity) {
             if(run_sanity) {
                 tree.pretty_print(tio, yield);
                 tree.pretty_print(tio, yield);
-                tree.check_avl(tio, yield);
+                assert(tree.check_avl(tio, yield));
             }
             }
             //tree.print_oram(tio, yield);
             //tree.print_oram(tio, yield);
         }
         }
@@ -1825,12 +1833,17 @@ void avl(MPCIO &mpcio,
             tree.del(tio, yield, del_key);
             tree.del(tio, yield, del_key);
             if(run_sanity) {
             if(run_sanity) {
                 tree.pretty_print(tio, yield);
                 tree.pretty_print(tio, yield);
-                tree.check_avl(tio, yield);
+                assert(tree.check_avl(tio, yield));
             }
             }
         }
         }
     });
     });
 }
 }
 
 
+/*
+
+  AVL tests by default run the optimized AVL tree protocols.
+
+*/
 
 
 void avl_tests(MPCIO &mpcio,
 void avl_tests(MPCIO &mpcio,
     const PRACOptions &opts, char **args)
     const PRACOptions &opts, char **args)
@@ -1843,7 +1856,7 @@ void avl_tests(MPCIO &mpcio,
     run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
     run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
         size_t size = size_t(1)<<depth;
         size_t size = size_t(1)<<depth;
         bool player0 = tio.player()==0;
         bool player0 = tio.player()==0;
-        AVL tree(tio.player(), size);
+        AVL tree(tio.player(), size, true);
 
 
         // (T1) : Test 1 : L rotation (root modified)
         // (T1) : Test 1 : L rotation (root modified)
         /*
         /*
@@ -1863,7 +1876,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 7, 9};
             int insert_array[] = {5, 7, 9};
-            size_t insert_array_size = 3;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
 
 
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
@@ -1931,7 +1944,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 3, 7, 9, 12};
             int insert_array[] = {5, 3, 7, 9, 12};
-            size_t insert_array_size = 5;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int);
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2015,7 +2028,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 7, 5};
             int insert_array[] = {9, 7, 5};
-            size_t insert_array_size = 3;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2085,7 +2098,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 12, 7, 5, 3};
             int insert_array[] = {9, 12, 7, 5, 3};
-            size_t insert_array_size = 5;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2169,7 +2182,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 5, 7};
             int insert_array[] = {9, 5, 7};
-            size_t insert_array_size = 3;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2249,7 +2262,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 12, 7, 3, 5};
             int insert_array[] = {9, 12, 7, 3, 5};
-            size_t insert_array_size = 5;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2333,13 +2346,13 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 9, 7};
             int insert_array[] = {5, 9, 7};
-            size_t insert_array_size = 3;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.insert(tio, yield, node);
-             success &= tree.check_avl(tio, yield);
+              success &= tree.check_avl(tio, yield);
             }
             }
 
 
             Duoram<Node>* oram = tree.get_oram();
             Duoram<Node>* oram = tree.get_oram();
@@ -2411,7 +2424,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 3, 12, 7, 9};
             int insert_array[] = {5, 3, 12, 7, 9};
-            size_t insert_array_size = 5;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2495,7 +2508,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 3, 7, 9};
             int insert_array[] = {5, 3, 7, 9};
-            size_t insert_array_size = 4;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2573,7 +2586,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
             int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
-            size_t insert_array_size = 7;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2667,7 +2680,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 7, 12, 5};
             int insert_array[] = {9, 7, 12, 5};
-            size_t insert_array_size = 4;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2749,7 +2762,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
             int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
-            size_t insert_array_size = 7;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2844,7 +2857,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 5, 12, 7};
             int insert_array[] = {9, 5, 12, 7};
-            size_t insert_array_size = 4;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -2929,7 +2942,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 12, 7, 3, 5};
             int insert_array[] = {9, 12, 7, 3, 5};
-            size_t insert_array_size = 5;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -3020,7 +3033,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 9, 3, 7};
             int insert_array[] = {5, 9, 3, 7};
-            size_t insert_array_size = 4;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -3106,7 +3119,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
             int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
-            size_t insert_array_size = 7;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);
@@ -3215,7 +3228,7 @@ void avl_tests(MPCIO &mpcio,
         {
         {
             bool success = true;
             bool success = true;
             int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
             int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
-            size_t insert_array_size = 13;
+            size_t insert_array_size = sizeof(insert_array)/sizeof(int); 
             Node node;
             Node node;
             for(size_t i = 0; i<insert_array_size; i++) {
             for(size_t i = 0; i<insert_array_size; i++) {
               randomize_node(node);
               randomize_node(node);

+ 3 - 1
avl.hpp

@@ -144,6 +144,7 @@ class AVL {
     size_t cur_max_index = 0;
     size_t cur_max_index = 0;
     size_t MAX_SIZE;
     size_t MAX_SIZE;
     int MAX_DEPTH;
     int MAX_DEPTH;
+    bool OPTIMIZED;
 
 
     std::vector<RegXS> empty_locations;
     std::vector<RegXS> empty_locations;
 
 
@@ -186,13 +187,14 @@ class AVL {
         value_t node, value_t min_key, value_t max_key);
         value_t node, value_t min_key, value_t max_key);
 
 
   public:
   public:
-    AVL(int num_players, size_t size) : oram(num_players, size) {
+    AVL(int num_players, size_t size, bool opt_flag) : oram(num_players, size) {
         this->MAX_SIZE = size;
         this->MAX_SIZE = size;
         MAX_DEPTH = 0;
         MAX_DEPTH = 0;
         while(size>0) {
         while(size>0) {
           MAX_DEPTH+=1;
           MAX_DEPTH+=1;
           size=size>>1;
           size=size>>1;
         }
         }
+        OPTIMIZED = opt_flag; 
     };
     };
 
 
     void init(){
     void init(){

+ 11 - 35
bst.cpp

@@ -213,8 +213,11 @@ void BST::check_bst(MPCTIO &tio, yield_t &yield) {
     flat (A), and the Time-To_live TTL, and a shared flag (isDummy) which
     flat (A), and the Time-To_live TTL, and a shared flag (isDummy) which
     tracks if the operation is dummy/real.
     tracks if the operation is dummy/real.
 
 
-    Returns a tuple <ptr, dir>,  
-    
+    Returns a tuple <ptr, dir> where
+    ptr: the pointer to the node where the insertion should happen
+    dir: the bit indicating whether the new node should be inserted as the
+         the left/right child.
+
 */
 */
 std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
 std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
     RegAS insertion_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
     RegAS insertion_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
@@ -437,12 +440,12 @@ RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
 
 
     Takes as input the pointer to the current node in tree traversal (ptr),
     Takes as input the pointer to the current node in tree traversal (ptr),
     the key to be deleted (del_key), the underlying Duoram as a
     the key to be deleted (del_key), the underlying Duoram as a
-    flat (A), Flags af (already found) and fs (find successor), the
+    flat (A), Flags af (already found) and fs (find successor), thei
     Time-To_live TTL. Finally, a return structure ret_struct that tracks
     Time-To_live TTL. Finally, a return structure ret_struct that tracks
-    the location of the successor node and the node to delete, in order to 
-    perform the actual deletion after the recursive traversal. This is required
-    in the case of a deletion that requires a successor swap (,i.e., when the 
-    node to delete has both children).
+    the location of the successor node and the node to delete, in order
+    to perform the actual deletion after the recursive traversal. This
+    is required in the case of a deletion that requires a successor swap
+    (i.e., when the node to delete has both children).
 
 
     Returns success/fail bit.
     Returns success/fail bit.
 */
 */
@@ -698,7 +701,6 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         del_return ret_struct;
         del_return ret_struct;
         auto A = oram.flat(tio, yield);
         auto A = oram.flat(tio, yield);
         int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct);
         int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct);
-        //printf ("Success =  %d\n", success);
         if(!success){
         if(!success){
             return 0;
             return 0;
         }
         }
@@ -760,7 +762,7 @@ void bst(MPCIO &mpcio,
     if (*args) {
     if (*args) {
         depth = atoi(*args);
         depth = atoi(*args);
         ++args;
         ++args;
-    } 
+    }
 
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
     MPCTIO tio(mpcio, 0, opts.num_threads);
     run_coroutines(tio, [&tio, depth] (yield_t &yield) {
     run_coroutines(tio, [&tio, depth] (yield_t &yield) {
@@ -793,14 +795,6 @@ void bst(MPCIO &mpcio,
         tree.pretty_print(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
         tree.check_bst(tio, yield);
 
 
-        /*
-        printf("\n\nDelete %x\n", 8);
-        del_key.set(8 * tio.player());
-        tree.del(tio, yield, del_key);
-        tree.pretty_print(tio, yield);
-        tree.check_bst(tio, yield);
-        */
-
         printf("\n\nDelete %x\n", 7);
         printf("\n\nDelete %x\n", 7);
         del_key.set(7 * tio.player());
         del_key.set(7 * tio.player());
         tree.del(tio, yield, del_key);
         tree.del(tio, yield, del_key);
@@ -850,23 +844,5 @@ void bst(MPCIO &mpcio,
                 printf("Lookup Failed\n");
                 printf("Lookup Failed\n");
             }
             }
         }
         }
-
-        printf("\n\nLookup %x\n", 63);
-        randomize_node(node);
-        lookup_key.set(63 * tio.player());
-        found = tree.lookup(tio, yield, lookup_key, &node);
-        rec_found = mpc_reconstruct(tio, yield, found);
-        //rec_found = reconstruct_RegBS(tio, yield, found);
-        tree.pretty_print(tio, yield);
-        if(tio.player()!=2) {
-            if(rec_found) {
-                printf("Lookup Success\n");
-                size_t value = mpc_reconstruct(tio, yield, node.value);
-                printf("value = %lx\n", value);
-            } else {
-                printf("Lookup Failed\n");
-            }
-        }
-
     });
     });
 }
 }