Browse Source

Shape and the basic Flat shape

Ian Goldberg 2 years ago
parent
commit
79b6d52cec
3 changed files with 152 additions and 36 deletions
  1. 81 0
      duoram.hpp
  2. 18 0
      duoram.tcc
  3. 53 36
      online.cpp

+ 81 - 0
duoram.hpp

@@ -50,11 +50,92 @@ public:
     // The type of this Duoram
     using type = T;
 
+    // The different Shapes are subclasses of this inner class
+    class Shape;
+    // These are the different Shapes that exist
+    class Flat;
+
     // Pass the player number and desired size
     Duoram(int player, size_t size);
 
     // Get the size
     inline size_t size() { return oram_size; }
+
+    // Get the basic Flat shape for this Duoram
+    Flat flat(MPCTIO &tio, yield_t &yield, size_t start = 0,
+            size_t len = 0) {
+        return Flat(*this, tio, yield, start, len);
+    }
+};
+
+// The parent class of all Shapes.  This is an abstract class that
+// cannot itself be instantiated.
+
+template <typename T>
+class Duoram<T>::Shape {
+protected:
+    // A reference to the parent shape.  As with ".." in the root
+    // directory of a filesystem, the topmost shape is indicated by
+    // having parent = *this.
+    const Shape &parent;
+
+    // A reference to the backing physical storage
+    Duoram &duoram;
+
+    // The size of this shape
+    size_t shape_size;
+
+    // The Shape's context (MPCTIO and yield_t)
+    MPCTIO &tio;
+    yield_t &yield;
+
+    // We need a constructor because we hold non-static references; this
+    // constructor is called by the subclass constructors
+    Shape(const Shape &parent, Duoram &duoram, MPCTIO &tio,
+        yield_t &yield) : parent(parent), duoram(duoram), shape_size(0),
+        tio(tio), yield(yield) {}
+
+public:
+    // The index-mapping function. Input the index relative to this
+    // shape, and output the corresponding physical address.  The
+    // strategy is to map the index relative to this shape to the index
+    // relative to the parent shape, call the parent's indexmap function
+    // on that (unless this is the topmost shape), and return what it
+    // returns.  If this is the topmost shape, just return what you
+    // would have passed to the parent's indexmap.
+    //
+    // This is a pure virtual function; all subclasses of Shape must
+    // implement it, and of course Shape itself therefore cannot be
+    // instantiated.
+    virtual size_t indexmap(size_t idx) const = 0;
+
+    // Get the size
+    inline size_t size() { return shape_size; }
+};
+
+// The most basic shape is Flat.  It is almost always the topmost shape,
+// and serves to provide MPCTIO and yield_t context to a Duoram without
+// changing the indices or size (but can specify a subrange if desired).
+
+template <typename T>
+class Duoram<T>::Flat : public Duoram<T>::Shape {
+    // If this is a subrange, start may be non-0, but it's usually 0
+    size_t start;
+
+    inline size_t indexmap(size_t idx) const {
+        size_t paridx = idx + start;
+        if (&(this->parent) == this) {
+            return paridx;
+        } else {
+            return this->parent.indexmap(paridx);
+        }
+    }
+
+public:
+    // Constructor.  len=0 means the maximum size (the parent's size
+    // minus start).
+    Flat(Duoram &duoram, MPCTIO &tio, yield_t &yield, size_t start = 0,
+        size_t len = 0);
 };
 
 #include "duoram.tcc"

+ 18 - 0
duoram.tcc

@@ -13,3 +13,21 @@ Duoram<T>::Duoram(int player, size_t size) : player(player),
         p1_blind.resize(size);
     }
 }
+
+// Constructor.  len=0 means the maximum size (the parent's size
+// minus start).
+template <typename T>
+Duoram<T>::Flat::Flat(Duoram &duoram, MPCTIO &tio, yield_t &yield,
+    size_t start, size_t len) : Shape(*this, duoram, tio, yield)
+{
+    size_t parentsize = duoram.size();
+    if (start > parentsize) {
+        start = parentsize;
+    }
+    this->start = start;
+    size_t maxshapesize = parentsize - start;
+    if (len > maxshapesize || len == 0) {
+        len = maxshapesize;
+    }
+    this->shape_size = len;
+}

+ 53 - 36
online.cpp

@@ -6,7 +6,8 @@
 #include "duoram.hpp"
 
 
-static void online_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void online_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     nbits_t nbits = VALUE_BITS;
 
@@ -61,7 +62,7 @@ static void online_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
         [&](yield_t &yield) {
             mpc_xs_to_as(tio, yield, A[8], X, nbits);
         });
-    run_coroutines(tio, coroutines);
+    run_coroutines(yield, coroutines);
     if (!is_server) {
         printf("\n");
         printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
@@ -107,7 +108,8 @@ static void online_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
     delete[] A;
 }
 
-static void lamport_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void lamport_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     // Create a bunch of threads and send a bunch of data to the other
     // peer, and receive their data.  If an arg is specified, repeat
@@ -152,7 +154,8 @@ static void lamport_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
     pool.join();
 }
 
-static void rdpf_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void rdpf_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     nbits_t depth=6;
 
@@ -233,7 +236,8 @@ static void rdpf_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
     pool.join();
 }
 
-static void rdpf_timing(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void rdpf_timing(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     nbits_t depth=6;
 
@@ -285,7 +289,8 @@ static void rdpf_timing(MPCIO &mpcio, const PRACOptions &opts, char **args)
     pool.join();
 }
 
-static void rdpfeval_timing(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void rdpfeval_timing(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     nbits_t depth=6;
     address_t start=0;
@@ -342,7 +347,8 @@ static void rdpfeval_timing(MPCIO &mpcio, const PRACOptions &opts, char **args)
     pool.join();
 }
 
-static void tupleeval_timing(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void tupleeval_timing(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     nbits_t depth=6;
     address_t start=0;
@@ -408,7 +414,8 @@ static void tupleeval_timing(MPCIO &mpcio, const PRACOptions &opts, char **args)
     pool.join();
 }
 
-static void duoram_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
+static void duoram_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
 {
     nbits_t depth=6;
 
@@ -420,11 +427,13 @@ static void duoram_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth] {
             MPCTIO tio(mpcio, thread_num);
             // size_t &op_counter = tio.aes_ops();
             Duoram<RegAS> oram(mpcio.player, size_t(1)<<depth);
             printf("%ld\n", oram.size());
+            auto A = oram.flat(tio, yield);
+            printf("%ld\n", A.size());
             tio.send();
         });
     }
@@ -433,31 +442,39 @@ static void duoram_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
 
 void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
 {
-    if (!*args) {
-        std::cerr << "Mode is required as the first argument when not preprocessing.\n";
-        return;
-    } else if (!strcmp(*args, "test")) {
-        ++args;
-        online_test(mpcio, opts, args);
-    } else if (!strcmp(*args, "lamporttest")) {
-        ++args;
-        lamport_test(mpcio, opts, args);
-    } else if (!strcmp(*args, "rdpftest")) {
-        ++args;
-        rdpf_test(mpcio, opts, args);
-    } else if (!strcmp(*args, "rdpftime")) {
-        ++args;
-        rdpf_timing(mpcio, opts, args);
-    } else if (!strcmp(*args, "evaltime")) {
-        ++args;
-        rdpfeval_timing(mpcio, opts, args);
-    } else if (!strcmp(*args, "tupletime")) {
-        ++args;
-        tupleeval_timing(mpcio, opts, args);
-    } else if (!strcmp(*args, "duotest")) {
-        ++args;
-        duoram_test(mpcio, opts, args);
-    } else {
-        std::cerr << "Unknown mode " << *args << "\n";
-    }
+    // Run everything inside a coroutine so that simple tests don't have
+    // to start one themselves
+    MPCTIO tio(mpcio, 0);
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(
+        [&](yield_t &yield) {
+            if (!*args) {
+                std::cerr << "Mode is required as the first argument when not preprocessing.\n";
+                return;
+            } else if (!strcmp(*args, "test")) {
+                ++args;
+                online_test(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "lamporttest")) {
+                ++args;
+                lamport_test(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "rdpftest")) {
+                ++args;
+                rdpf_test(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "rdpftime")) {
+                ++args;
+                rdpf_timing(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "evaltime")) {
+                ++args;
+                rdpfeval_timing(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "tupletime")) {
+                ++args;
+                tupleeval_timing(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "duotest")) {
+                ++args;
+                duoram_test(mpcio, yield, opts, args);
+            } else {
+                std::cerr << "Unknown mode " << *args << "\n";
+            }
+        });
+    run_coroutines(tio, coroutines);
 }