浏览代码

Improve the API a little bit

Ian Goldberg 2 年之前
父节点
当前提交
995b8c3fdf
共有 4 个文件被更改,包括 34 次插入13 次删除
  1. 26 1
      mpcio.hpp
  2. 4 4
      mpcops.cpp
  3. 2 2
      mpcops.hpp
  4. 2 6
      online.cpp

+ 26 - 1
mpcio.hpp

@@ -180,6 +180,7 @@ public:
 
 struct MPCIO {
     int player;
+    bool preprocessing;
     // We use a deque here instead of a vector because you can't have a
     // vector of a type without a copy constructor (tcp::socket is the
     // culprit), but you can have a deque of those for some reason.
@@ -190,7 +191,8 @@ struct MPCIO {
 
     MPCIO(unsigned player, bool preprocessing,
             std::deque<tcp::socket> &peersocks, tcp::socket &&serversock) :
-        player(player), serverio(std::move(serversock)) {
+        player(player), preprocessing(preprocessing),
+        serverio(std::move(serversock)) {
         unsigned num_threads = unsigned(peersocks.size());
         for (unsigned i=0; i<num_threads; ++i) {
             triples.emplace_back(player, preprocessing, "triples", i);
@@ -209,6 +211,29 @@ struct MPCIO {
         }
         serverio.send();
     }
+
+    // Functions to get precomputed values.  If we're in the online
+    // phase, get them from PreCompStorage.  If we're in the
+    // preprocessing phase, read them from the server.
+    MultTriple triple(unsigned thread_num) {
+        MultTriple val;
+        if (preprocessing) {
+            serverio.recv(boost::asio::buffer(&val, sizeof(val)));
+        } else {
+            triples[thread_num].get(val);
+        }
+        return val;
+    }
+
+    HalfTriple halftriple(unsigned thread_num) {
+        HalfTriple val;
+        if (preprocessing) {
+            serverio.recv(boost::asio::buffer(&val, sizeof(val)));
+        } else {
+            halftriples[thread_num].get(val);
+        }
+        return val;
+    }
 };
 
 // A class to represent all of the server party's IO, either to

+ 4 - 4
mpcops.cpp

@@ -13,11 +13,11 @@
 // consumes 1 MultTriple
 void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
-    MultTriple &T, nbits_t nbits)
+    nbits_t nbits)
 {
     value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Y, Z] = T;
+    auto [X, Y, Z] = mpcio.triple(thread_num);
 
     // Send x+X and y+Y
     value_t blind_x = (as_x + X) & mask;
@@ -46,11 +46,11 @@ void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
 // consumes 1 HalfTriple
 void mpc_valuemul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
     value_t &as_z, value_t x,
-    HalfTriple &H, nbits_t nbits)
+    nbits_t nbits)
 {
     value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Z] = H;
+    auto [X, Z] = mpcio.halftriple(thread_num);
 
     // Send x+X
     value_t blind_x = (x + X) & mask;

+ 2 - 2
mpcops.hpp

@@ -18,7 +18,7 @@
 // consumes 1 MultTriple
 void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
-    MultTriple &T, nbits_t nbits = VALUE_BITS);
+    nbits_t nbits = VALUE_BITS);
 
 // P0 holds the (complete) value x, P1 holds the (complete) value y;
 // compute additive shares of z = x*y.  x, y, and z are each at most
@@ -30,6 +30,6 @@ void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
 // consumes 1 HalfTriple
 void mpc_valuemul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
     value_t &as_z, value_t x,
-    HalfTriple &H, nbits_t nbits = VALUE_BITS);
+    nbits_t nbits = VALUE_BITS);
 
 #endif

+ 2 - 6
online.cpp

@@ -21,15 +21,11 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     std::vector<coro_t> coroutines;
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            MultTriple T;
-            mpcio.triples[0].get(T);
-            mpc_mul(mpcio, 0, yield, A[3], A[0], A[1], T, nbits);
+            mpc_mul(mpcio, 0, yield, A[3], A[0], A[1], nbits);
         });
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            HalfTriple H;
-            mpcio.halftriples[0].get(H);
-            mpc_valuemul(mpcio, 0, yield, A[4], A[2], H, nbits);
+            mpc_valuemul(mpcio, 0, yield, A[4], A[2], nbits);
         });
     run_coroutines(mpcio, coroutines);
     std::cout << A[3] << "\n";