瀏覽代碼

Make MPCIO variants for computation peers and the server into subclasses

Ian Goldberg 2 年之前
父節點
當前提交
36c4f65049
共有 3 個文件被更改,包括 99 次插入81 次删除
  1. 95 77
      mpcio.hpp
  2. 1 1
      oblivds.cpp
  3. 3 3
      preproc.cpp

+ 95 - 77
mpcio.hpp

@@ -175,12 +175,22 @@ public:
     }
 };
 
-// A class to represent all of a computation party's IO, either to other
-// parties or to local storage
+// A base class to represent all of a computation peer or server's IO,
+// either to other parties or to local storage (the computation and
+// server cases are separate subclasses below).
 
 struct MPCIO {
     int player;
     bool preprocessing;
+
+    MPCIO(int player, bool preprocessing) :
+        player(player), preprocessing(preprocessing) {}
+};
+
+// A class to represent all of a computation peer's IO, either to other
+// parties or to local storage
+
+struct MPCPeerIO : public MPCIO {
     // We use a deque here instead of a vector because you can't have a
     // vector of a type without a copy constructor (tcp::socket is the
     // culprit), but you can have a deque of those for some reason.
@@ -189,10 +199,11 @@ struct MPCIO {
     std::vector<PreCompStorage<MultTriple>> triples;
     std::vector<PreCompStorage<HalfTriple>> halftriples;
 
-    MPCIO(unsigned player, bool preprocessing,
+    MPCPeerIO(unsigned player, bool preprocessing,
             std::deque<tcp::socket> &peersocks,
-            std::deque<tcp::socket> &serversocks):
-        player(player), preprocessing(preprocessing) {
+            std::deque<tcp::socket> &serversocks) :
+        MPCIO(player, preprocessing)
+    {
         unsigned num_threads = unsigned(peersocks.size());
         for (unsigned i=0; i<num_threads; ++i) {
             triples.emplace_back(player, preprocessing, "triples", i);
@@ -209,6 +220,27 @@ struct MPCIO {
     }
 };
 
+// A class to represent all of the server party's IO, either to
+// computational parties or to local storage
+
+struct MPCServerIO : public MPCIO {
+    std::deque<MPCSingleIO> p0ios;
+    std::deque<MPCSingleIO> p1ios;
+
+    MPCServerIO(bool preprocessing,
+            std::deque<tcp::socket> &p0socks,
+            std::deque<tcp::socket> &p1socks) :
+        MPCIO(2, preprocessing)
+    {
+        for (auto &&sock : p0socks) {
+            p0ios.emplace_back(std::move(sock));
+        }
+        for (auto &&sock : p1socks) {
+            p1ios.emplace_back(std::move(sock));
+        }
+    }
+};
+
 // A handle to one thread's sockets and streams in a MPCIO
 
 class MPCTIO {
@@ -222,116 +254,102 @@ public:
     // Queue up data to the peer or to the server
 
     void queue_peer(const void *data, size_t len) {
-        mpcio.peerios[thread_num].queue(data, len);
+        assert(mpcio.player < 2);
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        mpcpio.peerios[thread_num].queue(data, len);
     }
 
     void queue_server(const void *data, size_t len) {
-        mpcio.serverios[thread_num].queue(data, len);
+        assert(mpcio.player < 2);
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        mpcpio.serverios[thread_num].queue(data, len);
     }
 
     // Receive data from the peer or to the server
 
     size_t recv_peer(void *data, size_t len) {
-        return mpcio.peerios[thread_num].recv(data, len);
+        assert(mpcio.player < 2);
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        return mpcpio.peerios[thread_num].recv(data, len);
     }
 
     size_t recv_server(void *data, size_t len) {
-        return mpcio.serverios[thread_num].recv(data, len);
+        assert(mpcio.player < 2);
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        return mpcpio.serverios[thread_num].recv(data, len);
     }
 
-    // Send all queued data for this thread
-    void send() {
-        mpcio.peerios[thread_num].send();
-        mpcio.serverios[thread_num].send();
-    }
-
-    // Functions to get precomputed values.  If we're in the online
-    // phase, get them from PreCompStorage.  If we're in the
-    // preprocessing phase, read them from the server.
-    MultTriple triple() {
-        MultTriple val;
-        if (mpcio.preprocessing) {
-            mpcio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
-        } else {
-            mpcio.triples[thread_num].get(val);
-        }
-        return val;
-    }
-
-    HalfTriple halftriple() {
-        HalfTriple val;
-        if (mpcio.preprocessing) {
-            mpcio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
-        } else {
-            mpcio.halftriples[thread_num].get(val);
-        }
-        return val;
-    }
-
-    // Accessors
-    inline int player() { return mpcio.player; }
-    inline bool preprocessing() { return mpcio.preprocessing; }
-};
-
-// A class to represent all of the server party's IO, either to
-// computational parties or to local storage
-
-struct MPCServerIO {
-    bool preprocessing;
-    std::deque<MPCSingleIO> p0ios;
-    std::deque<MPCSingleIO> p1ios;
-
-    MPCServerIO(bool preprocessing,
-            std::deque<tcp::socket> &p0socks,
-            std::deque<tcp::socket> &p1socks) :
-        preprocessing(preprocessing) {
-        for (auto &&sock : p0socks) {
-            p0ios.emplace_back(std::move(sock));
-        }
-        for (auto &&sock : p1socks) {
-            p1ios.emplace_back(std::move(sock));
-        }
-    }
-};
-
-// A handle to one thread's sockets and streams in a MPCServerIO
-
-class MPCServerTIO {
-    int thread_num;
-    MPCServerIO &mpcsrvio;
-
-public:
-    MPCServerTIO(MPCServerIO &mpcsrvio, int thread_num):
-        thread_num(thread_num), mpcsrvio(mpcsrvio) {}
-
     // Queue up data to p0 or p1
 
     void queue_p0(const void *data, size_t len) {
+        assert(mpcio.player == 2);
+        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
         mpcsrvio.p0ios[thread_num].queue(data, len);
     }
 
     void queue_p1(const void *data, size_t len) {
+        assert(mpcio.player == 2);
+        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
         mpcsrvio.p1ios[thread_num].queue(data, len);
     }
 
     // Receive data from p0 or p1
 
     size_t recv_p0(void *data, size_t len) {
+        assert(mpcio.player == 2);
+        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
         return mpcsrvio.p0ios[thread_num].recv(data, len);
     }
 
     size_t recv_p1(void *data, size_t len) {
+        assert(mpcio.player == 2);
+        MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
         return mpcsrvio.p1ios[thread_num].recv(data, len);
     }
 
     // Send all queued data for this thread
     void send() {
-        mpcsrvio.p0ios[thread_num].send();
-        mpcsrvio.p1ios[thread_num].send();
+        if (mpcio.player < 2) {
+            MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+            mpcpio.peerios[thread_num].send();
+            mpcpio.serverios[thread_num].send();
+        } else {
+            MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
+            mpcsrvio.p0ios[thread_num].send();
+            mpcsrvio.p1ios[thread_num].send();
+        }
+    }
+
+    // Functions to get precomputed values.  If we're in the online
+    // phase, get them from PreCompStorage.  If we're in the
+    // preprocessing phase, read them from the server.
+    MultTriple triple() {
+        assert(mpcio.player < 2);
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        MultTriple val;
+        if (mpcpio.preprocessing) {
+            mpcpio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
+        } else {
+            mpcpio.triples[thread_num].get(val);
+        }
+        return val;
+    }
+
+    HalfTriple halftriple() {
+        assert(mpcio.player < 2);
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        HalfTriple val;
+        if (mpcpio.preprocessing) {
+            mpcpio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
+        } else {
+            mpcpio.halftriples[thread_num].get(val);
+        }
+        return val;
     }
 
     // Accessors
-    inline bool preprocessing() { return mpcsrvio.preprocessing; }
+    inline int player() { return mpcio.player; }
+    inline bool preprocessing() { return mpcio.preprocessing; }
 };
 
 // Set up the socket connections between the two computational parties

+ 1 - 1
oblivds.cpp

@@ -25,7 +25,7 @@ static void comp_player_main(boost::asio::io_context &io_context,
     std::deque<tcp::socket> peersocks, serversocks;
     mpcio_setup_computational(player, io_context, p0addr, num_threads,
         peersocks, serversocks);
-    MPCIO mpcio(player, preprocessing, peersocks, serversocks);
+    MPCPeerIO mpcio(player, preprocessing, peersocks, serversocks);
 
     // Queue up the work to be done
     boost::asio::post(io_context, [&]{

+ 3 - 3
preproc.cpp

@@ -83,7 +83,7 @@ void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
 
 // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
 // (X0*Y1 + Y0*X1) = (Z0+Z1)
-static void create_triples(MPCServerTIO &stio, unsigned num)
+static void create_triples(MPCTIO &stio, unsigned num)
 {
     for (unsigned int i=0; i<num; ++i) {
         value_t X0, Y0, Z0, X1, Y1, Z1;
@@ -103,7 +103,7 @@ static void create_triples(MPCServerTIO &stio, unsigned num)
 
 // Create half-triples (X0,Z0),(Y1,Z1) such that
 // X0*Y1 = Z0 + Z1
-static void create_halftriples(MPCServerTIO &stio, unsigned num)
+static void create_halftriples(MPCTIO &stio, unsigned num)
 {
     for (unsigned int i=0; i<num; ++i) {
         value_t X0, Z0, Y1, Z1;
@@ -125,7 +125,7 @@ void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
         boost::asio::post(pool, [&mpcsrvio, thread_num, args] {
             char **threadargs = args;
-            MPCServerTIO stio(mpcsrvio, thread_num);
+            MPCTIO stio(mpcsrvio, thread_num);
             while (*threadargs) {
                 char *arg = strdup(*threadargs);
                 char *colon = strchr(arg, ':');