|
@@ -89,28 +89,45 @@ void AVL::pretty_print(MPCTIO &tio, yield_t &yield) {
|
|
|
// tuple<bool,address_t>, where the bool says whether the BST invariant
|
|
|
// holds, and the address_t is the height of the tree (which will be
|
|
|
// useful later when we check AVL trees).
|
|
|
-std::tuple<bool, bool, address_t> AVL::check_avl(const std::vector<Node> &R,
|
|
|
+std::tuple<bool, bool, bool, address_t> AVL::check_avl(const std::vector<Node> &R,
|
|
|
value_t node, value_t min_key = 0, value_t max_key = ~0)
|
|
|
{
|
|
|
if (node == 0) {
|
|
|
- return { true, true, 0 };
|
|
|
+ return { true, true, true, 0};
|
|
|
}
|
|
|
const Node &n = R[node];
|
|
|
value_t key = n.key.ashare;
|
|
|
value_t left_ptr = getAVLLeftPtr(n.pointers).xshare;
|
|
|
value_t right_ptr = getAVLRightPtr(n.pointers).xshare;
|
|
|
- auto [leftok, leftavlok, leftheight ] = check_avl(R, left_ptr, min_key, key);
|
|
|
- auto [rightok, rightavlok, rightheight ] = check_avl(R, right_ptr, key, max_key);
|
|
|
+ auto [leftok, leftavlok, leftbbok, leftheight ] = check_avl(R, left_ptr, min_key, key);
|
|
|
+ auto [rightok, rightavlok, rightbbok, rightheight ] = check_avl(R, right_ptr, key, max_key);
|
|
|
address_t height = leftheight;
|
|
|
if (rightheight > height) {
|
|
|
height = rightheight;
|
|
|
}
|
|
|
height += 1;
|
|
|
int heightgap = leftheight - rightheight;
|
|
|
+ bool leftbal = (getLeftBal(n.pointers)).bshare;
|
|
|
+ bool rightbal = (getRightBal(n.pointers)).bshare;
|
|
|
bool avlok = (abs(heightgap)<2);
|
|
|
+ bool bb_ok = false;
|
|
|
+
|
|
|
+ if(heightgap==-1) {
|
|
|
+ if(rightbal==1 && leftbal==0){
|
|
|
+ bb_ok = true;
|
|
|
+ }
|
|
|
+ } else if(heightgap==1){
|
|
|
+ if(leftbal==1 && rightbal==0){
|
|
|
+ bb_ok = true;
|
|
|
+ }
|
|
|
+ } else if(heightgap==0){
|
|
|
+ if(rightbal==0 && leftbal==0) {
|
|
|
+ bb_ok = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
//printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
|
|
|
return { leftok && rightok && key >= min_key && key <= max_key,
|
|
|
- avlok && leftavlok && rightavlok, height};
|
|
|
+ avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
|
|
|
}
|
|
|
|
|
|
void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
|
|
@@ -126,9 +143,9 @@ void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
|
|
|
rec_root+= peer_root;
|
|
|
}
|
|
|
if (tio.player() == 0) {
|
|
|
- auto [ bst_ok, avl_ok, height ] = check_avl(R, rec_root.xshare);
|
|
|
- printf("BST structure %s\nAVL structure %s\nTree height = %u\n",
|
|
|
- bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", height);
|
|
|
+ auto [ bst_ok, avl_ok, bb_ok, height ] = check_avl(R, rec_root.xshare);
|
|
|
+ printf("BST structure %s\nAVL structure %s\nBalance Bits %s\nTree height = %u\n",
|
|
|
+ bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", bb_ok? "ok" : "NOT OK", height);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -268,7 +285,7 @@ std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_
|
|
|
RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir) {
|
|
|
bool player0 = tio.player()==0;
|
|
|
RegBS s0;
|
|
|
- RegBS F_rs, F_ls, balanced, imbalance;
|
|
|
+ RegBS F_rs, F_ls, balanced, imbalance, not_imbalance;
|
|
|
RegBS nt_child_dir = child_dir;
|
|
|
if(player0) {
|
|
|
nt_child_dir^=1;
|
|
@@ -308,11 +325,26 @@ std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_
|
|
|
{ mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);});
|
|
|
|
|
|
// if(bal_upd) and not imbalance bal_upd<-0
|
|
|
+ /*
|
|
|
RegBS bu0;
|
|
|
- mpc_and(tio, yield, bu0, bal_upd, balanced);
|
|
|
+ not_imbalance = imbalance;
|
|
|
+ if(player0){
|
|
|
+ not_imbalance^=1;
|
|
|
+ }
|
|
|
+ mpc_and(tio, yield, bu0, bal_upd, not_imbalance);
|
|
|
mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
|
|
|
+ */
|
|
|
|
|
|
- // Any bal_upd, propogates all the way up to root
|
|
|
+ // if(bal_upd) and this node turns balanced, the height has decreased, so continue propogating bal_upd.
|
|
|
+ // if(bal_upd) and node turns imbalanced, fixImbalance will update bal_upd correctly.
|
|
|
+ // if(bal_upd) and node moves out of balanced to left/right heavy, the height of this subtree has not changed,
|
|
|
+ // so don't propogate bal_upd.
|
|
|
+ // if(bal_upd && bal_l ^ bal_r
|
|
|
+ RegBS LR_heavy, bu0;
|
|
|
+ LR_heavy = bal_l ^ bal_r;
|
|
|
+ mpc_and(tio, yield, bu0, bal_upd, LR_heavy);
|
|
|
+ mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
|
|
|
+
|
|
|
return {bal_l, bal_r, bal_upd, imbalance};
|
|
|
}
|
|
|
|
|
@@ -1020,20 +1052,29 @@ void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
|
|
|
|
|
|
*/
|
|
|
RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
|
|
|
+ RegBS cs_zero_bal = cs_bal_dpc ^ cs_bal_ndpc;
|
|
|
+ if(player0) {
|
|
|
+ cs_zero_bal^=1;
|
|
|
+ }
|
|
|
run_coroutines(tio, [&tio, &IC1, imb, cs_bal_ndpc] (yield_t &yield) {
|
|
|
// IC1 = Single rotation (L/R). L/R = dpc
|
|
|
- mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
|
|
|
+ mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
|
|
|
},
|
|
|
[&tio, &IC3, imb, cs_bal_dpc](yield_t &yield) {
|
|
|
// IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
|
|
|
- mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
|
|
|
+ mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
|
|
|
+ },
|
|
|
+ [&tio, &IC2, imb, cs_zero_bal](yield_t &yield) {
|
|
|
+ mpc_and(tio, yield, IC2, imb, cs_zero_bal);
|
|
|
});
|
|
|
|
|
|
// IC2 = Single rotation (L/R).
|
|
|
+ /*
|
|
|
IC2 = IC1 ^ IC3;
|
|
|
if(player0) {
|
|
|
IC2^=1;
|
|
|
}
|
|
|
+ */
|
|
|
|
|
|
RegBS p_bal_dpc, p_bal_ndpc;
|
|
|
RegBS IC2_ndpc_l, IC2_ndpc_r, IC2_dpc_l, IC2_dpc_r;
|
|
@@ -1375,16 +1416,40 @@ std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS d
|
|
|
RegBS p_bal_l, p_bal_r;
|
|
|
p_bal_l = getLeftBal(node.pointers);
|
|
|
p_bal_r = getRightBal(node.pointers);
|
|
|
+
|
|
|
+ #ifdef DEBUG
|
|
|
+ size_t rec_key = reconstruct_RegAS(tio, yield, node.key);
|
|
|
+ bool rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
|
|
|
+ printf("current_key = %ld, bal_upd (before updateBalanceDel) = %d\n", rec_key, rec_bal_upd);
|
|
|
+ #endif
|
|
|
+
|
|
|
auto [new_p_bal_l, new_p_bal_r, new_bal_upd, imb] =
|
|
|
updateBalanceDel(tio, yield, p_bal_l, p_bal_r, bal_upd, c_prime);
|
|
|
+ bal_upd = new_bal_upd;
|
|
|
+
|
|
|
+ #ifdef DEBUG
|
|
|
+ bool rec_imb = reconstruct_RegBS(tio, yield, imb);
|
|
|
+ bool rec_new_bal_upd = reconstruct_RegBS(tio, yield, new_bal_upd);
|
|
|
+ printf("new_bal_upd (after updateBalanceDel) = %d, imb = %d\n", rec_new_bal_upd, rec_imb);
|
|
|
+ #endif
|
|
|
|
|
|
// F_ri: subflag for F_r. F_ri = returned flag set to 1 from imbalance fix.
|
|
|
RegBS F_ri;
|
|
|
fixImbalance(tio, yield, A, oidx, oldptrs, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd,
|
|
|
c_prime, cs_ptr, imb, F_ri, ret_struct);
|
|
|
|
|
|
+ #ifdef DEBUG
|
|
|
+ rec_imb = reconstruct_RegBS(tio, yield, imb);
|
|
|
+ rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
|
|
|
+ printf("imb (after fixImbalance) = %d, bal_upd = %d\n", rec_imb, rec_bal_upd);
|
|
|
+ #endif
|
|
|
updateRetStruct(tio, yield, ptr, F_2, F_c2, F_c4, lf, F_ri, found, bal_upd, ret_struct);
|
|
|
|
|
|
+ #ifdef DEBUG
|
|
|
+ rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
|
|
|
+ printf("bal_upd (after updateRetStruct) = %d\n", rec_bal_upd);
|
|
|
+ #endif
|
|
|
+
|
|
|
return {key_found, bal_upd};
|
|
|
}
|
|
|
}
|
|
@@ -2856,25 +2921,25 @@ void avl_tests(MPCIO &mpcio,
|
|
|
|
|
|
5 5 5
|
|
|
/ \ / \ / \
|
|
|
- 3 12 Del 1 3 12 3 9
|
|
|
- / / ------> / ---> / \
|
|
|
- 1 7 9 7 12
|
|
|
- \ /
|
|
|
- 9 7
|
|
|
+ 3 8 Del 7 3 8 3 9
|
|
|
+ / / \ ------> / \ ---> / / \
|
|
|
+ 1 7 12 1 9 1 8 12
|
|
|
+ / \
|
|
|
+ 9 12
|
|
|
|
|
|
|
|
|
T8 checks:
|
|
|
- root is 5
|
|
|
- - 3,9,7,12 are in correct positions
|
|
|
- - Nodes 3,7,12 have 0 balance
|
|
|
- - Nodes 3,7,12 have no children
|
|
|
- - 5's bal = 0 1
|
|
|
+ - 3,9,8,12 are in correct positions
|
|
|
+ - Nodes 1,5,8,9,12 have 0 balance
|
|
|
+ - Nodes 1,5,8,9,12 have no children
|
|
|
+ - Node 3 has 1 0 balance
|
|
|
|
|
|
*/
|
|
|
{
|
|
|
bool success = 1;
|
|
|
- int insert_array[] = {5, 3, 12, 7, 1, 9};
|
|
|
- size_t insert_array_size = 5;
|
|
|
+ int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
|
|
|
+ size_t insert_array_size = 6;
|
|
|
Node node;
|
|
|
for(size_t i = 0; i<=insert_array_size; i++) {
|
|
|
newnode(node);
|
|
@@ -2884,7 +2949,7 @@ void avl_tests(MPCIO &mpcio,
|
|
|
}
|
|
|
|
|
|
RegAS del_key;
|
|
|
- del_key.set(1 * tio.player());
|
|
|
+ del_key.set(7 * tio.player());
|
|
|
tree.del(tio, yield, del_key);
|
|
|
tree.check_avl(tio, yield);
|
|
|
|
|
@@ -2893,8 +2958,8 @@ void avl_tests(MPCIO &mpcio,
|
|
|
size_t root = reconstruct_RegXS(tio, yield, root_xs);
|
|
|
auto A = oram->flat(tio, yield);
|
|
|
auto R = A.reconstruct();
|
|
|
- Node root_node, n3, n7, n9, n12;
|
|
|
- size_t n3_index, n7_index, n9_index, n12_index;
|
|
|
+ Node root_node, n1, n3, n8, n9, n12;
|
|
|
+ size_t n1_index, n3_index, n8_index, n9_index, n12_index;
|
|
|
root_node = R[root];
|
|
|
if((root_node.key).share()!=5) {
|
|
|
success = false;
|
|
@@ -2903,34 +2968,37 @@ void avl_tests(MPCIO &mpcio,
|
|
|
n9_index = (getAVLRightPtr(root_node.pointers)).share();
|
|
|
n3 = R[n3_index];
|
|
|
n9 = R[n9_index];
|
|
|
- n7_index = getAVLLeftPtr(n9.pointers).share();
|
|
|
+ n1_index = getAVLLeftPtr(n3.pointers).share();
|
|
|
+ n8_index = getAVLLeftPtr(n9.pointers).share();
|
|
|
n12_index = getAVLRightPtr(n9.pointers).share();
|
|
|
- n7 = R[n7_index];
|
|
|
+ n1 = R[n1_index];
|
|
|
+ n8 = R[n8_index];
|
|
|
n12 = R[n12_index];
|
|
|
|
|
|
// Node value checks
|
|
|
+ if(n1.key.share()!=1) {
|
|
|
+ success = false;
|
|
|
+ }
|
|
|
if(n3.key.share()!=3 || n9.key.share()!=9) {
|
|
|
success = false;
|
|
|
}
|
|
|
- if(n7.key.share()!=7 || n12.key.share()!=12) {
|
|
|
+ if(n8.key.share()!=8 || n12.key.share()!=12) {
|
|
|
success = false;
|
|
|
}
|
|
|
|
|
|
// Node balance checks
|
|
|
size_t zero = 0;
|
|
|
- zero+=(n3.pointers.share());
|
|
|
- zero+=(n7.pointers.share());
|
|
|
+ zero+=(n1.pointers.share());
|
|
|
+ zero+=(getRightBal(n3.pointers).share());
|
|
|
+ zero+=(n8.pointers.share());
|
|
|
zero+=(n12.pointers.share());
|
|
|
zero+=(getLeftBal(root_node.pointers).share());
|
|
|
+ zero+=(getRightBal(root_node.pointers).share());
|
|
|
zero+=(getLeftBal(n9.pointers).share());
|
|
|
zero+=(getRightBal(n9.pointers).share());
|
|
|
if(zero!=0) {
|
|
|
success = false;
|
|
|
}
|
|
|
- int one = (getRightBal(root_node.pointers).share());
|
|
|
- if(one!=1) {
|
|
|
- success = false;
|
|
|
- }
|
|
|
if(player0) {
|
|
|
if(success) {
|
|
|
print_green("T16 : SUCCESS\n");
|