Explorar el Código

Optionally add headers to all messages containing length and a Lamport clock

This is controlled by #define SEND_LAMPORT_CLOCKS

The clocks are currently sent, but not yet ever updated
Ian Goldberg hace 2 años
padre
commit
08ce010f6d
Se han modificado 1 ficheros con 102 adiciones y 25 borrados
  1. 102 25
      mpcio.hpp

+ 102 - 25
mpcio.hpp

@@ -7,6 +7,7 @@
 #include <deque>
 #include <queue>
 #include <string>
+#include <atomic>
 #include <bsd/stdlib.h> // arc4random_buf
 
 #include <boost/asio.hpp>
@@ -58,6 +59,29 @@ void PreCompStorage<T>::get(T& nextval) {
     ++count;
 }
 
+// If we want to send Lamport clocks in messages, define this.  It adds
+// an 8-byte header to each message (length and Lamport clock), so it
+// has a small network cost.  We always define and pass the Lamport
+// clock member of MPCIO to the IO functions for simplicity, but they're
+// ignored if this isn't defined
+#define SEND_LAMPORT_CLOCKS
+using lamport_t = std::atomic<uint32_t>;
+#ifdef SEND_LAMPORT_CLOCKS
+struct MessageWithHeader {
+    std::string header;
+    std::string message;
+
+    MessageWithHeader(std::string &&msg, lamport_t &lamport) :
+        message(std::move(msg)) {
+            char hdr[sizeof(uint32_t) + sizeof(lamport_t)];
+            uint32_t msglen = uint32_t(message.size());
+            memmove(hdr, &msglen, sizeof(msglen));
+            memmove(hdr+sizeof(msglen), &lamport, sizeof(lamport));
+            header.assign(hdr, sizeof(hdr));
+    }
+};
+#endif
+
 // A class to wrap a socket to another MPC party.  This wrapping allows
 // us to do some useful logging, and perform async_writes transparently
 // to the application.
@@ -93,7 +117,22 @@ class MPCSingleIO {
     // any more elements.  If so, it will start another async_write.
     // The invariant is that there is an async_write currently running
     // iff messagequeue is nonempty.
+#ifdef SEND_LAMPORT_CLOCKS
+    std::queue<MessageWithHeader> messagequeue;
+#else
     std::queue<std::string> messagequeue;
+#endif
+
+#ifdef SEND_LAMPORT_CLOCKS
+    // If Lamport clocks are being sent, then the data stream is divided
+    // into chunks, each with a header containing the length of the
+    // chunk and the Lamport clock.  So when we read, we'll read a whole
+    // chunk, and store it here.  Then calls to recv() will read pieces
+    // of this buffer until it has all been read, and then read the next
+    // header and chunk.
+    std::string recvdata;
+    size_t recvdataremain;
+#endif
 
     // Never touch the above messagequeue without holding this lock (you
     // _can_ touch the strings it contains, though, if you looked one up
@@ -105,8 +144,17 @@ class MPCSingleIO {
     // This method may be called from either thread (the work thread or
     // the async_write handler thread).
     void async_send_from_msgqueue() {
+#ifdef SEND_LAMPORT_CLOCKS
+        std::vector<boost::asio::const_buffer> tosend;
+        tosend.push_back(boost::asio::buffer(messagequeue.front().header));
+        tosend.push_back(boost::asio::buffer(messagequeue.front().message));
+#endif
         boost::asio::async_write(sock,
+#ifdef SEND_LAMPORT_CLOCKS
+            tosend,
+#else
             boost::asio::buffer(messagequeue.front()),
+#endif
             [&](boost::system::error_code ec, std::size_t amt){
                 messagequeuelock.lock();
                 messagequeue.pop();
@@ -121,17 +169,17 @@ public:
     MPCSingleIO(tcp::socket &&sock) :
         sock(std::move(sock)), totread(0), totwritten(0) {}
 
-    void queue(const void *data, size_t len) {
+    void queue(const void *data, size_t len, lamport_t &lamport) {
         dataqueue.append((const char *)data, len);
 
         // If we already have some full packets worth of data, may as
         // well send it.
         if (dataqueue.size() > 28800) {
-            send();
+            send(lamport);
         }
     }
 
-    void send() {
+    void send(lamport_t &lamport) {
         size_t thissize = dataqueue.size();
         // Ignore spurious calls to send()
         if (thissize == 0) return;
@@ -143,7 +191,11 @@ public:
         messagequeuelock.lock();
         // Move the current message to send into the message queue (this
         // moves a pointer to the data, not copying the data itself)
+#ifdef SEND_LAMPORT_CLOCKS
+        messagequeue.emplace(std::move(dataqueue), lamport);
+#else
         messagequeue.emplace(std::move(dataqueue));
+#endif
         // If this is now the first thing in the message queue, launch
         // an async_write to write it
         if (messagequeue.size() == 1) {
@@ -152,20 +204,44 @@ public:
         messagequeuelock.unlock();
     }
 
-    size_t recv(const boost::asio::mutable_buffer& buffer) {
-        size_t res = boost::asio::read(sock, buffer);
-#ifdef RECORD_IOTRACE
-        iotrace.push_back(-(ssize_t(res)));
-#endif
+    size_t recv(void *data, size_t len, lamport_t &lamport) {
+#ifdef SEND_LAMPORT_CLOCKS
+        char *cdata = (char *)data;
+        size_t res = 0;
+        while (len > 0) {
+            while (recvdataremain == 0) {
+                // Read a new header
+                char hdr[sizeof(uint32_t) + sizeof(lamport_t)];
+                uint32_t datalen;
+                lamport_t::value_type recv_lamport;
+                boost::asio::read(sock, boost::asio::buffer(hdr, sizeof(hdr)));
+                memmove(&datalen, hdr, sizeof(datalen));
+                memmove(&recv_lamport, hdr+sizeof(datalen), sizeof(lamport_t));
+                if (datalen > 0) {
+                    recvdata.resize(datalen, '\0');
+                    boost::asio::read(sock, boost::asio::buffer(recvdata));
+                    recvdataremain = datalen;
+                }
+            }
+            size_t amttoread = len;
+            if (amttoread > recvdataremain) {
+                amttoread = recvdataremain;
+            }
+            memmove(cdata, recvdata.data()+recvdata.size()-recvdataremain,
+                amttoread);
+            cdata += amttoread;
+            len -= amttoread;
+            recvdataremain -= amttoread;
+            res += amttoread;
+        }
         return res;
-    }
-
-    size_t recv(void *data, size_t len) {
+#else
         size_t res = boost::asio::read(sock, boost::asio::buffer(data, len));
 #ifdef RECORD_IOTRACE
         iotrace.push_back(-(ssize_t(res)));
 #endif
         return res;
+#endif
     }
 
 #ifdef RECORD_IOTRACE
@@ -193,9 +269,10 @@ public:
 struct MPCIO {
     int player;
     bool preprocessing;
+    lamport_t lamport;
 
     MPCIO(int player, bool preprocessing) :
-        player(player), preprocessing(preprocessing) {}
+        player(player), preprocessing(preprocessing), lamport(0) {}
 };
 
 // A class to represent all of a computation peer's IO, either to other
@@ -287,14 +364,14 @@ public:
     void queue_peer(const void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            mpcpio.peerios[thread_num].queue(data, len);
+            mpcpio.peerios[thread_num].queue(data, len, mpcio.lamport);
         }
     }
 
     void queue_server(const void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            mpcpio.serverios[thread_num].queue(data, len);
+            mpcpio.serverios[thread_num].queue(data, len, mpcio.lamport);
         }
     }
 
@@ -303,7 +380,7 @@ public:
     size_t recv_peer(void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            return mpcpio.peerios[thread_num].recv(data, len);
+            return mpcpio.peerios[thread_num].recv(data, len, mpcio.lamport);
         }
         return 0;
     }
@@ -311,7 +388,7 @@ public:
     size_t recv_server(void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            return mpcpio.serverios[thread_num].recv(data, len);
+            return mpcpio.serverios[thread_num].recv(data, len, mpcio.lamport);
         }
         return 0;
     }
@@ -321,14 +398,14 @@ public:
     void queue_p0(const void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            mpcsrvio.p0ios[thread_num].queue(data, len);
+            mpcsrvio.p0ios[thread_num].queue(data, len, mpcio.lamport);
         }
     }
 
     void queue_p1(const void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            mpcsrvio.p1ios[thread_num].queue(data, len);
+            mpcsrvio.p1ios[thread_num].queue(data, len, mpcio.lamport);
         }
     }
 
@@ -337,7 +414,7 @@ public:
     size_t recv_p0(void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            return mpcsrvio.p0ios[thread_num].recv(data, len);
+            return mpcsrvio.p0ios[thread_num].recv(data, len, mpcio.lamport);
         }
         return 0;
     }
@@ -345,7 +422,7 @@ public:
     size_t recv_p1(void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            return mpcsrvio.p1ios[thread_num].recv(data, len);
+            return mpcsrvio.p1ios[thread_num].recv(data, len, mpcio.lamport);
         }
         return 0;
     }
@@ -354,12 +431,12 @@ public:
     void send() {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            mpcpio.peerios[thread_num].send();
-            mpcpio.serverios[thread_num].send();
+            mpcpio.peerios[thread_num].send(mpcio.lamport);
+            mpcpio.serverios[thread_num].send(mpcio.lamport);
         } else {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            mpcsrvio.p0ios[thread_num].send();
-            mpcsrvio.p1ios[thread_num].send();
+            mpcsrvio.p0ios[thread_num].send(mpcio.lamport);
+            mpcsrvio.p1ios[thread_num].send(mpcio.lamport);
         }
     }
 
@@ -399,7 +476,7 @@ public:
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
             if (mpcpio.preprocessing) {
-                mpcpio.serverios[thread_num].recv(boost::asio::buffer(&val, sizeof(val)));
+                recv_server(&val, sizeof(val));
             } else {
                 mpcpio.halftriples[thread_num].get(val);
             }