Browse Source

fixing extract min bug

avadapal 2 years ago
parent
commit
2357e8ada0
1 changed files with 121 additions and 67 deletions
  1. 121 67
      heap.cpp

+ 121 - 67
heap.cpp

@@ -234,6 +234,8 @@ int MinHeap::verify_heap_property(MPCTIO tio, yield_t & yield) {
 }
 
 void verify_parent_children_heaps(MPCTIO tio, yield_t & yield, RegAS parent, RegAS leftchild, RegAS rightchild) {
+
+    std::cout << "calling this ... \n";
     
     uint64_t parent_reconstruction = mpc_reconstruct(tio, yield, parent);
     
@@ -246,41 +248,52 @@ void verify_parent_children_heaps(MPCTIO tio, yield_t & yield, RegAS parent, Reg
     assert(parent_reconstruction <= rightchild_reconstruction);
 }
 
-//  Let "x" be the root, and let "y" and "z" be the left and right children
-//  For an array, we have A[i] = x, A[2i] = y, A[2i + 1] = z.
-//  We want x \le y, and x \le z.
-//  The steps are as follows:
-//  Step 1: compare(y,z);  (1st call to to MPC Compare)
-//  Step 2: smaller = min(y,z); This is done with an mpcselect (1st call to mpcselect)
-//  Step 3: if(smaller == y) then smallerindex = 2i else smalleindex = 2i + 1;
-//  Step 4: compare(x,smaller); (2nd call to to MPC Compare)
-//  Step 5: smallest = min(x, smaller);  (2nd call to mpcselect)
-//  Step 6: otherchild = max(x, smaller)   
-//  Step 7: A[i] \gets smallest   (1st Duoam Write)
-//  Step 8: A[smallerindex] \gets otherchild (2nd Duoam Write)
-//  Overall restore_heap_property takes 2 MPC Comparisons, 2 MPC Selects, and 2 Duoram Writes
+ 
+
 RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
     RegAS smallest;
     auto HeapArray = oram.flat(tio, yield);
 
 
-    RegAS parent = HeapArray[index];
     RegXS leftchildindex = index;
     leftchildindex = index << 1;
 
     RegXS rightchildindex;
     rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
 
-    RegAS leftchild = HeapArray[leftchildindex];
-    RegAS rightchild = HeapArray[rightchildindex];
+    RegAS parent;     //   = HeapArray[index];
+    RegAS leftchild;  //  = HeapArray[leftchildindex];
+    RegAS rightchild; // = HeapArray[rightchildindex];
  
+    std::vector<coro_t> coroutines_read;
+    coroutines_read.emplace_back( 
+    [&tio, &parent, &HeapArray, index](yield_t &yield) { 
+            auto Acoro = HeapArray.context(yield); 
+            parent = Acoro[index]; //inserted_val;
+     });             
+   
+    coroutines_read.emplace_back( 
+    [&tio, &HeapArray, &leftchild, leftchildindex](yield_t &yield) { 
+            auto Acoro = HeapArray.context(yield); 
+            leftchild  = Acoro[leftchildindex]; //inserted_val;
+     }); 
+
+    coroutines_read.emplace_back( 
+    [&tio, &rightchild, &HeapArray, rightchildindex](yield_t &yield) { 
+            auto Acoro = HeapArray.context(yield); 
+            rightchild = Acoro[rightchildindex];
+     }); 
+
+    run_coroutines(tio, coroutines_read);
+
     //RegAS sum = parent + leftchild + rightchild;
 
     CDPF cdpf = tio.cdpf(yield);
     auto[lt_c, eq_c, gt_c] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops());
     auto lteq = lt_c ^ eq_c;    
     RegXS smallerindex;
-    
+    RegAS smallerchild;
+
 
     #ifdef VERBOSE
         uint64_t LC_rec = mpc_reconstruct(tio, yield, leftchildindex);
@@ -288,15 +301,22 @@ RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
     #endif
 
 
-    mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);
+    // mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);
+    // mpc_select(tio, yield, smallerchild, lt_c, rightchild, leftchild, 64);
+
+    run_coroutines(tio, [&tio, &smallerindex, lteq, rightchildindex, leftchildindex](yield_t &yield)
+            { mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);},
+            [&tio, &smallerchild, lt_c, rightchild, leftchild](yield_t &yield)
+            { mpc_select(tio, yield, smallerchild, lt_c, rightchild, leftchild, 64);});
+
 
     #ifdef VERBOSE
         uint64_t smallerindex_rec = mpc_reconstruct(tio, yield, smallerindex);
         std::cout << "smallerindex_rec = " << smallerindex_rec << std::endl; 
     #endif
 
-    RegAS smallerchild;
-    mpc_select(tio, yield, smallerchild, lt_c, rightchild, leftchild, 64);
+    
+
 
     CDPF cdpf0 = tio.cdpf(yield);
 
@@ -308,54 +328,38 @@ RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
     
     mpc_and(tio, yield, ltlt1, lteq, lt_p_eq_p);
 
-    RegAS z, zz, zzz;
+    RegAS update_index_by, update_leftindex_by;
 
-    run_coroutines(tio, [&tio, &zz, ltlt1, parent, leftchild](yield_t &yield)
-            { mpc_flagmult(tio, yield, zz, ltlt1, (parent - leftchild), 64);},
-            [&tio, &z, lt_p, parent, smallerchild](yield_t &yield)
-            {mpc_flagmult(tio, yield, z, lt_p, smallerchild - parent, 64);},
-            [&tio, &zzz, ltlt1, parent,  rightchild](yield_t &yield)
-            {mpc_flagmult(tio, yield, zzz, ltlt1, (parent - rightchild), 64);}
+    run_coroutines(tio, [&tio, &update_leftindex_by, ltlt1, parent, leftchild](yield_t &yield)
+            { mpc_flagmult(tio, yield, update_leftindex_by, ltlt1, (parent - leftchild), 64);},
+            [&tio, &update_index_by, lt_p, parent, smallerchild](yield_t &yield)
+            {mpc_flagmult(tio, yield, update_index_by, lt_p, smallerchild - parent, 64);}
             );
 
-    
-
-
-
-
-
-    
-    
-    // HeapArray[index]           += z;
-    // HeapArray[leftchildindex]  += zz;    
-    // HeapArray[rightchildindex] += zzz;
-
- 
     std::vector<coro_t> coroutines;
 
-
     coroutines.emplace_back( 
-    [&tio, &HeapArray, index, z](yield_t &yield) { 
+    [&tio, &HeapArray, index, update_index_by](yield_t &yield) { 
             auto Acoro = HeapArray.context(yield); 
-            Acoro[index] += z; //inserted_val;
+            Acoro[index] += update_index_by; //inserted_val;
      });             
    
     coroutines.emplace_back( 
-    [&tio, &HeapArray, leftchildindex, zz](yield_t &yield) { 
+    [&tio, &HeapArray, leftchildindex, update_leftindex_by](yield_t &yield) { 
             auto Acoro = HeapArray.context(yield); 
-            Acoro[leftchildindex] += zz; //inserted_val;
+            Acoro[leftchildindex] += update_leftindex_by; //inserted_val;
      }); 
 
     coroutines.emplace_back( 
-    [&tio, &HeapArray, rightchildindex, zzz](yield_t &yield) { 
+    [&tio, &HeapArray, rightchildindex, update_index_by, update_leftindex_by](yield_t &yield) { 
             auto Acoro = HeapArray.context(yield); 
-            Acoro[rightchildindex] += zzz;
+            Acoro[rightchildindex] += -(update_index_by + update_leftindex_by);
      }); 
 
     run_coroutines(tio, coroutines);
 
     
-    //verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
+   // verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
 
     return smallerindex;
 }
@@ -378,9 +382,37 @@ auto MinHeap::restore_heap_property_optimized(MPCTIO tio, yield_t & yield, RegXS
     typename Duoram < RegAS > ::Stride L(C, tio, yield, 0, 2);
     typename Duoram < RegAS > ::Stride R(C, tio, yield, 1, 2);
 
-    RegAS parent_tmp = P[oidx];
-    RegAS leftchild_tmp = L[oidx];
-    RegAS rightchild_tmp = R[oidx];
+    // RegAS parent_tmp = P[oidx];
+    // RegAS leftchild_tmp = L[oidx];
+    // RegAS rightchild_tmp = R[oidx];
+
+    RegAS parent_tmp;     //   = HeapArray[index];
+    RegAS leftchild_tmp;  //  = HeapArray[leftchildindex];
+    RegAS rightchild_tmp; // = HeapArray[rightchildindex];
+ 
+    std::vector<coro_t> coroutines_read;
+    coroutines_read.emplace_back( 
+    [&tio, &parent_tmp, &P, &oidx](yield_t &yield) { 
+           // auto Acoro = P.context(yield); 
+            parent_tmp = P[oidx]; //inserted_val;
+     });             
+   
+    coroutines_read.emplace_back( 
+    [&tio, &L, &leftchild_tmp, &oidx](yield_t &yield) { 
+            //auto Acoro = L.context(yield); 
+            leftchild_tmp  = L[oidx]; //inserted_val;
+     }); 
+
+    coroutines_read.emplace_back( 
+    [&tio, &R, &rightchild_tmp, &oidx](yield_t &yield) { 
+         //  auto Acoro = R.context(yield); 
+           rightchild_tmp = R[oidx];
+     }); 
+
+    run_coroutines(tio, coroutines_read);
+
+
+
 
     //RegAS sum = parent_tmp + leftchild_tmp + rightchild_tmp;
 
@@ -420,23 +452,45 @@ auto MinHeap::restore_heap_property_optimized(MPCTIO tio, yield_t & yield, RegXS
     //         [&tio, &zz, ltlt1, parent_tmp, leftchild_tmp](yield_t &yield)
     //         { mpc_flagmult(tio, yield, zz, ltlt1, (parent_tmp - leftchild_tmp), 64);});
 
-    RegAS z, zz, zzz;
+    RegAS update_index_by, update_leftindex_by;
 
-    run_coroutines(tio, [&tio, &zz, ltlt1, parent_tmp, leftchild_tmp](yield_t &yield)
-            { mpc_flagmult(tio, yield, zz, ltlt1, (parent_tmp - leftchild_tmp), 64);},
-            [&tio, &z, lt1eq1, parent_tmp, smallerchild](yield_t &yield)
-            {mpc_flagmult(tio, yield, z, lt1eq1, smallerchild - parent_tmp, 64);},
-            [&tio, &zzz, ltlt1, parent_tmp,  rightchild_tmp](yield_t &yield)
-            {mpc_flagmult(tio, yield, zzz, ltlt1, (parent_tmp - rightchild_tmp), 64);}
+    run_coroutines(tio, [&tio, &update_leftindex_by, ltlt1, parent_tmp, leftchild_tmp](yield_t &yield)
+            { mpc_flagmult(tio, yield, update_leftindex_by, ltlt1, (parent_tmp - leftchild_tmp), 64);},
+            [&tio, &update_index_by, lt1eq1, parent_tmp, smallerchild](yield_t &yield)
+            {mpc_flagmult(tio, yield, update_index_by, lt1eq1, smallerchild - parent_tmp, 64);} 
             );
     
 
     // RegAS leftchildplusparent = RegAS(HeapArray[index]) + RegAS(HeapArray[leftchildindex]);
     // RegAS tmp = (sum - leftchildplusparent);
 
-    P[oidx] += z;
-    L[oidx] += zz;// - leftchild_tmp;
-    R[oidx] += zzz;//tmp - rightchild_tmp;
+    std::vector<coro_t> coroutines;
+
+    coroutines.emplace_back( 
+    [&tio, &P, &oidx, update_index_by](yield_t &yield) { 
+            auto Acoro = P.context(yield); 
+            Acoro[oidx] += update_index_by; //inserted_val;
+     });             
+   
+    coroutines.emplace_back( 
+    [&tio, &L,  &oidx, update_leftindex_by](yield_t &yield) { 
+            auto Acoro = L.context(yield); 
+            Acoro[oidx] += update_leftindex_by; //inserted_val;
+     }); 
+
+    coroutines.emplace_back( 
+    [&tio, &R,  &oidx, update_leftindex_by, update_index_by](yield_t &yield) { 
+            auto Acoro = R.context(yield); 
+            Acoro[oidx] += -(update_leftindex_by + update_index_by);
+     }); 
+
+    run_coroutines(tio, coroutines);
+
+
+
+    // P[oidx] += z;
+    // L[oidx] += zz; 
+    // R[oidx] += zzz; 
 
     return std::make_pair(smallerindex, gt);
 }
@@ -708,8 +762,8 @@ void Heap(MPCIO & mpcio,
         tio.reset_lamport();
 
 
-   //      tree.heapify(tio, yield);
-     //    tree.print_heap(tio, yield);
+        // tree.heapify(tio, yield);
+        // tree.print_heap(tio, yield);
 
         for (size_t j = 0; j < n_inserts; ++j) {
             RegAS inserted_val;
@@ -731,8 +785,8 @@ void Heap(MPCIO & mpcio,
         tio.sync_lamport();
         mpcio.dump_stats(std::cout);
 
-        mpcio.reset_stats();
-        tio.reset_lamport();
+        // mpcio.reset_stats();
+        // tio.reset_lamport();
 
 
 
@@ -753,8 +807,8 @@ void Heap(MPCIO & mpcio,
         tio.sync_lamport();
         mpcio.dump_stats(std::cout);
 
-        // tree.print_heap(tio, yield);
-        // tree.verify_heap_property(tio, yield);
+        //tree.print_heap(tio, yield);
+        //tree.verify_heap_property(tio, yield);
         
       
     });