Browse Source

Restructing to BST Class (with MPCTIO and yield passed around) + code cleanup

Sajin Sasy 1 year ago
parent
commit
e9f037d222
2 changed files with 345 additions and 247 deletions
  1. 166 247
      bst.cpp
  2. 179 0
      bst.hpp

+ 166 - 247
bst.cpp

@@ -1,8 +1,5 @@
 #include <functional>
 
-#include "types.hpp"
-#include "duoram.hpp"
-#include "cdpf.hpp"
 #include "bst.hpp"
 
 // This file demonstrates how to implement custom ORAM wide cell types.
@@ -14,145 +11,11 @@
 // tree balancing information into one field) and the value doesn't
 // really matter, but XOR shared is usually slightly more efficient.
 
-struct Node {
-    RegAS key;
-    RegXS pointers;
-    RegXS value;
-
-// Field-access macros so we can write A[i].NODE_KEY instead of
-// A[i].field(&Node::key)
-
-#define NODE_KEY field(&Node::key)
-#define NODE_POINTERS field(&Node::pointers)
-#define NODE_VALUE field(&Node::value)
-
-    // For debugging and checking answers
-    void dump() const {
-        printf("[%016lx %016lx %016lx]", key.share(), pointers.share(),
-            value.share());
-    }
-
-    // You'll need to be able to create a random element, and do the
-    // operations +=, +, -=, - (binary and unary).  Note that for
-    // XOR-shared fields, + and - are both really XOR.
-
-    inline void randomize() {
-        key.randomize();
-        pointers.randomize();
-        value.randomize();
-    }
-
-    inline Node &operator+=(const Node &rhs) {
-        this->key += rhs.key;
-        this->pointers += rhs.pointers;
-        this->value += rhs.value;
-        return *this;
-    }
-
-    inline Node operator+(const Node &rhs) const {
-        Node res = *this;
-        res += rhs;
-        return res;
-    }
-
-    inline Node &operator-=(const Node &rhs) {
-        this->key -= rhs.key;
-        this->pointers -= rhs.pointers;
-        this->value -= rhs.value;
-        return *this;
-    }
-
-    inline Node operator-(const Node &rhs) const {
-        Node res = *this;
-        res -= rhs;
-        return res;
-    }
-
-    inline Node operator-() const {
-        Node res;
-        res.key = -this->key;
-        res.pointers = -this->pointers;
-        res.value = -this->value;
-        return res;
-    }
-
-    // Multiply each field by the local share of the corresponding field
-    // in the argument
-    inline Node mulshare(const Node &rhs) const {
-        Node res = *this;
-        res.key.mulshareeq(rhs.key);
-        res.pointers.mulshareeq(rhs.pointers);
-        res.value.mulshareeq(rhs.value);
-        return res;
-    }
-
-    // You need a method to turn a leaf node of a DPF into a share of a
-    // unit element of your type.  Typically set each RegAS to
-    // dpf.unit_as(leaf) and each RegXS or RegBS to dpf.unit_bs(leaf).
-    // Note that RegXS will extend a RegBS of 1 to the all-1s word, not
-    // the word with value 1.  This is used for ORAM reads, where the
-    // same DPF is used for all the fields.
-    inline void unit(const RDPF &dpf, DPFnode leaf) {
-        key = dpf.unit_as(leaf);
-        pointers = dpf.unit_bs(leaf);
-        value = dpf.unit_bs(leaf);
-    }
-
-    // Perform an update on each of the fields, using field-specific
-    // MemRefs constructed from the Shape shape and the index idx
-    template <typename Sh, typename U>
-    inline static void update(Sh &shape, yield_t &shyield, U idx,
-            const Node &M) {
-        run_coroutines(shyield,
-            [&shape, &idx, &M] (yield_t &yield) {
-                Sh Sh_coro = shape.context(yield);
-                Sh_coro[idx].NODE_KEY += M.key;
-            },
-            [&shape, &idx, &M] (yield_t &yield) {
-                Sh Sh_coro = shape.context(yield);
-                Sh_coro[idx].NODE_POINTERS += M.pointers;
-            },
-            [&shape, &idx, &M] (yield_t &yield) {
-                Sh Sh_coro = shape.context(yield);
-                Sh_coro[idx].NODE_VALUE += M.value;
-            });
-    }
-};
-
-// I/O operations (for sending over the network)
-
-template <typename T>
-T& operator>>(T& is, Node &x)
-{
-    is >> x.key >> x.pointers >> x.value;
-    return is;
-}
-
-template <typename T>
-T& operator<<(T& os, const Node &x)
-{
-    os << x.key << x.pointers << x.value;
-    return os;
-}
-
-// This macro will define I/O on tuples of two or three of the cell type
-
-DEFAULT_TUPLE_IO(Node)
-
 std::tuple<RegBS, RegBS> compare_keys(Node n1, Node n2, MPCTIO tio, yield_t &yield) {
-  CDPF cdpf = tio.cdpf(yield);
-  auto [lt, eq, gt] = cdpf.compare(tio, yield, n2.key - n1.key, tio.aes_ops());
-  RegBS lteq = lt^eq;
-  return {lteq, gt};
-}
-
-RegBS check_ptr_zero(MPCTIO tio, yield_t &yield, RegXS ptr) {
-  CDPF cdpf = tio.cdpf(yield);
-  RegAS ptr_as;
-  mpc_xs_to_as(tio, yield, ptr_as, ptr);
-  RegAS zero;
-  auto [lt, eq, gt] = cdpf.compare(tio, yield, ptr_as - zero, tio.aes_ops());
-  return eq;
+    CDPF cdpf = tio.cdpf(yield);
+    auto [lt, eq, gt] = cdpf.compare(tio, yield, n2.key - n1.key, tio.aes_ops());
+    RegBS lteq = lt^eq;
+    return {lteq, gt};
 }
 
 // Assuming pointer of 64 bits is split as:
@@ -161,114 +24,74 @@ RegBS check_ptr_zero(MPCTIO tio, yield_t &yield, RegXS ptr) {
 // < Left, Right>
 
 inline RegXS extractLeftPtr(RegXS pointer){ 
-  return ((pointer&(0xFFFFFFFF00000000))>>32); 
+    return ((pointer&(0xFFFFFFFF00000000))>>32); 
 }
 
 inline RegXS extractRightPtr(RegXS pointer){
-  return (pointer&(0x00000000FFFFFFFF)); 
+    return (pointer&(0x00000000FFFFFFFF)); 
 }
 
 inline void setLeftPtr(RegXS &pointer, RegXS new_ptr){ 
-  pointer&=(0x00000000FFFFFFFF);
-  pointer+=(new_ptr<<32);
+    pointer&=(0x00000000FFFFFFFF);
+    pointer+=(new_ptr<<32);
 }
 
 inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
-  pointer&=(0xFFFFFFFF00000000);
-  pointer+=(new_ptr);
-}
-
-std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr, const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
-  if(TTL==0) {
-    RegBS zero;
-    return {ptr, zero};
-  }
-
-  RegBS isNotDummy = isDummy ^ (tio.player());
-  Node cnode = A[ptr];
-  // Compare key
-  auto [lteq, gt] = compare_keys(cnode, new_node, tio, yield);
-
-  // Depending on [lteq, gt] select the next ptr/index as
-  // upper 32 bits of cnode.pointers if lteq
-  // lower 32 bits of cnode.pointers if gt 
-  RegXS left = extractLeftPtr(cnode.pointers);
-  RegXS right = extractRightPtr(cnode.pointers);
-  
-  RegXS next_ptr;
-  mpc_select(tio, yield, next_ptr, gt, left, right, 32);
-
-  CDPF dpf = tio.cdpf(yield);
-  size_t &aes_ops = tio.aes_ops();
-  // F_z: Check if this is last node on path
-  RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
-  RegBS F_i;
-
-  // F_i: If this was last node on path (F_z), and isNotDummy insert.
-  mpc_and(tio, yield, F_i, (isNotDummy), F_z);
-   
-  isDummy^=F_i;
-  auto [wptr, direction] = insert(tio, yield, next_ptr, new_node, A, TTL-1, isDummy);
-  
-  RegXS ret_ptr;
-  RegBS ret_direction;
-  // If we insert here (F_i), return the ptr to this node as wptr
-  // and update direction to the direction taken by compare_keys
-  mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
-  //ret_direction = direction + F_p(direction - gt)
-  mpc_and(tio, yield, ret_direction, F_i, direction^gt);
-  ret_direction^=direction;  
-
-  return {ret_ptr, ret_direction};
+    pointer&=(0xFFFFFFFF00000000);
+    pointer+=(new_ptr);
 }
 
-
-// Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
-void insert(MPCTIO &tio, yield_t &yield, RegXS &root, const Node &node, Duoram<Node>::Flat &A, size_t &num_items) {
-  bool player0 = tio.player()==0;
-  // If there are no items in tree. Make this new item the root.
-  if(num_items==0) {
-    Node zero;
-    A[0] = zero;
-    A[1] = node;
-    (root).set(1*tio.player());
-    num_items++;
-    return;
-  }
-  else {
-    // Insert node into next free slot in the ORAM
-    int new_id = 1 + num_items;
-    int TTL = num_items++;
-    A[new_id] = node;
-    RegXS new_addr;
-    new_addr.set(new_id * tio.player());
-    RegBS isDummy;
-
-    //Do a recursive insert
-    auto [wptr, direction] = insert(tio, yield, root, node, A, TTL, isDummy);
-
-    //Complete the insertion by reading wptr and updating its pointers
-    RegXS pointers = A[wptr].NODE_POINTERS;
-    RegXS left_ptr = extractLeftPtr(pointers);
-    RegXS right_ptr = extractRightPtr(pointers);
-    RegXS new_right_ptr, new_left_ptr;
-    mpc_select(tio, yield, new_right_ptr, direction, right_ptr, new_addr);
-    if(player0) {
-      direction^=1;
+std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
+    const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
+    if(TTL==0) {
+        RegBS zero;
+        return {ptr, zero};
     }
-    mpc_select(tio, yield, new_left_ptr, direction, left_ptr, new_addr);
-    setLeftPtr(pointers, new_left_ptr);
-    setRightPtr(pointers, new_right_ptr);
-    A[wptr].NODE_POINTERS = pointers;
-  }
-  
+
+    RegBS isNotDummy = isDummy ^ (tio.player());
+    Node cnode = A[ptr];
+    // Compare key
+    auto [lteq, gt] = compare_keys(cnode, new_node, tio, yield);
+
+    // Depending on [lteq, gt] select the next ptr/index as
+    // upper 32 bits of cnode.pointers if lteq
+    // lower 32 bits of cnode.pointers if gt 
+    RegXS left = extractLeftPtr(cnode.pointers);
+    RegXS right = extractRightPtr(cnode.pointers);
+    
+    RegXS next_ptr;
+    mpc_select(tio, yield, next_ptr, gt, left, right, 32);
+
+    CDPF dpf = tio.cdpf(yield);
+    size_t &aes_ops = tio.aes_ops();
+    // F_z: Check if this is last node on path
+    RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
+    RegBS F_i;
+
+    // F_i: If this was last node on path (F_z), and isNotDummy insert.
+    mpc_and(tio, yield, F_i, (isNotDummy), F_z);
+     
+    isDummy^=F_i;
+    auto [wptr, direction] = insert(tio, yield, next_ptr, new_node, A, TTL-1, isDummy);
+    
+    RegXS ret_ptr;
+    RegBS ret_direction;
+    // If we insert here (F_i), return the ptr to this node as wptr
+    // and update direction to the direction taken by compare_keys
+    mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
+    //ret_direction = direction + F_p(direction - gt)
+    mpc_and(tio, yield, ret_direction, F_i, direction^gt);
+    ret_direction^=direction;  
+
+    return {ret_ptr, ret_direction};
 }
 
+
 // Pretty-print a reconstructed BST, rooted at node. is_left_child and
 // is_right_child indicate whether node is a left or right child of its
 // parent.  They cannot both be true, but the root of the tree has both
 // of them false.
-void pretty_print(const std::vector<Node> &R, value_t node,
+void BST::pretty_print(const std::vector<Node> &R, value_t node,
     const std::string &prefix = "", bool is_left_child = false,
     bool is_right_child = false)
 {
@@ -308,13 +131,31 @@ void pretty_print(const std::vector<Node> &R, value_t node,
     pretty_print(R, left_ptr, leftprefix, true, false);
 }
 
+void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
+    RegXS peer_root;
+    RegXS reconstructed_root = root;
+    if (tio.player() == 1) {
+        tio.queue_peer(&root, sizeof(root));
+    } else {
+        RegXS peer_root;
+        tio.recv_peer(&peer_root, sizeof(peer_root));
+        reconstructed_root += peer_root;
+    }
+
+    auto A = oram->flat(tio, yield);
+    auto R = A.reconstruct();
+    if(tio.player()==0) {
+        pretty_print(R, reconstructed_root.xshare);
+    }
+}
+
 // Check the BST invariant of the tree (that all keys to the left are
 // less than or equal to this key, all keys to the right are strictly
 // greater, and this is true recursively).  Returns a
 // tuple<bool,address_t>, where the bool says whether the BST invariant
 // holds, and the address_t is the height of the tree (which will be
 // useful later when we check AVL trees).
-std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
+std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
     value_t node, value_t min_key = 0, value_t max_key = ~0)
 {
     if (node == 0) {
@@ -335,18 +176,94 @@ std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
         height };
 }
 
+void BST::check_bst(MPCTIO &tio, yield_t &yield) {
+    auto A = oram->flat(tio, yield);
+    auto R = A.reconstruct();
+
+    auto [ ok, height ] = check_bst(R, root.xshare);
+    printf("BST structure %s\nBST height = %u\n",
+        ok ? "ok" : "NOT OK", height);
+}  
+
 void newnode(Node &a) {
-  a.key.randomize(8);
-  a.pointers.set(0);
-  a.value.randomize();
+    a.key.randomize(8);
+    a.pointers.set(0);
+    a.value.randomize();
+}
+
+void BST::initialize(int num_players, size_t size) {
+    this->MAX_SIZE = size;
+    oram = new Duoram<Node>(num_players, size);
 }
 
+
+// Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
+void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A) {
+    bool player0 = tio.player()==0;
+    // If there are no items in tree. Make this new item the root.
+    if(num_items==0) {
+        Node zero;
+        A[0] = zero;
+        A[1] = node;
+        (root).set(1*tio.player());
+        num_items++;
+        //printf("num_items == %ld!\n", num_items);
+        return;
+    } else {
+        // Insert node into next free slot in the ORAM
+        int new_id = 1 + num_items;
+        int TTL = num_items++;
+        A[new_id] = node;
+        RegXS new_addr;
+        new_addr.set(new_id * tio.player());
+        RegBS isDummy;
+
+        //Do a recursive insert
+        auto [wptr, direction] = insert(tio, yield, root, node, A, TTL, isDummy);
+
+        //Complete the insertion by reading wptr and updating its pointers
+        RegXS pointers = A[wptr].NODE_POINTERS;
+        RegXS left_ptr = extractLeftPtr(pointers);
+        RegXS right_ptr = extractRightPtr(pointers);
+        RegXS new_right_ptr, new_left_ptr;
+        mpc_select(tio, yield, new_right_ptr, direction, right_ptr, new_addr);
+        if(player0) {
+            direction^=1;
+        }
+        mpc_select(tio, yield, new_left_ptr, direction, left_ptr, new_addr);
+        setLeftPtr(pointers, new_left_ptr);
+        setRightPtr(pointers, new_right_ptr);
+        A[wptr].NODE_POINTERS = pointers;
+        //printf("num_items == %ld!\n", num_items);
+    } 
+}
+
+
+void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
+    auto A = oram->flat(tio, yield);
+    auto R = A.reconstruct();
+
+    insert(tio, yield, node, A);
+    /*
+    // To visualize database and tree after each insert:
+    if (tio.player() == 0) {
+        for(size_t i=0;i<R.size();++i) {
+            printf("\n%04lx ", i);
+            R[i].dump();
+        }
+        printf("\n");
+    }
+    pretty_print(R, 1);
+    */
+}
+
+
 // Now we use the node in various ways.  This function is called by
 // online.cpp.
 void bst(MPCIO &mpcio,
     const PRACOptions &opts, char **args)
 {
-    nbits_t depth=5;
+    nbits_t depth=3;
 
     if (*args) {
         depth = atoi(*args);
@@ -361,20 +278,21 @@ void bst(MPCIO &mpcio,
     MPCTIO tio(mpcio, 0, opts.num_threads);
     run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
         size_t size = size_t(1)<<depth;
-        Duoram<Node> oram(tio.player(), size);
-        auto A = oram.flat(tio, yield);
-
-        size_t num_items = 0;
-        RegXS root;
+        BST tree(tio.player(), size);
 
-        Node c; 
-        for(size_t i = 0; i<items; i++) {
-          newnode(c);
-          insert(tio, yield, root, c, A, num_items);
+        Node node; 
+        for(size_t i = 1; i<=items; i++) {
+          newnode(node);
+          node.key.set(i * tio.player());
+          tree.insert(tio, yield, node);
         }
-        
+       
+        tree.pretty_print(tio, yield);
+
+
+        /*
         if (depth < 10) {
-            oram.dump();
+            //oram.dump();
             auto R = A.reconstruct();
             // Reconstruct the root
             if (tio.player() == 1) {
@@ -396,5 +314,6 @@ void bst(MPCIO &mpcio,
                     ok ? "ok" : "NOT OK", height);
             }
         }
+        */ 
     });
 }

+ 179 - 0
bst.hpp

@@ -1,9 +1,188 @@
 #ifndef __NODE_HPP__
 #define __NODE_HPP__
 
+#include "types.hpp"
+#include "duoram.hpp"
+#include "cdpf.hpp"
 #include "mpcio.hpp"
 #include "options.hpp"
 
+struct Node {
+    RegAS key;
+    RegXS pointers;
+    RegXS value;
+
+// Field-access macros so we can write A[i].NODE_KEY instead of
+// A[i].field(&Node::key)
+
+#define NODE_KEY field(&Node::key)
+#define NODE_POINTERS field(&Node::pointers)
+#define NODE_VALUE field(&Node::value)
+
+    // For debugging and checking answers
+    void dump() const {
+        printf("[%016lx %016lx %016lx]", key.share(), pointers.share(),
+            value.share());
+    }
+
+    // You'll need to be able to create a random element, and do the
+    // operations +=, +, -=, - (binary and unary).  Note that for
+    // XOR-shared fields, + and - are both really XOR.
+
+    inline void randomize() {
+        key.randomize();
+        pointers.randomize();
+        value.randomize();
+    }
+
+    inline Node &operator+=(const Node &rhs) {
+        this->key += rhs.key;
+        this->pointers += rhs.pointers;
+        this->value += rhs.value;
+        return *this;
+    }
+
+    inline Node operator+(const Node &rhs) const {
+        Node res = *this;
+        res += rhs;
+        return res;
+    }
+
+    inline Node &operator-=(const Node &rhs) {
+        this->key -= rhs.key;
+        this->pointers -= rhs.pointers;
+        this->value -= rhs.value;
+        return *this;
+    }
+
+    inline Node operator-(const Node &rhs) const {
+        Node res = *this;
+        res -= rhs;
+        return res;
+    }
+
+    inline Node operator-() const {
+        Node res;
+        res.key = -this->key;
+        res.pointers = -this->pointers;
+        res.value = -this->value;
+        return res;
+    }
+
+    // Multiply each field by the local share of the corresponding field
+    // in the argument
+    inline Node mulshare(const Node &rhs) const {
+        Node res = *this;
+        res.key.mulshareeq(rhs.key);
+        res.pointers.mulshareeq(rhs.pointers);
+        res.value.mulshareeq(rhs.value);
+        return res;
+    }
+
+    // You need a method to turn a leaf node of a DPF into a share of a
+    // unit element of your type.  Typically set each RegAS to
+    // dpf.unit_as(leaf) and each RegXS or RegBS to dpf.unit_bs(leaf).
+    // Note that RegXS will extend a RegBS of 1 to the all-1s word, not
+    // the word with value 1.  This is used for ORAM reads, where the
+    // same DPF is used for all the fields.
+    inline void unit(const RDPF &dpf, DPFnode leaf) {
+        key = dpf.unit_as(leaf);
+        pointers = dpf.unit_bs(leaf);
+        value = dpf.unit_bs(leaf);
+    }
+
+    // Perform an update on each of the fields, using field-specific
+    // MemRefs constructed from the Shape shape and the index idx
+    template <typename Sh, typename U>
+    inline static void update(Sh &shape, yield_t &shyield, U idx,
+            const Node &M) {
+        run_coroutines(shyield,
+            [&shape, &idx, &M] (yield_t &yield) {
+                Sh Sh_coro = shape.context(yield);
+                Sh_coro[idx].NODE_KEY += M.key;
+            },
+            [&shape, &idx, &M] (yield_t &yield) {
+                Sh Sh_coro = shape.context(yield);
+                Sh_coro[idx].NODE_POINTERS += M.pointers;
+            },
+            [&shape, &idx, &M] (yield_t &yield) {
+                Sh Sh_coro = shape.context(yield);
+                Sh_coro[idx].NODE_VALUE += M.value;
+            });
+    }
+};
+
+// I/O operations (for sending over the network)
+
+template <typename T>
+T& operator>>(T& is, Node &x)
+{
+    is >> x.key >> x.pointers >> x.value;
+    return is;
+}
+
+template <typename T>
+T& operator<<(T& os, const Node &x)
+{
+    os << x.key << x.pointers << x.value;
+    return os;
+}
+
+// This macro will define I/O on tuples of two or three of the node type
+
+DEFAULT_TUPLE_IO(Node)
+
+class BST {
+  private: 
+    Duoram<Node> *oram;
+    RegXS root;
+
+    size_t num_items = 0;
+    size_t MAX_SIZE;
+
+    std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr, const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy);
+    void insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A);
+
+  public:
+    BST(int num_players, size_t size) {
+      this->initialize(num_players, size);
+    };
+
+    ~BST() {
+      if(oram)
+        delete oram;
+    };
+
+    void initialize(int num_players, size_t size);
+    void insert(MPCTIO &tio, yield_t &yield, Node &node);
+
+    // Display and correctness check functions
+    void pretty_print(MPCTIO &tio, yield_t &yield);
+    void pretty_print(const std::vector<Node> &R, value_t node,
+        const std::string &prefix, bool is_left_child, bool is_right_child);
+    void check_bst(MPCTIO &tio, yield_t &yield);
+    std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
+        value_t node, value_t min_key, value_t max_key);
+
+};
+
+/*
+class BST_OP {
+  private:
+    MPCTIO tio;
+    yield_t yield;    
+    BST *ptr; 
+
+  public:
+    BST_OP* init(MPCTIO &tio, yield_t &yield, BST *bst_ptr) {
+      this->tio = tio;
+      this->yield = yield;
+      this->ptr = bst_ptr;
+      return this;
+    }
+};
+*/
+
 void bst(MPCIO &mpcio,
     const PRACOptions &opts, char **args);