Browse Source

removing the pointer for ORAM, take in params for number of extractmins and inserts

avadapal 1 year ago
parent
commit
10679b916a
2 changed files with 33 additions and 134 deletions
  1. 30 110
      heap.cpp
  2. 3 24
      heap.hpp

+ 30 - 110
heap.cpp

@@ -1,20 +1,11 @@
 #include <functional>
 
 #include "types.hpp"
-
 #include "duoram.hpp"
-
 #include "cell.hpp"
-
 #include  "heap.hpp"
 
 
-void MinHeap::initialize(int num_players, size_t size) {
-    this -> MAX_SIZE = size;
-    this -> num_items = 0;
-    oram = new Duoram < RegAS > (num_players, size);
-}
-
 RegAS reconstruct_AS(MPCTIO & tio, yield_t & yield, RegAS AS) {
     RegAS peer_AS;
     RegAS reconstructed_AS = AS;
@@ -86,8 +77,12 @@ bool reconstruct_flag(MPCTIO & tio, yield_t & yield, RegBS flag) {
     return reconstructed_flag.bshare;
 }
 
+// The insert protocol works as follows:
+// It adds a new element in the last entry of the array
+// From the leaf (the element added), compare with its parent (1 oblivious compare)
+// If the child is larger, then we do an OSWAP.
 int MinHeap::insert(MPCTIO tio, yield_t & yield, RegAS val) {
-    auto HeapArray = oram -> flat(tio, yield);
+    auto HeapArray = oram.flat(tio, yield);
     num_items++;
     std::cout << "num_items = " << num_items << std::endl;
 
@@ -96,104 +91,27 @@ int MinHeap::insert(MPCTIO tio, yield_t & yield, RegAS val) {
     val.dump();
     yield();
     size_t childindex = num_items;
-    size_t parent = childindex / 2;
+    size_t parentindex = childindex / 2;
     std::cout << "childindex = " << childindex << std::endl;
-    std::cout << "parent = " << parent << std::endl;
+    std::cout << "parentindex = " << parentindex << std::endl;
     HeapArray[num_items] = val;
-    RegAS tmp = HeapArray[num_items];
-    reconstruct_AS(tio, yield, tmp);
-    yield();
-    while (parent > 0) {
-        //std::cout << "while loop\n";
+   
+    while (parentindex > 0) {
         RegAS sharechild = HeapArray[childindex];
-        RegAS shareparent = HeapArray[parent];
-        // RegAS sharechildrec = reconstruct_AS(tio, yield, sharechild);
-        // yield();
-        // RegAS shareparentrec = reconstruct_AS(tio, yield, shareparent);
-        // yield();
-        // std::cout << "\nchild reconstruct_AS = \n";
-        // sharechildrec.dump();
-        // std::cout << "\nparent reconstruct_AS = \n";
-        // shareparentrec.dump();
-        // std::cout << "\n----\n";
+        RegAS shareparent = HeapArray[parentindex];
+
         CDPF cdpf = tio.cdpf(yield);
         RegAS diff = sharechild - shareparent;
-        // std::cout << "diff = " << std::endl;
-        // RegAS diff_rec = reconstruct_AS(tio, yield, diff);
-        // diff_rec.dump();
-        // std::cout << std::endl << std::endl;
+
         auto[lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());
-        //auto lteq = lt ^ eq;
-        // bool lteq_rec = reconstruct_flag(tio, yield, lteq);
-        // yield();   
-        // bool lt_rec = reconstruct_flag(tio, yield, lt);
-        // yield();
-        // std::cout <<"lt_rec = " << (int) lt_rec << std::endl;
-        // std::cout << std::endl;
-        // bool eq_rec = reconstruct_flag(tio, yield, eq);
-        // yield();
-        // std::cout <<"eq_rec = " << (int) eq_rec << std::endl;
-        // std::cout << std::endl;
-        // yield();   
-        // bool gt_rec = reconstruct_flag(tio, yield, gt);
-        // yield();
-        // std::cout <<"gt_rec = " << (int) gt_rec << std::endl;
-        // std::cout << std::endl;
-        // if(lteq_rec) {
-        //   if(sharechildrec.ashare > shareparentrec.ashare) {
-        //     std::cout << "\nchild reconstruct_AS = \n";
-        //     sharechildrec.dump();
-        //     std::cout << "\nchild reconstruct_AS = \n";
-        //     shareparentrec.dump();
-        //     std::cout << "\n----\n";
-        //   }
-        //   assert(sharechildrec.ashare <= shareparentrec.ashare);
-        // }
-        // if(gt_rec) {
-        //   if(sharechildrec.ashare < shareparentrec.ashare) {
-        //     std::cout << "\nchild reconstruct_AS = \n";
-        //     sharechildrec.dump();
-        //     std::cout << "\nchild reconstruct_AS = \n";
-        //     shareparentrec.dump();
-        //     std::cout << "\n----\n";
-        //   }
-        //     assert(sharechildrec.ashare > shareparentrec.ashare);
-        //   }
-
-        //   std::cout << "child = " << child << std::endl;
-        //   sharechildrec  =  reconstruct_AS(tio, yield, sharechild);
-        //   sharechildrec.dump();
-        //   yield();
-        //   std::cout << "parent = " << parent << std::endl;
-        //   shareparentrec =  reconstruct_AS(tio, yield, shareparent);
-        //   shareparentrec.dump();
-        //   yield();
-
-        // std::cout << "\n^^^ before mpc_oswap\n";
-        mpc_oswap(tio, yield, sharechild, shareparent, lt, 64);
-
-        HeapArray[childindex] = sharechild;
-        HeapArray[parent] = shareparent;
-
-        // std::cout << "child = " << child << std::endl;
-        // sharechildrec  =  reconstruct_AS(tio, yield, sharechild);
-        // sharechildrec.dump();
-        // yield();
-        // std::cout << "parent = " << parent << std::endl;
-        // shareparentrec =  reconstruct_AS(tio, yield, shareparent);
-        // shareparentrec.dump();
-        // yield();
-        // std::cout << "\n^^^after mpc_oswap\n";
-        // assert(sharechildrec.ashare >= shareparentrec.ashare);
-
-        // std::cout << "we asserted that: \n";
-        // sharechildrec.dump();
-        // std::cout << std::endl << " < " << std::endl;
-        // shareparentrec.dump();
-        // std::cout << "\n ----- \n";
-
-        childindex = parent;
-        parent = childindex / 2;
+        auto lteq = lt ^ eq;
+        mpc_oswap(tio, yield, sharechild, shareparent, lteq, 64);
+
+        HeapArray[childindex]  = sharechild;
+        HeapArray[parentindex] = shareparent;
+
+        childindex = parentindex;
+        parentindex = parentindex / 2;
     }
 
     return 1;
@@ -201,7 +119,7 @@ int MinHeap::insert(MPCTIO tio, yield_t & yield, RegAS val) {
 
 int MinHeap::verify_heap_property(MPCTIO tio, yield_t & yield) {
     std::cout << std::endl << std::endl << "verify_heap_property is being called " << std::endl;
-    auto HeapArray = oram -> flat(tio, yield);
+    auto HeapArray = oram.flat(tio, yield);
 
     RegAS heapreconstruction[num_items];
     for (size_t j = 0; j <= num_items; ++j) {
@@ -258,7 +176,7 @@ void verify_parent_children_heaps(MPCTIO tio, yield_t & yield, RegAS parent, Reg
 //  Overall restore_heap_property takes 2 MPC Comparisons, 2 MPC Selects, and 2 Duoram Writes
 RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
     RegAS smallest;
-    auto HeapArray = oram -> flat(tio, yield);
+    auto HeapArray = oram.flat(tio, yield);
     RegAS parent = HeapArray[index];
     RegXS leftchildindex = index;
     leftchildindex = index << 1;
@@ -307,7 +225,7 @@ RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
 */
 RegXS MinHeap::restore_heap_property_at_root(MPCTIO tio, yield_t & yield) {
     size_t index = 1;
-    auto HeapArray = oram -> flat(tio, yield);
+    auto HeapArray = oram.flat(tio, yield);
     RegAS parent = HeapArray[index];
     RegAS leftchild = HeapArray[2 * index];
     RegAS rightchild = HeapArray[2 * index + 1];
@@ -344,7 +262,7 @@ RegXS MinHeap::restore_heap_property_at_root(MPCTIO tio, yield_t & yield) {
 RegAS MinHeap::extract_min(MPCTIO tio, yield_t & yield) {
 
     RegAS minval;
-    auto HeapArray = oram -> flat(tio, yield);
+    auto HeapArray = oram.flat(tio, yield);
     minval = HeapArray[1];
     HeapArray[1] = RegAS(HeapArray[num_items]);
 
@@ -361,7 +279,9 @@ RegAS MinHeap::extract_min(MPCTIO tio, yield_t & yield) {
 
 void Heap(MPCIO & mpcio,
     const PRACOptions & opts, char ** args) {
-    nbits_t depth = atoi(args[0]);
+    nbits_t depth     = atoi(args[0]);
+    size_t n_inserts  = atoi(args[1]);
+    size_t n_extracts = atoi(args[2]);
     std::cout << "print arguements " << std::endl;
     std::cout << args[0] << std::endl;
 
@@ -381,13 +301,13 @@ void Heap(MPCIO & mpcio,
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
 
-    run_coroutines(tio, [ & tio, depth, items](yield_t & yield) {
+    run_coroutines(tio, [ & tio, depth, items, n_inserts, n_extracts](yield_t & yield) {
         size_t size = size_t(1) << depth;
         std::cout << "size = " << size << std::endl;
 
         MinHeap tree(tio.player(), size);
 
-        for (size_t j = 0; j < 60; ++j) {
+        for (size_t j = 0; j < n_inserts; ++j) {
             RegAS inserted_val;
             inserted_val.randomize(62);
             inserted_val.ashare = inserted_val.ashare;
@@ -397,7 +317,7 @@ void Heap(MPCIO & mpcio,
         std::cout << std::endl << "=============[Insert Done]================" << std::endl << std::endl;
         tree.verify_heap_property(tio, yield);
 
-        for (size_t j = 0; j < 10; ++j) {
+        for (size_t j = 0; j < n_extracts; ++j) {
             RegAS minval = tree.extract_min(tio, yield);
             tree.verify_heap_property(tio, yield);
             RegAS minval_reconstruction = reconstruct_AS(tio, yield, minval);

+ 3 - 24
heap.hpp

@@ -2,40 +2,21 @@
 #define __HEAP_HPP__
 
 #include "types.hpp"
-
 #include "mpcio.hpp"
-
 #include "coroutine.hpp"
-
 #include "options.hpp"
-
 #include "mpcops.hpp"
 
 class MinHeap {
-    private: Duoram < RegAS > * oram;
-    // RegXS root;
+    private: Duoram < RegAS > oram;
 
     size_t MAX_SIZE;
 
-    // std::tuple<RegXS, RegBS> insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
-    //     const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy);
-    // void insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A);
-
-    // int del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
-    //     Duoram<Node>::Flat &A, RegBS F_af, RegBS F_fs, int TTL, 
-    //     del_return &ret_struct);
-
     public: size_t num_items = 0;
-    MinHeap(int num_players, size_t size) {
-        this -> initialize(num_players, size);
-    };
-
-    ~MinHeap() {
-        if (oram)
-            delete oram;
+    MinHeap(int num_players, size_t size) : oram(num_players, size), MAX_SIZE(size){
     };
 
-    void initialize(int num_players, size_t size);
+    
     RegAS extract_min(MPCTIO tio, yield_t & yield);
     int insert(MPCTIO tio, yield_t & yield, RegAS val);
     int verify_heap_property(MPCTIO tio, yield_t & yield);
@@ -45,6 +26,4 @@ class MinHeap {
 
 // void MinHeap(MPCIO &mpcio,
 //    const PRACOptions &opts, char **args);
-
-//#include "heap.tcc"
 #endif