Browse Source

Addressed IG's code review on bst

sshsshy 1 year ago
parent
commit
868fec1544
1 changed files with 73 additions and 34 deletions
  1. 73 34
      bst.cpp

+ 73 - 34
bst.cpp

@@ -15,7 +15,7 @@
     }
 #endif
 
-// Helper functions to reconstruct shared RegBS, RegAS or RegXS
+// Helper function to reconstruct shared RegBS
 bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
     RegBS reconstructed_flag;
     if (tio.player() < 2) {
@@ -37,8 +37,9 @@ bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
     return reconstructed_flag.bshare;
 }
 
-/* A function to assign a new random 8-bit key to a node, and resets its 
-   pointers to zeroes. The node is assigned a new random 64-bit value.
+/* 
+    A function to assign a new random 8-bit key to a node, and resets its 
+    pointers to zeroes. The node is assigned a new random 64-bit value.
 */
 static void randomize_node(Node &a) {
     a.key.randomize(8);
@@ -64,7 +65,7 @@ std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1,
 // Assuming pointer of 64 bits is split as:
 // - 32 bits Left ptr (L)
 // - 32 bits Right ptr (R)
-// The pointers are stored as: L | R 
+// The pointers are stored as: (L << 32) | R 
 
 inline RegXS extractLeftPtr(RegXS pointer){ 
     return ((pointer&(0xFFFFFFFF00000000))>>32); 
@@ -84,6 +85,7 @@ inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
     pointer+=(new_ptr);
 }
 
+
 // Pretty-print a reconstructed BST, rooted at node. is_left_child and
 // is_right_child indicate whether node is a left or right child of its
 // parent.  They cannot both be true, but the root of the tree has both
@@ -137,7 +139,9 @@ void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
     } else {
         RegXS peer_root;
         yield();
-        tio.recv_peer(&peer_root, sizeof(peer_root));
+        if(tio.player()==0) {
+           tio.recv_peer(&peer_root, sizeof(peer_root));
+        }
         reconstructed_root += peer_root;
     }
 
@@ -185,9 +189,13 @@ void BST::check_bst(MPCTIO &tio, yield_t &yield) {
     RegXS rec_root = this->root;
     if (tio.player() == 1) {
         tio.queue_peer(&(this->root), sizeof(this->root));
+        yield();
     } else {
         RegXS peer_root;
-        tio.recv_peer(&peer_root, sizeof(peer_root));
+        yield();
+        if(tio.player()==0) {
+            tio.recv_peer(&peer_root, sizeof(peer_root));
+        }
         rec_root+= peer_root;
     }
     if (tio.player() == 0) {
@@ -233,7 +241,7 @@ std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
     RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
     RegBS F_i;
 
-    // F_i: If this was last node on path (F_z) ^ isNotDummy: 
+    // F_i: If this was last node on path (F_z) && isNotDummy: 
     //          insert new_node here.
     mpc_and(tio, yield, F_i, (isNotDummy), F_z);
      
@@ -266,10 +274,9 @@ void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Fl
     bool player0 = tio.player()==0;
     // If there are no items in tree. Make this new item the root.
     if (num_items==0) {
-        Node zero;
         A[1] = node;
         // Set root to a secret sharing of the constant value 1
-        (root).set(1*tio.player());
+        root.set(1*tio.player());
         num_items++;
         //printf("num_items == %ld!\n", num_items);
         return;
@@ -422,6 +429,22 @@ RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
     return found;
 }
 
+
+/*
+    The recursive del() call, invoked by the wrapper del() function.
+    
+    Takes as input the pointer to the current node in tree traversal (ptr),
+    the key to be deleted (del_key), the underlying Duoram as a 
+    flat (A), Flags af (already found) and fs (find successor), the 
+    Time-To_live TTL. Finally, a return structure ret_struct that tracks
+    the location of the successor node and the node to delete to perform
+    the actual deletion after the recursive traversal; which is required in
+    the case of a deletion that requires a successor swap (,i.e., when node
+    to delete has both children).
+
+    Returns success/fail bit.
+*/
+
 bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
      Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL, 
     del_return &ret_struct) {
@@ -431,8 +454,9 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         //Reconstruct and return af
         bool success = reconstruct_RegBS(tio, yield, af);
         //printf("Reconstructed flag = %d\n", success); 
-        if(player0) 
-          ret_struct.F_r^=1;
+        if(player0) { 
+            ret_struct.F_r^=1;
+        }
         return success;
     } else {
         // s1: shares of 1 bit, s0: shares of 0 bit
@@ -446,16 +470,17 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         CDPF cdpf = tio.cdpf(yield);
         size_t &aes_ops = tio.aes_ops();
         RegBS l0, r0, lt, eq, gt;
-        // Check if left and right children are 0, and compute F_0, F_1, F_2
+        // l0: Is left child 0 
+        // r0: Is right child 0
         run_coroutines(tio,
             [&tio, &l0, left, &aes_ops, &cdpf](yield_t &yield)
             { l0 = cdpf.is_zero(tio, yield, left, aes_ops);},
             [&tio, &r0, right, &aes_ops, &cdpf](yield_t &yield)
             { r0 = cdpf.is_zero(tio, yield, right, aes_ops);},
             [&tio, &lt, &eq, &gt, del_key, node, &cdpf](yield_t &yield)
+            // Compare Key
             { auto [a, b, c] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
               lt = a; eq = b; gt = c;});
-        // Compare Key
 
         /*
         // Reconstruct and Debug Block 0
@@ -469,13 +494,15 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         printf("node.key = %ld, del_key= %ld\n", node_key_rec, del_key_rec);
         printf("cdpf.compare results: lt = %d, eq = %d, gt = %d\n", lt_rec, eq_rec, gt_rec);
         */
+
         // c is the direction bit for next_ptr 
         // (c=0: go left or c=1: go right)
         RegBS c = gt;
         // lf = local found. We found the key to delete in this level.
         RegBS lf = eq;
 
-
+        // F_{X}: Flags that indicate the number of children this node has
+        // F_0: no children, F_1: one child, F_2: both children
         RegBS F_0, F_1, F_2;
         // F_1 = l0 \xor r0
         F_1 = l0 ^ r0;
@@ -566,13 +593,15 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         bool key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
 
         // If we didn't find the key, we can end here.
-        if(!key_found)
-          return 0;
+        if(!key_found) {
+            return 0;
+        }
 
         //printf("TTL = %d\n", TTL); 
         RegBS F_rs_right, F_rs_left, not_c_prime=c_prime;
-        if(player0)
-            not_c_prime^=1; 
+        if(player0) {
+            not_c_prime^=1;
+        }
         // Flag here should be direction (c_prime) and F_r i.e. we need to swap return ptr in,
         // F_r needs to be returned in ret_struct
         run_coroutines(tio,
@@ -638,6 +667,12 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
     }
 }
 
+/*
+    The main del() function.
+
+    Takes as input the key to delete (del_key).
+    Returns success/fail bit.
+*/
 
 bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
     if(num_items==0)
@@ -735,9 +770,9 @@ void bst(MPCIO &mpcio,
 
         int insert_array[] = {10, 10, 13, 11, 14, 8, 15, 20, 17, 19, 7, 12};
         //int insert_array[] = {1, 2, 3, 4, 5, 6};
-        size_t insert_array_size = 11;
+        size_t insert_array_size = sizeof(insert_array)/sizeof(int);
         Node node; 
-        for(size_t i = 0; i<=insert_array_size; i++) {
+        for(size_t i = 0; i<insert_array_size; i++) {
           randomize_node(node);
           node.key.set(insert_array[i] * tio.player());
           tree.insert(tio, yield, node);
@@ -747,7 +782,6 @@ void bst(MPCIO &mpcio,
          
         RegAS del_key;
 
-        /*
         printf("\n\nDelete %x\n", 20);
         del_key.set(20 * tio.player());
         tree.del(tio, yield, del_key);
@@ -760,11 +794,13 @@ void bst(MPCIO &mpcio,
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
 
+        /*
         printf("\n\nDelete %x\n", 8);
         del_key.set(8 * tio.player());
         tree.del(tio, yield, del_key);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
+        */
 
         printf("\n\nDelete %x\n", 7);
         del_key.set(7 * tio.player());
@@ -789,7 +825,6 @@ void bst(MPCIO &mpcio,
         tree.del(tio, yield, del_key);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
-        */
 
         printf("\n\nInsert %x\n", 14);
         randomize_node(node);
@@ -807,12 +842,14 @@ void bst(MPCIO &mpcio,
         found = tree.lookup(tio, yield, lookup_key, &node);
         rec_found = mpc_reconstruct(tio, yield, found);
         tree.pretty_print(tio, yield);
-        if(rec_found) {
-          printf("Lookup Success\n");
-          size_t value = mpc_reconstruct(tio, yield, node.value, 64);
-          printf("value = %lx\n", value);
-        } else {
-          printf("Lookup Failed\n");
+        if(tio.player()!=2) {
+            if(rec_found) {
+                printf("Lookup Success\n");
+                size_t value = mpc_reconstruct(tio, yield, node.value, 64);
+                printf("value = %lx\n", value);
+            } else {
+                printf("Lookup Failed\n");
+            }
         }
 
         printf("\n\nLookup %x\n", 63);
@@ -822,12 +859,14 @@ void bst(MPCIO &mpcio,
         rec_found = mpc_reconstruct(tio, yield, found);
         //rec_found = reconstruct_RegBS(tio, yield, found);
         tree.pretty_print(tio, yield);
-        if(rec_found) {
-          printf("Lookup Success\n");
-          size_t value = mpc_reconstruct(tio, yield, node.value, 64);
-          printf("value = %lx\n", value);    
-        } else {
-          printf("Lookup Failed\n");
+        if(tio.player()!=2) {
+            if(rec_found) {
+                printf("Lookup Success\n");
+                size_t value = mpc_reconstruct(tio, yield, node.value, 64);
+                printf("value = %lx\n", value);    
+            } else {
+                printf("Lookup Failed\n");
+            }
         }
 
     });