Browse Source

Online-only mode

In this mode, there is no preprocessing, and any required values that
would normally be preprocessed are computed online.  Counts are kept and
displayed of the number of such values computed, so that that
information can be given to a subsequent preprocessing mode for future
runs.
Ian Goldberg 1 year ago
parent
commit
62855f7b92
13 changed files with 163 additions and 113 deletions
  1. 15 13
      Makefile
  2. 2 0
      cdpf.cpp
  3. 11 0
      corotypes.hpp
  4. 1 5
      coroutine.hpp
  5. 8 8
      duoram.tcc
  6. 48 28
      mpcio.cpp
  7. 13 10
      mpcio.hpp
  8. 2 2
      mpcops.cpp
  9. 40 32
      online.cpp
  10. 4 2
      options.hpp
  11. 9 5
      prac.cpp
  12. 4 8
      preproc.cpp
  13. 6 0
      types.hpp

+ 15 - 13
Makefile

@@ -28,16 +28,18 @@ depend:
 
 # DO NOT DELETE THIS LINE -- make depend depends on it.
 
-prac.o: mpcio.hpp types.hpp preproc.hpp options.hpp online.hpp
-mpcio.o: mpcio.hpp types.hpp rdpf.hpp coroutine.hpp bitutils.hpp dpf.hpp
-mpcio.o: prg.hpp aes.hpp rdpf.tcc cdpf.hpp cdpf.tcc
-preproc.o: types.hpp coroutine.hpp mpcio.hpp preproc.hpp options.hpp rdpf.hpp
-preproc.o: bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc cdpf.hpp cdpf.tcc
-online.o: online.hpp mpcio.hpp types.hpp options.hpp mpcops.hpp coroutine.hpp
-online.o: rdpf.hpp bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc duoram.hpp
-online.o: duoram.tcc cdpf.hpp cdpf.tcc
-mpcops.o: mpcops.hpp types.hpp mpcio.hpp coroutine.hpp bitutils.hpp
-rdpf.o: rdpf.hpp mpcio.hpp types.hpp coroutine.hpp bitutils.hpp dpf.hpp
-rdpf.o: prg.hpp aes.hpp rdpf.tcc mpcops.hpp
-cdpf.o: bitutils.hpp cdpf.hpp mpcio.hpp types.hpp coroutine.hpp dpf.hpp
-cdpf.o: prg.hpp aes.hpp cdpf.tcc
+prac.o: mpcio.hpp types.hpp corotypes.hpp preproc.hpp options.hpp online.hpp
+mpcio.o: mpcio.hpp types.hpp corotypes.hpp rdpf.hpp coroutine.hpp
+mpcio.o: bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc cdpf.hpp cdpf.tcc
+preproc.o: types.hpp coroutine.hpp corotypes.hpp mpcio.hpp preproc.hpp
+preproc.o: options.hpp rdpf.hpp bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc
+preproc.o: cdpf.hpp cdpf.tcc
+online.o: online.hpp mpcio.hpp types.hpp corotypes.hpp options.hpp mpcops.hpp
+online.o: coroutine.hpp rdpf.hpp bitutils.hpp dpf.hpp prg.hpp aes.hpp
+online.o: rdpf.tcc duoram.hpp duoram.tcc cdpf.hpp cdpf.tcc
+mpcops.o: mpcops.hpp types.hpp mpcio.hpp corotypes.hpp coroutine.hpp
+mpcops.o: bitutils.hpp
+rdpf.o: rdpf.hpp mpcio.hpp types.hpp corotypes.hpp coroutine.hpp bitutils.hpp
+rdpf.o: dpf.hpp prg.hpp aes.hpp rdpf.tcc mpcops.hpp
+cdpf.o: bitutils.hpp cdpf.hpp mpcio.hpp types.hpp corotypes.hpp coroutine.hpp
+cdpf.o: dpf.hpp prg.hpp aes.hpp cdpf.tcc

+ 2 - 0
cdpf.cpp

@@ -188,6 +188,8 @@ std::tuple<RegBS,RegBS,RegBS> CDPF::compare(MPCTIO &tio, yield_t &yield,
         // After that one single-word exchange, the rest of this
         // algorithm is entirely a local computation.
         return compare(S, aes_ops);
+    } else {
+        yield();
     }
     // The server gets three shares of 0 (which is not a valid output
     // for the computational players)

+ 11 - 0
corotypes.hpp

@@ -0,0 +1,11 @@
+#ifndef __COROTYPES_HPP__
+#define __COROTYPES_HPP__
+
+#include <functional>
+#include <boost/coroutine2/coroutine.hpp>
+
+using coro_t = boost::coroutines2::coroutine<void>::pull_type;
+using yield_t = boost::coroutines2::coroutine<void>::push_type;
+using coro_lambda_t = std::function<void(yield_t&)>;
+
+#endif

+ 1 - 5
coroutine.hpp

@@ -2,14 +2,10 @@
 #define __COROUTINE_HPP__
 
 #include <vector>
-#include <functional>
-#include <boost/coroutine2/coroutine.hpp>
 
+#include "corotypes.hpp"
 #include "mpcio.hpp"
 
-using coro_t = boost::coroutines2::coroutine<void>::pull_type;
-using yield_t = boost::coroutines2::coroutine<void>::push_type;
-using coro_lambda_t = std::function<void(yield_t&)>;
 
 // The top-level coroutine runner will call run_coroutines with
 // a MPCTIO, and we should call its send() method.  Subcoroutines that

+ 8 - 8
duoram.tcc

@@ -281,7 +281,7 @@ Duoram<T>::Shape::MemRefAS::operator T()
     if (player < 2) {
         // Computational players do this
 
-        RDPFTriple dt = shape.tio.rdpftriple(shape.addr_size);
+        RDPFTriple dt = shape.tio.rdpftriple(shape.yield, shape.addr_size);
 
         // Compute the index offset
         RegAS indoffset = dt.as_target;
@@ -324,7 +324,7 @@ Duoram<T>::Shape::MemRefAS::operator T()
     } else {
         // The server does this
 
-        RDPFPair dp = shape.tio.rdpfpair(shape.addr_size);
+        RDPFPair dp = shape.tio.rdpfpair(shape.yield, shape.addr_size);
         RegAS p0indoffset, p1indoffset;
 
         // Receive the index offset from the computational players and
@@ -376,7 +376,7 @@ typename Duoram<T>::Shape::MemRefAS
     if (player < 2) {
         // Computational players do this
 
-        RDPFTriple dt = shape.tio.rdpftriple(shape.addr_size);
+        RDPFTriple dt = shape.tio.rdpftriple(shape.yield, shape.addr_size);
 
         // Compute the index and message offsets
         RegAS indoffset = dt.as_target;
@@ -425,7 +425,7 @@ typename Duoram<T>::Shape::MemRefAS
     } else {
         // The server does this
 
-        RDPFPair dp = shape.tio.rdpfpair(shape.addr_size);
+        RDPFPair dp = shape.tio.rdpfpair(shape.yield, shape.addr_size);
         RegAS p0indoffset, p1indoffset;
         std::tuple<T,T> p0Moffset, p1Moffset;
 
@@ -509,7 +509,7 @@ Duoram<T>::Shape::MemRefXS::operator T()
     if (player < 2) {
         // Computational players do this
 
-        RDPFTriple dt = shape.tio.rdpftriple(shape.addr_size);
+        RDPFTriple dt = shape.tio.rdpftriple(shape.yield, shape.addr_size);
 
         // Compute the index offset
         RegXS indoffset = dt.xs_target;
@@ -550,7 +550,7 @@ Duoram<T>::Shape::MemRefXS::operator T()
     } else {
         // The server does this
 
-        RDPFPair dp = shape.tio.rdpfpair(shape.addr_size);
+        RDPFPair dp = shape.tio.rdpfpair(shape.yield, shape.addr_size);
         RegXS p0indoffset, p1indoffset;
 
         // Receive the index offset from the computational players and
@@ -602,7 +602,7 @@ typename Duoram<T>::Shape::MemRefXS
     if (player < 2) {
         // Computational players do this
 
-        RDPFTriple dt = shape.tio.rdpftriple(shape.addr_size);
+        RDPFTriple dt = shape.tio.rdpftriple(shape.yield, shape.addr_size);
 
         // Compute the index and message offsets
         RegXS indoffset = dt.xs_target;
@@ -651,7 +651,7 @@ typename Duoram<T>::Shape::MemRefXS
     } else {
         // The server does this
 
-        RDPFPair dp = shape.tio.rdpfpair(shape.addr_size);
+        RDPFPair dp = shape.tio.rdpfpair(shape.yield, shape.addr_size);
         RegXS p0indoffset, p1indoffset;
         std::tuple<T,T> p0Moffset, p1Moffset;
 

+ 48 - 28
mpcio.cpp

@@ -4,24 +4,25 @@
 #include "rdpf.hpp"
 #include "cdpf.hpp"
 #include "bitutils.hpp"
+#include "coroutine.hpp"
 
 // T is the type being stored
 // N is a type whose "name" static member is a string naming the type
 //   so that we can report something useful to the user if they try
 //   to read a type that we don't have any more values for
 template<typename T, typename N>
-PreCompStorage<T,N>::PreCompStorage(unsigned player, bool preprocessing,
+PreCompStorage<T,N>::PreCompStorage(unsigned player, ProcessingMode mode,
         const char *filenameprefix, unsigned thread_num) :
         name(N::name), depth(0)
 {
-    init(player, preprocessing, filenameprefix, thread_num);
+    init(player, mode, filenameprefix, thread_num);
 }
 
 template<typename T, typename N>
-void PreCompStorage<T,N>::init(unsigned player, bool preprocessing,
+void PreCompStorage<T,N>::init(unsigned player, ProcessingMode mode,
         const char *filenameprefix, unsigned thread_num, nbits_t depth)
 {
-    if (preprocessing) return;
+    if (mode != MODE_ONLINE) return;
     std::string filename(filenameprefix);
     char suffix[20];
     if (depth) {
@@ -255,27 +256,27 @@ void MPCIO::dump_stats(std::ostream &os)
     dump_memusage(os);
 }
 
-MPCPeerIO::MPCPeerIO(unsigned player, bool preprocessing,
+MPCPeerIO::MPCPeerIO(unsigned player, ProcessingMode mode,
         std::deque<tcp::socket> &peersocks,
         std::deque<tcp::socket> &serversocks) :
-    MPCIO(player, preprocessing, peersocks.size())
+    MPCIO(player, mode, peersocks.size())
 {
     unsigned num_threads = unsigned(peersocks.size());
     for (unsigned i=0; i<num_threads; ++i) {
-        triples.emplace_back(player, preprocessing, "triples", i);
+        triples.emplace_back(player, mode, "triples", i);
     }
     for (unsigned i=0; i<num_threads; ++i) {
-        halftriples.emplace_back(player, preprocessing, "halves", i);
+        halftriples.emplace_back(player, mode, "halves", i);
     }
     rdpftriples.resize(num_threads);
     for (unsigned i=0; i<num_threads; ++i) {
         for (unsigned depth=1; depth<=ADDRESS_MAX_BITS; ++depth) {
-            rdpftriples[i][depth-1].init(player, preprocessing,
+            rdpftriples[i][depth-1].init(player, mode,
                 "rdpf", i, depth);
         }
     }
     for (unsigned i=0; i<num_threads; ++i) {
-        cdpfs.emplace_back(player, preprocessing, "cdpf", i);
+        cdpfs.emplace_back(player, mode, "cdpf", i);
     }
     for (auto &&sock : peersocks) {
         peerios.emplace_back(std::move(sock));
@@ -325,15 +326,15 @@ void MPCPeerIO::dump_stats(std::ostream &os)
     dump_precomp_stats(os);
 }
 
-MPCServerIO::MPCServerIO(bool preprocessing,
+MPCServerIO::MPCServerIO(ProcessingMode mode,
         std::deque<tcp::socket> &p0socks,
         std::deque<tcp::socket> &p1socks) :
-    MPCIO(2, preprocessing, p0socks.size())
+    MPCIO(2, mode, p0socks.size())
 {
     rdpfpairs.resize(num_threads);
     for (unsigned i=0; i<num_threads; ++i) {
         for (unsigned depth=1; depth<=ADDRESS_MAX_BITS; ++depth) {
-            rdpfpairs[i][depth-1].init(player, preprocessing,
+            rdpfpairs[i][depth-1].init(player, mode,
                 "rdpf", i, depth);
         }
     }
@@ -532,18 +533,19 @@ void MPCTIO::send()
 
 // Functions to get precomputed values.  If we're in the online
 // phase, get them from PreCompStorage.  If we're in the
-// preprocessing phase, read them from the server.
+// preprocessing or online-only phase, read them from the server.
 MultTriple MPCTIO::triple()
 {
     MultTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        if (mpcpio.preprocessing) {
+        if (mpcpio.mode != MODE_ONLINE) {
             recv_server(&val, sizeof(val));
+            mpcpio.triples[thread_num].inc();
         } else {
             mpcpio.triples[thread_num].get(val);
         }
-    } else if (mpcio.preprocessing) {
+    } else if (mpcio.mode != MODE_ONLINE) {
         // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
         // (X0*Y1 + Y0*X1) = (Z0+Z1)
         value_t X0, Y0, Z0, X1, Y1, Z1;
@@ -567,12 +569,13 @@ HalfTriple MPCTIO::halftriple()
     HalfTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        if (mpcpio.preprocessing) {
+        if (mpcpio.mode != MODE_ONLINE) {
             recv_server(&val, sizeof(val));
+            mpcpio.halftriples[thread_num].inc();
         } else {
             mpcpio.halftriples[thread_num].get(val);
         }
-    } else if (mpcio.preprocessing) {
+    } else if (mpcio.mode != MODE_ONLINE) {
         // Create half-triples (X0,Z0),(Y1,Z1) such that
         // X0*Y1 = Z0 + Z1
         value_t X0, Z0, Y1, Z1;
@@ -594,7 +597,7 @@ SelectTriple MPCTIO::selecttriple()
     SelectTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        if (mpcpio.preprocessing) {
+        if (mpcpio.mode != MODE_ONLINE) {
             uint8_t Xbyte;
             recv_server(&Xbyte, sizeof(Xbyte));
             val.X = Xbyte & 1;
@@ -603,7 +606,7 @@ SelectTriple MPCTIO::selecttriple()
         } else {
             std::cerr << "Attempted to read SelectTriple in online phase\n";
         }
-    } else if (mpcio.preprocessing) {
+    } else if (mpcio.mode != MODE_ONLINE) {
         // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
         // (X0*Y1 ^ Y0*X1) = (Z0^Z1)
         bit_t X0, X1;
@@ -629,22 +632,38 @@ SelectTriple MPCTIO::selecttriple()
     return val;
 }
 
-RDPFTriple MPCTIO::rdpftriple(nbits_t depth)
+RDPFTriple MPCTIO::rdpftriple(yield_t &yield, nbits_t depth,
+    bool keep_expansion)
 {
     RDPFTriple val;
-    if (!mpcio.preprocessing && mpcio.player <= 2) {
+    if (mpcio.player <= 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        mpcpio.rdpftriples[thread_num][depth-1].get(val);
+        if (mpcio.mode == MODE_ONLINE) {
+            mpcpio.rdpftriples[thread_num][depth-1].get(val);
+        } else {
+            val = RDPFTriple(*this, yield, depth,
+                keep_expansion);
+            iostream_server() <<
+                val.dpf[(mpcio.player == 0) ? 1 : 2];
+            mpcpio.rdpftriples[thread_num][depth-1].inc();
+        }
     }
     return val;
 }
 
-RDPFPair MPCTIO::rdpfpair(nbits_t depth)
+RDPFPair MPCTIO::rdpfpair(yield_t &yield, nbits_t depth)
 {
     RDPFPair val;
-    if (!mpcio.preprocessing && mpcio.player == 2) {
+    if (mpcio.player == 2) {
         MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-        mpcsrvio.rdpfpairs[thread_num][depth-1].get(val);
+        if (mpcio.mode == MODE_ONLINE) {
+            mpcsrvio.rdpfpairs[thread_num][depth-1].get(val);
+        } else {
+            RDPFTriple trip(*this, yield, depth, true);
+            iostream_p0() >> val.dpf[0];
+            iostream_p1() >> val.dpf[1];
+            mpcsrvio.rdpfpairs[thread_num][depth-1].inc();
+        }
     }
     return val;
 }
@@ -654,12 +673,13 @@ CDPF MPCTIO::cdpf()
     CDPF val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        if (mpcpio.preprocessing) {
+        if (mpcpio.mode != MODE_ONLINE) {
             iostream_server() >> val;
+            mpcpio.cdpfs[thread_num].inc();
         } else {
             mpcpio.cdpfs[thread_num].get(val);
         }
-    } else if (mpcio.preprocessing) {
+    } else if (mpcio.mode != MODE_ONLINE) {
         auto [ cdpf0, cdpf1 ] = CDPF::generate(aes_ops());
         iostream_p0() << cdpf0;
         iostream_p1() << cdpf1;

+ 13 - 10
mpcio.hpp

@@ -17,6 +17,7 @@
 #include <boost/chrono.hpp>
 
 #include "types.hpp"
+#include "corotypes.hpp"
 
 using boost::asio::ip::tcp;
 
@@ -26,12 +27,13 @@ template<typename T, typename N>
 class PreCompStorage {
 public:
     PreCompStorage() : name(N::name), depth(0), count(0) {}
-    PreCompStorage(unsigned player, bool preprocessing,
+    PreCompStorage(unsigned player, ProcessingMode mode,
         const char *filenameprefix, unsigned thread_num);
-    void init(unsigned player, bool preprocessing,
+    void init(unsigned player, ProcessingMode mode,
         const char *filenameprefix, unsigned thread_num, nbits_t depth = 0);
     void get(T& nextval);
 
+    inline void inc() { ++count; }
     inline size_t get_stats() { return count; }
     inline void reset_stats() { count = 0; }
 private:
@@ -166,7 +168,7 @@ public:
 
 struct MPCIO {
     int player;
-    bool preprocessing;
+    ProcessingMode mode;
     size_t num_threads;
     atomic_lamport_t lamport;
     std::vector<size_t> msgs_sent;
@@ -175,8 +177,8 @@ struct MPCIO {
     boost::chrono::steady_clock::time_point steady_start;
     boost::chrono::process_cpu_clock::time_point cpu_start;
 
-    MPCIO(int player, bool preprocessing, size_t num_threads) :
-        player(player), preprocessing(preprocessing),
+    MPCIO(int player, ProcessingMode mode, size_t num_threads) :
+        player(player), mode(mode),
         num_threads(num_threads), lamport(0)
     {
         reset_stats();
@@ -205,7 +207,7 @@ struct MPCPeerIO : public MPCIO {
     // The inner array is indexed by DPF depth (depth d is at entry d-1)
     std::vector<std::array<PreCompStorage<RDPFTriple, RDPFTripleName>,ADDRESS_MAX_BITS>> rdpftriples;
 
-    MPCPeerIO(unsigned player, bool preprocessing,
+    MPCPeerIO(unsigned player, ProcessingMode mode,
             std::deque<tcp::socket> &peersocks,
             std::deque<tcp::socket> &serversocks);
 
@@ -226,7 +228,7 @@ struct MPCServerIO : public MPCIO {
     // The inner array is indexed by DPF depth (depth d is at entry d-1)
     std::vector<std::array<PreCompStorage<RDPFPair, RDPFPairName>,ADDRESS_MAX_BITS>> rdpfpairs;
 
-    MPCServerIO(bool preprocessing,
+    MPCServerIO(ProcessingMode mode,
             std::deque<tcp::socket> &p0socks,
             std::deque<tcp::socket> &p1socks);
 
@@ -339,16 +341,17 @@ public:
 
     // These ones only work during the online phase
     // Computational peers call:
-    RDPFTriple rdpftriple(nbits_t depth);
+    RDPFTriple rdpftriple(yield_t &yield, nbits_t depth,
+        bool keep_expansion = true);
     // The server calls:
-    RDPFPair rdpfpair(nbits_t depth);
+    RDPFPair rdpfpair(yield_t &yield, nbits_t depth);
     // Anyone can call:
     CDPF cdpf();
 
     // Accessors
 
     inline int player() { return mpcio.player; }
-    inline bool preprocessing() { return mpcio.preprocessing; }
+    inline bool preprocessing() { return mpcio.mode == MODE_PREPROCESSING; }
     inline bool is_server() { return mpcio.player == 2; }
     inline size_t& aes_ops() { return mpcio.aes_ops[thread_num]; }
     inline size_t msgs_sent() { return mpcio.msgs_sent[thread_num]; }

+ 2 - 2
mpcops.cpp

@@ -44,7 +44,7 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
     yield();
 
     // Read the peer's x+X and y+Y
-    value_t  peer_blind_x, peer_blind_y;
+    value_t  peer_blind_x=0, peer_blind_y=0;
     tio.recv_peer(&peer_blind_x, nbytes);
     tio.recv_peer(&peer_blind_y, nbytes);
 
@@ -75,7 +75,7 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     yield();
 
     // Read the peer's y+Y
-    value_t  peer_blind_y;
+    value_t  peer_blind_y=0;
     tio.recv_peer(&peer_blind_y, nbytes);
 
     if (tio.player() == 0) {

+ 40 - 32
online.cpp

@@ -168,11 +168,11 @@ static void rdpf_test(MPCIO &mpcio, yield_t &yield,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
             if (mpcio.player == 2) {
-                RDPFPair dp = tio.rdpfpair(depth);
+                RDPFPair dp = tio.rdpfpair(yield, depth);
                 for (int i=0;i<2;++i) {
                     const RDPF &dpf = dp.dpf[i];
                     for (address_t x=0;x<(address_t(1)<<depth);++x) {
@@ -187,7 +187,7 @@ static void rdpf_test(MPCIO &mpcio, yield_t &yield,
                     printf("\n");
                 }
             } else {
-                RDPFTriple dt = tio.rdpftriple(depth);
+                RDPFTriple dt = tio.rdpftriple(yield, depth);
                 for (int i=0;i<3;++i) {
                     const RDPF &dpf = dt.dpf[i];
                     RegXS peer_scaled_xor;
@@ -250,11 +250,11 @@ static void rdpf_timing(MPCIO &mpcio, yield_t &yield,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
             if (mpcio.player == 2) {
-                RDPFPair dp = tio.rdpfpair(depth);
+                RDPFPair dp = tio.rdpfpair(yield, depth);
                 for (int i=0;i<2;++i) {
                     RDPF &dpf = dp.dpf[i];
                     dpf.expand(aes_ops);
@@ -269,7 +269,7 @@ static void rdpf_timing(MPCIO &mpcio, yield_t &yield,
                     printf("\n");
                 }
             } else {
-                RDPFTriple dt = tio.rdpftriple(depth);
+                RDPFTriple dt = tio.rdpftriple(yield, depth);
                 for (int i=0;i<3;++i) {
                     RDPF &dpf = dt.dpf[i];
                     dpf.expand(aes_ops);
@@ -308,11 +308,11 @@ static void rdpfeval_timing(MPCIO &mpcio, yield_t &yield,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth, start] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth, start] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
             if (mpcio.player == 2) {
-                RDPFPair dp = tio.rdpfpair(depth);
+                RDPFPair dp = tio.rdpfpair(yield, depth);
                 for (int i=0;i<2;++i) {
                     RDPF &dpf = dp.dpf[i];
                     RegXS scaled_xor;
@@ -327,7 +327,7 @@ static void rdpfeval_timing(MPCIO &mpcio, yield_t &yield,
                     printf("\n");
                 }
             } else {
-                RDPFTriple dt = tio.rdpftriple(depth);
+                RDPFTriple dt = tio.rdpftriple(yield, depth);
                 for (int i=0;i<3;++i) {
                     RDPF &dpf = dt.dpf[i];
                     RegXS scaled_xor;
@@ -366,11 +366,11 @@ static void tupleeval_timing(MPCIO &mpcio, yield_t &yield,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth, start] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth, start] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
             if (mpcio.player == 2) {
-                RDPFPair dp = tio.rdpfpair(depth);
+                RDPFPair dp = tio.rdpfpair(yield, depth);
                 RegXS scaled_xor0, scaled_xor1;
                 auto ev = StreamEval(dp, start, 0, aes_ops, false);
                 for (address_t x=0;x<(address_t(1)<<depth);++x) {
@@ -387,7 +387,7 @@ static void tupleeval_timing(MPCIO &mpcio, yield_t &yield,
                     dp.dpf[1].scaled_xor.xshare);
                 printf("\n");
             } else {
-                RDPFTriple dt = tio.rdpftriple(depth);
+                RDPFTriple dt = tio.rdpftriple(yield, depth);
                 RegXS scaled_xor0, scaled_xor1, scaled_xor2;
                 auto ev = StreamEval(dt, start, 0, aes_ops, false);
                 for (address_t x=0;x<(address_t(1)<<depth);++x) {
@@ -501,6 +501,7 @@ static void cdpf_test(MPCIO &mpcio, yield_t &yield,
     const PRACOptions &opts, char **args)
 {
     value_t query, target;
+    int iters = 1;
     arc4random_buf(&query, sizeof(query));
     arc4random_buf(&target, sizeof(target));
 
@@ -512,33 +513,40 @@ static void cdpf_test(MPCIO &mpcio, yield_t &yield,
         target = strtoull(*args, NULL, 16);
         ++args;
     }
+    if (*args) {
+        iters = atoi(*args);
+        ++args;
+    }
 
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, &query, &target] {
+        boost::asio::post(pool, [&mpcio, thread_num, &query, &target, &iters] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
-            if (mpcio.player == 2) {
-                auto [ dpf0, dpf1 ] = CDPF::generate(target, aes_ops);
-                DPFnode leaf0 = dpf0.leaf(query, aes_ops);
-                DPFnode leaf1 = dpf1.leaf(query, aes_ops);
-                printf("DPFXOR_{%016lx}(%016lx} = ", target, query);
-                dump_node(leaf0 ^ leaf1);
-            } else {
-                CDPF dpf = tio.cdpf();
-                printf("ashare = %016lX\nxshare = %016lX\n",
-                    dpf.as_target.ashare, dpf.xs_target.xshare);
-                DPFnode leaf = dpf.leaf(query, aes_ops);
-                printf("DPF(%016lx) = ", query);
-                dump_node(leaf);
-                if (mpcio.player == 1) {
-                    tio.iostream_peer() << leaf;
+            for (int i=0;i<iters;++i) {
+                if (mpcio.player == 2) {
+                    tio.cdpf();
+                    auto [ dpf0, dpf1 ] = CDPF::generate(target, aes_ops);
+                    DPFnode leaf0 = dpf0.leaf(query, aes_ops);
+                    DPFnode leaf1 = dpf1.leaf(query, aes_ops);
+                    printf("DPFXOR_{%016lx}(%016lx} = ", target, query);
+                    dump_node(leaf0 ^ leaf1);
                 } else {
-                    DPFnode peerleaf;
-                    tio.iostream_peer() >> peerleaf;
-                    printf("XOR = ");
-                    dump_node(leaf ^ peerleaf);
+                    CDPF dpf = tio.cdpf();
+                    printf("ashare = %016lX\nxshare = %016lX\n",
+                        dpf.as_target.ashare, dpf.xs_target.xshare);
+                    DPFnode leaf = dpf.leaf(query, aes_ops);
+                    printf("DPF(%016lx) = ", query);
+                    dump_node(leaf);
+                    if (mpcio.player == 1) {
+                        tio.iostream_peer() << leaf;
+                    } else {
+                        DPFnode peerleaf;
+                        tio.iostream_peer() >> peerleaf;
+                        printf("XOR = ");
+                        dump_node(leaf ^ peerleaf);
+                    }
                 }
             }
             tio.send();

+ 4 - 2
options.hpp

@@ -1,13 +1,15 @@
 #ifndef __OPTIONS_HPP__
 #define __OPTIONS_HPP__
 
+#include "types.hpp"
+
 struct PRACOptions {
-    bool preprocessing;
+    ProcessingMode mode;
     int num_threads;
     bool expand_rdpfs;
     bool use_xor_db;
 
-    PRACOptions() : preprocessing(false), num_threads(1),
+    PRACOptions() : mode(MODE_ONLINE), num_threads(1),
         expand_rdpfs(true), use_xor_db(false) {}
 };
 

+ 9 - 5
prac.cpp

@@ -10,6 +10,7 @@ static void usage(const char *progname)
 {
     std::cerr << "Usage: " << progname << " [-p] [-t num] player_num player_addrs args ...\n";
     std::cerr << "-p: preprocessing mode\n";
+    std::cerr << "-o: online-only mode\n";
     std::cerr << "-t num: use num threads\n";
     std::cerr << "-c: store DPFs compressed (default is expanded)\n";
     std::cerr << "-x: use XOR-shared database (default is additive)\n";
@@ -28,11 +29,11 @@ static void comp_player_main(boost::asio::io_context &io_context,
     std::deque<tcp::socket> peersocks, serversocks;
     mpcio_setup_computational(player, io_context, p0addr,
         opts.num_threads, peersocks, serversocks);
-    MPCPeerIO mpcio(player, opts.preprocessing, peersocks, serversocks);
+    MPCPeerIO mpcio(player, opts.mode, peersocks, serversocks);
 
     // Queue up the work to be done
     boost::asio::post(io_context, [&]{
-        if (opts.preprocessing) {
+        if (opts.mode == MODE_PREPROCESSING) {
             preprocessing_comp(mpcio, opts, args);
         } else {
             online_main(mpcio, opts, args);
@@ -54,11 +55,11 @@ static void server_player_main(boost::asio::io_context &io_context,
     std::deque<tcp::socket> p0socks, p1socks;
     mpcio_setup_server(io_context, p0addr, p1addr,
         opts.num_threads, p0socks, p1socks);
-    MPCServerIO mpcserverio(opts.preprocessing, p0socks, p1socks);
+    MPCServerIO mpcserverio(opts.mode, p0socks, p1socks);
 
     // Queue up the work to be done
     boost::asio::post(io_context, [&]{
-        if (opts.preprocessing) {
+        if (opts.mode == MODE_PREPROCESSING) {
             preprocessing_server(mpcserverio, opts, args);
         } else {
             online_main(mpcserverio, opts, args);
@@ -83,7 +84,10 @@ int main(int argc, char **argv)
     // Get the options
     while (*args && *args[0] == '-') {
         if (!strcmp("-p", *args)) {
-            opts.preprocessing = true;
+            opts.mode = MODE_PREPROCESSING;
+            ++args;
+        } else if (!strcmp("-o", *args)) {
+            opts.mode = MODE_ONLINEONLY;
             ++args;
         } else if (!strcmp("-t", *args)) {
             if (args[1]) {

+ 4 - 8
preproc.cpp

@@ -117,16 +117,15 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     for (unsigned int i=0; i<num; ++i) {
                         coroutines.emplace_back(
                             [&, tripfile, type](yield_t &yield) {
-                                RDPFTriple rdpftrip(tio, yield, type,
-                                    opts.expand_rdpfs);
+                                RDPFTriple rdpftrip =
+                                    tio.rdpftriple(yield, type, opts.expand_rdpfs);
+                                printf("dep  = %d\n", type);
                                 printf("usi0 = %016lx\n", rdpftrip.dpf[0].unit_sum_inverse);
                                 printf("sxr0 = %016lx\n", rdpftrip.dpf[0].scaled_xor.xshare);
                                 printf("usi1 = %016lx\n", rdpftrip.dpf[1].unit_sum_inverse);
                                 printf("sxr1 = %016lx\n", rdpftrip.dpf[1].scaled_xor.xshare);
                                 printf("usi2 = %016lx\n", rdpftrip.dpf[2].unit_sum_inverse);
                                 printf("sxr2 = %016lx\n", rdpftrip.dpf[2].scaled_xor.xshare);
-                                tio.iostream_server() <<
-                                    rdpftrip.dpf[(mpcio.player == 0) ? 1 : 2];
                                 tripfile.os() << rdpftrip;
                             });
                     }
@@ -224,10 +223,7 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                         for (unsigned int i=0; i<num; ++i) {
                             coroutines.emplace_back(
                                 [&, pairfile, depth](yield_t &yield) {
-                                    RDPFTriple rdpftrip(stio, yield, depth);
-                                    RDPFPair rdpfpair;
-                                    stio.iostream_p0() >> rdpfpair.dpf[0];
-                                    stio.iostream_p1() >> rdpfpair.dpf[1];
+                                    RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
                                 printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
                                 printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
                                 printf("dep0 = %d\n", rdpfpair.dpf[0].depth());

+ 6 - 0
types.hpp

@@ -574,4 +574,10 @@ DEFAULT_IO(HalfTriple)
 DEFAULT_TUPLE_IO(RegAS)
 DEFAULT_TUPLE_IO(RegXS)
 
+enum ProcessingMode {
+    MODE_ONLINE,        // Online mode, after preprocessing has been done
+    MODE_PREPROCESSING, // Preprocessing mode
+    MODE_ONLINEONLY     // Online-only mode, where all computations are
+};                      // done online
+
 #endif