Browse Source

Tweaks to run unoptimized version

sshsshy 11 months ago
parent
commit
18680a2fe2
2 changed files with 28 additions and 16 deletions
  1. 18 12
      avl.cpp
  2. 10 4
      avl.hpp

+ 18 - 12
avl.cpp

@@ -590,7 +590,8 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
         // Perform balance procedure
         RegXS gp_pointers, parent_pointers, child_pointers;
         #ifdef OPT_ON
-            int logn = int(ceil(log2(num_items)));
+            int logn = int(ceil(AVL_TTL(num_items)));
+            printf("n = %ld, logn = %d\n", num_items, logn);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, logn);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, logn);
             typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, logn); 
@@ -598,9 +599,9 @@ void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
             parent_pointers = A[oidx_p].NODE_POINTERS;
             child_pointers = A[oidx_c].NODE_POINTERS;
         #else
-            RegXS gp_pointers = A[ret.gp_node].NODE_POINTERS;
-            RegXS parent_pointers = A[ret.p_node].NODE_POINTERS;
-            RegXS child_pointers = A[ret.c_node].NODE_POINTERS;
+            gp_pointers = A[ret.gp_node].NODE_POINTERS;
+            parent_pointers = A[ret.p_node].NODE_POINTERS;
+            child_pointers = A[ret.c_node].NODE_POINTERS;
         #endif
         // n_node (child's next node)
         RegXS child_left = getAVLLeftPtr(child_pointers);
@@ -1121,17 +1122,18 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
     setRightBal(gcs_node.pointers, gcs_bal_r);
     setLeftBal(cs_node.pointers, cs_bal_l);
     setRightBal(cs_node.pointers, cs_bal_r);
-
-    A[oidx_cs].NODE_POINTERS+= (cs_node.pointers - old_cs_ptr);
-    A[oidx_gcs].NODE_POINTERS+= (gcs_node.pointers - old_gcs_ptr);
-
-    // Write back updated pointers correctly accounting for rotations
     setLeftBal(nodeptrs, new_p_bal_l);
     setRightBal(nodeptrs, new_p_bal_r);
+
+    // Write back updated pointers correctly accounting for rotations
     #ifdef OPT_ON
-        A[oidx].NODE_POINTERS +=(nodeptrs - oidx_oldptrs);
+      A[oidx_cs].NODE_POINTERS+= (cs_node.pointers - old_cs_ptr);
+      A[oidx_gcs].NODE_POINTERS+= (gcs_node.pointers - old_gcs_ptr);
+      A[oidx].NODE_POINTERS +=(nodeptrs - oidx_oldptrs);
     #else
-        A[ptr].NODE_POINTERS = nodeptrs;
+      A[cs_ptr].NODE_POINTERS = cs_node.pointers;
+      A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
+      A[ptr].NODE_POINTERS = nodeptrs;
     #endif
 }
 
@@ -1221,8 +1223,10 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
     } else {
         Node node;
         RegXS oldptrs;
+        // This OblivIndex creation is not required if !OPT_ON, but for convenience we leave it in
+        // so that fixImbalance has an oidx to be supplied when in the !OPT_ON setting.
+        typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
         #ifdef OPT_ON
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, MAX_DEPTH);
             node = A[oidx];
             oldptrs = node.pointers;
         #else
@@ -1585,6 +1589,7 @@ void avl_tests(MPCIO &mpcio,
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
+              tree.pretty_print(tio, yield);
             }
             Duoram<Node>* oram = tree.get_oram();
             RegXS root_xs = tree.get_root();
@@ -1619,6 +1624,7 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T1 : FAIL\n");
                 }
             }
+            tree.pretty_print(tio, yield);
             A.init();
             tree.init();
         }

+ 10 - 4
avl.hpp

@@ -20,7 +20,7 @@
 #define KCYN  "\x1B[36m"
 #define KWHT  "\x1B[37m"
 
-#define OPT_ON 1
+#define OPT_ON 0
 
 /*
   For AVL tree we'll treat the pointers fields as:
@@ -34,9 +34,15 @@
 #define AVL_PTR_SIZE 31
 
 inline int AVL_TTL(size_t n) {
-    double logn = log2(n);
-    double TTL = 1.44 * logn;
-    return (int(ceil(TTL)));
+    if(n==0) {
+        return 0;
+    } else if (n==1) {
+        return 1;
+    } else {
+        double logn = log2(n);
+        double TTL = 1.44 * logn;
+        return (int(ceil(TTL)));
+    }
 }
 
 inline RegXS getAVLLeftPtr(RegXS pointer){