Browse Source

Nodes can now asynchronously receive data using callbacks

Ian Goldberg 1 year ago
parent
commit
3a506cf620
3 changed files with 137 additions and 59 deletions
  1. 72 45
      App/net.cpp
  2. 31 12
      App/net.hpp
  3. 34 2
      App/start.cpp

+ 72 - 45
App/net.cpp

@@ -45,30 +45,39 @@ void NodeIO::return_frame(uint8_t *frame)
 
 void NodeIO::send_header_data(uint64_t header, uint8_t *data, size_t len)
 {
-    std::vector<boost::asio::const_buffer> tosend;
+    commands_deque_lock.lock();
+    commands_inflight.push_back({header, data, len});
+    if (commands_inflight.size() == 1) {
+        async_send_commands();
+    }
+    commands_deque_lock.unlock();
+}
 
-    // Put the header into the deque so it's in memory at a stable
-    // address during the async write
-    header_deque_lock.lock();
-    headers_inflight.push_back(header);
+void NodeIO::async_send_commands()
+{
+    std::vector<boost::asio::const_buffer> tosend;
 
-    uint64_t *headerp = &(headers_inflight.back());
-    header_deque_lock.unlock();
-    tosend.push_back(boost::asio::buffer(headerp, 5));
-    if (data != NULL && len > 0) {
-        tosend.push_back(boost::asio::buffer(data, len));
+    CommandTuple *commandp = &(commands_inflight.front());
+    tosend.push_back(boost::asio::buffer(&(std::get<0>(*commandp)), 5));
+    if (std::get<1>(*commandp) != NULL && std::get<2>(*commandp) > 0) {
+        tosend.push_back(boost::asio::buffer(std::get<1>(*commandp),
+            std::get<2>(*commandp)));
     }
     boost::asio::async_write(sock, tosend,
-        [this, headerp, data](boost::system::error_code, std::size_t){
-            // When the write completes, pop the header from the deque
+        [this, commandp](boost::system::error_code, std::size_t){
+            // When the write completes, pop the command from the deque
             // (which should now be in the front)
-            header_deque_lock.lock();
-            assert(!headers_inflight.empty() &&
-                &(headers_inflight.front()) == headerp);
-            headers_inflight.pop_front();
-            header_deque_lock.unlock();
+            commands_deque_lock.lock();
+            assert(!commands_inflight.empty() &&
+                &(commands_inflight.front()) == commandp);
+            uint8_t *data = std::get<1>(*commandp);
+            commands_inflight.pop_front();
+            if (commands_inflight.size() > 0) {
+                async_send_commands();
+            }
             // And return the frame
             return_frame(data);
+            commands_deque_lock.unlock();
         });
 }
 
@@ -98,35 +107,53 @@ void NodeIO::send_chunk(uint8_t *data, uint32_t chunk_len)
     assert(chunksize_inflight <= msgsize_inflight);
 }
 
-bool NodeIO::recv_header(uint64_t &header)
-{
-    header = 0;
-    try {
-        boost::asio::read(sock, boost::asio::buffer(&header, 5));
-    } catch (...) {
-        return false;
-    }
-    return true;
-}
-
-bool NodeIO::recv_chunk(uint64_t header, uint8_t *&data, size_t &len)
+void NodeIO::recv_commands(
+        std::function<void(boost::system::error_code)> error_cb,
+    std::function<void(uint32_t)> epoch_cb,
+    std::function<void(uint32_t)> message_cb,
+    std::function<void(uint8_t*,uint32_t)> chunk_cb)
 {
-    len = 0;
-    data = NULL;
-    assert((header & 0xff) == 0x02);
-    size_t datalen = header >> 8;
-    if (datalen > MAXCHUNKSIZE) {
-        return false;
-    }
-    try {
-        boost::asio::read(sock,
-            boost::asio::buffer(receive_frame, datalen));
-    } catch (...) {
-        return false;
-    }
-    data = receive_frame;
-    len = datalen;
-    return true;
+    // Asynchronously read the header
+    receive_header = 0;
+    boost::asio::async_read(sock, boost::asio::buffer(&receive_header, 5),
+        [this, error_cb, epoch_cb, message_cb, chunk_cb]
+        (boost::system::error_code ec, std::size_t) {
+            if (ec) {
+                error_cb(ec);
+                return;
+            }
+            if ((receive_header & 0xff) == 0x00) {
+                epoch_cb(uint32_t(receive_header >> 8));
+                recv_commands(error_cb, epoch_cb, message_cb, chunk_cb);
+            } else if ((receive_header & 0xff) == 0x01) {
+                assert(recv_msgsize_inflight == recv_chunksize_inflight);
+                recv_msgsize_inflight = uint32_t(receive_header >> 8);
+                recv_chunksize_inflight = 0;
+                message_cb(recv_msgsize_inflight);
+                recv_commands(error_cb, epoch_cb, message_cb, chunk_cb);
+            } else if ((receive_header & 0xff) == 0x02) {
+                uint32_t this_chunk_size = uint32_t(receive_header >> 8);
+                assert(recv_chunksize_inflight + this_chunk_size <=
+                    recv_msgsize_inflight);
+                recv_chunksize_inflight += this_chunk_size;
+                boost::asio::async_read(sock, boost::asio::buffer(
+                    receive_frame, this_chunk_size),
+                    [this, error_cb, epoch_cb, message_cb, chunk_cb,
+                        this_chunk_size]
+                    (boost::system::error_code ecc, std::size_t) {
+                        if (ecc) {
+                            error_cb(ecc);
+                            return;
+                        }
+                        chunk_cb(receive_frame, this_chunk_size);
+                        recv_commands(error_cb, epoch_cb,
+                            message_cb, chunk_cb);
+                    });
+            } else {
+                error_cb(boost::system::errc::make_error_code(
+                    boost::system::errc::errc_t::invalid_argument));
+            }
+        });
 }
 
 NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)

+ 31 - 12
App/net.hpp

@@ -4,6 +4,8 @@
 #include <vector>
 #include <deque>
 #include <optional>
+#include <functional>
+#include <tuple>
 #include <boost/asio.hpp>
 #include <boost/thread.hpp>
 
@@ -67,22 +69,35 @@ using boost::asio::ip::tcp;
 
 class NodeIO {
     tcp::socket sock;
-    std::deque<uint64_t> headers_inflight;
+    using CommandTuple = std::tuple<uint64_t,uint8_t*,size_t>;
+    std::deque<CommandTuple> commands_inflight;
     std::deque<uint8_t *> frames_available;
-    // The frames and headers are used and returned by different
+    // The frames and commands are used and returned by different
     // threads, so we protect them with a mutex each
-    boost::mutex frame_deque_lock, header_deque_lock;
+    boost::mutex frame_deque_lock, commands_deque_lock;
 
     // The claimed size of the message currently being sent in chunks
     uint32_t msgsize_inflight;
     // The total size of the chunks so far we've sent for this message
     uint32_t chunksize_inflight;
 
-    // The static frame used to _receive_ data
+    // As above, but for incoming messages and chunks
+    uint32_t recv_msgsize_inflight;
+    uint32_t recv_chunksize_inflight;
+
+    // The static uint64_t used to receive a header
+    uint64_t receive_header;
+    // The static frame used to receive a chunk
     uint8_t receive_frame[MAXCHUNKSIZE];
 
     void send_header_data(uint64_t header, uint8_t *data, size_t len);
 
+    // Asynchronously send the first message from the command queue.
+    // * The command_deque_lock must be held when this is called! *
+    // This method may be called from either thread (the work thread or
+    // the async_write handler thread).
+    void async_send_commands();
+
 public:
     NodeIO(tcp::socket &&socket);
 
@@ -93,14 +108,18 @@ public:
     void send_message_header(uint32_t tot_message_len);
     void send_chunk(uint8_t *data, uint32_t chunk_len);
 
-    // These functions return true for success, false for failure
-    bool recv_header(uint64_t &header);
-    // This function puts the received data into a _static_ frame that's
-    // only used for receiving.  Be sure to do whatever you need to do
-    // with the contents (typically, pass it to the enclave) before
-    // calling this function again.  Pass *in* the header you got from
-    // recv_header.
-    bool recv_chunk(uint64_t header, uint8_t *&data, size_t &len);
+    // Asynchronously receive commands from this socket.  Depending on
+    // what they are, one of the three callbacks will be called.  The
+    // callbacks may be called from a different thread.  The data
+    // pointer in chunk_cb is to a _static_ frame that's only used for
+    // receiving.  Be sure to do whatever you need to do with the
+    // contents (typically, pass it to the enclave) before calling this
+    // function again.
+    void recv_commands(
+        std::function<void(boost::system::error_code)> error_cb,
+        std::function<void(uint32_t)> epoch_cb,
+        std::function<void(uint32_t)> message_cb,
+        std::function<void(uint8_t*,uint32_t)> chunk_cb);
 };
 
 class NetIO {

+ 34 - 2
App/start.cpp

@@ -6,8 +6,9 @@
 // to do on the command line
 void start(NetIO &netio, int argc, char **argv)
 {
-    srand48(1);
+    srand48(netio.me);
     // Send a bunch of data to all peers
+    for(int j=0;j<3;++j)
     for (size_t node_num = 0; node_num < netio.num_nodes; ++node_num) {
         if (node_num == netio.me) continue;
         NodeIO &node = netio.node(node_num);
@@ -16,6 +17,7 @@ void start(NetIO &netio, int argc, char **argv)
         node.send_message_header(msgsize);
 
         uint8_t c = 0;
+        uint32_t cl = 0;
         while (msgsize > 0) {
             uint8_t* frame = node.request_frame();
             uint32_t chunk_size = (lrand48() % (MAXCHUNKSIZE-1)) + 1;
@@ -23,11 +25,41 @@ void start(NetIO &netio, int argc, char **argv)
                 chunk_size = msgsize;
             }
             memset(frame, ++c, chunk_size);
+            ++cl;
+            memmove(frame, &cl, sizeof(cl));
             node.send_chunk(frame, chunk_size);
             msgsize -= chunk_size;
         }
     }
 
     printf("Sleeping\n");
-    sleep(10);
+    sleep(3);
+
+    printf("Reading\n");
+    for (size_t node_num = 0; node_num < netio.num_nodes; ++node_num) {
+        if (node_num == netio.me) continue;
+        NodeIO &node = netio.node(node_num);
+        node.recv_commands(
+            // error_cb
+            [](boost::system::error_code) {
+                printf("Error\n");
+            },
+            // epoch_cb
+            [](uint32_t epoch) {
+                printf("Epoch %u\n", epoch);
+            },
+            // message_cb
+            [](uint32_t msg_len) {
+                printf("Message len %u\n", msg_len);
+            },
+            // chunk_cb
+            [](uint8_t *data, uint32_t chunk_len) {
+                printf("Chunk len %u: ", chunk_len);
+                for (size_t i=0;i<chunk_len && i<10; ++i) {
+                    printf("%02x", data[i]);
+                }
+                printf("\n");
+            });
+    }
+
 }