Browse Source

Add a function to check the BST invariant

Ian Goldberg 1 year ago
parent
commit
5900b98086
1 changed files with 33 additions and 2 deletions
  1. 33 2
      bst.cpp

+ 33 - 2
bst.cpp

@@ -269,7 +269,8 @@ void insert(MPCTIO &tio, yield_t &yield, RegXS &root, const Node &node, Duoram<N
 // parent.  They cannot both be true, but the root of the tree has both
 // of them false.
 void pretty_print(const std::vector<Node> &R, value_t node,
-    const std::string &prefix, bool is_left_child, bool is_right_child)
+    const std::string &prefix = "", bool is_left_child = false,
+    bool is_right_child = false)
 {
     if (node == 0) {
         // NULL pointer
@@ -307,6 +308,33 @@ void pretty_print(const std::vector<Node> &R, value_t node,
     pretty_print(R, left_ptr, leftprefix, true, false);
 }
 
+// Check the BST invariant of the tree (that all keys to the left are
+// less than or equal to this key, all keys to the right are strictly
+// greater, and this is true recursively).  Returns a
+// tuple<bool,address_t>, where the bool says whether the BST invariant
+// holds, and the address_t is the height of the tree (which will be
+// useful later when we check AVL trees).
+std::tuple<bool, address_t> check_bst(const std::vector<Node> &R,
+    value_t node, value_t min_key = 0, value_t max_key = ~0)
+{
+    if (node == 0) {
+        return { true, 0 };
+    }
+    const Node &n = R[node];
+    value_t key = n.key.ashare;
+    value_t left_ptr = extractLeftPtr(n.pointers).xshare;
+    value_t right_ptr = extractRightPtr(n.pointers).xshare;
+    auto [leftok, leftheight ] = check_bst(R, left_ptr, min_key, key);
+    auto [rightok, rightheight ] = check_bst(R, right_ptr, key+1, max_key);
+    address_t height = leftheight;
+    if (rightheight > height) {
+        height = rightheight;
+    }
+    height += 1;
+    return { leftok && rightok && key >= min_key && key <= max_key,
+        height };
+}
+
 void newnode(Node &a) {
   a.key.randomize(8);
   a.pointers.set(0);
@@ -362,7 +390,10 @@ void bst(MPCIO &mpcio,
                     R[i].dump();
                 }
                 printf("\n");
-                pretty_print(R, root.xshare, "", false, false);
+                pretty_print(R, root.xshare);
+                auto [ ok, height ] = check_bst(R, root.xshare);
+                printf("BST structure %s\nBST height = %u\n",
+                    ok ? "ok" : "NOT OK", height);
             }
         }
     });