Browse Source

Some more fixes from the code review

sshsshy 1 year ago
parent
commit
5e8c06df4e
2 changed files with 71 additions and 59 deletions
  1. 70 58
      avl.cpp
  2. 1 1
      avl.hpp

+ 70 - 58
avl.cpp

@@ -161,7 +161,9 @@ std::tuple<bool, bool, bool, address_t> AVL::check_avl(const std::vector<Node> &
         avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
 }
 
-void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
+// Note only P0 gets the correct result of check_AVL.
+// That's fine since P0 outputs all the correctness outputs for the test suite.
+bool AVL::check_avl(MPCTIO &tio, yield_t &yield) {
     auto A = oram.flat(tio, yield);
     auto R = A.reconstruct();
 
@@ -174,9 +176,13 @@ void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
         rec_root+= peer_root;
     }
     if (tio.player() == 0) {
-      auto [ bst_ok, avl_ok, bb_ok, height ] = check_avl(R, rec_root.xshare);
-      printf("BST structure %s\nAVL structure %s\nBalance Bits %s\nTree height = %u\n",
-          bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", bb_ok? "ok" : "NOT OK", height);
+        auto [ bst_ok, avl_ok, bb_ok, height ] = check_avl(R, rec_root.xshare);
+        printf("BST structure %s\nAVL structure %s\nBalance Bits %s\nTree height = %u\n",
+            bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", bb_ok? "ok" : "NOT OK", height);
+        return (bst_ok && avl_ok && bb_ok);
+    }
+    else {
+        return 0;
     }
 }
 
@@ -932,7 +938,7 @@ bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>
     // then we found the node to return
 
     // If multiple keys in the tree match the lookup key, this returns the last match.
-    // Extracting the first match would add an extra round here, since the 
+    // Extracting the first match would add an extra round here, since the
     // F_found flag will have to be computed first, then the next two based on F_found
     // instead of eq
     run_coroutines(tio,
@@ -1502,7 +1508,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
 
         // If we didn't find the key, we can end here.
         if(!key_found) {
-          return {0, s0};
+          return {false, s0};
         }
 
         updateChildPointers(tio, yield, left, right, c_prime, ret_struct);
@@ -1561,7 +1567,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
 */
 bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
     if(num_items==0)
-        return 0;
+        return false;
 
     auto A = oram.flat(tio, yield, 0, cur_max_index+1);
     if(num_items==1) {
@@ -1578,9 +1584,9 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             empty_locations.emplace_back(root);
             A[oidx] = zero;
             num_items--;
-            return 1;
+            return true;
         } else {
-            return 0;
+            return false;
         }
     } else {
         int TTL = AVL_TTL(num_items);
@@ -1591,7 +1597,7 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         auto [success, bal_upd] = del(tio, yield, root, del_key, A, found, find_successor, TTL, ret_struct);
         //printf ("Success =  %d\n", success);
         if(!success){
-            return 0;
+            return false;
         }
         else{
             num_items--;
@@ -1607,10 +1613,10 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
 
             Node del_node, suc_node;
             nbits_t width = ceil(log2(cur_max_index+1));
-            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
-            typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
-            std::vector<coro_t> coroutines;
             #ifdef AVL_OPT_ON
+                typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
+                typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
+                std::vector<coro_t> coroutines;
                 coroutines.emplace_back(
                     [&tio, &A, &oidx_nd, &del_node](yield_t &yield) {
                       auto acont = A.context(yield);
@@ -1636,9 +1642,9 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             printf("rec_F_ss = %d, del_node.key = %lu, suc_nod.key = %lu\n",
                 rec_F_ss, rec_del_key, rec_suc_key);
             */
-            RegXS old_del_value;
-            RegAS old_del_key;
             #ifdef AVL_OPT_ON
+                RegXS old_del_value;
+                RegAS old_del_key;
                 old_del_value = del_node.value;
                 old_del_key = del_node.key;
             #endif
@@ -1685,7 +1691,7 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             empty_locations.emplace_back(empty_loc);
         }
 
-      return 1;
+      return true;
     }
 }
 
@@ -1846,7 +1852,7 @@ void avl_tests(MPCIO &mpcio,
             - 5 and 9 have no children and 0 balances
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 7, 9};
             size_t insert_array_size = 2;
             Node node;
@@ -1855,7 +1861,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
             Duoram<Node>* oram = tree.get_oram();
             RegXS root_xs = tree.get_root();
@@ -1884,22 +1890,13 @@ void avl_tests(MPCIO &mpcio,
                 success = false;
             }
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T1 : SUCCESS\n");
                 } else {
                     print_red("T1 : FAIL\n");
                 }
             }
-            /*
-            //MY_tests:
-            // OblivIndex read on the ORAM:
-            RegXS mptr;
-            mptr.set(tio.player() * 1);
-            nbits_t width_bits = 2;
-            typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, mptr, width_bits);
-            size_t rec_key = reconstruct_RegAS(tio, yield, A[oidx].NODE_KEY);
-            printf("RI: Retrieved key (3) = %ld\n", rec_key);
-            */
             A.init();
             tree.init();
         }
@@ -1924,7 +1921,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 3, 7, 9, 12};
             size_t insert_array_size = 4;
             Node node;
@@ -1932,7 +1929,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             Duoram<Node>* oram = tree.get_oram();
@@ -1980,6 +1977,7 @@ void avl_tests(MPCIO &mpcio,
             }
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T2 : SUCCESS\n");
                 } else {
@@ -2008,7 +2006,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 7, 5};
             size_t insert_array_size = 2;
             Node node;
@@ -2016,7 +2014,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             Duoram<Node>* oram = tree.get_oram();
@@ -2046,6 +2044,7 @@ void avl_tests(MPCIO &mpcio,
                 success = false;
             }
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T3 : SUCCESS\n");
                 } else{
@@ -2078,7 +2077,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 12, 7, 5, 3};
             size_t insert_array_size = 4;
             Node node;
@@ -2086,7 +2085,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             Duoram<Node>* oram = tree.get_oram();
@@ -2133,6 +2132,7 @@ void avl_tests(MPCIO &mpcio,
                 success = false;
             }
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T4 : SUCCESS\n");
                 } else {
@@ -2162,7 +2162,7 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 5, 7};
             size_t insert_array_size = 2;
             Node node;
@@ -2170,7 +2170,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             Duoram<Node>* oram = tree.get_oram();
@@ -2209,6 +2209,7 @@ void avl_tests(MPCIO &mpcio,
                 success = false;
             }
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T5 : SUCCESS\n");
                 } else {
@@ -2242,7 +2243,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 12, 7, 3, 5};
             size_t insert_array_size = 4;
             Node node;
@@ -2250,7 +2251,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             Duoram<Node>* oram = tree.get_oram();
@@ -2297,6 +2298,7 @@ void avl_tests(MPCIO &mpcio,
                 success = false;
             }
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T6 : SUCCESS\n");
                 } else {
@@ -2404,7 +2406,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 3, 12, 7, 9};
             size_t insert_array_size = 4;
             Node node;
@@ -2412,7 +2414,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             Duoram<Node>* oram = tree.get_oram();
@@ -2459,6 +2461,7 @@ void avl_tests(MPCIO &mpcio,
                 success = false;
             }
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T8 : SUCCESS\n");
                 } else {
@@ -2488,7 +2491,7 @@ void avl_tests(MPCIO &mpcio,
             - 7 has 0 balances
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 3, 7, 9};
             size_t insert_array_size = 3;
             Node node;
@@ -2496,7 +2499,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -2535,6 +2538,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T9 : SUCCESS\n");
                 } else {
@@ -2566,7 +2570,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
             size_t insert_array_size = 6;
             Node node;
@@ -2574,7 +2578,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -2633,6 +2637,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T10 : SUCCESS\n");
                 } else {
@@ -2660,7 +2665,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 7, 12, 5};
             size_t insert_array_size = 3;
             Node node;
@@ -2668,7 +2673,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -2710,6 +2715,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T11 : SUCCESS\n");
                 } else{
@@ -2742,7 +2748,7 @@ void avl_tests(MPCIO &mpcio,
             - 12 bal = 0 1
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
             size_t insert_array_size = 6;
             Node node;
@@ -2750,7 +2756,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -2808,6 +2814,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T12 : SUCCESS\n");
                 } else {
@@ -2837,7 +2844,7 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 5, 12, 7};
             size_t insert_array_size = 3;
             Node node;
@@ -2845,7 +2852,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -2891,6 +2898,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T13 : SUCCESS\n");
                 } else {
@@ -2922,7 +2930,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 12, 7, 3, 5};
             size_t insert_array_size = 4;
             Node node;
@@ -2930,7 +2938,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -2985,6 +2993,7 @@ void avl_tests(MPCIO &mpcio,
             success &=(!del_ret);
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T14 : SUCCESS\n");
                 } else {
@@ -3013,7 +3022,7 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 9, 3, 7};
             size_t insert_array_size = 3;
             Node node;
@@ -3021,7 +3030,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -3067,6 +3076,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T15 : SUCCESS\n");
                 } else {
@@ -3099,7 +3109,7 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
             size_t insert_array_size = 6;
             Node node;
@@ -3107,7 +3117,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -3165,6 +3175,7 @@ void avl_tests(MPCIO &mpcio,
             success &= del_ret;
 
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T16 : SUCCESS\n");
                 } else {
@@ -3208,7 +3219,7 @@ void avl_tests(MPCIO &mpcio,
             - balances and children are correct
         */
         {
-            bool success = 1;
+            bool success = 1, check_avl;
             int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
             size_t insert_array_size = 12;
             Node node;
@@ -3216,7 +3227,7 @@ void avl_tests(MPCIO &mpcio,
               randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
-              tree.check_avl(tio, yield);
+              check_avl = tree.check_avl(tio, yield);
             }
 
             RegAS del_key;
@@ -3307,6 +3318,7 @@ void avl_tests(MPCIO &mpcio,
             }
             success &= del_ret;
             if(player0) {
+                success &= check_avl;
                 if(success) {
                     print_green("T17 : SUCCESS\n");
                 } else {

+ 1 - 1
avl.hpp

@@ -216,7 +216,7 @@ class AVL {
 
     // Display and correctness check functions
     void pretty_print(MPCTIO &tio, yield_t &yield);
-    void check_avl(MPCTIO &tio, yield_t &yield);
+    bool check_avl(MPCTIO &tio, yield_t &yield);
     void print_oram(MPCTIO &tio, yield_t &yield);
 
     // For test functions ONLY: