Parcourir la source

Merge the computation peer and server online execution paths

This will be important when we involve the server in the online
porttion of the computation
Ian Goldberg il y a 2 ans
Parent
commit
bba5036964
5 fichiers modifiés avec 31 ajouts et 8 suppressions
  1. 1 0
      mpcio.hpp
  2. 14 0
      mpcops.cpp
  3. 2 2
      oblivds.cpp
  4. 13 4
      online.cpp
  5. 1 2
      online.hpp

+ 1 - 0
mpcio.hpp

@@ -350,6 +350,7 @@ public:
     // Accessors
     inline int player() { return mpcio.player; }
     inline bool preprocessing() { return mpcio.preprocessing; }
+    inline bool is_server() { return mpcio.player == 2; }
 };
 
 // Set up the socket connections between the two computational parties

+ 14 - 0
mpcops.cpp

@@ -16,6 +16,8 @@ void mpc_mul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
     // Compute as_z to be an additive share of (x0*y1+y0*x1)
     mpc_cross(tio, yield, as_z, as_x, as_y, nbits);
@@ -34,6 +36,8 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
     auto [X, Y, Z] = tio.triple();
@@ -67,6 +71,8 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t x,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
     auto [X, Z] = tio.halftriple();
@@ -101,6 +107,8 @@ void mpc_flagmult(MPCTIO &tio, yield_t &yield,
     value_t &as_z, bit_t bs_f, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
@@ -128,6 +136,8 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
     value_t &as_z, bit_t bs_f, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // The desired result is z = x + f * (y-x)
@@ -148,6 +158,8 @@ void mpc_oswap(MPCTIO &tio, yield_t &yield,
     value_t &as_x, value_t &as_y, bit_t bs_f,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // Let s = f*(y-x).  Then the desired result is
@@ -168,6 +180,8 @@ void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
     value_t &as_x, value_t xs_x,
     nbits_t nbits)
 {
+    if (tio.is_server()) return;
+
     const value_t mask = MASKBITS(nbits);
 
     // We use the fact that for any nbits-bit A and B,

+ 2 - 2
oblivds.cpp

@@ -32,7 +32,7 @@ static void comp_player_main(boost::asio::io_context &io_context,
         if (preprocessing) {
             preprocessing_comp(mpcio, num_threads, args);
         } else {
-            online_comp(mpcio, num_threads, args);
+            online_main(mpcio, num_threads, args);
         }
     });
 
@@ -57,7 +57,7 @@ static void server_player_main(boost::asio::io_context &io_context,
         if (preprocessing) {
             preprocessing_server(mpcserverio, num_threads, args);
         } else {
-            online_server(mpcserverio, num_threads, args);
+            online_main(mpcserverio, num_threads, args);
         }
     });
 

+ 13 - 4
online.cpp

@@ -4,7 +4,7 @@
 #include "mpcops.hpp"
 
 
-void online_comp(MPCIO &mpcio, int num_threads, char **args)
+static void online_test(MPCIO &mpcio, int num_threads, char **args)
 {
     nbits_t nbits = VALUE_BITS;
 
@@ -48,10 +48,10 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
 
     // Check the answers
-    if (mpcio.player) {
+    if (mpcio.player == 1) {
         tio.queue_peer(A, memsize*sizeof(value_t));
         tio.send();
-    } else {
+    } else if (mpcio.player == 0) {
         value_t *B = new value_t[memsize];
         value_t *S = new value_t[memsize];
         tio.recv_peer(B, memsize*sizeof(value_t));
@@ -66,6 +66,15 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     delete[] A;
 }
 
-void online_server(MPCServerIO &mpcio, int num_threads, char **args)
+void online_main(MPCIO &mpcio, int num_threads, char **args)
 {
+    if (!*args) {
+        std::cerr << "Mode is required as the first argument when not preprocessing.\n";
+        return;
+    } else if (!strcmp(*args, "test")) {
+        ++args;
+        online_test(mpcio, num_threads, args);
+    } else {
+        std::cerr << "Unknown mode " << *args << "\n";
+    }
 }

+ 1 - 2
online.hpp

@@ -3,7 +3,6 @@
 
 #include "mpcio.hpp"
 
-void online_comp(MPCIO &mpcio, int num_threads, char **args);
-void online_server(MPCServerIO &mpcio, int num_threads, char **args);
+void online_main(MPCIO &mpcio, int num_threads, char **args);
 
 #endif