Przeglądaj źródła

avl_tests seperation, and using init() to refresh ORAM for AVL tests

sshsshy 1 rok temu
rodzic
commit
4a05888c8a
2 zmienionych plików z 69 dodań i 51 usunięć
  1. 66 50
      avl.cpp
  2. 3 1
      online.cpp

+ 66 - 50
avl.cpp

@@ -286,7 +286,7 @@ std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_
     rec_bal_l = reconstruct_RegBS(tio, yield, bal_l);
     rec_bal_r = reconstruct_RegBS(tio, yield, bal_r);
     rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
-    printf("In updateBalanceDel, afterBalance: rec_bal_l = %d, rec_bal_r = %d, rec_bal_upd = %d\n",
+    printf("In updateBalanceDel, foundterBalance: rec_bal_l = %d, rec_bal_r = %d, rec_bal_upd = %d\n",
         rec_bal_l, rec_bal_r, rec_bal_upd);
     */
 
@@ -431,11 +431,10 @@ std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield,
     mpc_select(tio, yield, ret->dir_cn, imbalance, ret->dir_cn, prev_dir);
 
     // Store new_bal_l and new_bal_r for this node
-    // but this can be handled in the rotation component in one shot,
-    // since insertion rotations always resolve with p,c having 0,0 balance
-
     setLeftBal(cnode.pointers, new_bal_l);
     setRightBal(cnode.pointers, new_bal_r);
+    // We have to write the node pointers anyway to resolve balance updates
+    // We still defer the actual insertion to post recursion.
     A[ptr].NODE_POINTERS = cnode.pointers;
 
     // s0 = shares of 0
@@ -693,14 +692,13 @@ bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
 
 
 std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
-      Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL,
+      Duoram<Node>::Flat &A, RegBS found, RegBS find_successor, int TTL,
       avl_del_return &ret_struct) {
     bool player0 = tio.player()==0;
     if(TTL==0) {
-        //Reconstruct and return af
-        bool success = reconstruct_RegBS(tio, yield, af);
+        //Reconstruct and return found
+        bool success = reconstruct_RegBS(tio, yield, found);
         RegBS zero;
-        //printf("Reconstructed flag = %d\n", success);
         if(player0)
           ret_struct.F_r^=1;
         return {success, zero};
@@ -766,15 +764,15 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         printf("c_prime = %d, F_c2 = %d, s1 = %d\n", c_prime_rec, F_c2_rec, s1_rec);
         */
 
-        // Case 3: finding successor (fs) and node has both children (F_2)
+        // Case 3: finding successor (find_successor) and node has both children (F_2)
         // Go left.
-        mpc_and(tio, yield, F_c3, fs, F_2);
+        mpc_and(tio, yield, F_c3, find_successor, F_2);
         mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);
 
-        // Case 4: finding successor (fs) and node has no more left children (l0)
+        // Case 4: finding successor (find_successor) and node has no more left children (l0)
         // This is the successor node then.
         // Go right (since no more left)
-        mpc_and(tio, yield, F_c4, fs, l0);
+        mpc_and(tio, yield, F_c4, find_successor, l0);
         mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
 
         // Set next_ptr
@@ -782,17 +780,17 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         // cs_ptr: child's sibling pointer
         mpc_select(tio, yield, cs_ptr, c_prime, right, left, AVL_PTR_SIZE);
 
-        RegBS af_prime, fs_prime;
-        mpc_or(tio, yield, af_prime, af, lf);
+        RegBS found_prime, find_successor_prime;
+        mpc_or(tio, yield, found_prime, found, lf);
 
-        // If in Case 2, set fs. We are now finding successor
-        mpc_or(tio, yield, fs_prime, fs, F_c2);
+        // If in Case 2, set find_successor. We are now finding successor
+        mpc_or(tio, yield, find_successor_prime, find_successor, F_c2);
 
-        // If in Case 4. Successor found here already. Toggle fs off
-        fs_prime=fs_prime^F_c4;
+        // If in Case 4. Successor found here already. Toggle find_successor off
+        find_successor_prime=find_successor_prime^F_c4;
 
         TTL-=1;
-        auto [key_found, bal_upd] = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL, ret_struct);
+        auto [key_found, bal_upd] = del(tio, yield, next_ptr, del_key, A, found_prime, find_successor_prime, TTL, ret_struct);
 
         // If we didn't find the key, we can end here.
         if(!key_found) {
@@ -807,7 +805,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
             ii) F_ns: Found the successor (no more left children while
                 traversing to find successor)
            In cases i and ii we skip the next node, and make the current node
-           point to the node after the next node.
+           point to the node foundter the next node.
 
            iii) We did rotation(s) at the lower level, changing the child in
                 that position. So we update it to the correct node in that
@@ -855,7 +853,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         size_t rec_p_left_0, rec_p_right_0;
         rec_p_left_0 = reconstruct_RegXS(tio, yield, getAVLLeftPtr(node.pointers));
         rec_p_right_0 = reconstruct_RegXS(tio, yield, getAVLRightPtr(node.pointers));
-        printf("parent_ptrs (after read): left = %lu, right = %lu\n", rec_p_left_0, rec_p_right_0);
+        printf("parent_ptrs (foundter read): left = %lu, right = %lu\n", rec_p_left_0, rec_p_right_0);
         printf("F_c1 = %d, F_c2 = %d, F_c3 = %d, F_c4 = %d\n", rec_F_c1, rec_F_c2, rec_F_c3, rec_F_c4);
         printf("bal_upd = %d, new_bal_upd = %d, imb= %d\n", rec_bal_upd, rec_new_bal_upd, rec_imb);
         */
@@ -953,7 +951,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         rec_p_left_1 = reconstruct_RegXS(tio, yield, getAVLLeftPtr(node.pointers));
         rec_p_right_1 = reconstruct_RegXS(tio, yield, getAVLRightPtr(node.pointers));
         printf("flag_imb = %d, flag_dr = %d\n", rec_flag_imb, rec_flag_dr);
-        printf("parent_ptrs (after rotations): left = %lu, right = %lu\n", rec_p_left_1, rec_p_right_1);
+        printf("parent_ptrs (foundter rotations): left = %lu, right = %lu\n", rec_p_left_1, rec_p_right_1);
         */
 
         // If imb (we do some rotation), then update F_r, and ret_ptr, to
@@ -1076,8 +1074,8 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         RegBS F_dh, F_sf, F_rs;
         mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);
         if(player0)
-            af^=1;
-        mpc_and(tio, yield, F_dh, lf, af);
+            found^=1;
+        mpc_and(tio, yield, F_dh, lf, found);
         mpc_select(tio, yield, ret_struct.N_d, F_dh, ret_struct.N_d, ptr);
         // F_sf = Successor found = F_c4 = Finding successor & no more left child
         F_sf = F_c4;
@@ -1117,7 +1115,7 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
         /*
         rec_F_rs = reconstruct_RegBS(tio, yield, F_rs);
         bool rec_bal_upd_set = reconstruct_RegBS(tio, yield, bal_upd);
-        printf("after bal_upd select from rec_F_rs = %d, rec_bal_upd = %d\n",
+        printf("foundter bal_upd select from rec_F_rs = %d, rec_bal_upd = %d\n",
             rec_F_rs, rec_bal_upd_set);
         */ 
 
@@ -1141,39 +1139,39 @@ bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         return 1;
     } else {
         int TTL = AVL_TTL(num_items);
-        // Flags for already found (af) item to delete and find successor (fs)
+        // Flags for already found (found) item to delete and find successor (find_successor)
         // if this deletion requires a successor swap
-        RegBS af, fs;
+        RegBS found, find_successor;
         avl_del_return ret_struct;
-        auto [success, bal_upd] = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct);
+        auto [success, bal_upd] = del(tio, yield, root, del_key, A, found, find_successor, TTL, ret_struct);
         printf ("Success =  %d\n", success);
         if(!success){
             return 0;
         }
         else{
             num_items--;
-
-            
+            /*
             printf("In delete's swap portion\n");
             Node rec_del_node = A.reconstruct(A[ret_struct.N_d]);
             Node rec_suc_node = A.reconstruct(A[ret_struct.N_s]);
             printf("del_node key = %ld, suc_node key = %ld\n",
                 rec_del_node.key.ashare, rec_suc_node.key.ashare);
             printf("flag_s = %d\n", ret_struct.F_ss.bshare);
-            
+            */
             Node del_node = A[ret_struct.N_d];
             Node suc_node = A[ret_struct.N_s];
             RegAS zero_as; RegXS zero_xs;
             // Update root if needed
             mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
 
-            
+            /*
             bool rec_F_ss = reconstruct_RegBS(tio, yield, ret_struct.F_ss);
             size_t rec_del_key = reconstruct_RegAS(tio, yield, del_node.key);
             size_t rec_suc_key = reconstruct_RegAS(tio, yield, suc_node.key);
             printf("rec_F_ss = %d, del_node.key = %lu, suc_nod.key = %lu\n",
                 rec_F_ss, rec_del_key, rec_suc_key);
-            
+            */            
+
             mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);
             mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);
             A[ret_struct.N_d].NODE_KEY = del_node.key;
@@ -1291,6 +1289,7 @@ void avl_tests(MPCIO &mpcio,
     run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
         size_t size = size_t(1)<<depth;
         bool player0 = tio.player()==0;
+        AVL tree(tio.player(), size);
 
         // (T1) : Test 1 : L rotation (root modified)
         /*
@@ -1308,7 +1307,6 @@ void avl_tests(MPCIO &mpcio,
             - 5 and 9 have no children and 0 balances
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 7, 9};
             size_t insert_array_size = 2;
@@ -1353,6 +1351,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T1 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // (T2) : Test 2 : L rotation (root unmodified)
@@ -1375,7 +1375,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 3, 7, 9, 12};
             size_t insert_array_size = 4;
@@ -1438,6 +1437,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T2 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -1458,7 +1459,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 7, 5};
             size_t insert_array_size = 2;
@@ -1503,6 +1503,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T3 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -1527,7 +1529,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 12, 7, 5, 3};
             size_t insert_array_size = 4;
@@ -1589,6 +1590,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T4 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -1610,7 +1613,6 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 5, 7};
             size_t insert_array_size = 2;
@@ -1663,6 +1665,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T5 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -1688,7 +1692,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 12, 7, 3, 5};
             size_t insert_array_size = 4;
@@ -1750,6 +1753,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T6 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -1771,7 +1776,6 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 9, 7};
             size_t insert_array_size = 2;
@@ -1824,6 +1828,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T7 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // (T8) : Test 8 : RL rotation (root unmodified)
@@ -1848,7 +1854,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 3, 12, 7, 9};
             size_t insert_array_size = 4;
@@ -1910,6 +1915,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T8 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // Deletion Tests:
@@ -1931,7 +1938,6 @@ void avl_tests(MPCIO &mpcio,
             - 7 has 0 balances
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 3, 7, 9};
             size_t insert_array_size = 3;
@@ -1981,6 +1987,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T9 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // (T10) : Test 10 : L rotation (root unmodified)
@@ -2003,7 +2011,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
             size_t insert_array_size = 6;
@@ -2075,6 +2082,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T10 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // (T11) : Test 11 : R rotation (root modified)
@@ -2094,7 +2103,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 7, 12, 5};
             size_t insert_array_size = 3;
@@ -2148,6 +2156,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T11 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -2172,7 +2182,6 @@ void avl_tests(MPCIO &mpcio,
             - 12 bal = 0 1
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
             size_t insert_array_size = 6;
@@ -2242,6 +2251,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T12 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -2263,7 +2274,6 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 5, 12, 7};
             size_t insert_array_size = 3;
@@ -2321,6 +2331,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T13 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -2346,7 +2358,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 12, 7, 3, 5};
             size_t insert_array_size = 4;
@@ -2413,6 +2424,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T14 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // (T15) : Test 15 : RL rotation (root modified)
@@ -2433,7 +2446,6 @@ void avl_tests(MPCIO &mpcio,
         */
 
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 9, 3, 7};
             size_t insert_array_size = 3;
@@ -2491,6 +2503,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T15 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
         // (T16) : Test 16 : RL rotation (root unmodified)
@@ -2515,7 +2529,6 @@ void avl_tests(MPCIO &mpcio,
 
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {5, 3, 12, 7, 1, 9};
             size_t insert_array_size = 5;
@@ -2582,6 +2595,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T16 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
 
 
@@ -2617,7 +2632,6 @@ void avl_tests(MPCIO &mpcio,
             - balances and children are correct
         */
         {
-            AVL tree(tio.player(), size);
             bool success = 1;
             int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
             size_t insert_array_size = 12;
@@ -2721,6 +2735,8 @@ void avl_tests(MPCIO &mpcio,
                     print_red("T17 : FAIL\n");
                 }
             }
+            auto A = oram->flat(tio, yield);
+            A.init();
         }
     });
 }

+ 3 - 1
online.cpp

@@ -1623,7 +1623,9 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
         bst(mpcio, opts, args);
     } else if (!strcmp(*args, "avl")) {
         ++args;
-        //avl(mpcio, opts, args);
+        avl(mpcio, opts, args);
+    } else if (!strcmp(*args, "avl_tests")) {
+        ++args;
         avl_tests(mpcio, opts, args);
     } else {
         std::cerr << "Unknown mode " << *args << "\n";