Browse Source

Fix the code to pass the test

Two changes:

Each thread now keeps its own Lamport clock, syncing with the master
clock when the thread joins or when threads locally communicate.

When large messages are automatically split into chunks in order to get
the start of them onto the wire while the rest are still being computed
(but no messages are being received), all the chunks should have the
Lamport clock that was current when the message started to be created.
Ian Goldberg 2 years ago
parent
commit
3478cbf398
1 changed files with 100 additions and 37 deletions
  1. 100 37
      mpcio.hpp

+ 100 - 37
mpcio.hpp

@@ -8,6 +8,7 @@
 #include <queue>
 #include <string>
 #include <atomic>
+#include <optional>
 #include <bsd/stdlib.h> // arc4random_buf
 
 #include <boost/asio.hpp>
@@ -67,6 +68,7 @@ void PreCompStorage<T>::get(T& nextval) {
 #define SEND_LAMPORT_CLOCKS
 using lamport_t = uint32_t;
 using atomic_lamport_t = std::atomic<lamport_t>;
+using opt_lamport_t = std::optional<lamport_t>;
 #ifdef SEND_LAMPORT_CLOCKS
 struct MessageWithHeader {
     std::string header;
@@ -120,6 +122,17 @@ class MPCSingleIO {
     // iff messagequeue is nonempty.
 #ifdef SEND_LAMPORT_CLOCKS
     std::queue<MessageWithHeader> messagequeue;
+
+    // If a single message is broken into chunks in order to get the
+    // first part of it out on the wire while the rest of it is still
+    // being computed, we want the Lamport clock of all the chunks to be
+    // that of when the message is first created.  This value will be
+    // nullopt when there has been no queue() since the last explicit
+    // send() (as opposed to the implicit send() called by queue()
+    // itself if it wants to get a chunk on its way), and will be set to
+    // the current lamport clock when that first queue() after each
+    // explicit send() happens.
+    opt_lamport_t message_lamport;
 #else
     std::queue<std::string> messagequeue;
 #endif
@@ -173,17 +186,38 @@ public:
     void queue(const void *data, size_t len, lamport_t lamport) {
         dataqueue.append((const char *)data, len);
 
+#ifdef SEND_LAMPORT_CLOCKS
+        // If this is the first queue() since the last explicit send(),
+        // which we'll know because message_lamport will be nullopt, set
+        // message_lamport to the current Lamport clock.  Note that the
+        // boolean test tests whether message_lamport is nullopt, not
+        // whether its value is zero.
+        if (!message_lamport) {
+            message_lamport = lamport;
+        }
+#endif
+
         // If we already have some full packets worth of data, may as
         // well send it.
         if (dataqueue.size() > 28800) {
-            send(lamport);
+            send(true);
         }
     }
 
-    void send(lamport_t lamport) {
+    void send(bool implicit_send = false) {
         size_t thissize = dataqueue.size();
-        // Ignore spurious calls to send()
-        if (thissize == 0) return;
+        // Ignore spurious calls to send(), except for resetting
+        // message_lamport if this was an explicit send().
+        if (thissize == 0) {
+#ifdef SEND_LAMPORT_CLOCKS
+            // If this was an explicit send(), reset the message_lamport so
+            // that it gets updated at the next queue().
+            if (!implicit_send) {
+                message_lamport.reset();
+            }
+#endif
+            return;
+        }
 
 #ifdef RECORD_IOTRACE
         iotrace.push_back(thissize);
@@ -193,7 +227,13 @@ public:
         // Move the current message to send into the message queue (this
         // moves a pointer to the data, not copying the data itself)
 #ifdef SEND_LAMPORT_CLOCKS
-        messagequeue.emplace(std::move(dataqueue), lamport);
+        messagequeue.emplace(std::move(dataqueue),
+            message_lamport.value());
+        // If this was an explicit send(), reset the message_lamport so
+        // that it gets updated at the next queue().
+        if (!implicit_send) {
+            message_lamport.reset();
+        }
 #else
         messagequeue.emplace(std::move(dataqueue));
 #endif
@@ -205,7 +245,7 @@ public:
         messagequeuelock.unlock();
     }
 
-    size_t recv(void *data, size_t len, atomic_lamport_t &lamport) {
+    size_t recv(void *data, size_t len, lamport_t &lamport) {
 #ifdef SEND_LAMPORT_CLOCKS
         char *cdata = (char *)data;
         size_t res = 0;
@@ -218,25 +258,10 @@ public:
                 boost::asio::read(sock, boost::asio::buffer(hdr, sizeof(hdr)));
                 memmove(&datalen, hdr, sizeof(datalen));
                 memmove(&recv_lamport, hdr+sizeof(datalen), sizeof(lamport_t));
-                // Update our Lamport time to be max of recv_lamport+1
-                // and what we thought it was before.  We use this
-                // compare_exchange construction in order to atomically
-                // do the comparison, computation, and replacement
-                lamport_t old_lamport = lamport;
                 lamport_t new_lamport = recv_lamport + 1;
-                do {
-                    if (new_lamport < old_lamport) {
-                        new_lamport = old_lamport;
-                    }
-                // The next line atomically checks if lamport still has
-                // the value old_lamport; if so, it changes its value to
-                // new_lamport and returns true (ending the loop).  If
-                // not, it sets old_lamport to the current value of
-                // lamport, and returns false (continuing the loop so
-                // that new_lamport can be recomputed based on this new
-                // value).
-                } while (!lamport.compare_exchange_weak(
-                    old_lamport, new_lamport));
+                if (lamport < new_lamport) {
+                    lamport = new_lamport;
+                }
                 if (datalen > 0) {
                     recvdata.resize(datalen, '\0');
                     boost::asio::read(sock, boost::asio::buffer(recvdata));
@@ -373,25 +398,63 @@ struct MPCServerIO : public MPCIO {
 
 class MPCTIO {
     int thread_num;
+    lamport_t thread_lamport;
     MPCIO &mpcio;
 
 public:
     MPCTIO(MPCIO &mpcio, int thread_num):
-        thread_num(thread_num), mpcio(mpcio) {}
+        thread_num(thread_num), thread_lamport(mpcio.lamport),
+        mpcio(mpcio) {}
+
+    // Sync our per-thread lamport clock with the master one in the
+    // mpcio.  You only need to call this explicitly if your MPCTIO
+    // outlives your thread (in which case call it after the join), or
+    // if your threads do interthread communication amongst themselves
+    // (in which case call it in the sending thread before the send, and
+    // call it in the receiving thread after the receive).
+    void sync_lamport() {
+        // Update the mpcio Lamport time to be max of the thread Lamport
+        // time and what we thought it was before.  We use this
+        // compare_exchange construction in order to atomically
+        // do the comparison, computation, and replacement
+        lamport_t old_lamport = mpcio.lamport;
+        lamport_t new_lamport = thread_lamport;
+        do {
+            if (new_lamport < old_lamport) {
+                new_lamport = old_lamport;
+            }
+        // The next line atomically checks if lamport still has
+        // the value old_lamport; if so, it changes its value to
+        // new_lamport and returns true (ending the loop).  If
+        // not, it sets old_lamport to the current value of
+        // lamport, and returns false (continuing the loop so
+        // that new_lamport can be recomputed based on this new
+        // value).
+        } while (!mpcio.lamport.compare_exchange_weak(
+            old_lamport, new_lamport));
+        thread_lamport = new_lamport;
+    }
+
+    // The normal case, where the MPCIO is created inside the thread,
+    // and so destructed when the thread ends, is handles automatically
+    // here.
+    ~MPCTIO() {
+        sync_lamport();
+    }
 
     // Queue up data to the peer or to the server
 
     void queue_peer(const void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            mpcpio.peerios[thread_num].queue(data, len, mpcio.lamport);
+            mpcpio.peerios[thread_num].queue(data, len, thread_lamport);
         }
     }
 
     void queue_server(const void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            mpcpio.serverios[thread_num].queue(data, len, mpcio.lamport);
+            mpcpio.serverios[thread_num].queue(data, len, thread_lamport);
         }
     }
 
@@ -400,7 +463,7 @@ public:
     size_t recv_peer(void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            return mpcpio.peerios[thread_num].recv(data, len, mpcio.lamport);
+            return mpcpio.peerios[thread_num].recv(data, len, thread_lamport);
         }
         return 0;
     }
@@ -408,7 +471,7 @@ public:
     size_t recv_server(void *data, size_t len) {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            return mpcpio.serverios[thread_num].recv(data, len, mpcio.lamport);
+            return mpcpio.serverios[thread_num].recv(data, len, thread_lamport);
         }
         return 0;
     }
@@ -418,14 +481,14 @@ public:
     void queue_p0(const void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            mpcsrvio.p0ios[thread_num].queue(data, len, mpcio.lamport);
+            mpcsrvio.p0ios[thread_num].queue(data, len, thread_lamport);
         }
     }
 
     void queue_p1(const void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            mpcsrvio.p1ios[thread_num].queue(data, len, mpcio.lamport);
+            mpcsrvio.p1ios[thread_num].queue(data, len, thread_lamport);
         }
     }
 
@@ -434,7 +497,7 @@ public:
     size_t recv_p0(void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            return mpcsrvio.p0ios[thread_num].recv(data, len, mpcio.lamport);
+            return mpcsrvio.p0ios[thread_num].recv(data, len, thread_lamport);
         }
         return 0;
     }
@@ -442,7 +505,7 @@ public:
     size_t recv_p1(void *data, size_t len) {
         if (mpcio.player == 2) {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            return mpcsrvio.p1ios[thread_num].recv(data, len, mpcio.lamport);
+            return mpcsrvio.p1ios[thread_num].recv(data, len, thread_lamport);
         }
         return 0;
     }
@@ -451,12 +514,12 @@ public:
     void send() {
         if (mpcio.player < 2) {
             MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
-            mpcpio.peerios[thread_num].send(mpcio.lamport);
-            mpcpio.serverios[thread_num].send(mpcio.lamport);
+            mpcpio.peerios[thread_num].send();
+            mpcpio.serverios[thread_num].send();
         } else {
             MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
-            mpcsrvio.p0ios[thread_num].send(mpcio.lamport);
-            mpcsrvio.p1ios[thread_num].send(mpcio.lamport);
+            mpcsrvio.p0ios[thread_num].send();
+            mpcsrvio.p1ios[thread_num].send();
         }
     }