Browse Source

BST clean up

sshsshy 10 months ago
parent
commit
75d8183265
3 changed files with 138 additions and 167 deletions
  1. 26 18
      avl.cpp
  2. 74 111
      bst.cpp
  3. 38 38
      bst.hpp

+ 26 - 18
avl.cpp

@@ -2,6 +2,12 @@
 
 #include "avl.hpp"
 
+static void randomize_node(Node &a) {
+    a.key.randomize(8);
+    a.pointers.set(0);
+    a.value.randomize();
+}
+
 void print_green(std::string line) {
     printf("%s%s%s", KGRN, line.c_str(), KNRM);
 }
@@ -70,8 +76,10 @@ void AVL::pretty_print(MPCTIO &tio, yield_t &yield) {
     RegXS reconstructed_root = root;
     if (tio.player() == 1) {
         tio.queue_peer(&root, sizeof(root));
+        yield();
     } else {
         RegXS peer_root;
+        yield();
         tio.recv_peer(&peer_root, sizeof(peer_root));
         reconstructed_root += peer_root;
     }
@@ -1709,7 +1717,7 @@ void avl(MPCIO &mpcio,
         tio.reset_lamport();
 
         for(size_t i = 1; i<=n_inserts; i++) {
-            newnode(node);
+            randomize_node(node);
             size_t ikey;
             #ifdef RANDOMIZE
                 ikey = (1+(rand()%oram_size));
@@ -1786,7 +1794,7 @@ void avl_tests(MPCIO &mpcio,
             Node node;
 
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -1863,7 +1871,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 4;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -1947,7 +1955,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 2;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2017,7 +2025,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 4;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2101,7 +2109,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 2;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2181,7 +2189,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 4;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2265,7 +2273,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 2;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2343,7 +2351,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 4;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2427,7 +2435,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 3;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2502,7 +2510,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 6;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2594,7 +2602,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 3;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2673,7 +2681,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 6;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2765,7 +2773,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 3;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2849,7 +2857,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 4;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -2937,7 +2945,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 3;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -3020,7 +3028,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 6;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);
@@ -3126,7 +3134,7 @@ void avl_tests(MPCIO &mpcio,
             size_t insert_array_size = 12;
             Node node;
             for(size_t i = 0; i<=insert_array_size; i++) {
-              newnode(node);
+              randomize_node(node);
               node.key.set(insert_array[i] * tio.player());
               tree.insert(tio, yield, node);
               tree.check_avl(tio, yield);

+ 74 - 111
bst.cpp

@@ -2,6 +2,19 @@
 
 #include "bst.hpp"
 
+#ifdef BST_DEBUG
+    void BST::print_oram(MPCTIO &tio, yield_t &yield) {
+        auto A = oram.flat(tio, yield);
+        auto R = A.reconstruct();
+
+        for(size_t i=0;i<R.size();++i) {
+            printf("\n%04lx ", i);
+            R[i].dump();
+        }
+        printf("\n");
+    }
+#endif
+
 // Helper functions to reconstruct shared RegBS, RegAS or RegXS
 bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
     RegBS reconstructed_flag;
@@ -23,59 +36,25 @@ bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
     }
     return reconstructed_flag.bshare;
 }
-    
-size_t reconstruct_RegAS(MPCTIO &tio, yield_t &yield, RegAS variable) {
-    RegAS reconstructed_var;
-    if (tio.player() < 2) {
-        RegAS peer_var;
-        tio.queue_peer(&variable, sizeof(variable));
-        tio.queue_server(&variable, sizeof(variable));
-        yield();
-        tio.recv_peer(&peer_var, sizeof(variable));
-        reconstructed_var = variable;
-        reconstructed_var += peer_var;
-    } else {
-        RegAS p0_var, p1_var;
-        yield();
-        tio.recv_p0(&p0_var, sizeof(variable));
-        tio.recv_p1(&p1_var, sizeof(variable));
-        reconstructed_var = p0_var;
-        reconstructed_var += p1_var;
-    }   
-    return reconstructed_var.ashare;
-}
-    
-size_t reconstruct_RegXS(MPCTIO &tio, yield_t &yield, RegXS variable) {
-    RegXS reconstructed_var;
-    if (tio.player() < 2) {
-        RegXS peer_var;
-        tio.queue_peer(&variable, sizeof(variable));
-        tio.queue_server(&variable, sizeof(variable));
-        yield();
-        tio.recv_peer(&peer_var, sizeof(variable));
-        reconstructed_var = variable;
-        reconstructed_var ^= peer_var;
-    } else {
-        RegXS p0_var, p1_var;
-        yield();
-        tio.recv_p0(&p0_var, sizeof(variable));
-        tio.recv_p1(&p1_var, sizeof(variable));
-        reconstructed_var = p0_var;
-        reconstructed_var ^= p1_var;
-    }   
-    return reconstructed_var.xshare;
-}
-
 
-
-std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, Node n1, Node n2) {
-    CDPF cdpf = tio.cdpf(yield);
-    auto [lt, eq, gt] = cdpf.compare(tio, yield, n2.key - n1.key, tio.aes_ops());
-    RegBS lteq = lt^eq;
-    return {lteq, gt};
+/* A function to assign a new random 8-bit key to a node, and resets its 
+   pointers to zeroes. The node is assigned a new random 64-bit value.
+*/
+static void randomize_node(Node &a) {
+    a.key.randomize(8);
+    a.pointers.set(0);
+    a.value.randomize();
 }
 
-std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1, RegAS k2) {
+/*
+    A function to perform key comparsions for BST traversal. 
+    Inputs: k1 = key of node in the tree, k2 = insertion/deletion/lookup key.
+    Evaluates (k2-k1), and combines the lt and eq flag into one (flag to go 
+    left), and keeps the gt flag as is (flag to go right) during traversal.
+    
+*/
+std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1, 
+        RegAS k2) {
     CDPF cdpf = tio.cdpf(yield);
     auto [lt, eq, gt] = cdpf.compare(tio, yield, k2 - k1, tio.aes_ops());
     RegBS lteq = lt^eq;
@@ -83,9 +62,9 @@ std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1, RegA
 }
 
 // Assuming pointer of 64 bits is split as:
-// - 32 bits Left ptr
-// - 32 bits Right ptr
-// < Left, Right>
+// - 32 bits Left ptr (L)
+// - 32 bits Right ptr (R)
+// The pointers are stored as: L | R 
 
 inline RegXS extractLeftPtr(RegXS pointer){ 
     return ((pointer&(0xFFFFFFFF00000000))>>32); 
@@ -149,29 +128,20 @@ void BST::pretty_print(const std::vector<Node> &R, value_t node,
     pretty_print(R, left_ptr, leftprefix, true, false);
 }
 
-void BST::print_oram(MPCTIO &tio, yield_t &yield) {
-    auto A = oram->flat(tio, yield);
-    auto R = A.reconstruct();
-
-    for(size_t i=0;i<R.size();++i) {
-        printf("\n%04lx ", i);
-        R[i].dump();
-    }
-    printf("\n");
-}
-
 void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
     RegXS peer_root;
     RegXS reconstructed_root = root;
     if (tio.player() == 1) {
         tio.queue_peer(&root, sizeof(root));
+        yield();
     } else {
         RegXS peer_root;
+        yield();
         tio.recv_peer(&peer_root, sizeof(peer_root));
         reconstructed_root += peer_root;
     }
 
-    auto A = oram->flat(tio, yield);
+    auto A = oram.flat(tio, yield);
     auto R = A.reconstruct();
     if(tio.player()==0) {
         pretty_print(R, reconstructed_root.xshare);
@@ -207,8 +177,9 @@ std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
         height };
 }
 
+
 void BST::check_bst(MPCTIO &tio, yield_t &yield) {
-    auto A = oram->flat(tio, yield);
+    auto A = oram.flat(tio, yield);
     auto R = A.reconstruct();
 
     RegXS rec_root = this->root;
@@ -226,18 +197,15 @@ void BST::check_bst(MPCTIO &tio, yield_t &yield) {
     }
 }
 
-void newnode(Node &a) {
-    a.key.randomize(8);
-    a.pointers.set(0);
-    a.value.randomize();
-}
-
-void BST::initialize(int num_players, size_t size) {
-    this->MAX_SIZE = size;
-    oram = new Duoram<Node>(num_players, size);
-}
-
+/*
+    The recursive insert() call, invoked by the wrapper insert() function.
+    
+    Takes as input the pointer to the current node in tree traversal (ptr),
+    the new node to be inserted (new_node), the underlying Duoram as a 
+    flat (A), and the Time-To_live TTL, and a shared flag (isDummy) which
+    tracks if the operation is dummy/real.
 
+*/
 std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
     const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
     if(TTL==0) {
@@ -245,10 +213,10 @@ std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
         return {ptr, zero};
     }
 
-    RegBS isNotDummy = isDummy ^ (tio.player());
+    RegBS isNotDummy = isDummy ^ (!tio.player());
     Node cnode = A[ptr];
     // Compare key
-    auto [lteq, gt] = compare_keys(tio, yield, cnode, new_node);
+    auto [lteq, gt] = compare_keys(tio, yield, cnode.key, new_node.key);
 
     // Depending on [lteq, gt] select the next ptr/index as
     // upper 32 bits of cnode.pointers if lteq
@@ -265,7 +233,8 @@ std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
     RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
     RegBS F_i;
 
-    // F_i: If this was last node on path (F_z), and isNotDummy insert.
+    // F_i: If this was last node on path (F_z) ^ isNotDummy: 
+    //          insert new_node here.
     mpc_and(tio, yield, F_i, (isNotDummy), F_z);
      
     isDummy^=F_i;
@@ -276,7 +245,7 @@ std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
     // If we insert here (F_i), return the ptr to this node as wptr
     // and update direction to the direction taken by compare_keys
     mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
-    //ret_direction = direction + F_p(direction - gt)
+    //ret_direction = direction + F_i (direction - gt)
     mpc_and(tio, yield, ret_direction, F_i, direction^gt);
     ret_direction^=direction;  
 
@@ -284,14 +253,19 @@ std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
 }
 
 
-// Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
+/*
+    The wrapper insert() operation invoked by the main insert call
+    BST::insert(tio, yield, Node& new_node);
+
+    Takes as input the new node (node), the underlying Duoram as a flat (A).
+*/
 void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A) {
     bool player0 = tio.player()==0;
     // If there are no items in tree. Make this new item the root.
-    if(num_items==0) {
+    if (num_items==0) {
         Node zero;
-        A[0] = zero;
         A[1] = node;
+        // Set root to a secret sharing of the constant value 1
         (root).set(1*tio.player());
         num_items++;
         //printf("num_items == %ld!\n", num_items);
@@ -301,7 +275,7 @@ void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Fl
         int new_id;
         RegXS insert_address;
         int TTL = num_items++;
-        bool insertAtEmptyLocation = (numEmptyLocations() > 0);
+        bool insertAtEmptyLocation = (empty_locations.size() > 0);
         if(insertAtEmptyLocation) {
             insert_address = empty_locations.back();
             empty_locations.pop_back(); 
@@ -323,7 +297,7 @@ void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Fl
         RegXS new_right_ptr, new_left_ptr;
       
         mpc_select(tio, yield, new_right_ptr, direction, right_ptr, insert_address);
-        if(player0) {
+        if (player0) {
             direction^=1;
         }
         mpc_select(tio, yield, new_left_ptr, direction, left_ptr, insert_address);
@@ -334,14 +308,17 @@ void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Fl
     } 
 }
 
-
+/*
+    Insert a new node into the BST.
+    Takes as input the new node (node).
+*/
 void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
-    auto A = oram->flat(tio, yield);
-    auto R = A.reconstruct();
+    auto A = oram.flat(tio, yield);
 
     insert(tio, yield, node, A);
     /*
     // To visualize database and tree after each insert:
+    auto R = A.reconstruct();
     if (tio.player() == 0) {
         for(size_t i=0;i<R.size();++i) {
             printf("\n%04lx ", i);
@@ -391,14 +368,14 @@ bool BST::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>
 }
 
 bool BST::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
-    auto A = oram->flat(tio, yield);
-    auto R = A.reconstruct();
+    auto A = oram.flat(tio, yield);
 
     RegBS isDummy;
 
     bool found = lookup(tio, yield, root, key, A, num_items, isDummy, ret_node);
     /*
     // To visualize database and tree after each lookup:
+    auto R = A.reconstruct();
     if (tio.player() == 0) {
         for(size_t i=0;i<R.size();++i) {
             printf("\n%04lx ", i);
@@ -596,7 +573,7 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         return 0;
     if(num_items==1) {
         //Delete root
-        auto A = oram->flat(tio, yield);
+        auto A = oram.flat(tio, yield);
         Node zero;
         empty_locations.emplace_back(root);
         A[root] = zero;
@@ -609,7 +586,7 @@ bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
         RegBS af;
         RegBS fs;
         del_return ret_struct; 
-        auto A = oram->flat(tio, yield);
+        auto A = oram.flat(tio, yield);
         int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct); 
         printf ("Success =  %d\n", success); 
         if(!success){
@@ -673,12 +650,11 @@ void bst(MPCIO &mpcio,
         size_t insert_array_size = 11;
         Node node; 
         for(size_t i = 0; i<=insert_array_size; i++) {
-          newnode(node);
+          randomize_node(node);
           node.key.set(insert_array[i] * tio.player());
           tree.insert(tio, yield, node);
         }
        
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
          
         RegAS del_key;
@@ -687,71 +663,59 @@ void bst(MPCIO &mpcio,
         printf("\n\nDelete %x\n", 20);
         del_key.set(20 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
 
         printf("\n\nDelete %x\n", 10);
         del_key.set(10 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
 
         printf("\n\nDelete %x\n", 8);
         del_key.set(8 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
 
         printf("\n\nDelete %x\n", 7);
         del_key.set(7 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
 
         printf("\n\nDelete %x\n", 17);
         del_key.set(17 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
 
         printf("\n\nDelete %x\n", 15);
         del_key.set(15 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
-        printf("Num empty_locations = %ld\n", tree.numEmptyLocations());
 
         printf("\n\nDelete %x\n", 5);
         del_key.set(5 * tio.player());
         tree.del(tio, yield, del_key);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
-        printf("Num empty_locations = %ld\n", tree.numEmptyLocations());
         */
 
         printf("\n\nInsert %x\n", 14);
-        newnode(node);
+        randomize_node(node);
         node.key.set(14 * tio.player());
         tree.insert(tio, yield, node);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         tree.check_bst(tio, yield);
-        printf("Num empty_locations = %ld\n", tree.numEmptyLocations());
 
         printf("\n\nLookup %x\n", 8);
-        newnode(node);
+        randomize_node(node);
         RegAS lookup_key;
         bool found;
         lookup_key.set(8 * tio.player());
         found = tree.lookup(tio, yield, lookup_key, &node);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         if(found) {
           printf("Lookup Success\n");
@@ -762,10 +726,9 @@ void bst(MPCIO &mpcio,
         }
 
         printf("\n\nLookup %x\n", 99);
-        newnode(node);
+        randomize_node(node);
         lookup_key.set(99 * tio.player());
         found = tree.lookup(tio, yield, lookup_key, &node);
-        tree.print_oram(tio, yield);
         tree.pretty_print(tio, yield);
         if(found) {
           printf("Lookup Success\n");

+ 38 - 38
bst.hpp

@@ -7,6 +7,8 @@
 #include "mpcio.hpp"
 #include "options.hpp"
 
+// #define BST_DEBUG
+
 // Some simple utility functions:
 bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag);
 
@@ -117,9 +119,13 @@ struct Node {
     }
 };
 
-void newnode(Node &a);
-
-std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, Node n1, Node n2);
+/*
+    A function to perform key comparsions for BST traversal. 
+    Inputs: k1 = key of node in the tree, k2 = insertion/deletion/lookup key.
+    Evaluates (k2-k1), and combines the lt and eq flag into one (flag to go 
+    left), and keeps the gt flag as is (flag to go right) during traversal.
+    Returns the shared bit flags lteq (go left) and gt (go right).
+*/
 std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1, RegAS k2);
 
 // I/O operations (for sending over the network)
@@ -156,7 +162,7 @@ struct del_return {
 
 class BST {
   private: 
-    Duoram<Node> *oram;
+    Duoram<Node> oram;
     RegXS root;
 
     size_t num_items = 0;
@@ -175,58 +181,52 @@ class BST {
     bool lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, 
         Duoram<Node>::Flat &A, int TTL, RegBS isDummy, Node *ret_node);
 
-  public:
-    BST(int num_players, size_t size) {
-      this->initialize(num_players, size);
-    };
+    void pretty_print(const std::vector<Node> &R, value_t node,
+        const std::string &prefix, bool is_left_child, bool is_right_child);
 
-    ~BST() {
-      if(oram)
-        delete oram;
-    };
+    std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
+        value_t node, value_t min_key, value_t max_key);
 
-    size_t numEmptyLocations(){
-      return(empty_locations.size());
+  public:
+    BST(int num_players, size_t size) : oram(num_players, size) {  
+        this->MAX_SIZE = size;
     };
 
-    void initialize(int num_players, size_t size);
+
+    // Inserts the provided node into the BST
     void insert(MPCTIO &tio, yield_t &yield, Node &node);
 
-    // Deletes the first node that matches del_key
+    // Deletes the first node that matches del_key from the BST
     bool del(MPCTIO &tio, yield_t &yield, RegAS del_key); 
 
-    // Returns the first node that matches key 
+    // Returns the first node that matches key in the BST
     bool lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node);
 
     // Display and correctness check functions
+
+    // Print the BST
     void pretty_print(MPCTIO &tio, yield_t &yield);
-    void pretty_print(const std::vector<Node> &R, value_t node,
-        const std::string &prefix, bool is_left_child, bool is_right_child);
+
+    // Check BST correctness
     void check_bst(MPCTIO &tio, yield_t &yield);
-    std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
-        value_t node, value_t min_key, value_t max_key);
-    void print_oram(MPCTIO &tio, yield_t &yield);
-};
 
-/*
-class BST_OP {
-  private:
-    MPCTIO tio;
-    yield_t yield;    
-    BST *ptr; 
+    // Debugging Functions    
 
-  public:
-    BST_OP* init(MPCTIO &tio, yield_t &yield, BST *bst_ptr) {
-      this->tio = tio;
-      this->yield = yield;
-      this->ptr = bst_ptr;
-      return this;
-    }
+    #ifdef BST_DEBUG
+
+        // Print the underlying ORAM state
+        void print_oram(MPCTIO &tio, yield_t &yield);
+
+        // Check the number of empty locations in ORAM 
+        // (Locations freed up after a delete operation, reusable for next insert.)
+        size_t numEmptyLocations(){
+            return(empty_locations.size());
+        };
+
+    #endif
 };
-*/
 
 void bst(MPCIO &mpcio,
     const PRACOptions &opts, char **args);
 
-
 #endif