Browse Source

Merge the computation peer and server online execution paths

This will be important when we involve the server in the online
porttion of the computation
Ian Goldberg 2 years ago
parent
commit
bba5036964
5 changed files with 31 additions and 8 deletions
  1. 1 0
      mpcio.hpp
  2. 14 0
      mpcops.cpp
  3. 2 2
      oblivds.cpp
  4. 13 4
      online.cpp
  5. 1 2
      online.hpp

+ 1 - 0
mpcio.hpp

@@ -350,6 +350,7 @@ public:
     // Accessors
     inline int player() { return mpcio.player; }
     inline bool preprocessing() { return mpcio.preprocessing; }
+    inline bool is_server() { return mpcio.player == 2; }
 };
 
 // Set up the socket connections between the two computational parties

+ 14 - 0
mpcops.cpp

@@ -16,6 +16,8 @@ void mpc_mul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
     // Compute as_z to be an additive share of (x0*y1+y0*x1)
     mpc_cross(tio, yield, as_z, as_x, as_y, nbits);
@@ -34,6 +36,8 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
     auto [X, Y, Z] = tio.triple();
@@ -67,6 +71,8 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t x,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
     auto [X, Z] = tio.halftriple();
@@ -101,6 +107,8 @@ void mpc_flagmult(MPCTIO &tio, yield_t &yield,
     value_t &as_z, bit_t bs_f, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
@@ -128,6 +136,8 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
     value_t &as_z, bit_t bs_f, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // The desired result is z = x + f * (y-x)
@@ -148,6 +158,8 @@ void mpc_oswap(MPCTIO &tio, yield_t &yield,
     value_t &as_x, value_t &as_y, bit_t bs_f,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // Let s = f*(y-x).  Then the desired result is
@@ -168,6 +180,8 @@ void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
     value_t &as_x, value_t xs_x,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // We use the fact that for any nbits-bit A and B,

+ 2 - 2
oblivds.cpp

@@ -32,7 +32,7 @@ static void comp_player_main(boost::asio::io_context &io_context,
         if (preprocessing) {
             preprocessing_comp(mpcio, num_threads, args);
         } else {
-            online_comp(mpcio, num_threads, args);
+            online_main(mpcio, num_threads, args);
         }
     });
 
@@ -57,7 +57,7 @@ static void server_player_main(boost::asio::io_context &io_context,
         if (preprocessing) {
             preprocessing_server(mpcserverio, num_threads, args);
         } else {
-            online_server(mpcserverio, num_threads, args);
+            online_main(mpcserverio, num_threads, args);
         }
     });
 

+ 13 - 4
online.cpp

@@ -4,7 +4,7 @@
 #include "mpcops.hpp"
 
 
-void online_comp(MPCIO &mpcio, int num_threads, char **args)
+static void online_test(MPCIO &mpcio, int num_threads, char **args)
 {
     nbits_t nbits = VALUE_BITS;
 
@@ -48,10 +48,10 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
 
     // Check the answers
-    if (mpcio.player) {
+    if (mpcio.player == 1) {
         tio.queue_peer(A, memsize*sizeof(value_t));
         tio.send();
-    } else {
+    } else if (mpcio.player == 0) {
         value_t *B = new value_t[memsize];
         value_t *S = new value_t[memsize];
         tio.recv_peer(B, memsize*sizeof(value_t));
@@ -66,6 +66,15 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     delete[] A;
 }
 
-void online_server(MPCServerIO &mpcio, int num_threads, char **args)
+void online_main(MPCIO &mpcio, int num_threads, char **args)
 {
+    if (!*args) {
+        std::cerr << "Mode is required as the first argument when not preprocessing.\n";
+        return;
+    } else if (!strcmp(*args, "test")) {
+        ++args;
+        online_test(mpcio, num_threads, args);
+    } else {
+        std::cerr << "Unknown mode " << *args << "\n";
+    }
 }

+ 1 - 2
online.hpp

@@ -3,7 +3,6 @@
 
 #include "mpcio.hpp"
 
-void online_comp(MPCIO &mpcio, int num_threads, char **args);
-void online_server(MPCServerIO &mpcio, int num_threads, char **args);
+void online_main(MPCIO &mpcio, int num_threads, char **args);
 
 #endif