Parcourir la source

Add yield()s in the right places

to ensure that the server stays in step with the computational peers,
even when it's not communicating with them
Ian Goldberg il y a 2 ans
Parent
commit
dfb63bdb5d
8 fichiers modifiés avec 64 ajouts et 29 suppressions
  1. 1 1
      Makefile
  2. 2 2
      duoram.tcc
  3. 14 4
      mpcio.cpp
  4. 4 4
      mpcio.hpp
  5. 3 3
      mpcops.cpp
  6. 3 3
      online.cpp
  7. 35 12
      preproc.cpp
  8. 2 0
      rdpf.cpp

+ 1 - 1
Makefile

@@ -5,7 +5,7 @@ LDFLAGS=-ggdb
 LDLIBS=-lbsd -lboost_system -lboost_context -lboost_chrono -lboost_thread -lpthread
 
 # Enable this to have all communication logged to stdout
-CXXFLAGS += -DVERBOSE_COMMS
+# CXXFLAGS += -DVERBOSE_COMMS
 
 BIN=prac
 SRCS=prac.cpp mpcio.cpp preproc.cpp online.cpp mpcops.cpp rdpf.cpp \

+ 2 - 2
duoram.tcc

@@ -235,7 +235,7 @@ RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
         // Obliviously read the value there
         RegAS val = operator[](index);
         // Compare it to the target
-        CDPF cdpf = tio.cdpf();
+        CDPF cdpf = tio.cdpf(this->yield);
         auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
             val-target, tio.aes_ops());
         if (depth > 1) {
@@ -474,7 +474,7 @@ void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
             val2 = Acoro[idx2];
         });
     // Get a CDPF
-    CDPF cdpf = tio.cdpf();
+    CDPF cdpf = tio.cdpf(yield);
     // Use it to compare the values
     RegAS diff = val1-val2;
     auto [lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());

+ 14 - 4
mpcio.cpp

@@ -563,12 +563,13 @@ void MPCTIO::send()
 // Functions to get precomputed values.  If we're in the online
 // phase, get them from PreCompStorage.  If we're in the
 // preprocessing or online-only phase, read them from the server.
-MultTriple MPCTIO::triple()
+MultTriple MPCTIO::triple(yield_t &yield)
 {
     MultTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
+            yield();
             recv_server(&val, sizeof(val));
             mpcpio.triples[thread_num].inc();
         } else {
@@ -589,16 +590,18 @@ MultTriple MPCTIO::triple()
         T1 = std::make_tuple(X1, Y1, Z1);
         queue_p0(&T0, sizeof(T0));
         queue_p1(&T1, sizeof(T1));
+        yield();
     }
     return val;
 }
 
-HalfTriple MPCTIO::halftriple()
+HalfTriple MPCTIO::halftriple(yield_t &yield)
 {
     HalfTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
+            yield();
             recv_server(&val, sizeof(val));
             mpcpio.halftriples[thread_num].inc();
         } else {
@@ -617,17 +620,19 @@ HalfTriple MPCTIO::halftriple()
         H1 = std::make_tuple(Y1, Z1);
         queue_p0(&H0, sizeof(H0));
         queue_p1(&H1, sizeof(H1));
+        yield();
     }
     return val;
 }
 
-SelectTriple MPCTIO::selecttriple()
+SelectTriple MPCTIO::selecttriple(yield_t &yield)
 {
     SelectTriple val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
             uint8_t Xbyte;
+            yield();
             recv_server(&Xbyte, sizeof(Xbyte));
             val.X = Xbyte & 1;
             recv_server(&val.Y, sizeof(val.Y));
@@ -657,6 +662,7 @@ SelectTriple MPCTIO::selecttriple()
         queue_p1(&X1, sizeof(X1));
         queue_p1(&Y1, sizeof(Y1));
         queue_p1(&Z1, sizeof(Z1));
+        yield();
     }
     return val;
 }
@@ -675,6 +681,7 @@ RDPFTriple MPCTIO::rdpftriple(yield_t &yield, nbits_t depth,
             iostream_server() <<
                 val.dpf[(mpcio.player == 0) ? 1 : 2];
             mpcpio.rdpftriples[thread_num][depth-1].inc();
+            yield();
         }
     }
     return val;
@@ -689,6 +696,7 @@ RDPFPair MPCTIO::rdpfpair(yield_t &yield, nbits_t depth)
             mpcsrvio.rdpfpairs[thread_num][depth-1].get(val);
         } else {
             RDPFTriple trip(*this, yield, depth, true);
+            yield();
             iostream_p0() >> val.dpf[0];
             iostream_p1() >> val.dpf[1];
             mpcsrvio.rdpfpairs[thread_num][depth-1].inc();
@@ -697,12 +705,13 @@ RDPFPair MPCTIO::rdpfpair(yield_t &yield, nbits_t depth)
     return val;
 }
 
-CDPF MPCTIO::cdpf()
+CDPF MPCTIO::cdpf(yield_t &yield)
 {
     CDPF val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
+            yield();
             iostream_server() >> val;
             mpcpio.cdpfs[thread_num].inc();
         } else {
@@ -712,6 +721,7 @@ CDPF MPCTIO::cdpf()
         auto [ cdpf0, cdpf1 ] = CDPF::generate(aes_ops());
         iostream_p0() << cdpf0;
         iostream_p1() << cdpf1;
+        yield();
     }
     return val;
 }

+ 4 - 4
mpcio.hpp

@@ -345,9 +345,9 @@ public:
     // phase, get them from PreCompStorage.  If we're in the
     // preprocessing phase, read them from the server.
 
-    MultTriple triple();
-    HalfTriple halftriple();
-    SelectTriple selecttriple();
+    MultTriple triple(yield_t &yield);
+    HalfTriple halftriple(yield_t &yield);
+    SelectTriple selecttriple(yield_t &yield);
 
     // These ones only work during the online phase
     // Computational peers call:
@@ -356,7 +356,7 @@ public:
     // The server calls:
     RDPFPair rdpfpair(yield_t &yield, nbits_t depth);
     // Anyone can call:
-    CDPF cdpf();
+    CDPF cdpf(yield_t &yield);
 
     // Accessors
 

+ 3 - 3
mpcops.cpp

@@ -32,7 +32,7 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
 {
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Y, Z] = tio.triple();
+    auto [X, Y, Z] = tio.triple(yield);
 
     // Send x+X and y+Y
     value_t blind_x = (x.ashare + X) & mask;
@@ -65,7 +65,7 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
 {
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Z] = tio.halftriple();
+    auto [X, Z] = tio.halftriple(yield);
 
     // Send x+X
     value_t blind_x = (x + X) & mask;
@@ -215,7 +215,7 @@ void mpc_reconstruct_choice(MPCTIO &tio, yield_t &yield,
     DPFnode fext = if128_mask[f.bshare];
 
     // Compute XOR shares of f & (x ^ y)
-    auto [X, Y, Z] = tio.selecttriple();
+    auto [X, Y, Z] = tio.selecttriple(yield);
 
     bit_t blind_f = f.bshare ^ X;
     DPFnode d = x ^ y;

+ 3 - 3
online.cpp

@@ -528,19 +528,19 @@ static void cdpf_test(MPCIO &mpcio, yield_t &yield,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, &query, &target, &iters] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, &query, &target, &iters] {
             MPCTIO tio(mpcio, thread_num);
             size_t &aes_ops = tio.aes_ops();
             for (int i=0;i<iters;++i) {
                 if (mpcio.player == 2) {
-                    tio.cdpf();
+                    tio.cdpf(yield);
                     auto [ dpf0, dpf1 ] = CDPF::generate(target, aes_ops);
                     DPFnode leaf0 = dpf0.leaf(query, aes_ops);
                     DPFnode leaf1 = dpf1.leaf(query, aes_ops);
                     printf("DPFXOR_{%016lx}(%016lx} = ", target, query);
                     dump_node(leaf0 ^ leaf1);
                 } else {
-                    CDPF dpf = tio.cdpf();
+                    CDPF dpf = tio.cdpf(yield);
                     printf("ashare = %016lX\nxshare = %016lX\n",
                         dpf.as_target.ashare, dpf.xs_target.xshare);
                     DPFnode leaf = dpf.leaf(query, aes_ops);

+ 35 - 12
preproc.cpp

@@ -96,20 +96,26 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     auto tripfile = ofiles.open("triples",
                         mpcio.player, thread_num);
 
-                    MultTriple T;
                     for (unsigned int i=0; i<num; ++i) {
-                        T = tio.triple();
-                        tripfile.os() << T;
+                        coroutines.emplace_back(
+                            [&, tripfile](yield_t &yield) {
+                                yield();
+                                MultTriple T = tio.triple(yield);
+                                tripfile.os() << T;
+                            });
                     }
                 } else if (type == 0x81) {
                     // Multiplication half triples
                     auto halffile = ofiles.open("halves",
                         mpcio.player, thread_num);
 
-                    HalfTriple H;
                     for (unsigned int i=0; i<num; ++i) {
-                        H = tio.halftriple();
-                        halffile.os() << H;
+                        coroutines.emplace_back(
+                            [&, halffile](yield_t &yield) {
+                                yield();
+                                HalfTriple H = tio.halftriple(yield);
+                                halffile.os() << H;
+                            });
                     }
                 } else if (type >= 0x01 && type <= 0x30) {
                     // RAM DPFs
@@ -118,6 +124,7 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     for (unsigned int i=0; i<num; ++i) {
                         coroutines.emplace_back(
                             [&, tripfile, type](yield_t &yield) {
+                                yield();
                                 RDPFTriple rdpftrip =
                                     tio.rdpftriple(yield, type, opts.expand_rdpfs);
                                 printf("dep  = %d\n", type);
@@ -135,10 +142,13 @@ void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
                     auto cdpffile = ofiles.open("cdpf",
                         mpcio.player, thread_num);
 
-                    CDPF C;
                     for (unsigned int i=0; i<num; ++i) {
-                        C = tio.cdpf();
-                        cdpffile.os() << C;
+                        coroutines.emplace_back(
+                            [&, cdpffile](yield_t &yield) {
+                                yield();
+                                CDPF C = tio.cdpf(yield);
+                                cdpffile.os() << C;
+                            });
                     }
                 } else if (type == 0x82) {
                     coroutines.emplace_back(
@@ -215,7 +225,11 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                     stio.queue_p1(&num, 4);
 
                     for (unsigned int i=0; i<num; ++i) {
-                        stio.triple();
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                yield();
+                                stio.triple(yield);
+                            });
                     }
                 } else if (!strcmp(type, "h")) {
                     unsigned char typetag = 0x81;
@@ -225,7 +239,11 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                     stio.queue_p1(&num, 4);
 
                     for (unsigned int i=0; i<num; ++i) {
-                        stio.halftriple();
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                yield();
+                                stio.halftriple(yield);
+                            });
                     }
                 } else if (type[0] == 'r') {
                     int depth = atoi(type+1);
@@ -243,6 +261,7 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                         for (unsigned int i=0; i<num; ++i) {
                             coroutines.emplace_back(
                                 [&, pairfile, depth](yield_t &yield) {
+                                    yield();
                                     RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
                                 printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
                                 printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
@@ -266,7 +285,11 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                     stio.queue_p1(&num, 4);
 
                     for (unsigned int i=0; i<num; ++i) {
-                        stio.cdpf();
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                yield();
+                                stio.cdpf(yield);
+                            });
                     }
                 } else if (!strcmp(type, "i")) {
                     unsigned char typetag = 0x82;

+ 2 - 0
rdpf.cpp

@@ -231,6 +231,8 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
                 unit_sum_inverse = inverse_value_t(low_sum);
             }
             cw.push_back(CW);
+        } else if (level == depth-1) {
+            yield();
         }
 
         ++level;