Pārlūkot izejas kodu

The server now also opens one socket per thread to each computational peer

Also further API improvements
Ian Goldberg 1 gadu atpakaļ
vecāks
revīzija
33a6a54fe4
10 mainītis faili ar 307 papildinājumiem un 163 dzēšanām
  1. 2 2
      coroutine.hpp
  2. 68 17
      mpcio.cpp
  3. 115 26
      mpcio.hpp
  4. 12 12
      mpcops.cpp
  5. 2 2
      mpcops.hpp
  6. 13 13
      oblivds.cpp
  7. 9 7
      online.cpp
  8. 1 1
      online.hpp
  9. 84 82
      preproc.cpp
  10. 1 1
      preproc.hpp

+ 2 - 2
coroutine.hpp

@@ -7,7 +7,7 @@
 typedef boost::coroutines2::coroutine<void>::pull_type  coro_t;
 typedef boost::coroutines2::coroutine<void>::push_type  yield_t;
 
-inline void run_coroutines(MPCIO &mpcio, std::vector<coro_t> &coroutines) {
+inline void run_coroutines(MPCTIO &tio, std::vector<coro_t> &coroutines) {
     // Loop until all the coroutines are finished
     bool finished = false;
     while(!finished) {
@@ -15,7 +15,7 @@ inline void run_coroutines(MPCIO &mpcio, std::vector<coro_t> &coroutines) {
         // this is the top-level function that launches all the
         // coroutines), here's where to call send().  Otherwise, call
         // yield() here to let other coroutines at this level run.
-        mpcio.sendall();
+        tio.send();
         finished = true;
         for (auto &c : coroutines) {
             // This tests if coroutine c still has work to do (is not

+ 68 - 17
mpcio.cpp

@@ -13,7 +13,8 @@ void mpcio_setup_computational(unsigned player,
     boost::asio::io_context &io_context,
     const char *p0addr,  // can be NULL when player=0
     int num_threads,
-    std::deque<tcp::socket> &peersocks, tcp::socket &serversock)
+    std::deque<tcp::socket> &peersocks,
+    std::deque<tcp::socket> &serversocks)
 {
     if (player == 0) {
         // Listen for connections from P1 and from P2
@@ -22,8 +23,11 @@ void mpcio_setup_computational(unsigned player,
         tcp::acceptor acceptor_p2(io_context,
             tcp::endpoint(tcp::v4(), port_p2_p0));
 
+        peersocks.clear();
+        serversocks.clear();
         for (int i=0;i<num_threads;++i) {
             peersocks.emplace_back(io_context);
+            serversocks.emplace_back(io_context);
         }
         for (int i=0;i<num_threads;++i) {
             tcp::socket peersock = acceptor_p1.accept();
@@ -38,7 +42,19 @@ void mpcio_setup_computational(unsigned player,
                 peersocks[thread_num] = std::move(peersock);
             }
         }
-        serversock = acceptor_p2.accept();
+        for (int i=0;i<num_threads;++i) {
+            tcp::socket serversock = acceptor_p2.accept();
+            // Read 2 bytes from the socket, which will be the thread
+            // number
+            unsigned short thread_num;
+            boost::asio::read(serversock,
+                boost::asio::buffer(&thread_num, sizeof(thread_num)));
+            if (thread_num >= num_threads) {
+                std::cerr << "Received bad thread number from server\n";
+            } else {
+                serversocks[thread_num] = std::move(serversock);
+            }
+        }
     } else if (player == 1) {
         // Listen for connections from P2, make num_threads connections to P0
         tcp::acceptor acceptor_p2(io_context,
@@ -47,6 +63,10 @@ void mpcio_setup_computational(unsigned player,
         tcp::resolver resolver(io_context);
         boost::system::error_code err;
         peersocks.clear();
+        serversocks.clear();
+        for (int i=0;i<num_threads;++i) {
+            serversocks.emplace_back(io_context);
+        }
         for (unsigned short thread_num = 0; thread_num < num_threads; ++thread_num) {
             tcp::socket peersock(io_context);
             while(1) {
@@ -62,31 +82,62 @@ void mpcio_setup_computational(unsigned player,
                 boost::asio::buffer(&thread_num, sizeof(thread_num)));
             peersocks.push_back(std::move(peersock));
         }
-        serversock = acceptor_p2.accept();
+        for (int i=0;i<num_threads;++i) {
+            tcp::socket serversock = acceptor_p2.accept();
+            // Read 2 bytes from the socket, which will be the thread
+            // number
+            unsigned short thread_num;
+            boost::asio::read(serversock,
+                boost::asio::buffer(&thread_num, sizeof(thread_num)));
+            if (thread_num >= num_threads) {
+                std::cerr << "Received bad thread number from server\n";
+            } else {
+                serversocks[thread_num] = std::move(serversock);
+            }
+        }
     } else {
         std::cerr << "Invalid player number passed to mpcio_setup_computational\n";
     }
 }
 
 void mpcio_setup_server(boost::asio::io_context &io_context,
-    const char *p0addr, const char *p1addr,
-    tcp::socket &p0sock, tcp::socket &p1sock)
+    const char *p0addr, const char *p1addr, int num_threads,
+    std::deque<tcp::socket> &p0socks,
+    std::deque<tcp::socket> &p1socks)
 {
     // Make connections to P0 and P1
     tcp::resolver resolver(io_context);
     boost::system::error_code err;
-    while(1) {
-        boost::asio::connect(p0sock,
-            resolver.resolve(p0addr, std::to_string(port_p2_p0)), err);
-        if (!err) break;
-        std::cerr << "Connection to p0 refused, will retry.\n";
-        sleep(1);
+    p0socks.clear();
+    p1socks.clear();
+    for (unsigned short thread_num = 0; thread_num < num_threads; ++thread_num) {
+        tcp::socket p0sock(io_context);
+        while(1) {
+            boost::asio::connect(p0sock,
+                resolver.resolve(p0addr, std::to_string(port_p2_p0)), err);
+            if (!err) break;
+            std::cerr << "Connection to p0 refused, will retry.\n";
+            sleep(1);
+        }
+        // Write 2 bytes to the socket indicating which thread
+        // number this socket is for
+        boost::asio::write(p0sock,
+            boost::asio::buffer(&thread_num, sizeof(thread_num)));
+        p0socks.push_back(std::move(p0sock));
     }
-    while(1) {
-        boost::asio::connect(p1sock,
-            resolver.resolve(p1addr, std::to_string(port_p2_p1)), err);
-        if (!err) break;
-        std::cerr << "Connection to p1 refused, will retry.\n";
-        sleep(1);
+    for (unsigned short thread_num = 0; thread_num < num_threads; ++thread_num) {
+        tcp::socket p1sock(io_context);
+        while(1) {
+            boost::asio::connect(p1sock,
+                resolver.resolve(p1addr, std::to_string(port_p2_p1)), err);
+            if (!err) break;
+            std::cerr << "Connection to p1 refused, will retry.\n";
+            sleep(1);
+        }
+        // Write 2 bytes to the socket indicating which thread
+        // number this socket is for
+        boost::asio::write(p1sock,
+            boost::asio::buffer(&thread_num, sizeof(thread_num)));
+        p1socks.push_back(std::move(p1sock));
     }
 }

+ 115 - 26
mpcio.hpp

@@ -185,14 +185,14 @@ struct MPCIO {
     // vector of a type without a copy constructor (tcp::socket is the
     // culprit), but you can have a deque of those for some reason.
     std::deque<MPCSingleIO> peerios;
-    MPCSingleIO serverio;
+    std::deque<MPCSingleIO> serverios;
     std::vector<PreCompStorage<MultTriple>> triples;
     std::vector<PreCompStorage<HalfTriple>> halftriples;
 
     MPCIO(unsigned player, bool preprocessing,
-            std::deque<tcp::socket> &peersocks, tcp::socket &&serversock) :
-        player(player), preprocessing(preprocessing),
-        serverio(std::move(serversock)) {
+            std::deque<tcp::socket> &peersocks,
+            std::deque<tcp::socket> &serversocks):
+        player(player), preprocessing(preprocessing) {
         unsigned num_threads = unsigned(peersocks.size());
         for (unsigned i=0; i<num_threads; ++i) {
             triples.emplace_back(player, preprocessing, "triples", i);
@@ -203,49 +203,136 @@ struct MPCIO {
         for (auto &&sock : peersocks) {
             peerios.emplace_back(std::move(sock));
         }
+        for (auto &&sock : serversocks) {
+            serverios.emplace_back(std::move(sock));
+        }
     }
+};
 
-    void sendall() {
-        for (auto &p: peerios) {
-            p.send();
-        }
-        serverio.send();
+// A handle to one thread's sockets and streams in a MPCIO
+
+class MPCTIO {
+    int thread_num;
+    MPCIO &mpcio;
+
+public:
+    MPCTIO(MPCIO &mpcio, int thread_num):
+        thread_num(thread_num), mpcio(mpcio) {}
+
+    // Queue up data to the peer or to the server
+
+    void queue_peer(const void *data, size_t len) {
+        mpcio.peerios[thread_num].queue(data, len);
+    }
+
+    void queue_server(const void *data, size_t len) {
+        mpcio.serverios[thread_num].queue(data, len);
+    }
+
+    // Receive data from the peer or to the server
+
+    size_t recv_peer(void *data, size_t len) {
+        return mpcio.peerios[thread_num].recv(data, len);
+    }
+
+    size_t recv_server(void *data, size_t len) {
+        return mpcio.serverios[thread_num].recv(data, len);
+    }
+
+    // Send all queued data for this thread
+    void send() {
+        mpcio.peerios[thread_num].send();
+        mpcio.serverios[thread_num].send();
     }
 
     // Functions to get precomputed values.  If we're in the online
     // phase, get them from PreCompStorage.  If we're in the
     // preprocessing phase, read them from the server.
-    MultTriple triple(unsigned thread_num) {
+    MultTriple triple() {
         MultTriple val;
-        if (preprocessing) {
-            serverio.recv(boost::asio::buffer(&val, sizeof(val)));
+        if (mpcio.preprocessing) {
+            mpcio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
         } else {
-            triples[thread_num].get(val);
+            mpcio.triples[thread_num].get(val);
         }
         return val;
     }
 
-    HalfTriple halftriple(unsigned thread_num) {
+    HalfTriple halftriple() {
         HalfTriple val;
-        if (preprocessing) {
-            serverio.recv(boost::asio::buffer(&val, sizeof(val)));
+        if (mpcio.preprocessing) {
+            mpcio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
         } else {
-            halftriples[thread_num].get(val);
+            mpcio.halftriples[thread_num].get(val);
         }
         return val;
     }
+
+    // Accessors
+    inline int player() { return mpcio.player; }
+    inline bool preprocessing() { return mpcio.preprocessing; }
 };
 
 // A class to represent all of the server party's IO, either to
 // computational parties or to local storage
 
 struct MPCServerIO {
-    MPCSingleIO p0io;
-    MPCSingleIO p1io;
+    bool preprocessing;
+    std::deque<MPCSingleIO> p0ios;
+    std::deque<MPCSingleIO> p1ios;
+
+    MPCServerIO(bool preprocessing,
+            std::deque<tcp::socket> &p0socks,
+            std::deque<tcp::socket> &p1socks) :
+        preprocessing(preprocessing) {
+        for (auto &&sock : p0socks) {
+            p0ios.emplace_back(std::move(sock));
+        }
+        for (auto &&sock : p1socks) {
+            p1ios.emplace_back(std::move(sock));
+        }
+    }
+};
+
+// A handle to one thread's sockets and streams in a MPCServerIO
+
+class MPCServerTIO {
+    int thread_num;
+    MPCServerIO &mpcsrvio;
+
+public:
+    MPCServerTIO(MPCServerIO &mpcsrvio, int thread_num):
+        thread_num(thread_num), mpcsrvio(mpcsrvio) { std::cerr <<
+        "Creating " << thread_num << "\n";}
+
+    // Queue up data to p0 or p1
+
+    void queue_p0(const void *data, size_t len) {
+        mpcsrvio.p0ios[thread_num].queue(data, len);
+    }
+
+    void queue_p1(const void *data, size_t len) {
+        mpcsrvio.p1ios[thread_num].queue(data, len);
+    }
+
+    // Receive data from p0 or p1
+
+    size_t recv_p0(void *data, size_t len) {
+        return mpcsrvio.p0ios[thread_num].recv(data, len);
+    }
+
+    size_t recv_p1(void *data, size_t len) {
+        return mpcsrvio.p1ios[thread_num].recv(data, len);
+    }
+
+    // Send all queued data for this thread
+    void send() {
+        mpcsrvio.p0ios[thread_num].send();
+        mpcsrvio.p1ios[thread_num].send();
+    }
 
-    MPCServerIO(bool preprocessing, tcp::socket &&p0sock,
-            tcp::socket &&p1sock) :
-        p0io(std::move(p0sock)), p1io(std::move(p1sock)) {}
+    // Accessors
+    inline bool preprocessing() { return mpcsrvio.preprocessing; }
 };
 
 // Set up the socket connections between the two computational parties
@@ -259,12 +346,14 @@ void mpcio_setup_computational(unsigned player,
     boost::asio::io_context &io_context,
     const char *p0addr,  // can be NULL when player=0
     int num_threads,
-    std::deque<tcp::socket> &peersocks, tcp::socket &serversock);
+    std::deque<tcp::socket> &peersocks,
+    std::deque<tcp::socket> &serversocks);
 
-// Server calls this version with player=2
+// Server calls this version
 
 void mpcio_setup_server(boost::asio::io_context &io_context,
-    const char *p0addr, const char *p1addr,
-    tcp::socket &p0sock, tcp::socket &p1sock);
+    const char *p0addr, const char *p1addr, int num_threads,
+    std::deque<tcp::socket> &p0socks,
+    std::deque<tcp::socket> &p1socks);
 
 #endif

+ 12 - 12
mpcops.cpp

@@ -11,27 +11,27 @@
 // Cost:
 // 1 word sent in 1 message
 // consumes 1 MultTriple
-void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+void mpc_mul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits)
 {
     value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Y, Z] = mpcio.triple(thread_num);
+    auto [X, Y, Z] = tio.triple();
 
     // Send x+X and y+Y
     value_t blind_x = (as_x + X) & mask;
     value_t blind_y = (as_y + Y) & mask;
 
-    mpcio.peerios[thread_num].queue(&blind_x, nbytes);
-    mpcio.peerios[thread_num].queue(&blind_y, nbytes);
+    tio.queue_peer(&blind_x, nbytes);
+    tio.queue_peer(&blind_y, nbytes);
 
     yield();
 
     // Read the peer's x+X and y+Y
     value_t  peer_blind_x, peer_blind_y;
-    mpcio.peerios[thread_num].recv(&peer_blind_x, nbytes);
-    mpcio.peerios[thread_num].recv(&peer_blind_y, nbytes);
+    tio.recv_peer(&peer_blind_x, nbytes);
+    tio.recv_peer(&peer_blind_y, nbytes);
 
     as_z = ((as_x * (as_y + peer_blind_y)) - Y * peer_blind_x + Z) & mask;
 }
@@ -44,28 +44,28 @@ void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
 // Cost:
 // 1 word sent in 1 message
 // consumes 1 HalfTriple
-void mpc_valuemul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t x,
     nbits_t nbits)
 {
     value_t mask = MASKBITS(nbits);
     size_t nbytes = BITBYTES(nbits);
-    auto [X, Z] = mpcio.halftriple(thread_num);
+    auto [X, Z] = tio.halftriple();
 
     // Send x+X
     value_t blind_x = (x + X) & mask;
 
-    mpcio.peerios[thread_num].queue(&blind_x, nbytes);
+    tio.queue_peer(&blind_x, nbytes);
 
     yield();
 
     // Read the peer's y+Y
     value_t  peer_blind_y;
-    mpcio.peerios[thread_num].recv(&peer_blind_y, nbytes);
+    tio.recv_peer(&peer_blind_y, nbytes);
 
-    if (mpcio.player == 0) {
+    if (tio.player() == 0) {
         as_z = ((x * peer_blind_y) + Z) & mask;
-    } else if (mpcio.player == 1) {
+    } else if (tio.player() == 1) {
         as_z = ((-X * peer_blind_y) + Z) & mask;
     }
 }

+ 2 - 2
mpcops.hpp

@@ -16,7 +16,7 @@
 // Cost:
 // 1 word sent in 1 message
 // consumes 1 MultTriple
-void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+void mpc_mul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t as_x, value_t as_y,
     nbits_t nbits = VALUE_BITS);
 
@@ -28,7 +28,7 @@ void mpc_mul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
 // Cost:
 // 1 word sent in 1 message
 // consumes 1 HalfTriple
-void mpc_valuemul(MPCIO &mpcio, size_t thread_num, yield_t &yield,
+void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     value_t &as_z, value_t x,
     nbits_t nbits = VALUE_BITS);
 

+ 13 - 13
oblivds.cpp

@@ -9,7 +9,7 @@ static void usage(const char *progname)
 {
     std::cerr << "Usage: " << progname << " [-p] [-t num] player_num player_addrs args ...\n";
     std::cerr << "-p: preprocessing mode\n";
-    std::cerr << "-t num: use num threads for the computational players\n";
+    std::cerr << "-t num: use num threads\n";
     std::cerr << "player_num = 0 or 1 for the computational players\n";
     std::cerr << "player_num = 2 for the server player\n";
     std::cerr << "player_addrs is omitted for player 0\n";
@@ -22,11 +22,10 @@ static void comp_player_main(boost::asio::io_context &io_context,
     unsigned player, bool preprocessing, int num_threads, const char *p0addr,
     char **args)
 {
-    tcp::socket serversock(io_context);
-    std::deque<tcp::socket> peersocks;
+    std::deque<tcp::socket> peersocks, serversocks;
     mpcio_setup_computational(player, io_context, p0addr, num_threads,
-        peersocks, serversock);
-    MPCIO mpcio(player, preprocessing, peersocks, std::move(serversock));
+        peersocks, serversocks);
+    MPCIO mpcio(player, preprocessing, peersocks, serversocks);
 
     // Queue up the work to be done
     boost::asio::post(io_context, [&]{
@@ -45,19 +44,20 @@ static void comp_player_main(boost::asio::io_context &io_context,
 }
 
 static void server_player_main(boost::asio::io_context &io_context,
-    bool preprocessing, const char *p0addr, const char *p1addr, char **args)
+    bool preprocessing, int num_threads, const char *p0addr,
+    const char *p1addr, char **args)
 {
-    tcp::socket p0sock(io_context), p1sock(io_context);
-    mpcio_setup_server(io_context, p0addr, p1addr, p0sock, p1sock);
-    MPCServerIO mpcserverio(preprocessing, std::move(p0sock),
-        std::move(p1sock));
+    std::deque<tcp::socket> p0socks, p1socks;
+    mpcio_setup_server(io_context, p0addr, p1addr, num_threads,
+        p0socks, p1socks);
+    MPCServerIO mpcserverio(preprocessing, p0socks, p1socks);
 
     // Queue up the work to be done
     boost::asio::post(io_context, [&]{
         if (preprocessing) {
-            preprocessing_server(mpcserverio, args);
+            preprocessing_server(mpcserverio, num_threads, args);
         } else {
-            online_server(mpcserverio, args);
+            online_server(mpcserverio, num_threads, args);
         }
     });
 
@@ -144,7 +144,7 @@ int main(int argc, char **argv)
     if (player < 2) {
         comp_player_main(io_context, player, preprocessing, num_threads, p0addr, args);
     } else {
-        server_player_main(io_context, preprocessing, p0addr, p1addr, args);
+        server_player_main(io_context, preprocessing, num_threads, p0addr, p1addr, args);
     }
 
     return 0;

+ 9 - 7
online.cpp

@@ -14,6 +14,8 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
 
     size_t memsize = 5;
 
+    MPCTIO tio(mpcio, 0);
+
     value_t *A = new value_t[memsize];
 
     arc4random_buf(A, 3*sizeof(value_t));
@@ -23,23 +25,23 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     std::vector<coro_t> coroutines;
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            mpc_mul(mpcio, 0, yield, A[3], A[0], A[1], nbits);
+            mpc_mul(tio, yield, A[3], A[0], A[1], nbits);
         });
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            mpc_valuemul(mpcio, 0, yield, A[4], A[2], nbits);
+            mpc_valuemul(tio, yield, A[4], A[2], nbits);
         });
-    run_coroutines(mpcio, coroutines);
+    run_coroutines(tio, coroutines);
     std::cout << A[3] << "\n";
     std::cout << A[4] << "\n";
 
     // Check the answers
     if (mpcio.player) {
-        mpcio.peerios[0].queue(A, memsize*sizeof(value_t));
-        mpcio.sendall();
+        tio.queue_peer(A, memsize*sizeof(value_t));
+        tio.send();
     } else {
         value_t *B = new value_t[memsize];
-        mpcio.peerios[0].recv(B, memsize*sizeof(value_t));
+        tio.recv_peer(B, memsize*sizeof(value_t));
         printf("%016lx\n", ((A[0]+B[0])*(A[1]+B[1])-(A[3]+B[3])));
         printf("%016lx\n", (A[2]*B[2])-(A[4]+B[4]));
         delete[] B;
@@ -48,6 +50,6 @@ void online_comp(MPCIO &mpcio, int num_threads, char **args)
     delete[] A;
 }
 
-void online_server(MPCServerIO &mpcio, char **args)
+void online_server(MPCServerIO &mpcio, int num_threads, char **args)
 {
 }

+ 1 - 1
online.hpp

@@ -4,6 +4,6 @@
 #include "mpcio.hpp"
 
 void online_comp(MPCIO &mpcio, int num_threads, char **args);
-void online_server(MPCServerIO &mpcio, char **args);
+void online_server(MPCServerIO &mpcio, int num_threads, char **args);
 
 #endif

+ 84 - 82
preproc.cpp

@@ -37,61 +37,53 @@ static std::ofstream openfile(const char *prefix, unsigned player,
 // Then that number of objects
 //
 // Repeat the whole thing until type == 0x00 is received
-//
-// The incoming objects are written into num_threads files in a
-// round-robin manner
 
 void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
 {
-    while(1) {
-        unsigned char type = 0;
-        unsigned int num = 0;
-        size_t res = mpcio.serverio.recv(&type, 1);
-        if (res < 1 || type == 0) break;
-        mpcio.serverio.recv(&num, 4);
-        if (type == 0x80) {
-            // Multiplication triples
-            std::vector<std::ofstream> tripfiles;
-            for (int i=0; i<num_threads; ++i) {
-                tripfiles.push_back(openfile("triples", mpcio.player, i));
-            }
-            unsigned thread_num = 0;
+    boost::asio::thread_pool pool(num_threads);
+    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
+        boost::asio::post(pool, [&mpcio, thread_num] {
+            MPCTIO tio(mpcio, thread_num);
+            while(1) {
+                unsigned char type = 0;
+                unsigned int num = 0;
+                size_t res = tio.recv_server(&type, 1);
+                if (res < 1 || type == 0) break;
+                tio.recv_server(&num, 4);
+                if (type == 0x80) {
+                    // Multiplication triples
+                    std::ofstream tripfile = openfile("triples",
+                        mpcio.player, thread_num);
 
-            MultTriple T;
-            for (unsigned int i=0; i<num; ++i) {
-                res = mpcio.serverio.recv(&T, sizeof(T));
-                if (res < sizeof(T)) break;
-                tripfiles[thread_num].write((const char *)&T, sizeof(T));
-                thread_num = (thread_num + 1) % num_threads;
-            }
-            for (int i=0; i<num_threads; ++i) {
-                tripfiles[i].close();
-            }
-        } else if (type == 0x81) {
-            // Multiplication half triples
-            std::vector<std::ofstream> halffiles;
-            for (int i=0; i<num_threads; ++i) {
-                halffiles.push_back(openfile("halves", mpcio.player, i));
-            }
-            unsigned thread_num = 0;
+                    MultTriple T;
+                    for (unsigned int i=0; i<num; ++i) {
+                        res = tio.recv_server(&T, sizeof(T));
+                        if (res < sizeof(T)) break;
+                        tripfile.write((const char *)&T, sizeof(T));
+                    }
+                    tripfile.close();
+                } else if (type == 0x81) {
+                    // Multiplication half triples
+                    std::ofstream halffile = openfile("halves",
+                        mpcio.player, thread_num);
 
-            HalfTriple H;
-            for (unsigned int i=0; i<num; ++i) {
-                res = mpcio.serverio.recv(&H, sizeof(H));
-                if (res < sizeof(H)) break;
-                halffiles[thread_num].write((const char *)&H, sizeof(H));
-                thread_num = (thread_num + 1) % num_threads;
-            }
-            for (int i=0; i<num_threads; ++i) {
-                halffiles[i].close();
+                    HalfTriple H;
+                    for (unsigned int i=0; i<num; ++i) {
+                        res = tio.recv_server(&H, sizeof(H));
+                        if (res < sizeof(H)) break;
+                        halffile.write((const char *)&H, sizeof(H));
+                    }
+                    halffile.close();
+                }
             }
-        }
+        });
     }
+    pool.join();
 }
 
 // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
 // (X0*Y1 + Y0*X1) = (Z0+Z1)
-static void create_triples(MPCServerIO &mpcsrvio, unsigned num)
+static void create_triples(MPCServerTIO &stio, unsigned num)
 {
     for (unsigned int i=0; i<num; ++i) {
         value_t X0, Y0, Z0, X1, Y1, Z1;
@@ -104,14 +96,14 @@ static void create_triples(MPCServerIO &mpcsrvio, unsigned num)
         MultTriple T0, T1;
         T0 = std::make_tuple(X0, Y0, Z0);
         T1 = std::make_tuple(X1, Y1, Z1);
-        mpcsrvio.p0io.queue(&T0, sizeof(T0));
-        mpcsrvio.p1io.queue(&T1, sizeof(T1));
+        stio.queue_p0(&T0, sizeof(T0));
+        stio.queue_p1(&T1, sizeof(T1));
     }
 }
 
 // Create half-triples (X0,Z0),(Y1,Z1) such that
 // X0*Y1 = Z0 + Z1
-static void create_halftriples(MPCServerIO &mpcsrvio, unsigned num)
+static void create_halftriples(MPCServerTIO &stio, unsigned num)
 {
     for (unsigned int i=0; i<num; ++i) {
         value_t X0, Z0, Y1, Z1;
@@ -122,46 +114,56 @@ static void create_halftriples(MPCServerIO &mpcsrvio, unsigned num)
         HalfTriple H0, H1;
         H0 = std::make_tuple(X0, Z0);
         H1 = std::make_tuple(Y1, Z1);
-        mpcsrvio.p0io.queue(&H0, sizeof(H0));
-        mpcsrvio.p1io.queue(&H1, sizeof(H1));
+        stio.queue_p0(&H0, sizeof(H0));
+        stio.queue_p1(&H1, sizeof(H1));
     }
 }
 
-void preprocessing_server(MPCServerIO &mpcsrvio, char **args)
+void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
 {
-    while (*args) {
-        char *colon = strchr(*args, ':');
-        if (!colon) {
-            std::cerr << "Args must be type:num\n";
-            ++args;
-            continue;
-        }
-        unsigned num = atoi(colon+1);
-        *colon = '\0';
-        char *type = *args;
-        if (!strcmp(type, "t")) {
-            unsigned char typetag = 0x80;
-            mpcsrvio.p0io.queue(&typetag, 1);
-            mpcsrvio.p0io.queue(&num, 4);
-            mpcsrvio.p1io.queue(&typetag, 1);
-            mpcsrvio.p1io.queue(&num, 4);
+    boost::asio::thread_pool pool(num_threads);
+    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
+        boost::asio::post(pool, [&mpcsrvio, thread_num, args] {
+            char **threadargs = args;
+            MPCServerTIO stio(mpcsrvio, thread_num);
+            while (*threadargs) {
+                char *arg = strdup(*threadargs);
+                char *colon = strchr(arg, ':');
+                if (!colon) {
+                    std::cerr << "Args must be type:num\n";
+                    ++threadargs;
+                    free(arg);
+                    continue;
+                }
+                unsigned num = atoi(colon+1);
+                *colon = '\0';
+                char *type = arg;
+                if (!strcmp(type, "t")) {
+                    unsigned char typetag = 0x80;
+                    stio.queue_p0(&typetag, 1);
+                    stio.queue_p0(&num, 4);
+                    stio.queue_p1(&typetag, 1);
+                    stio.queue_p1(&num, 4);
 
-            create_triples(mpcsrvio, num);
-        } else if (!strcmp(type, "h")) {
-            unsigned char typetag = 0x81;
-            mpcsrvio.p0io.queue(&typetag, 1);
-            mpcsrvio.p0io.queue(&num, 4);
-            mpcsrvio.p1io.queue(&typetag, 1);
-            mpcsrvio.p1io.queue(&num, 4);
+                    create_triples(stio, num);
+                } else if (!strcmp(type, "h")) {
+                    unsigned char typetag = 0x81;
+                    stio.queue_p0(&typetag, 1);
+                    stio.queue_p0(&num, 4);
+                    stio.queue_p1(&typetag, 1);
+                    stio.queue_p1(&num, 4);
 
-            create_halftriples(mpcsrvio, num);
-        }
-        ++args;
+                    create_halftriples(stio, num);
+                }
+                free(arg);
+                ++threadargs;
+            }
+            // That's all
+            unsigned char typetag = 0x00;
+            stio.queue_p0(&typetag, 1);
+            stio.queue_p1(&typetag, 1);
+            stio.send();
+        });
     }
-    // That's all
-    unsigned char typetag = 0x00;
-    mpcsrvio.p0io.queue(&typetag, 1);
-    mpcsrvio.p1io.queue(&typetag, 1);
-    mpcsrvio.p0io.send();
-    mpcsrvio.p1io.send();
+    pool.join();
 }

+ 1 - 1
preproc.hpp

@@ -4,6 +4,6 @@
 #include "mpcio.hpp"
 
 void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args);
-void preprocessing_server(MPCServerIO &mpcio, char **args);
+void preprocessing_server(MPCServerIO &mpcio, int num_threads, char **args);
 
 #endif