Browse Source

More merging of computational peer and server execution paths

Ian Goldberg 1 year ago
parent
commit
7af573e9da
4 changed files with 97 additions and 99 deletions
  1. 79 36
      mpcio.hpp
  2. 0 14
      mpcops.cpp
  3. 11 6
      online.cpp
  4. 7 43
      preproc.cpp

+ 79 - 36
mpcio.hpp

@@ -7,6 +7,7 @@
 #include <deque>
 #include <queue>
 #include <string>
+#include <bsd/stdlib.h> // arc4random_buf
 
 #include <boost/asio.hpp>
 #include <boost/thread.hpp>
@@ -254,57 +255,69 @@ public:
     // Queue up data to the peer or to the server
 
     void queue_peer(const void *data, size_t len) {
-        assert(mpcio.player < 2);
-        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        mpcpio.peerios[thread_num].queue(data, len);
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            mpcpio.peerios[thread_num].queue(data, len);
+        }
     }
 
     void queue_server(const void *data, size_t len) {
-        assert(mpcio.player < 2);
-        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        mpcpio.serverios[thread_num].queue(data, len);
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            mpcpio.serverios[thread_num].queue(data, len);
+        }
     }
 
     // Receive data from the peer or to the server
 
     size_t recv_peer(void *data, size_t len) {
-        assert(mpcio.player < 2);
-        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        return mpcpio.peerios[thread_num].recv(data, len);
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            return mpcpio.peerios[thread_num].recv(data, len);
+        }
+        return 0;
     }
 
     size_t recv_server(void *data, size_t len) {
-        assert(mpcio.player < 2);
-        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-        return mpcpio.serverios[thread_num].recv(data, len);
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            return mpcpio.serverios[thread_num].recv(data, len);
+        }
+        return 0;
     }
 
     // Queue up data to p0 or p1
 
     void queue_p0(const void *data, size_t len) {
-        assert(mpcio.player == 2);
-        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-        mpcsrvio.p0ios[thread_num].queue(data, len);
+        if (mpcio.player == 2) {
+            MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
+            mpcsrvio.p0ios[thread_num].queue(data, len);
+        }
     }
 
     void queue_p1(const void *data, size_t len) {
-        assert(mpcio.player == 2);
-        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-        mpcsrvio.p1ios[thread_num].queue(data, len);
+        if (mpcio.player == 2) {
+            MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
+            mpcsrvio.p1ios[thread_num].queue(data, len);
+        }
     }
 
     // Receive data from p0 or p1
 
     size_t recv_p0(void *data, size_t len) {
-        assert(mpcio.player == 2);
-        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-        return mpcsrvio.p0ios[thread_num].recv(data, len);
+        if (mpcio.player == 2) {
+            MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
+            return mpcsrvio.p0ios[thread_num].recv(data, len);
+        }
+        return 0;
     }
 
     size_t recv_p1(void *data, size_t len) {
-        assert(mpcio.player == 2);
-        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-        return mpcsrvio.p1ios[thread_num].recv(data, len);
+        if (mpcio.player == 2) {
+            MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
+            return mpcsrvio.p1ios[thread_num].recv(data, len);
+        }
+        return 0;
     }
 
     // Send all queued data for this thread
@@ -324,25 +337,55 @@ public:
     // phase, get them from PreCompStorage.  If we're in the
     // preprocessing phase, read them from the server.
     MultTriple triple() {
-        assert(mpcio.player < 2);
-        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         MultTriple val;
-        if (mpcpio.preprocessing) {
-            mpcpio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
-        } else {
-            mpcpio.triples[thread_num].get(val);
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            if (mpcpio.preprocessing) {
+                recv_server(&val, sizeof(val));
+            } else {
+                mpcpio.triples[thread_num].get(val);
+            }
+        } else if (mpcio.preprocessing) {
+            // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
+            // (X0*Y1 + Y0*X1) = (Z0+Z1)
+            value_t X0, Y0, Z0, X1, Y1, Z1;
+            arc4random_buf(&X0, sizeof(X0));
+            arc4random_buf(&Y0, sizeof(Y0));
+            arc4random_buf(&Z0, sizeof(Z0));
+            arc4random_buf(&X1, sizeof(X1));
+            arc4random_buf(&Y1, sizeof(Y1));
+            Z1 = X0 * Y1 + X1 * Y0 - Z0;
+            MultTriple T0, T1;
+            T0 = std::make_tuple(X0, Y0, Z0);
+            T1 = std::make_tuple(X1, Y1, Z1);
+            queue_p0(&T0, sizeof(T0));
+            queue_p1(&T1, sizeof(T1));
         }
         return val;
     }
 
     HalfTriple halftriple() {
-        assert(mpcio.player < 2);
-        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         HalfTriple val;
-        if (mpcpio.preprocessing) {
-            mpcpio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
-        } else {
-            mpcpio.halftriples[thread_num].get(val);
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            if (mpcpio.preprocessing) {
+                mpcpio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
+            } else {
+                mpcpio.halftriples[thread_num].get(val);
+            }
+        } else if (mpcio.preprocessing) {
+            // Create half-triples (X0,Z0),(Y1,Z1) such that
+            // X0*Y1 = Z0 + Z1
+            value_t X0, Z0, Y1, Z1;
+            arc4random_buf(&X0, sizeof(X0));
+            arc4random_buf(&Z0, sizeof(Z0));
+            arc4random_buf(&Y1, sizeof(Y1));
+            Z1 = X0 * Y1 - Z0;
+            HalfTriple H0, H1;
+            H0 = std::make_tuple(X0, Z0);
+            H1 = std::make_tuple(Y1, Z1);
+            queue_p0(&H0, sizeof(H0));
+            queue_p1(&H1, sizeof(H1));
         }
         return val;
     }

+ 0 - 14
mpcops.cpp

@@ -16,8 +16,6 @@ void mpc_mul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
     // Compute as_z to be an additive share of (x0*y1+y0*x1)
     mpc_cross(tio, yield, as_z, as_x, as_y, nbits);
@@ -36,8 +34,6 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
     auto [X, Y, Z] = tio.triple();
@@ -71,8 +67,6 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t x,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
     auto [X, Z] = tio.halftriple();
@@ -107,8 +101,6 @@ void mpc_flagmult(MPCTIO &tio, yield_t &yield,
     value_t &as_z, bit_t bs_f, value_t as_y,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
 
     // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
@@ -136,8 +128,6 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
     value_t &as_z, bit_t bs_f, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
 
     // The desired result is z = x + f * (y-x)
@@ -158,8 +148,6 @@ void mpc_oswap(MPCTIO &tio, yield_t &yield,
     value_t &as_x, value_t &as_y, bit_t bs_f,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
 
     // Let s = f*(y-x).  Then the desired result is
@@ -180,8 +168,6 @@ void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
     value_t &as_x, value_t xs_x,
     nbits_t nbits)
 {
-    if (tio.is_server()) return;
-
     const value_t mask = MASKBITS(nbits);
 
     // We use the fact that for any nbits-bit A and B,

+ 11 - 6
online.cpp

@@ -15,13 +15,16 @@ static void online_test(MPCIO &mpcio, int num_threads, char **args)
     size_t memsize = 13;
 
     MPCTIO tio(mpcio, 0);
+    bool is_server = (mpcio.player == 2);
 
     value_t *A = new value_t[memsize];
 
-    arc4random_buf(A, memsize*sizeof(value_t));
-    A[5] &= 1;
-    A[8] &= 1;
-    printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
+    if (!is_server) {
+        arc4random_buf(A, memsize*sizeof(value_t));
+        A[5] &= 1;
+        A[8] &= 1;
+        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
+    }
     std::vector<coro_t> coroutines;
     coroutines.emplace_back(
         [&](yield_t &yield) {
@@ -44,8 +47,10 @@ static void online_test(MPCIO &mpcio, int num_threads, char **args)
             mpc_xs_to_as(tio, yield, A[12], A[11], nbits);
         });
     run_coroutines(tio, coroutines);
-    printf("\n");
-    printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
+    if (!is_server) {
+        printf("\n");
+        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
+    }
 
     // Check the answers
     if (mpcio.player == 1) {

+ 7 - 43
preproc.cpp

@@ -1,5 +1,4 @@
 #include <vector>
-#include <bsd/stdlib.h> // arc4random_buf
 
 #include "types.hpp"
 #include "preproc.hpp"
@@ -57,8 +56,7 @@ void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
 
                     MultTriple T;
                     for (unsigned int i=0; i<num; ++i) {
-                        res = tio.recv_server(&T, sizeof(T));
-                        if (res < sizeof(T)) break;
+                        T = tio.triple();
                         tripfile.write((const char *)&T, sizeof(T));
                     }
                     tripfile.close();
@@ -81,44 +79,6 @@ void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
     pool.join();
 }
 
-// Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
-// (X0*Y1 + Y0*X1) = (Z0+Z1)
-static void create_triples(MPCTIO &stio, unsigned num)
-{
-    for (unsigned int i=0; i<num; ++i) {
-        value_t X0, Y0, Z0, X1, Y1, Z1;
-        arc4random_buf(&X0, sizeof(X0));
-        arc4random_buf(&Y0, sizeof(Y0));
-        arc4random_buf(&Z0, sizeof(Z0));
-        arc4random_buf(&X1, sizeof(X1));
-        arc4random_buf(&Y1, sizeof(Y1));
-        Z1 = X0 * Y1 + X1 * Y0 - Z0;
-        MultTriple T0, T1;
-        T0 = std::make_tuple(X0, Y0, Z0);
-        T1 = std::make_tuple(X1, Y1, Z1);
-        stio.queue_p0(&T0, sizeof(T0));
-        stio.queue_p1(&T1, sizeof(T1));
-    }
-}
-
-// Create half-triples (X0,Z0),(Y1,Z1) such that
-// X0*Y1 = Z0 + Z1
-static void create_halftriples(MPCTIO &stio, unsigned num)
-{
-    for (unsigned int i=0; i<num; ++i) {
-        value_t X0, Z0, Y1, Z1;
-        arc4random_buf(&X0, sizeof(X0));
-        arc4random_buf(&Z0, sizeof(Z0));
-        arc4random_buf(&Y1, sizeof(Y1));
-        Z1 = X0 * Y1 - Z0;
-        HalfTriple H0, H1;
-        H0 = std::make_tuple(X0, Z0);
-        H1 = std::make_tuple(Y1, Z1);
-        stio.queue_p0(&H0, sizeof(H0));
-        stio.queue_p1(&H1, sizeof(H1));
-    }
-}
-
 void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
 {
     boost::asio::thread_pool pool(num_threads);
@@ -145,7 +105,9 @@ void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
                     stio.queue_p1(&typetag, 1);
                     stio.queue_p1(&num, 4);
 
-                    create_triples(stio, num);
+                    for (unsigned int i=0; i<num; ++i) {
+                        stio.triple();
+                    }
                 } else if (!strcmp(type, "h")) {
                     unsigned char typetag = 0x81;
                     stio.queue_p0(&typetag, 1);
@@ -153,7 +115,9 @@ void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
                     stio.queue_p1(&typetag, 1);
                     stio.queue_p1(&num, 4);
 
-                    create_halftriples(stio, num);
+                    for (unsigned int i=0; i<num; ++i) {
+                        stio.halftriple();
+                    }
                 }
                 free(arg);
                 ++threadargs;