瀏覽代碼

More towards private routing

Ian Goldberg 1 年之前
父節點
當前提交
68c67d8621
共有 9 個文件被更改,包括 138 次插入9 次删除
  1. 21 1
      App/net.cpp
  2. 11 0
      App/net.hpp
  3. 16 0
      App/start.cpp
  4. 7 0
      Enclave/Enclave.edl
  5. 57 0
      Enclave/route.cpp
  6. 4 5
      Enclave/sort.cpp
  7. 3 3
      Enclave/sort.hpp
  8. 16 0
      Untrusted/Untrusted.cpp
  9. 3 0
      Untrusted/Untrusted.hpp

+ 21 - 1
App/net.cpp

@@ -119,7 +119,7 @@ bool NodeIO::send_chunk(uint8_t *data, uint32_t chunk_len)
 }
 
 void NodeIO::recv_commands(
-        std::function<void(boost::system::error_code)> error_cb,
+    std::function<void(boost::system::error_code)> error_cb,
     std::function<void(uint32_t)> epoch_cb)
 {
     // Asynchronously read the header
@@ -240,6 +240,26 @@ NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)
     }
 }
 
+void NetIO::recv_commands(
+    std::function<void(boost::system::error_code)> error_cb,
+    std::function<void(uint32_t)> epoch_cb)
+{
+    for (nodenum_t node_num = 0; node_num < num_nodes; ++node_num) {
+        if (node_num == me) continue;
+        NodeIO &n = node(node_num);
+        n.recv_commands(error_cb, epoch_cb);
+    }
+}
+
+void NetIO::close()
+{
+    for (nodenum_t node_num = 0; node_num < num_nodes; ++node_num) {
+        if (node_num == me) continue;
+        NodeIO &n = node(node_num);
+        n.close();
+    }
+}
+
 /* The enclave calls this to inform the untrusted app that there's a new
  * messaage to send. The return value is the frame the enclave should
  * use to store the first (encrypted) chunk of this message. */

+ 11 - 0
App/net.hpp

@@ -119,6 +119,9 @@ public:
     void recv_commands(
         std::function<void(boost::system::error_code)> error_cb,
         std::function<void(uint32_t)> epoch_cb);
+
+    // Close the socket
+    void close() { sock.close(); }
 };
 
 class NetIO {
@@ -136,6 +139,14 @@ public:
         return nodeios[node_num].value();
     }
     const Config &config() { return conf; }
+    // Call recv_commands with these arguments on each of the nodes (not
+    // including ourselves)
+    void recv_commands(
+        std::function<void(boost::system::error_code)> error_cb,
+        std::function<void(uint32_t)> epoch_cb);
+
+    // Close all the sockets
+    void close();
 };
 
 extern NetIO *g_netio;

+ 16 - 0
App/start.cpp

@@ -74,10 +74,26 @@ static void route_test(NetIO &netio, char **args)
     for (int i=0;i<int(num_sizes);++i) {
         ecall_precompute_sort(i);
     }
+
+    netio.recv_commands(
+        // error_cb
+        [](boost::system::error_code) {
+            printf("Error\n");
+        },
+        // epoch_cb
+        [](uint32_t epoch) {
+            printf("Epoch %u\n", epoch);
+        });
+
     if (!ecall_ingest_raw(msgs, tot_tokens)) {
         printf("Ingestion failed\n");
         return;
     }
+
+    ecall_routing_proceed([&](uint32_t round_num){
+        printf("Round %u complete\n", round_num);
+        //netio.close();
+    });
 }
 
 // Once all the networking is set up, start doing whatever we were asked

+ 7 - 0
Enclave/Enclave.edl

@@ -40,6 +40,9 @@ enclave {
         public bool ecall_ingest_raw(
             [user_check] uint8_t *msgs,
             uint32_t num_msgs);
+
+        public void ecall_routing_proceed(
+            [user_check]void *cbpointer);
     };
 
     untrusted {
@@ -57,5 +60,9 @@ enclave {
             nodenum_t node_num,
             [user_check] uint8_t *chunkdata,
             uint32_t chunklen);
+
+        void ocall_routing_round_complete(
+            [user_check] void *cbpointer,
+            uint32_t round_num);
     };
 };

+ 57 - 0
Enclave/route.cpp

@@ -34,6 +34,13 @@ struct MsgBuffer {
         buf = new uint8_t[size_t(msgs) * g_teems_config.msg_size];
     }
 
+    // Reset the contents of the buffer
+    void reset() {
+        memset(buf, 0, inserted * g_teems_config.msg_size);
+        reserved = 0;
+        inserted = 0;
+    }
+
     // You can't copy a MsgBuffer
     MsgBuffer(const MsgBuffer&) = delete;
     MsgBuffer &operator=(const MsgBuffer&) = delete;
@@ -212,3 +219,53 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 
     return true;
 }
+
+// Send the round 1 messages
+static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
+    uint32_t N)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    uint16_t tot_weight = g_teems_config.tot_weight;
+
+    /*
+    for (uint32_t i=0;i<N;++i) {
+        const uint8_t *msg = msgs + indices[i]*msg_size;
+        for (uint16_t j=0;j<msg_size/4;++j) {
+            printf("%08x ", ((const uint32_t*)msg)[j]);
+        }
+        printf("\n");
+    }
+    */
+}
+
+// Perform the next round of routing.  The callback pointer will be
+// passed to ocall_routing_round_complete when the round is complete.
+void ecall_routing_proceed(void *cbpointer)
+{
+    if (route_state.step == ROUTE_NOT_STARTED) {
+
+        MsgBuffer &round1 = route_state.round1;
+
+        pthread_mutex_lock(&round1.mutex);
+        // Ensure there are no pending messages currently being inserted
+        // into the buffer
+        while (round1.reserved != round1.inserted) {
+            pthread_mutex_unlock(&round1.mutex);
+            pthread_mutex_lock(&round1.mutex);
+        }
+        // Sort the messages we've received
+#ifdef PROFILE_ROUTING
+        uint32_t inserted = round1.inserted;
+        unsigned long start = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
+#endif
+        sort_mtobliv(g_teems_config.nthreads, round1.buf,
+            g_teems_config.msg_size, round1.inserted,
+            route_state.tot_msg_per_ing, send_round1_msgs);
+        round1.reset();
+#ifdef PROFILE_ROUTING
+        printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
+#endif
+        route_state.step = ROUTE_ROUND_1;
+        ocall_routing_round_complete(cbpointer, 1);
+    }
+}

+ 4 - 5
Enclave/sort.cpp

@@ -74,10 +74,9 @@ void sort_precompute_evalplan(uint32_t N, threadid_t nthreads)
 // item.
 void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na,
-    // the arguments to the callback are nthreads, items, the sorted
-    // indices, and the number of non-padding items
-    std::function<void(threadid_t, const uint8_t*, const uint64_t*,
-        uint32_t Nr)> cb)
+    // the arguments to the callback are items, the sorted indices, and
+    // the number of non-padding items
+    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)> cb)
 {
     // Find the smallest Nw for which we have a precomputed
     // WaksmanNetwork with Nr <= Nw <= Na
@@ -150,7 +149,7 @@ void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     for (uint32_t i=0; i<Nr; ++i) {
         sortedidx[i] &= uint64_t(0xffffffff);
     }
-    cb(nthreads, items, sortedidx, Nr);
+    cb(items, sortedidx, Nr);
 
     delete[] idx;
     delete[] backingidx;

+ 3 - 3
Enclave/sort.hpp

@@ -39,8 +39,8 @@ void sort_precompute_evalplan(uint32_t N, threadid_t nthreads);
 // id contatenated with the 22-bit uid at the storage server.
 void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na,
-    // the arguments to the callback are nthreads, items, the sorted
-    // indices, and the number of non-padding items
-    std::function<void(threadid_t, const uint8_t*, const uint64_t*, uint32_t Nr)>);
+    // the arguments to the callback are items, the sorted indices, and
+    // the number of non-padding items
+    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)>);
 
 #endif

+ 16 - 0
Untrusted/Untrusted.cpp

@@ -277,3 +277,19 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
     ecall_ingest_raw(global_eid, &ret, msgs, num_msgs);
     return ret;
 }
+
+void ecall_routing_proceed(std::function<void(uint32_t)> cb)
+{
+    std::function<void(uint32_t)> *p = new std::function<void(uint32_t)>;
+    *p = cb;
+    ecall_routing_proceed(global_eid, p);
+}
+
+void ocall_routing_round_complete(void *cbpointer, uint32_t round_num)
+{
+    std::function<void(uint32_t)> *p =
+        (std::function<void(uint32_t)> *)cbpointer;
+    std::function<void(uint32_t)> f = *p;
+    delete p;
+    f(round_num);
+}

+ 3 - 0
Untrusted/Untrusted.hpp

@@ -2,6 +2,7 @@
 #define __UNTRUSTED_HPP__
 
 #include <cstddef>
+#include <functional>
 
 #include "sgx_eid.h"
 #include "sgx_tseal.h"
@@ -37,4 +38,6 @@ size_t ecall_precompute_sort(int size);
 
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs);
 
+void ecall_routing_proceed(std::function<void(uint32_t)>);
+
 #endif