Quellcode durchsuchen

temp commit to get mpc_select for RegBS from main

Sajin Sasy vor 1 Jahr
Ursprung
Commit
dd814146a2
2 geänderte Dateien mit 228 neuen und 49 gelöschten Zeilen
  1. 208 48
      bst.cpp
  2. 20 1
      bst.hpp

+ 208 - 48
bst.cpp

@@ -11,7 +11,7 @@
 // tree balancing information into one field) and the value doesn't
 // really matter, but XOR shared is usually slightly more efficient.
 
-std::tuple<RegBS, RegBS> compare_keys(Node n1, Node n2, MPCTIO tio, yield_t &yield) {
+std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, Node n1, Node n2) {
     CDPF cdpf = tio.cdpf(yield);
     auto [lt, eq, gt] = cdpf.compare(tio, yield, n2.key - n1.key, tio.aes_ops());
     RegBS lteq = lt^eq;
@@ -41,52 +41,6 @@ inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
     pointer+=(new_ptr);
 }
 
-std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
-    const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
-    if(TTL==0) {
-        RegBS zero;
-        return {ptr, zero};
-    }
-
-    RegBS isNotDummy = isDummy ^ (tio.player());
-    Node cnode = A[ptr];
-    // Compare key
-    auto [lteq, gt] = compare_keys(cnode, new_node, tio, yield);
-
-    // Depending on [lteq, gt] select the next ptr/index as
-    // upper 32 bits of cnode.pointers if lteq
-    // lower 32 bits of cnode.pointers if gt 
-    RegXS left = extractLeftPtr(cnode.pointers);
-    RegXS right = extractRightPtr(cnode.pointers);
-    
-    RegXS next_ptr;
-    mpc_select(tio, yield, next_ptr, gt, left, right, 32);
-
-    CDPF dpf = tio.cdpf(yield);
-    size_t &aes_ops = tio.aes_ops();
-    // F_z: Check if this is last node on path
-    RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
-    RegBS F_i;
-
-    // F_i: If this was last node on path (F_z), and isNotDummy insert.
-    mpc_and(tio, yield, F_i, (isNotDummy), F_z);
-     
-    isDummy^=F_i;
-    auto [wptr, direction] = insert(tio, yield, next_ptr, new_node, A, TTL-1, isDummy);
-    
-    RegXS ret_ptr;
-    RegBS ret_direction;
-    // If we insert here (F_i), return the ptr to this node as wptr
-    // and update direction to the direction taken by compare_keys
-    mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
-    //ret_direction = direction + F_p(direction - gt)
-    mpc_and(tio, yield, ret_direction, F_i, direction^gt);
-    ret_direction^=direction;  
-
-    return {ret_ptr, ret_direction};
-}
-
-
 // Pretty-print a reconstructed BST, rooted at node. is_left_child and
 // is_right_child indicate whether node is a left or right child of its
 // parent.  They cannot both be true, but the root of the tree has both
@@ -131,6 +85,28 @@ void BST::pretty_print(const std::vector<Node> &R, value_t node,
     pretty_print(R, left_ptr, leftprefix, true, false);
 }
 
+bool reconstruct_flag(MPCTIO &tio, yield_t &yield, RegBS flag) {
+    RegBS peer_flag;
+    RegBS reconstructed_flag;
+    if (tio.player() == 1) {
+        tio.queue_peer(&flag, sizeof(flag));
+    } else {
+        RegBS peer_flag;
+        tio.recv_peer(&peer_flag, sizeof(peer_flag));
+        reconstructed_flag ^= peer_flag;
+    }
+
+    if (tio.player() == 0) {
+        tio.queue_peer(&flag, sizeof(flag));
+    } else {
+        RegBS peer_flag;
+        tio.recv_peer(&peer_flag, sizeof(peer_flag));
+        reconstructed_flag ^= peer_flag;
+    }
+    
+    return reconstructed_flag.bshare;
+}
+
 void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
     RegXS peer_root;
     RegXS reconstructed_root = root;
@@ -197,6 +173,52 @@ void BST::initialize(int num_players, size_t size) {
 }
 
 
+std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
+    const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
+    if(TTL==0) {
+        RegBS zero;
+        return {ptr, zero};
+    }
+
+    RegBS isNotDummy = isDummy ^ (tio.player());
+    Node cnode = A[ptr];
+    // Compare key
+    auto [lteq, gt] = compare_keys(tio, yield, cnode, new_node);
+
+    // Depending on [lteq, gt] select the next ptr/index as
+    // upper 32 bits of cnode.pointers if lteq
+    // lower 32 bits of cnode.pointers if gt 
+    RegXS left = extractLeftPtr(cnode.pointers);
+    RegXS right = extractRightPtr(cnode.pointers);
+    
+    RegXS next_ptr;
+    mpc_select(tio, yield, next_ptr, gt, left, right, 32);
+
+    CDPF dpf = tio.cdpf(yield);
+    size_t &aes_ops = tio.aes_ops();
+    // F_z: Check if this is last node on path
+    RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
+    RegBS F_i;
+
+    // F_i: If this was last node on path (F_z), and isNotDummy insert.
+    mpc_and(tio, yield, F_i, (isNotDummy), F_z);
+     
+    isDummy^=F_i;
+    auto [wptr, direction] = insert(tio, yield, next_ptr, new_node, A, TTL-1, isDummy);
+    
+    RegXS ret_ptr;
+    RegBS ret_direction;
+    // If we insert here (F_i), return the ptr to this node as wptr
+    // and update direction to the direction taken by compare_keys
+    mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
+    //ret_direction = direction + F_p(direction - gt)
+    mpc_and(tio, yield, ret_direction, F_i, direction^gt);
+    ret_direction^=direction;  
+
+    return {ret_ptr, ret_direction};
+}
+
+
 // Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
 void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A) {
     bool player0 = tio.player()==0;
@@ -257,6 +279,140 @@ void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
     */
 }
 
+/*
+// Compute in MPC a | b 
+void mpc_or(MPCTIO &tio, yield_t &yield, RegBS &result, RegBS a, RegBS b) {
+    int player0 = tio.player();
+    if(player0) {
+        a^=1;
+        b^=1;    
+    } 
+
+    mpc_and(tio, yield, result, a, b);
+    if(player0)
+        result^=1;
+}
+*/
+
+int BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
+     Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL, 
+    del_return &ret_struct) {
+
+    if(TTL==0) {
+        //Reconstruct and return af
+        bool af = reconstruct_flag(tio, yield, af);
+        printf("Reconstructed flag = %d\n", af);
+        return af;
+    } else {
+        bool player0 = tio.player()==0;
+        Node node = A[ptr];
+        // Compare key
+
+        CDPF cdpf = tio.cdpf(yield);
+        auto [lt, eq, gt] = cdpf.compare(tio, yield, node.key - del_key, tio.aes_ops());
+        // c is the direction bit for next_ptr 
+        // (c=0: go left or c=1: go right)
+        RegBS c = gt;
+        // lf = local found. We found the key to delete in this level.
+        RegBS lf = eq;
+
+        // Depending on [lteq, gt] select the next ptr/index as
+        // upper 32 bits of cnode.pointers if lteq
+        // lower 32 bits of cnode.pointers if gt 
+        RegXS left = extractLeftPtr(node.pointers);
+        RegXS right = extractRightPtr(node.pointers);
+        
+        CDPF dpf = tio.cdpf(yield);
+        size_t &aes_ops = tio.aes_ops();
+        // Check if left and right children are 0, and compute F_0, F_1, F_2
+        RegBS l0 = dpf.is_zero(tio, yield, left, aes_ops);
+        RegBS r0 = dpf.is_zero(tio, yield, right, aes_ops);
+        RegBS F_0, F_1, F_2;
+        // F_0 = l0 & r0
+        mpc_and(tio, yield, F_0, l0, r0);
+        // F_1 = l0 \xor r0
+        F_1 = l0 ^ r0;
+        // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
+        F_2 = F_0 ^ F_1;
+        if(player0)
+            F_2^=1;
+
+        // We set next ptr based on c, but we need to handle three 
+        // edge cases where we do not go by just the comparison result
+        RegXS next_ptr;
+        RegBS c_prime;
+        // Case 1: found the node here (lf) or we are finding successor (fs)
+        // and there is only one child. We traverse down the lone child path.
+        RegBS F_c11, F_c12, F_c2, F_c3;
+        // Case 1a: lf & F_1
+        mpc_and(tio, yield, F_c11, lf, F_1);
+        // Case 1b: fs & F_1
+        mpc_and(tio, yield, F_c12, fs, F_1);
+        // Set c_prime for Case 1a and 1b
+        mpc_select(tio, yield, c_prime, F_c1, c, l0);
+        mpc_select(tio, yield, c_prime, F_c2, c, l0);
+
+        // s1: shares of 1 bit, s0: shares of 0 bit
+        RegBS s1, s0;
+        s1.set(tio.player()==1);
+        // Case 2: found the node here (lf) and node has both children (F_2)
+        // In find successor case, so find inorder successor
+        // (Go right and then find leftmost child.)
+        mpc_and(tio, yield, F_c2, lf, F_2);
+        mpc_select(tio, yield, c_prime, F_c2, c, s1);
+
+        // Case 3: finding successor (fs) and node has both children (F_2)
+        // Go left. 
+        mpc_and(tio, yield, F_c3, fs, F_2);
+        mpc_select(tio, yield, c_prime, F_c3, c, s0);
+
+        // Set next_ptr
+        mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);
+        
+        RegBS af_prime, fs_prime;
+        mpc_or(tio, yield, af_prime, af, lf);
+
+        // If in Case 2, set fs. We are now finding successor
+        mpc_or(tio, yield, fs_prime, fs, F_c2); 
+        int key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
+
+        // If we didn't find the key, we can end here.
+        if(!key_found)
+          return 0;
+
+        // Update node.left and node.right with ret_struct.rptr and [c] as slct bit
+
+        // Update the return structure        
+
+    }
+
+    return 1;
+}
+
+
+int BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
+    if(num_items==0)
+        return 0;
+    if(num_items==1) {
+        //Delete root
+        auto A = oram->flat(tio, yield);
+        Node zero;
+        A[0] = zero;
+        num_items--;
+        return 1; 
+    } else {
+        int TTL = num_items;
+        // Flags for already found (af) item to delete and find successor (fs)
+        // if this deletion requires a successor swap
+        RegBS af;
+        RegBS fs;
+        del_return ret_struct; 
+        auto A = oram->flat(tio, yield);
+        int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct); 
+        printf ("Success =  %d\n", success); 
+      return 1;
+    }
+}
 
 // Now we use the node in various ways.  This function is called by
 // online.cpp.
@@ -280,6 +436,7 @@ void bst(MPCIO &mpcio,
         size_t size = size_t(1)<<depth;
         BST tree(tio.player(), size);
 
+        /*
         Node node; 
         for(size_t i = 1; i<=items; i++) {
           newnode(node);
@@ -288,7 +445,10 @@ void bst(MPCIO &mpcio,
         }
        
         tree.pretty_print(tio, yield);
-
+        */
+        
+        RegAS del_key;
+        tree.del(tio, yield, del_key);
 
         /*
         if (depth < 10) {

+ 20 - 1
bst.hpp

@@ -132,6 +132,18 @@ T& operator<<(T& os, const Node &x)
 
 DEFAULT_TUPLE_IO(Node)
 
+struct del_return {
+    // Flag to indicate if the key to delete was found in tree
+    RegBS F_f;
+    RegXS ret_ptr;
+    // Flag to indicate if the key this deletion requires a successor swap
+    RegBS F_ss;
+    // Pointers to node to delete and successor node that would replace
+    // deleted node
+    RegXS N_d;
+    RegXS N_s;
+};
+
 class BST {
   private: 
     Duoram<Node> *oram;
@@ -140,9 +152,15 @@ class BST {
     size_t num_items = 0;
     size_t MAX_SIZE;
 
-    std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr, const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy);
+    std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
+        const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy);
     void insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A);
 
+
+    int del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
+        Duoram<Node>::Flat &A, RegBS F_af, RegBS F_fs, int TTL, 
+        del_return &ret_struct);
+
   public:
     BST(int num_players, size_t size) {
       this->initialize(num_players, size);
@@ -155,6 +173,7 @@ class BST {
 
     void initialize(int num_players, size_t size);
     void insert(MPCTIO &tio, yield_t &yield, Node &node);
+    int del(MPCTIO &tio, yield_t &yield, RegAS del_key); 
 
     // Display and correctness check functions
     void pretty_print(MPCTIO &tio, yield_t &yield);