Преглед на файлове

Start filling in some MPC operations

The API could still use some work
Ian Goldberg преди 1 година
родител
ревизия
8077f953e7
променени са 9 файла, в които са добавени 215 реда и са изтрити 2 реда
  1. 3 1
      Makefile
  2. 32 0
      coroutine.hpp
  3. 7 1
      mpcio.hpp
  4. 71 0
      mpcops.cpp
  5. 35 0
      mpcops.hpp
  6. 5 0
      oblivds.cpp
  7. 41 0
      online.cpp
  8. 9 0
      online.hpp
  9. 12 0
      types.hpp

+ 3 - 1
Makefile

@@ -5,7 +5,7 @@ LDFLAGS=-ggdb
 LDLIBS=-lbsd -lboost_system -lboost_context -lboost_thread -lpthread
 
 BIN=oblivds
-OBJS=oblivds.o mpcio.o preproc.o
+OBJS=oblivds.o mpcio.o preproc.o online.o mpcops.o
 
 $(BIN): $(OBJS)
 	g++ $(LDFLAGS) -o $@ $^ $(LDLIBS)
@@ -13,6 +13,8 @@ $(BIN): $(OBJS)
 oblivds.o: preproc.hpp mpcio.hpp types.hpp
 mpcio.o: mpcio.hpp types.hpp
 preproc.o: preproc.hpp mpcio.hpp types.hpp
+online.o: online.hpp mpcops.hpp
+mpcops.o: mpcops.hpp
 
 clean:
 	-rm -f $(BIN) $(OBJS) *.p[01].t*

+ 32 - 0
coroutine.hpp

@@ -0,0 +1,32 @@
+#ifndef __COROUTINE_HPP__
+#define __COROUTINE_HPP__
+
+#include <vector>
+#include <boost/coroutine2/coroutine.hpp>
+
+typedef boost::coroutines2::coroutine<void>::pull_type  coro_t;
+typedef boost::coroutines2::coroutine<void>::push_type  yield_t;
+
+inline void run_coroutines(MPCIO &mpcio, std::vector<coro_t> &coroutines) {
+    // Loop until all the coroutines are finished
+    bool finished = false;
+    while(!finished) {
+        // If this current function is not itself a coroutine (i.e.,
+        // this is the top-level function that launches all the
+        // coroutines), here's where to call send().  Otherwise, call
+        // yield() here to let other coroutines at this level run.
+        mpcio.sendall();
+        finished = true;
+        for (auto &c : coroutines) {
+            // This tests if coroutine c still has work to do (is not
+            // finished)
+            if (c) {
+                finished = false;
+                // Resume coroutine c from the point it yield()ed
+                c();
+            }
+        }
+    }
+}
+
+#endif

+ 7 - 1
mpcio.hpp

@@ -9,7 +9,6 @@
 #include <string>
 
 #include <boost/asio.hpp>
-#include <boost/coroutine2/all.hpp>
 #include <boost/thread.hpp>
 
 #include "types.hpp"
@@ -203,6 +202,13 @@ struct MPCIO {
             peerios.emplace_back(std::move(sock));
         }
     }
+
+    void sendall() {
+        for (auto &p: peerios) {
+            p.send();
+        }
+        serverio.send();
+    }
 };
 
 // A class to represent all of the server party's IO, either to

+ 71 - 0
mpcops.cpp

@@ -0,0 +1,71 @@
+#include "mpcops.hpp"
+
+// as_ denotes additive shares
+// xs_ denotes xor shares
+// bs_ denotes a share of a single bit (which is effectively both an xor
+//     share and an additive share mod 2)
+
+// P0 and P1 both hold additive shares of x and y; compute additive
+// shares of z = x*y. x, y, and z are each at most nbits bits long.
+//
+// Cost:
+// 1 word sent in 1 message
+// consumes 1 MultTriple
+void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+    value_t &as_z, value_t as_x, value_t as_y,
+    MultTriple &T, nbits_t nbits)
+{
+    value_t mask = MASKBITS(nbits);
+    size_t nbytes = BITBYTES(nbits);
+    auto [X, Y, Z] = T;
+
+    // Send x+X and y+Y
+    value_t blind_x = (as_x + X) & mask;
+    value_t blind_y = (as_y + Y) & mask;
+
+    mpcio.peerios[thread_num].queue(&blind_x, nbytes);
+    mpcio.peerios[thread_num].queue(&blind_y, nbytes);
+
+    yield();
+
+    // Read the peer's x+X and y+Y
+    value_t  peer_blind_x, peer_blind_y;
+    mpcio.peerios[thread_num].recv(&peer_blind_x, nbytes);
+    mpcio.peerios[thread_num].recv(&peer_blind_y, nbytes);
+
+    as_z = ((as_x * (as_y + peer_blind_y)) - Y * peer_blind_x + Z) & mask;
+}
+
+// P0 holds the (complete) value x, P1 holds the (complete) value y;
+// compute additive shares of z = x*y.  x, y, and z are each at most
+// nbits bits long.  The parameter is called x, but P1 will pass y
+// there.
+//
+// Cost:
+// 1 word sent in 1 message
+// consumes 1 HalfTriple
+void mpc_valuemul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+    value_t &as_z, value_t x,
+    HalfTriple &H, nbits_t nbits)
+{
+    value_t mask = MASKBITS(nbits);
+    size_t nbytes = BITBYTES(nbits);
+    auto [X, Z] = H;
+
+    // Send x+X
+    value_t blind_x = (x + X) & mask;
+
+    mpcio.peerios[thread_num].queue(&blind_x, nbytes);
+
+    yield();
+
+    // Read the peer's y+Y
+    value_t  peer_blind_y;
+    mpcio.peerios[thread_num].recv(&peer_blind_y, nbytes);
+
+    if (mpcio.player == 0) {
+        as_z = ((x * peer_blind_y) + Z) & mask;
+    } else if (mpcio.player == 1) {
+        as_z = ((-X * peer_blind_y) + Z) & mask;
+    }
+}

+ 35 - 0
mpcops.hpp

@@ -0,0 +1,35 @@
+#ifndef __MPCOPS_HPP__
+#define __MPCOPS_HPP__
+
+#include "types.hpp"
+#include "mpcio.hpp"
+#include "coroutine.hpp"
+
+// as_ denotes additive shares
+// xs_ denotes xor shares
+// bs_ denotes a share of a single bit (which is effectively both an xor
+//     share and an additive share mod 2)
+
+// P0 and P1 both hold additive shares of x and y; compute additive
+// shares of z = x*y. x, y, and z are each at most nbits bits long.
+//
+// Cost:
+// 1 word sent in 1 message
+// consumes 1 MultTriple
+void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+    value_t &as_z, value_t as_x, value_t as_y,
+    MultTriple &T, nbits_t nbits = VALUE_BITS);
+
+// P0 holds the (complete) value x, P1 holds the (complete) value y;
+// compute additive shares of z = x*y.  x, y, and z are each at most
+// nbits bits long.  The parameter is called x, but P1 will pass y
+// there.
+//
+// Cost:
+// 1 word sent in 1 message
+// consumes 1 HalfTriple
+void mpc_valuemul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+    value_t &as_z, value_t x,
+    HalfTriple &H, nbits_t nbits = VALUE_BITS);
+
+#endif

+ 5 - 0
oblivds.cpp

@@ -3,6 +3,7 @@
 
 #include "mpcio.hpp"
 #include "preproc.hpp"
+#include "online.hpp"
 
 static void usage(const char *progname)
 {
@@ -31,6 +32,8 @@ static void comp_player_main(boost::asio::io_context &io_context,
     boost::asio::post(io_context, [&]{
         if (preprocessing) {
             preprocessing_comp(mpcio, num_threads, args);
+        } else {
+            online_comp(mpcio, num_threads, args);
         }
     });
 
@@ -53,6 +56,8 @@ static void server_player_main(boost::asio::io_context &io_context,
     boost::asio::post(io_context, [&]{
         if (preprocessing) {
             preprocessing_server(mpcserverio, args);
+        } else {
+            online_server(mpcserverio, args);
         }
     });
 

+ 41 - 0
online.cpp

@@ -0,0 +1,41 @@
+#include <bsd/stdlib.h> // arc4random_buf
+
+#include "online.hpp"
+#include "mpcops.hpp"
+
+
+void online_comp(MPCIO &mpcio, int num_threads, char **args)
+{
+    nbits_t nbits = VALUE_BITS;
+
+    if (*args) {
+        nbits = atoi(*args);
+    }
+
+    value_t A[5];
+
+    arc4random_buf(A, 3*sizeof(value_t));
+    std::cout << A[0] << "\n";
+    std::cout << A[1] << "\n";
+    std::cout << A[2] << "\n";
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(
+        [&](yield_t &yield) {
+            MultTriple T;
+            mpcio.triples[0].get(T);
+            mpc_mul(mpcio, 0, yield, A[3], A[0], A[1], T, nbits);
+        });
+    coroutines.emplace_back(
+        [&](yield_t &yield) {
+            HalfTriple H;
+            mpcio.halftriples[0].get(H);
+            mpc_valuemul(mpcio, 0, yield, A[4], A[2], H, nbits);
+        });
+    run_coroutines(mpcio, coroutines);
+    std::cout << A[3] << "\n";
+    std::cout << A[4] << "\n";
+}
+
+void online_server(MPCServerIO &mpcio, char **args)
+{
+}

+ 9 - 0
online.hpp

@@ -0,0 +1,9 @@
+#ifndef __ONLINE_HPP__
+#define __ONLINE_HPP__
+
+#include "mpcio.hpp"
+
+void online_comp(MPCIO &mpcio, int num_threads, char **args);
+void online_server(MPCServerIO &mpcio, char **args);
+
+#endif

+ 12 - 0
types.hpp

@@ -49,6 +49,18 @@ typedef uint64_t address_t;
 
 typedef bool bit_t;
 
+// Counts of the number of bits in a value are of this type, which must
+// be large enough to store the _value_ VALUE_BITS
+typedef uint8_t nbits_t;
+
+// Convert a number of bits to the number of bytes required to store (or
+// more to the point, send) them.
+#define BITBYTES(nbits) (((nbits)+7)>>3)
+
+// A mask of this many bits; the test is to prevent 1<<nbits from
+// overflowing if nbits == VALUE_BITS
+#define MASKBITS(nbits) (((nbits) < VALUE_BITS) ? (value_t(1)<<(nbits))-1 : ~0)
+
 // A multiplication triple is a triple (X0,Y0,Z0) held by P0 (and
 // correspondingly (X1,Y1,Z1) held by P1), with all values random,
 // but subject to the relation that X0*Y1 + Y0*X1 = Z0+Z1