sshsshy пре 1 година
родитељ
комит
66956bf43d
2 измењених фајлова са 259 додато и 115 уклоњено
  1. 252 115
      bst.cpp
  2. 7 0
      bst.hpp

+ 252 - 115
bst.cpp

@@ -2,14 +2,70 @@
 
 #include "bst.hpp"
 
-// This file demonstrates how to implement custom ORAM wide cell types.
-// Such types can be structures of arbitrary numbers of RegAS and RegXS
-// fields.  The example here imagines a node of a binary search tree,
-// where you would want the key to be additively shared (so that you can
-// easily do comparisons), the pointers field to be XOR shared (so that
-// you can easily do bit operations to pack two pointers and maybe some
-// tree balancing information into one field) and the value doesn't
-// really matter, but XOR shared is usually slightly more efficient.
+// Helper functions to reconstruct shared RegBS, RegAS or RegXS
+
+bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
+    RegBS reconstructed_flag;
+    if (tio.player() < 2) {
+        RegBS peer_flag;
+        tio.queue_peer(&flag, 1);
+        tio.queue_server(&flag, 1);
+        yield();
+        tio.recv_peer(&peer_flag, 1);
+        reconstructed_flag = flag;
+        reconstructed_flag ^= peer_flag;
+    } else {
+        RegBS p0_flag, p1_flag;
+        yield();
+        tio.recv_p0(&p0_flag, 1);
+        tio.recv_p1(&p1_flag, 1);
+        reconstructed_flag = p0_flag;
+        reconstructed_flag ^= p1_flag;
+    }
+    return reconstructed_flag.bshare;
+}
+
+size_t reconstruct_RegAS(MPCTIO &tio, yield_t &yield, RegAS variable) {
+    RegAS reconstructed_var;
+    if (tio.player() < 2) {
+        RegAS peer_var;
+        tio.queue_peer(&variable, sizeof(variable));
+        tio.queue_server(&variable, sizeof(variable));
+        yield();
+        tio.recv_peer(&peer_var, sizeof(variable));
+        reconstructed_var = variable;
+        reconstructed_var += peer_var;
+    } else {
+        RegAS p0_var, p1_var;
+        yield();
+        tio.recv_p0(&p0_var, sizeof(variable));
+        tio.recv_p1(&p1_var, sizeof(variable));
+        reconstructed_var = p0_var;
+        reconstructed_var += p1_var;
+    }
+    return reconstructed_var.ashare;
+}
+
+size_t reconstruct_RegXS(MPCTIO &tio, yield_t &yield, RegXS variable) {
+    RegXS reconstructed_var;
+    if (tio.player() < 2) {
+        RegXS peer_var;
+        tio.queue_peer(&variable, sizeof(variable));
+        tio.queue_server(&variable, sizeof(variable));
+        yield();
+        tio.recv_peer(&peer_var, sizeof(variable));
+        reconstructed_var = variable;
+        reconstructed_var ^= peer_var;
+    } else {
+        RegXS p0_var, p1_var;
+        yield();
+        tio.recv_p0(&p0_var, sizeof(variable));
+        tio.recv_p1(&p1_var, sizeof(variable));
+        reconstructed_var = p0_var;
+        reconstructed_var ^= p1_var;
+    }
+    return reconstructed_var.xshare;
+}
 
 std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, Node n1, Node n2) {
     CDPF cdpf = tio.cdpf(yield);
@@ -85,24 +141,15 @@ void BST::pretty_print(const std::vector<Node> &R, value_t node,
     pretty_print(R, left_ptr, leftprefix, true, false);
 }
 
-bool reconstruct_flag(MPCTIO &tio, yield_t &yield, RegBS flag) {
-    RegBS peer_flag;
-    RegBS reconstructed_flag = flag;
-    tio.queue_peer(&flag, 1);
-    yield();
-    tio.recv_peer(&peer_flag, 1);
-    reconstructed_flag ^= peer_flag;
-    //return 1;
-    /* 
-    //Opt 1:
-    if(reconstructed_flag.bshare)
-      return 1;
-    else
-      return 0;
-    */
-    //Opt 2:
-    return reconstructed_flag.bshare;
-    
+void BST::print_oram(MPCTIO &tio, yield_t &yield) {
+    auto A = oram->flat(tio, yield);
+    auto R = A.reconstruct();
+
+    for(size_t i=0;i<R.size();++i) {
+        printf("\n%04lx ", i);
+        R[i].dump();
+    }
+    printf("\n");
 }
 
 void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
@@ -132,6 +179,7 @@ void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
 std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
     value_t node, value_t min_key = 0, value_t max_key = ~0)
 {
+    //printf("node = %ld\n", node);
     if (node == 0) {
         return { true, 0 };
     }
@@ -146,6 +194,7 @@ std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
         height = rightheight;
     }
     height += 1;
+    //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
     return { leftok && rightok && key >= min_key && key <= max_key,
         height };
 }
@@ -154,10 +203,20 @@ void BST::check_bst(MPCTIO &tio, yield_t &yield) {
     auto A = oram->flat(tio, yield);
     auto R = A.reconstruct();
 
-    auto [ ok, height ] = check_bst(R, root.xshare);
-    printf("BST structure %s\nBST height = %u\n",
-        ok ? "ok" : "NOT OK", height);
-}  
+    RegXS rec_root = this->root;
+    if (tio.player() == 1) {
+        tio.queue_peer(&(this->root), sizeof(this->root));
+    } else {
+        RegXS peer_root;
+        tio.recv_peer(&peer_root, sizeof(peer_root));
+        rec_root+= peer_root;
+    }
+    if (tio.player() == 0) {
+      auto [ ok, height ] = check_bst(R, rec_root.xshare);
+      printf("BST structure %s\nBST height = %u\n",
+          ok ? "ok" : "NOT OK", height);
+    }
+}
 
 void newnode(Node &a) {
     a.key.randomize(8);
@@ -277,29 +336,17 @@ void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
     */
 }
 
-/*
-// Compute in MPC a | b 
-void mpc_or(MPCTIO &tio, yield_t &yield, RegBS &result, RegBS a, RegBS b) {
-    int player0 = tio.player();
-    if(player0) {
-        a^=1;
-        b^=1;    
-    } 
-
-    mpc_and(tio, yield, result, a, b);
-    if(player0)
-        result^=1;
-}
-*/
-
 bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
      Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL, 
     del_return &ret_struct) {
-    printf("TTL = %d\n", TTL);
+    bool player0 = tio.player()==0;
+    //printf("TTL = %d\n", TTL);
     if(TTL==0) {
         //Reconstruct and return af
-        bool success = reconstruct_flag(tio, yield, af);
-        printf("Reconstructed flag = %d\n", success);
+        bool success = reconstruct_RegBS(tio, yield, af);
+        //printf("Reconstructed flag = %d\n", success); 
+        if(player0) 
+          ret_struct.F_r^=1;
         return success;
     } else {
         bool player0 = tio.player()==0;
@@ -308,6 +355,19 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
 
         CDPF cdpf = tio.cdpf(yield);
         auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
+      
+        /*
+        // Reconstruct and Debug Block 0
+        bool lt_rec, eq_rec, gt_rec;
+        lt_rec = reconstruct_RegBS(tio, yield, lt);
+        eq_rec = reconstruct_RegBS(tio, yield, eq);
+        gt_rec = reconstruct_RegBS(tio, yield, gt);
+        size_t del_key_rec, node_key_rec;
+        del_key_rec = reconstruct_RegAS(tio, yield, del_key);
+        node_key_rec = reconstruct_RegAS(tio, yield, node.key);
+        printf("node.key = %ld, del_key= %ld\n", node_key_rec, del_key_rec);
+        printf("cdpf.compare results: lt = %d, eq = %d, gt = %d\n", lt_rec, eq_rec, gt_rec);
+        */
         // c is the direction bit for next_ptr 
         // (c=0: go left or c=1: go right)
         RegBS c = gt;
@@ -339,16 +399,23 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         // edge cases where we do not go by just the comparison result
         RegXS next_ptr;
         RegBS c_prime;
-        // Case 1: found the node here (lf) or we are finding successor (fs)
-        // and there is only one child. We traverse down the lone child path.
-        RegBS F_c1a, F_c1b, F_c2, F_c3;
-        // Case 1a: lf & F_1
-        mpc_and(tio, yield, F_c1a, lf, F_1);
-        // Case 1b: fs & F_1
-        mpc_and(tio, yield, F_c1b, fs, F_1);
-        // Set c_prime for Case 1a and 1b
-        mpc_select(tio, yield, c_prime, F_c1a, c, l0);
-        mpc_select(tio, yield, c_prime, F_c1b, c, l0);
+        // Case 1: found the node here (lf): we traverse down the lone child path.
+        // or we are finding successor (fs) and there is no left child. 
+        RegBS F_c1, F_c2, F_c3, F_c4;
+        // Case 1: lf & F_1
+        mpc_and(tio, yield, F_c1, lf, F_1);
+        // Set c_prime for Case 1
+        mpc_select(tio, yield, c_prime, F_c1, c, l0);
+
+        /*
+        // Reconstruct and Debug Block 1
+        bool F_0_rec, F_1_rec, F_2_rec, c_prime_rec;
+        F_0_rec = reconstruct_RegBS(tio, yield, F_0);
+        F_1_rec = reconstruct_RegBS(tio, yield, F_1);
+        F_2_rec = reconstruct_RegBS(tio, yield, F_2);
+        c_prime_rec = reconstruct_RegBS(tio, yield, c_prime);
+        printf("F_0 = %d, F_1 = %d, F_2 = %d, c_prime = %d\n", F_0_rec, F_1_rec, F_2_rec, c_prime_rec);
+        */
 
         // s1: shares of 1 bit, s0: shares of 0 bit
         RegBS s1, s0;
@@ -357,12 +424,28 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         // In find successor case, so find inorder successor
         // (Go right and then find leftmost child.)
         mpc_and(tio, yield, F_c2, lf, F_2);
-        mpc_select(tio, yield, c_prime, F_c2, c, s1);
+        mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);
+
+        /*
+        // Reconstruct and Debug Block 2
+        bool F_c2_rec, s1_rec;
+        F_c2_rec = reconstruct_RegBS(tio, yield, F_c2);
+        s1_rec = reconstruct_RegBS(tio, yield, s1); 
+        c_prime_rec = reconstruct_RegBS(tio, yield, c_prime); 
+        printf("c_prime = %d, F_c2 = %d, s1 = %d\n", c_prime_rec, F_c2_rec, s1_rec);
+        */
+
 
         // Case 3: finding successor (fs) and node has both children (F_2)
         // Go left. 
         mpc_and(tio, yield, F_c3, fs, F_2);
-        mpc_select(tio, yield, c_prime, F_c3, c, s0);
+        mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);
+
+        // Case 4: finding successor (fs) and node has no more left children (l0)
+        // This is the successor node then.
+        // Go left (to end the traversal without triggering flags on the real path to the right.
+        mpc_and(tio, yield, F_c4, fs, l0);
+        mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
 
         // Set next_ptr
         mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);
@@ -370,18 +453,19 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         RegBS af_prime, fs_prime;
         mpc_or(tio, yield, af_prime, af, lf);
 
-        //A[ptr].dump();
-        printf("af = %d, lf = %d\n", af.bshare, lf.bshare);
         // If in Case 2, set fs. We are now finding successor
-        mpc_or(tio, yield, fs_prime, fs, F_c2); 
+        mpc_or(tio, yield, fs_prime, fs, F_c2);
+
+        // If in Case 3. Successor found here already. Toggle fs off
+        fs_prime=fs_prime^F_c4;
+
         bool key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
 
         // If we didn't find the key, we can end here.
         if(!key_found)
           return 0;
 
-        printf("TTL = %d\n", TTL);
-        /*
+        //printf("TTL = %d\n", TTL); 
         RegBS F_rs;
         // Flag here should be direction (c_prime) and F_r i.e. we need to swap return ptr in,
         // F_r needs to be returned in ret_struct
@@ -392,29 +476,42 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
         mpc_and(tio, yield, F_rs, c_prime, ret_struct.F_r);
         mpc_select(tio, yield, left, F_rs, left, ret_struct.ret_ptr); 
 
-        // Update the return structure        
-        RegBS F_nd, F_ns, F_r, F_rp, F_rp0;
+        /*
+        // Reconstruct and Debug Block 3
+        bool F_rs_rec, F_ls_rec;
+        size_t ret_ptr_rec;
+        F_rs_rec = reconstruct_RegBS(tio, yield, F_rs);
+        F_ls_rec = reconstruct_RegBS(tio, yield, F_rs);
+        ret_ptr_rec = reconstruct_RegXS(tio, yield, ret_struct.ret_ptr);
+        printf("F_rs_rec = %d, F_ls_rec = %d, ret_ptr_rec = %ld\n", F_rs_rec, F_ls_rec, ret_ptr_rec);
+        */
+        RegXS new_ptr;
+        setLeftPtr(new_ptr, left);
+        setRightPtr(new_ptr, right);
+        A[ptr].NODE_POINTERS = new_ptr;
+
+        // Update the return structure 
+        RegBS F_nd, F_ns, F_r;
         mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);
         if(player0)
-            af^=1; 
+            af^=1;
         mpc_and(tio, yield, F_nd, lf, af);
-        mpc_and(tio, yield, F_ns, fs, F_0);
+        // F_ns = fs & l0 
+        // Finding successor flag & no more left child
+        F_ns = F_c4;
         // F_r = F_d.(!F_2)
         if(player0)
             F_2^=1;
         // If we have to delete here, and it doesn't have two children we have to
         // update child pointer in parent with the returned pointer
         mpc_and(tio, yield, F_r, F_nd, F_2);
-        mpc_and(tio, yield, F_rp, F_nd, F_1);
-        mpc_and(tio, yield, F_rp0, F_nd, F_0);
+        mpc_or(tio, yield, F_r, F_r, F_ns);
+        ret_struct.F_r = F_r;
 
-        mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_s, ptr);
+        mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_d, ptr);
         mpc_select(tio, yield, ret_struct.N_s, F_ns, ret_struct.N_s, ptr);
-        
-        mpc_select(tio, yield, ret_struct.ret_ptr, F_rp, ptr, ret_struct.ret_ptr);
-        mpc_select(tio, yield, ret_struct.ret_ptr, F_rp0, ptr, s0);
-        ret_struct.F_r = F_r;
-        */
+        mpc_select(tio, yield, ret_struct.ret_ptr, F_r, ptr, ret_struct.ret_ptr);
+        //We don't empty the key and value of the node with del_key in the ORAM 
         return 1;
     }
 }
@@ -427,7 +524,8 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         //Delete root
         auto A = oram->flat(tio, yield);
         Node zero;
-        A[0] = zero;
+        empty_locations.emplace_back(root);
+        A[root] = zero;
         num_items--;
         return 1; 
     } else {
@@ -444,12 +542,29 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
             return 0;
         }
         else{
-            //Fix up the actual deletion and succesor swap (if needed) here
+            num_items--;
+            
+            //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
+            empty_locations.emplace_back(ret_struct.N_d);
+
+            /*
+            printf("In delete's swap portion\n");
             Node del_node = A.reconstruct(A[ret_struct.N_d]);
-            Node suc_ptr = A.reconstruct(A[ret_struct.N_s]);
-            printf("del_node key = %ld, suc_node key = %ld\n", 
-                del_node.key.ashare, suc_ptr.key.ashare); 
-            //print("flag_s = %d\n", rec_struct.F_ss);
+            Node suc_node = A.reconstruct(A[ret_struct.N_s]);
+            printf("del_node key = %ld, suc_node key = %ld\n",
+                del_node.key.ashare, suc_node.key.ashare);
+            printf("flag_s = %d\n", ret_struct.F_ss.bshare);
+            */
+            Node del_node = A[ret_struct.N_d];
+            Node suc_node = A[ret_struct.N_s];
+            RegAS zero_as; RegXS zero_xs;
+            mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
+            mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);
+            mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);           
+            A[ret_struct.N_d].NODE_KEY = del_node.key;
+            A[ret_struct.N_d].NODE_VALUE = del_node.value;
+            A[ret_struct.N_s].NODE_KEY = zero_as;
+            A[ret_struct.N_s].NODE_VALUE = zero_xs;
         }
 
       return 1;
@@ -461,7 +576,7 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
 void bst(MPCIO &mpcio,
     const PRACOptions &opts, char **args)
 {
-    nbits_t depth=3;
+    nbits_t depth=4;
 
     if (*args) {
         depth = atoi(*args);
@@ -478,46 +593,68 @@ void bst(MPCIO &mpcio,
         size_t size = size_t(1)<<depth;
         BST tree(tio.player(), size);
 
-        
+        int insert_array[] = {10, 10, 13, 11, 14, 8, 15, 20, 17, 19, 7, 12};
+        //int insert_array[] = {1, 2, 3, 4, 5, 6};
+        size_t insert_array_size = 11;
         Node node; 
-        for(size_t i = 1; i<=items; i++) {
+        for(size_t i = 0; i<=insert_array_size; i++) {
           newnode(node);
-          node.key.set(i * tio.player());
+          node.key.set(insert_array[i] * tio.player());
           tree.insert(tio, yield, node);
         }
        
+        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
-        
-        
+         
         RegAS del_key;
-        del_key.set(1);
+
+        printf("\n\nDelete %x\n", 20);
+        del_key.set(20 * tio.player());
         tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
+        tree.pretty_print(tio, yield);
+        tree.check_bst(tio, yield);
 
+        printf("\n\nDelete %x\n", 10);
+        del_key.set(10 * tio.player());
+        tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
-        /*
-        if (depth < 10) {
-            //oram.dump();
-            auto R = A.reconstruct();
-            // Reconstruct the root
-            if (tio.player() == 1) {
-                tio.queue_peer(&root, sizeof(root));
-            } else {
-                RegXS peer_root;
-                tio.recv_peer(&peer_root, sizeof(peer_root));
-                root += peer_root;
-            }
-            if (tio.player() == 0) {
-                for(size_t i=0;i<R.size();++i) {
-                    printf("\n%04lx ", i);
-                    R[i].dump();
-                }
-                printf("\n");
-                pretty_print(R, root.xshare);
-                auto [ ok, height ] = check_bst(R, root.xshare);
-                printf("BST structure %s\nBST height = %u\n",
-                    ok ? "ok" : "NOT OK", height);
-            }
-        }
-        */ 
+        tree.check_bst(tio, yield);
+
+        printf("\n\nDelete %x\n", 8);
+        del_key.set(8 * tio.player());
+        tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
+        tree.pretty_print(tio, yield);
+        tree.check_bst(tio, yield);
+
+        printf("\n\nDelete %x\n", 7);
+        del_key.set(7 * tio.player());
+        tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
+        tree.pretty_print(tio, yield);
+        tree.check_bst(tio, yield);
+
+        printf("\n\nDelete %x\n", 17);
+        del_key.set(17 * tio.player());
+        tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
+        tree.pretty_print(tio, yield);
+        tree.check_bst(tio, yield);
+
+        printf("\n\nDelete %x\n", 15);
+        del_key.set(15 * tio.player());
+        tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
+        tree.pretty_print(tio, yield);
+        tree.check_bst(tio, yield);
+
+        printf("\n\nDelete %x\n", 5);
+        del_key.set(5 * tio.player());
+        tree.del(tio, yield, del_key);
+        tree.print_oram(tio, yield);
+        tree.pretty_print(tio, yield);
+        tree.check_bst(tio, yield);
     });
 }

+ 7 - 0
bst.hpp

@@ -154,6 +154,8 @@ class BST {
     size_t num_items = 0;
     size_t MAX_SIZE;
 
+    std::vector<RegXS> empty_locations;
+
     std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
         const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy);
     void insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A);
@@ -184,6 +186,10 @@ class BST {
     void check_bst(MPCTIO &tio, yield_t &yield);
     std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
         value_t node, value_t min_key, value_t max_key);
+    void print_oram(MPCTIO &tio, yield_t &yield);
+    size_t numEmptyLocations(){
+      return(empty_locations.size());
+    };
 
 };
 
@@ -207,4 +213,5 @@ class BST_OP {
 void bst(MPCIO &mpcio,
     const PRACOptions &opts, char **args);
 
+
 #endif