Browse Source

Add yield()s in the right places

to ensure that the server stays in step with the computational peers,
even when it's not communicating with them
Ian Goldberg 1 year ago
parent
commit
dfb63bdb5d
8 changed files with 64 additions and 29 deletions
  1. 1 1
      Makefile
  2. 2 2
      duoram.tcc
  3. 14 4
      mpcio.cpp
  4. 4 4
      mpcio.hpp
  5. 3 3
      mpcops.cpp
  6. 3 3
      online.cpp
  7. 35 12
      preproc.cpp
  8. 2 0
      rdpf.cpp

+ 1 - 1
Makefile

@@ -5,7 +5,7 @@ LDFLAGS=-ggdb
 LDLIBS=-lbsd -lboost_system -lboost_context -lboost_chrono -lboost_thread -lpthread
 
 # Enable this to have all communication logged to stdout
-CXXFLAGS += -DVERBOSE_COMMS
+# CXXFLAGS += -DVERBOSE_COMMS
 
 BIN=prac
 SRCS=prac.cpp mpcio.cpp preproc.cpp online.cpp mpcops.cpp rdpf.cpp \

+ 2 - 2
duoram.tcc

@@ -235,7 +235,7 @@ RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
         // Obliviously read the value there
         RegAS val = operator[](index);
         // Compare it to the target
-        CDPF cdpf = tio.cdpf();
+        CDPF cdpf = tio.cdpf(this->yield);
         auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
             val-target, tio.aes_ops());
         if (depth > 1) {
@@ -474,7 +474,7 @@ void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
             val2 = Acoro[idx2];
         });
     // Get a CDPF
-    CDPF cdpf = tio.cdpf();
+    CDPF cdpf = tio.cdpf(yield);
     // Use it to compare the values
     RegAS diff = val1-val2;
     auto [lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());

+ 14 - 4
mpcio.cpp

@@ -563,12 +563,13 @@ void MPCTIO::send()
 // Functions to get precomputed values.  If we're in the online
 // phase, get them from PreCompStorage.  If we're in the
 // preprocessing or online-only phase, read them from the server.
-MultTriple MPCTIO::triple()
+MultTriple MPCTIO::triple(yield_t &yield)
 {
     MultTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
+            yield();
             recv_server(&val, sizeof(val));
             mpcpio.triples[thread_num].inc();
         } else {
@@ -589,16 +590,18 @@ MultTriple MPCTIO::triple()
         T1 = std::make_tuple(X1, Y1, Z1);
         queue_p0(&T0, sizeof(T0));
         queue_p1(&T1, sizeof(T1));
+        yield();
     }
     return val;
 }
 
-HalfTriple MPCTIO::halftriple()
+HalfTriple MPCTIO::halftriple(yield_t &yield)
 {
     HalfTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
+            yield();
             recv_server(&val, sizeof(val));
             mpcpio.halftriples[thread_num].inc();
         } else {
@@ -617,17 +620,19 @@ HalfTriple MPCTIO::halftriple()
         H1 = std::make_tuple(Y1, Z1);
         queue_p0(&H0, sizeof(H0));
         queue_p1(&H1, sizeof(H1));
+        yield();
     }
     return val;
 }
 
-SelectTriple MPCTIO::selecttriple()
+SelectTriple MPCTIO::selecttriple(yield_t &yield)
 {
     SelectTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
             uint8_t Xbyte;
+            yield();
             recv_server(&Xbyte, sizeof(Xbyte));
             val.X = Xbyte & 1;
             recv_server(&val.Y, sizeof(val.Y));
@@ -657,6 +662,7 @@ SelectTriple MPCTIO::selecttriple()
         queue_p1(&X1, sizeof(X1));
         queue_p1(&Y1, sizeof(Y1));
         queue_p1(&Z1, sizeof(Z1));
+        yield();
     }
     return val;
 }
@@ -675,6 +681,7 @@ RDPFTriple MPCTIO::rdpftriple(yield_t &yield, nbits_t depth,
             iostream_server() <<
                 val.dpf[(mpcio.player == 0) ? 1 : 2];
             mpcpio.rdpftriples[thread_num][depth-1].inc();
+            yield();
         }
     }
     return val;
@@ -689,6 +696,7 @@ RDPFPair MPCTIO::rdpfpair(yield_t &yield, nbits_t depth)
             mpcsrvio.rdpfpairs[thread_num][depth-1].get(val);
         } else {
             RDPFTriple trip(*this, yield, depth, true);
+            yield();
             iostream_p0() >> val.dpf[0];
             iostream_p1() >> val.dpf[1];
             mpcsrvio.rdpfpairs[thread_num][depth-1].inc();
@@ -697,12 +705,13 @@ RDPFPair MPCTIO::rdpfpair(yield_t &yield, nbits_t depth)
     return val;
 }
 
-CDPF MPCTIO::cdpf()
+CDPF MPCTIO::cdpf(yield_t &yield)
 {
     CDPF val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
+            yield();
             iostream_server() >> val;
             mpcpio.cdpfs[thread_num].inc();
         } else {
@@ -712,6 +721,7 @@ CDPF MPCTIO::cdpf()
         auto [ cdpf0, cdpf1 ] = CDPF::generate(aes_ops());
         iostream_p0() << cdpf0;
         iostream_p1() << cdpf1;
+        yield();
     }
     return val;
 }

+ 4 - 4
mpcio.hpp

@@ -345,9 +345,9 @@ public:
     // phase, get them from PreCompStorage.  If we're in the
     // preprocessing phase, read them from the server.
 
-    MultTriple triple();
-    HalfTriple halftriple();
-    SelectTriple selecttriple();
+    MultTriple triple(yield_t &yield);
+    HalfTriple halftriple(yield_t &yield);
+    SelectTriple selecttriple(yield_t &yield);
 
     // These ones only work during the online phase
     // Computational peers call:
@@ -356,7 +356,7 @@ public:
     // The server calls:
     RDPFPair rdpfpair(yield_t &yield, nbits_t depth);
     // Anyone can call:
-    CDPF cdpf();
+    CDPF cdpf(yield_t &yield);
 
     // Accessors
 

+ 3 - 3
mpcops.cpp

@@ -32,7 +32,7 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
 {
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Y, Z] = tio.triple();
+    auto [X, Y, Z] = tio.triple(yield);
 
     // Send x+X and y+Y
     value_t blind_x = (x.ashare + X) & mask;
@@ -65,7 +65,7 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
 {
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Z] = tio.halftriple();
+    auto [X, Z] = tio.halftriple(yield);
 
     // Send x+X
     value_t blind_x = (x + X) & mask;
@@ -215,7 +215,7 @@ void mpc_reconstruct_choice(MPCTIO &tio, yield_t &yield,
     DPFnode fext = if128_mask[f.bshare];
 
     // Compute XOR shares of f & (x ^ y)
-    auto [X, Y, Z] = tio.selecttriple();
+    auto [X, Y, Z] = tio.selecttriple(yield);
 
     bit_t blind_f = f.bshare ^ X;
     DPFnode d = x ^ y;

+ 3 - 3
online.cpp

@@ -528,19 +528,19 @@ static void cdpf_test(MPCIO &mpcio, yield_t &yield,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, &query, &target, &iters] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, &query, &target, &iters] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
             for (int i=0;i<iters;++i) {
                 if (mpcio.player == 2) {
-                    tio.cdpf();
+                    tio.cdpf(yield);
                     auto [ dpf0, dpf1 ] = CDPF::generate(target, aes_ops);
                     DPFnode leaf0 = dpf0.leaf(query, aes_ops);
                     DPFnode leaf1 = dpf1.leaf(query, aes_ops);
                     printf("DPFXOR_{%016lx}(%016lx} = ", target, query);
                     dump_node(leaf0 ^ leaf1);
                 } else {
-                    CDPF dpf = tio.cdpf();
+                    CDPF dpf = tio.cdpf(yield);
                     printf("ashare = %016lX\nxshare = %016lX\n",
                         dpf.as_target.ashare, dpf.xs_target.xshare);
                     DPFnode leaf = dpf.leaf(query, aes_ops);

+ 35 - 12
preproc.cpp

@@ -96,20 +96,26 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     auto tripfile = ofiles.open("triples",
                         mpcio.player, thread_num);
 
-                    MultTriple T;
                     for (unsigned int i=0; i<num; ++i) {
-                        T = tio.triple();
-                        tripfile.os() << T;
+                        coroutines.emplace_back(
+                            [&, tripfile](yield_t &yield) {
+                                yield();
+                                MultTriple T = tio.triple(yield);
+                                tripfile.os() << T;
+                            });
                     }
                 } else if (type == 0x81) {
                     // Multiplication half triples
                     auto halffile = ofiles.open("halves",
                         mpcio.player, thread_num);
 
-                    HalfTriple H;
                     for (unsigned int i=0; i<num; ++i) {
-                        H = tio.halftriple();
-                        halffile.os() << H;
+                        coroutines.emplace_back(
+                            [&, halffile](yield_t &yield) {
+                                yield();
+                                HalfTriple H = tio.halftriple(yield);
+                                halffile.os() << H;
+                            });
                     }
                 } else if (type >= 0x01 && type <= 0x30) {
                     // RAM DPFs
@@ -118,6 +124,7 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     for (unsigned int i=0; i<num; ++i) {
                         coroutines.emplace_back(
                             [&, tripfile, type](yield_t &yield) {
+                                yield();
                                 RDPFTriple rdpftrip =
                                     tio.rdpftriple(yield, type, opts.expand_rdpfs);
                                 printf("dep  = %d\n", type);
@@ -135,10 +142,13 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     auto cdpffile = ofiles.open("cdpf",
                         mpcio.player, thread_num);
 
-                    CDPF C;
                     for (unsigned int i=0; i<num; ++i) {
-                        C = tio.cdpf();
-                        cdpffile.os() << C;
+                        coroutines.emplace_back(
+                            [&, cdpffile](yield_t &yield) {
+                                yield();
+                                CDPF C = tio.cdpf(yield);
+                                cdpffile.os() << C;
+                            });
                     }
                 } else if (type == 0x82) {
                     coroutines.emplace_back(
@@ -215,7 +225,11 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                     stio.queue_p1(&num, 4);
 
                     for (unsigned int i=0; i<num; ++i) {
-                        stio.triple();
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                yield();
+                                stio.triple(yield);
+                            });
                     }
                 } else if (!strcmp(type, "h")) {
                     unsigned char typetag = 0x81;
@@ -225,7 +239,11 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                     stio.queue_p1(&num, 4);
 
                     for (unsigned int i=0; i<num; ++i) {
-                        stio.halftriple();
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                yield();
+                                stio.halftriple(yield);
+                            });
                     }
                 } else if (type[0] == 'r') {
                     int depth = atoi(type+1);
@@ -243,6 +261,7 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                         for (unsigned int i=0; i<num; ++i) {
                             coroutines.emplace_back(
                                 [&, pairfile, depth](yield_t &yield) {
+                                    yield();
                                     RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
                                 printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
                                 printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
@@ -266,7 +285,11 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                     stio.queue_p1(&num, 4);
 
                     for (unsigned int i=0; i<num; ++i) {
-                        stio.cdpf();
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                yield();
+                                stio.cdpf(yield);
+                            });
                     }
                 } else if (!strcmp(type, "i")) {
                     unsigned char typetag = 0x82;

+ 2 - 0
rdpf.cpp

@@ -231,6 +231,8 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
                 unit_sum_inverse = inverse_value_t(low_sum);
             }
             cw.push_back(CW);
+        } else if (level == depth-1) {
+            yield();
         }
 
         ++level;