Browse Source

Forward received data into the enclave

Ian Goldberg 1 year ago
parent
commit
6b3139b84c
4 changed files with 29 additions and 37 deletions
  1. 18 17
      App/net.cpp
  2. 6 8
      App/net.hpp
  3. 0 12
      App/start.cpp
  4. 5 0
      Untrusted/Untrusted.hpp

+ 18 - 17
App/net.cpp

@@ -1,6 +1,7 @@
 #include <iostream>
 
 #include "Enclave_u.h"
+#include "Untrusted.hpp"
 #include "net.hpp"
 
 // The command type byte values
@@ -11,7 +12,8 @@
 
 NetIO *g_netio = NULL;
 
-NodeIO::NodeIO(tcp::socket &&socket) : sock(std::move(socket))
+NodeIO::NodeIO(tcp::socket &&socket, nodenum_t nodenum) :
+    sock(std::move(socket)), node_num(nodenum)
 {
 }
 
@@ -118,14 +120,12 @@ bool NodeIO::send_chunk(uint8_t *data, uint32_t chunk_len)
 
 void NodeIO::recv_commands(
         std::function<void(boost::system::error_code)> error_cb,
-    std::function<void(uint32_t)> epoch_cb,
-    std::function<void(uint32_t)> message_cb,
-    std::function<void(uint8_t*,uint32_t)> chunk_cb)
+    std::function<void(uint32_t)> epoch_cb)
 {
     // Asynchronously read the header
     receive_header = 0;
     boost::asio::async_read(sock, boost::asio::buffer(&receive_header, 5),
-        [this, error_cb, epoch_cb, message_cb, chunk_cb]
+        [this, error_cb, epoch_cb]
         (boost::system::error_code ec, std::size_t) {
             if (ec) {
                 error_cb(ec);
@@ -133,13 +133,14 @@ void NodeIO::recv_commands(
             }
             if ((receive_header & 0xff) == COMMAND_EPOCH) {
                 epoch_cb(uint32_t(receive_header >> 8));
-                recv_commands(error_cb, epoch_cb, message_cb, chunk_cb);
+                recv_commands(error_cb, epoch_cb);
             } else if ((receive_header & 0xff) == COMMAND_MESSAGE) {
                 assert(recv_msgsize_inflight == recv_chunksize_inflight);
                 recv_msgsize_inflight = uint32_t(receive_header >> 8);
                 recv_chunksize_inflight = 0;
-                message_cb(recv_msgsize_inflight);
-                recv_commands(error_cb, epoch_cb, message_cb, chunk_cb);
+                if (ecall_message(node_num, recv_msgsize_inflight)) {
+                    recv_commands(error_cb, epoch_cb);
+                }
             } else if ((receive_header & 0xff) == COMMAND_CHUNK) {
                 uint32_t this_chunk_size = uint32_t(receive_header >> 8);
                 assert(recv_chunksize_inflight + this_chunk_size <=
@@ -147,16 +148,16 @@ void NodeIO::recv_commands(
                 recv_chunksize_inflight += this_chunk_size;
                 boost::asio::async_read(sock, boost::asio::buffer(
                     receive_frame, this_chunk_size),
-                    [this, error_cb, epoch_cb, message_cb, chunk_cb,
-                        this_chunk_size]
+                    [this, error_cb, epoch_cb, this_chunk_size]
                     (boost::system::error_code ecc, std::size_t) {
                         if (ecc) {
                             error_cb(ecc);
                             return;
                         }
-                        chunk_cb(receive_frame, this_chunk_size);
-                        recv_commands(error_cb, epoch_cb,
-                            message_cb, chunk_cb);
+                        if (ecall_chunk(node_num, receive_frame,
+                                this_chunk_size)) {
+                            recv_commands(error_cb, epoch_cb);
+                        }
                     });
             } else {
                 error_cb(boost::system::errc::make_error_code(
@@ -168,7 +169,7 @@ void NodeIO::recv_commands(
 NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)
     : conf(config), myconf(config.nodes[config.my_node_num])
 {
-    num_nodes = conf.nodes.size();
+    num_nodes = nodenum_t(conf.nodes.size());
     nodeios.resize(num_nodes);
     me = conf.my_node_num;
 
@@ -200,7 +201,7 @@ NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)
         if (node_num >= num_nodes) {
             std::cerr << "Received bad node number\n";
         } else {
-            nodeios[node_num].emplace(std::move(nodesock));
+            nodeios[node_num].emplace(std::move(nodesock), node_num);
 #ifdef VERBOSE_NET
             std::cerr << "Received connection from " <<
                 config.nodes[node_num].name << "\n";
@@ -225,10 +226,10 @@ NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)
         }
         // Write 2 bytes to the socket to tell the peer node our node
         // number
-        unsigned short node_num = (unsigned short)me;
+        nodenum_t node_num = (nodenum_t)me;
         boost::asio::write(nodesock,
             boost::asio::buffer(&node_num, sizeof(node_num)));
-        nodeios[i].emplace(std::move(nodesock));
+        nodeios[i].emplace(std::move(nodesock), node_num);
 #ifdef VERBOSE_NET
         std::cerr << "Connected to " << config.nodes[i].name << "\n";
 #endif

+ 6 - 8
App/net.hpp

@@ -69,6 +69,7 @@ using boost::asio::ip::tcp;
 
 class NodeIO {
     tcp::socket sock;
+    nodenum_t node_num;
     using CommandTuple = std::tuple<uint64_t,uint8_t*,size_t>;
     std::deque<CommandTuple> commands_inflight;
     std::deque<uint8_t *> frames_available;
@@ -99,8 +100,7 @@ class NodeIO {
     void async_send_commands();
 
 public:
-    NodeIO(tcp::socket &&socket);
-
+    NodeIO(tcp::socket &&socket, nodenum_t node_num);
     uint8_t *request_frame();
     void return_frame(uint8_t* frame);
 
@@ -119,9 +119,7 @@ public:
     // function again.
     void recv_commands(
         std::function<void(boost::system::error_code)> error_cb,
-        std::function<void(uint32_t)> epoch_cb,
-        std::function<void(uint32_t)> message_cb,
-        std::function<void(uint8_t*,uint32_t)> chunk_cb);
+        std::function<void(uint32_t)> epoch_cb);
 };
 
 class NetIO {
@@ -132,9 +130,9 @@ class NetIO {
 public:
     NetIO(boost::asio::io_context &io_context, const Config &config);
 
-    size_t num_nodes;
-    size_t me;
-    NodeIO &node(size_t node_num) {
+    nodenum_t num_nodes;
+    nodenum_t me;
+    NodeIO &node(nodenum_t node_num) {
         assert(node_num < num_nodes);
         return nodeios[node_num].value();
     }

+ 0 - 12
App/start.cpp

@@ -47,18 +47,6 @@ void start(NetIO &netio, int argc, char **argv)
             // epoch_cb
             [](uint32_t epoch) {
                 printf("Epoch %u\n", epoch);
-            },
-            // message_cb
-            [](uint32_t msg_len) {
-                printf("Message len %u\n", msg_len);
-            },
-            // chunk_cb
-            [](uint8_t *data, uint32_t chunk_len) {
-                printf("Chunk len %u: ", chunk_len);
-                for (size_t i=0;i<chunk_len && i<10; ++i) {
-                    printf("%02x", data[i]);
-                }
-                printf("\n");
             });
     }
 

+ 5 - 0
Untrusted/Untrusted.hpp

@@ -22,4 +22,9 @@ bool ecall_config_load(struct EnclaveAPIParams *apiparams,
     struct EnclaveAPINodeConfig *apinodeconfigs,
     nodenum_t num_nodes, nodenum_t my_node_num);
 
+bool ecall_message(nodenum_t node_num, uint32_t message_len);
+
+bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
+    uint32_t chunklen);
+
 #endif