Sajin Sasy 1 рік тому
батько
коміт
1f36f457c3

+ 1 - 0
App/appconfig.cpp

@@ -149,6 +149,7 @@ bool config_parse(Config &config, const std::string configstr,
         std::cerr << "Could not find my own node entry in config\n";
         ret = false;
     }
+    config.nthreads = nthreads;
 
     if (!ret) return ret;
 

+ 1 - 0
App/appconfig.hpp

@@ -29,6 +29,7 @@ struct Config {
     uint8_t m_priv_in;
     uint8_t m_pub_out;
     uint8_t m_pub_in;
+    uint16_t nthreads;
     // config for each node
     std::vector<NodeConfig> nodes;
     // Which node is this one?

+ 2 - 2
App/net.cpp

@@ -218,7 +218,7 @@ void NetIO::handle_async_clients(std::shared_ptr<tcp::socket> csocket,
 */
 void NetIO::start_accept(size_t auth_size, size_t msgbundle_size)
 {
-    std::shared_ptr<tcp::socket> csocket(new tcp::socket(io_context_));
+    std::shared_ptr<tcp::socket> csocket(new tcp::socket(io_context()));
 #ifdef VERBOSE_NET
     std::cout << "Accepting on " << myconf.clistenhost << ":" << myconf.clistenport << "\n";
 #endif
@@ -229,7 +229,7 @@ void NetIO::start_accept(size_t auth_size, size_t msgbundle_size)
 
 
 NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)
-    : io_context_(io_context), conf(config),
+    : context(io_context), conf(config),
       myconf(config.nodes[config.my_node_num])
 {
     num_nodes = nodenum_t(conf.nodes.size());

+ 2 - 1
App/net.hpp

@@ -125,7 +125,7 @@ public:
 };
 
 class NetIO {
-    boost::asio::io_context& io_context_;
+    boost::asio::io_context &context;
     const Config &conf;
     const NodeConfig &myconf;
     std::deque<std::optional<NodeIO>> nodeios;
@@ -146,6 +146,7 @@ public:
         return nodeios[node_num].value();
     }
     const Config &config() { return conf; }
+    boost::asio::io_context &io_context() { return context; }
     // Call recv_commands with these arguments on each of the nodes (not
     // including ourselves)
     void recv_commands(

+ 113 - 49
App/start.cpp

@@ -1,31 +1,56 @@
+#include <condition_variable>
+#include <mutex>
 #include <stdlib.h>
 
 #include "Untrusted.hpp"
 #include "start.hpp"
 
-static void route_test(NetIO &netio, char **args)
-{
-    // Count the number of arguments
-    size_t nargs = 0;
-    while (args[nargs]) {
-        ++nargs;
+class Epoch {
+    boost::asio::io_context &io_context;
+    uint32_t epoch_num;
+    std::mutex m;
+    std::condition_variable cv;
+    bool epoch_complete;
+
+    void round_cb(uint32_t round_num) {
+        if (round_num) {
+            printf("Round %u complete\n", round_num);
+            boost::asio::post(io_context, [this]{
+                proceed();
+            });
+        } else {
+            printf("Epoch %u complete\n", epoch_num);
+            {
+                std::lock_guard lk(m);
+                epoch_complete = true;
+            }
+            cv.notify_one();
+        }
     }
 
-    uint16_t num_nodes = netio.num_nodes;
-    size_t sq_nodes = num_nodes;
-    sq_nodes *= sq_nodes;
+public:
+    Epoch(boost::asio::io_context &context, uint32_t ep_num):
+        io_context(context), epoch_num(ep_num),
+        epoch_complete(false) {}
 
-    if (nargs != sq_nodes) {
-        printf("Expecting %lu arguments, found %lu\n", sq_nodes, nargs);
-        return;
+    void proceed() {
+        ecall_routing_proceed([this](uint32_t round_num) {
+            round_cb(round_num);
+        });
     }
 
-    // The arguments are num_nodes sets of num_nodes values.  The jth
-    // value in the ith set is the number of private routing tokens
-    // ingestion node i holds for storage node j.
+    void wait() {
+        std::unique_lock lk(m);
+        cv.wait(lk, [this]{ return epoch_complete; });
+    }
 
-    // We are node i = netio.me, so ignore the other sets of values.
+};
+
+static void epoch(NetIO &netio, char **args) {
 
+    static uint32_t epoch_num = 1;
+
+    uint16_t num_nodes = netio.num_nodes;
     uint32_t num_tokens[num_nodes];
     uint32_t tot_tokens = 0;
     for (nodenum_t j=0;j<num_nodes;++j) {
@@ -35,6 +60,7 @@ static void route_test(NetIO &netio, char **args)
 
     const Config &config = netio.config();
     uint16_t msg_size = config.msg_size;
+    nodenum_t my_node_num = config.my_node_num;
 
     uint8_t *msgs = new uint8_t[tot_tokens * msg_size];
     uint8_t *nextmsg = msgs;
@@ -47,7 +73,8 @@ static void route_test(NetIO &netio, char **args)
             if (r < num_tokens[j]) {
                 // Use a token from node j
                 *((uint32_t*)nextmsg) =
-                    (j << DEST_UID_BITS) + (r & dest_uid_mask);
+                    (j << DEST_UID_BITS) +
+                        (((r<<8)+(my_node_num&0xff)) & dest_uid_mask);
                 // Put a bunch of copies of r as the message body
                 for (uint16_t i=1;i<msg_size/4;++i) {
                     ((uint32_t*)nextmsg)[i] = r;
@@ -69,31 +96,83 @@ static void route_test(NetIO &netio, char **args)
     }
     */
 
-    // Precompute some WaksmanNetworks
+    if (!ecall_ingest_raw(msgs, tot_tokens)) {
+        printf("Ingestion failed\n");
+        return;
+    }
+
+    Epoch epoch(netio.io_context(), epoch_num);
+    epoch.proceed();
+    epoch.wait();
+    // Launch threads to refill the precomputed Waksman networks we
+    // used, but just let them run in the background.
     size_t num_sizes = ecall_precompute_sort(-1);
     for (int i=0;i<int(num_sizes);++i) {
-        ecall_precompute_sort(i);
+        boost::thread t([i] {
+            ecall_precompute_sort(i);
+        });
+        t.detach();
     }
+    ++epoch_num;
+}
 
-    netio.recv_commands(
-        // error_cb
-        [](boost::system::error_code) {
-            printf("Error\n");
-        },
-        // epoch_cb
-        [](uint32_t epoch) {
-            printf("Epoch %u\n", epoch);
-        });
+static void route_test(NetIO &netio, char **args)
+{
+    // Count the number of arguments
+    size_t nargs = 0;
+    while (args[nargs]) {
+        ++nargs;
+    }
 
-    if (!ecall_ingest_raw(msgs, tot_tokens)) {
-        printf("Ingestion failed\n");
+    uint16_t num_nodes = netio.num_nodes;
+    size_t sq_nodes = num_nodes;
+    sq_nodes *= sq_nodes;
+
+    if (nargs != sq_nodes) {
+        printf("Expecting %lu arguments, found %lu\n", sq_nodes, nargs);
         return;
     }
 
-    ecall_routing_proceed([&](uint32_t round_num){
-        printf("Round %u complete\n", round_num);
-        //netio.close();
-    });
+    // The arguments are num_nodes sets of num_nodes values.  The jth
+    // value in the ith set is the number of private routing tokens
+    // ingestion node i holds for storage node j.
+
+    // We are node i = netio.me, so ignore the other sets of values.
+
+    // Precompute some WaksmanNetworks
+    const Config &config = netio.config();
+    size_t num_sizes = ecall_precompute_sort(-1);
+    for (int i=0;i<int(num_sizes);++i) {
+        std::vector<boost::thread> ts;
+        for (int j=0; j<config.nthreads; ++j) {
+            ts.emplace_back([i] {
+                ecall_precompute_sort(i);
+            });
+        }
+        for (auto& t: ts) {
+            t.join();
+        }
+    }
+
+    // The epoch interval, in microseconds
+    uint32_t epoch_interval_us = 1000000;
+
+    // Run 10 epochs
+    for (int i=0; i<10; ++i) {
+        struct timespec tp;
+        clock_gettime(CLOCK_REALTIME_COARSE, &tp);
+        unsigned long start = tp.tv_sec * 1000000 + tp.tv_nsec/1000;
+        epoch(netio, args);
+        clock_gettime(CLOCK_REALTIME_COARSE, &tp);
+        unsigned long end = tp.tv_sec * 1000000 + tp.tv_nsec/1000;
+        unsigned long diff = end - start;
+        printf("Epoch time: %lu.%06lu s\n", diff/1000000, diff%1000000);
+        // Sleep for the rest of the epoch interval
+        if (diff < epoch_interval_us) {
+            usleep(epoch_interval_us - (useconds_t)diff);
+        }
+    }
+    netio.close();
 }
 
 // Once all the networking is set up, start doing whatever we were asked
@@ -105,19 +184,4 @@ void start(NetIO &netio, char **args)
         route_test(netio, args);
         return;
     }
-    printf("Reading\n");
-    for (nodenum_t node_num = 0; node_num < netio.num_nodes; ++node_num) {
-        if (node_num == netio.me) continue;
-        NodeIO &node = netio.node(node_num);
-        node.recv_commands(
-            // error_cb
-            [](boost::system::error_code) {
-                printf("Error\n");
-            },
-            // epoch_cb
-            [](uint32_t epoch) {
-                printf("Epoch %u\n", epoch);
-            });
-    }
-
 }

+ 23 - 2
App/teems.cpp

@@ -243,9 +243,30 @@ int main(int argc, char **argv)
 
     // Queue up the actual work
     boost::asio::post(io_context, [&]{
+
         // Start enclave-to-enclave communications
-        ecall_comms_start();
-        start(netio, argv);
+        ecall_comms_start([&]{
+            boost::asio::post(io_context, [&]{
+                // This runs when we have completed our handshakes with
+                // all other nodes
+                printf("Starting\n");
+                start(netio, argv);
+            });
+        });
+        printf("Reading\n");
+        for (nodenum_t node_num = 0; node_num < netio.num_nodes; ++node_num) {
+            if (node_num == netio.me) continue;
+            NodeIO &node = netio.node(node_num);
+            node.recv_commands(
+                // error_cb
+                [](boost::system::error_code ec) {
+                    printf("Error %s\n", ec.message().c_str());
+                },
+                // epoch_cb
+                [](uint32_t epoch) {
+                    printf("Epoch %u\n", epoch);
+                });
+        }
     });
 
     // Start another thread; one will perform the work and the other

+ 2 - 2
Enclave/Enclave.config.xml

@@ -3,8 +3,8 @@
   <ProdID>0</ProdID>
   <ISVSVN>0</ISVSVN>
   <StackMaxSize>0x40000</StackMaxSize>
-  <HeapMaxSize>0x8000000</HeapMaxSize>
-  <TCSNum>10</TCSNum>
+  <HeapMaxSize>0x10000000</HeapMaxSize>
+  <TCSNum>32</TCSNum>
   <TCSPolicy>1</TCSPolicy>
   <DisableDebug>0</DisableDebug>
   <MiscSelect>0</MiscSelect>

+ 5 - 1
Enclave/Enclave.edl

@@ -25,7 +25,8 @@ enclave {
 
         public void ecall_close();
 
-        public bool ecall_comms_start();
+        public bool ecall_comms_start(
+            [user_check]void *cbpointer);
 
         public bool ecall_message(
             nodenum_t node_num, uint32_t message_len);
@@ -66,6 +67,9 @@ enclave {
             [user_check] uint8_t *chunkdata,
             uint32_t chunklen);
 
+        void ocall_comms_ready(
+            [user_check] void *cbpointer);
+
         void ocall_routing_round_complete(
             [user_check] void *cbpointer,
             uint32_t round_num);

+ 1 - 2
Enclave/OblivAlgs/RecursiveShuffle.cpp

@@ -23,7 +23,6 @@ void MarkHalf(uint64_t N, bool *selected_list) {
   
   uint64_t left_to_mark = N/2;
   uint64_t total_left = N;
-  PRB_buffer *randpool = PRB_pool + g_thread_id;
   uint32_t coins[RS_MARKHALF_MAX_COINS];
   size_t coinsleft=0;
   
@@ -36,7 +35,7 @@ void MarkHalf(uint64_t N, bool *selected_list) {
         if (numcoins > RS_MARKHALF_MAX_COINS) {
             numcoins = RS_MARKHALF_MAX_COINS;
         }
-        randpool->getRandomBytes((unsigned char *) coins,
+        PRB_buf.getRandomBytes((unsigned char *) coins,
             sizeof(coins[0])*numcoins);
         coinsleft = numcoins;
     }

+ 5 - 6
Enclave/OblivAlgs/utils.cpp

@@ -5,7 +5,7 @@
 thread_local uint64_t OSWAP_COUNTER=0;
 #endif
 
-PRB_buffer* PRB_pool;
+thread_local PRB_buffer PRB_buf;
 thread_local uint64_t PRB_rand_bits = 0;
 thread_local uint32_t PRB_rand_bits_remaining = 0;
 
@@ -494,19 +494,16 @@ size_t packetsConsumedUptoMSN(signed long msn_no, size_t msns_with_extra_packets
 
 #ifdef USE_PRB
   void PRB_pool_init(int nthreads) {
-    PRB_pool = new PRB_buffer[nthreads];
+    // Nothing needs to be done any more
   }
 
   void PRB_pool_shutdown() {
-    delete [] PRB_pool;
+    // Nothing needs to be done any more
   }
 
   PRB_buffer::PRB_buffer() {
   }
 
-  PRB_buffer::~PRB_buffer() {
-  }
-
   sgx_status_t PRB_buffer::init_PRB_buffer(uint32_t buffer_size = PRB_BUFFER_SIZE) {
     sgx_status_t rt = SGX_SUCCESS;
     if(initialized==false) {
@@ -577,6 +574,7 @@ size_t packetsConsumedUptoMSN(signed long msn_no, size_t msns_with_extra_packets
     }
   }
 
+/*
   sgx_status_t PRB_buffer::getBulkRandomBytes(unsigned char *buffer, size_t size) {
     sgx_status_t rt = SGX_SUCCESS;
     rt = sgx_aes_ctr_encrypt(random_seed, (const uint8_t*) buffer, size,
@@ -622,6 +620,7 @@ size_t packetsConsumedUptoMSN(signed long msn_no, size_t msns_with_extra_packets
     }
     return rt;
   }
+*/
 #else
   sgx_status_t getRandomBytes(unsigned char *random_bytes, size_t size) {
     sgx_status_t rt = SGX_SUCCESS;

+ 2 - 4
Enclave/OblivAlgs/utils.hpp

@@ -161,7 +161,6 @@
 
       public:
         PRB_buffer();
-        ~PRB_buffer();
         sgx_status_t init_PRB_buffer(uint32_t buffer_size);
         /*  Intended for getting random bytes of size << PRB_BUFFER_SIZE at a time.
          Draws random bytes from the (typically) pre-filled random_bytes[PRB_BUFFER_SIZE] 
@@ -174,7 +173,7 @@
         */
         sgx_status_t getBulkRandomBytes(unsigned char *random_bytes, size_t size);
     };
-    extern PRB_buffer* PRB_pool;
+    extern thread_local PRB_buffer PRB_buf;
 
     // Spawn a PRB pool for each thread
     void PRB_pool_init(int nthreads);
@@ -183,8 +182,7 @@
     
     inline sgx_status_t getRandomBytes(unsigned char *random_bytes, size_t size) {
       FOAV_SAFE_CNTXT(PRB, size)
-      FOAV_SAFE_CNTXT(PRB, g_thread_id)
-      return((PRB_pool[g_thread_id]).getRandomBytes(random_bytes, size));
+      return(PRB_buf.getRandomBytes(random_bytes, size));
     }
 
     // Return a random bit

+ 102 - 137
Enclave/comms.cpp

@@ -1,118 +1,82 @@
 #include <vector>
 #include <functional>
 #include <cstring>
+#include <pthread.h>
 
 #include "sgx_tcrypto.h"
 #include "sgx_tseal.h"
 #include "Enclave_t.h"
 #include "utils.hpp"
 #include "config.hpp"
+#include "route.hpp"
+#include "comms.hpp"
 
 // Our public and private identity keys
 static sgx_ec256_private_t g_privkey;
 static sgx_ec256_public_t g_pubkey;
 
-// What step of the handshake are we on?
-enum HandshakeStep {
-    HANDSHAKE_NONE,
-    HANDSHAKE_C_SENT_1,
-    HANDSHAKE_S_SENT_2,
-    HANDSHAKE_COMPLETE
-};
+// The communication states for all the nodes.  There's an entry for
+// ourselves in here, but it is unused.
+std::vector<NodeCommState> g_commstates;
 
-// Communication state for a node
-struct NodeCommState {
-    sgx_ec256_public_t pubkey;
-    nodenum_t node_num;
-    HandshakeStep handshake_step;
-
-    // Our DH keypair during the handshake
-    sgx_ec256_private_t handshake_dh_privkey;
-    sgx_ec256_public_t handshake_dh_pubkey;
-
-    // The server keeps this state between handshake messages 1 and 3
-    uint8_t handshake_cli_srv_mac[16];
-
-    // The outgoing and incoming AES keys after the handshake
-    sgx_aes_gcm_128bit_key_t out_aes_key, in_aes_key;
-
-    // The outgoing and incoming IV counters
-    uint8_t out_aes_iv[SGX_AESGCM_IV_SIZE];
-    uint8_t in_aes_iv[SGX_AESGCM_IV_SIZE];
-
-    // The GCM state for incrementally building each outgoing chunk
-    sgx_aes_state_handle_t out_aes_gcm_state;
-
-    // The current outgoing frame and the current offset into it
-    uint8_t *frame;
-    uint32_t frame_offset;
-
-    // The current outgoing message ciphertext size and the offset into
-    // it of the start of the current frame
-    uint32_t msg_size;
-    uint32_t msg_frame_offset;
-
-    // The current outgoing message plaintext size, how many plaintext
-    // bytes we've already processed with message_data, and how many
-    // plaintext bytes remain for the current chunk
-    uint32_t msg_plaintext_size;
-    uint32_t msg_plaintext_processed;
-    uint32_t msg_plaintext_chunk_remain;
-
-    // The current incoming message ciphertext size and the offset into
-    // it of all previous chunks of this message
-    uint32_t in_msg_size;
-    uint32_t in_msg_offset;
-    // The current incoming message number of plaintext bytes processed
-    uint32_t in_msg_plaintext_processed;
-    // The internal buffer where we're storing the (decrypted) message
-    uint8_t *in_msg_buf;
-
-    // The function to call when a new incoming message header arrives.
-    // This function should return a pointer to enough memory to hold
-    // the (decrypted) chunks of the message.  Remember that the length
-    // passed here is the total size of the _encrypted_ chunks.  This
-    // function should not itself modify the in_msg_size, in_msg_offset,
-    // or in_msg_buf members.  This function will usually allocate an
-    // appropriate amount of memory and return the pointer to it, but
-    // may do other things, like return a pointer to the middle of a
-    // previously allocated region of memory.
-    std::function<uint8_t*(NodeCommState&,uint32_t)> in_msg_get_buf;
-
-    // The function to call after the last chunk of a message has been
-    // received.  If in_msg_get_buf allocated memory, this function
-    // should deallocate it.  in_msg_size, in_msg_offset,
-    // in_msg_plaintext_processed, and in_msg_buf will already have been
-    // reset when this function is called.  The uint32_t that is passed
-    // are the total size of the _decrypted_ data and the original total
-    // size of the _encrypted_ chunks that was passed to in_msg_get_buf.
-    std::function<void(NodeCommState&,uint8_t*,uint32_t,uint32_t)>
-        in_msg_received;
-
-    NodeCommState(const sgx_ec256_public_t* conf_pubkey, nodenum_t i) :
-            node_num(i), handshake_step(HANDSHAKE_NONE),
-            out_aes_gcm_state(NULL), frame(NULL),
-            frame_offset(0), msg_size(0), msg_frame_offset(0),
-            msg_plaintext_size(0), msg_plaintext_processed(0),
-            msg_plaintext_chunk_remain(0),
-            in_msg_size(0), in_msg_offset(0),
-            in_msg_plaintext_processed(0), in_msg_buf(NULL),
-            in_msg_get_buf(NULL), in_msg_received(NULL) {
-        memmove(&pubkey, conf_pubkey, sizeof(pubkey));
-    }
-
-    void message_start(uint32_t plaintext_len, bool encrypt=true);
-
-    void message_data(uint8_t *data, uint32_t len, bool encrypt=true);
-
-    // Start the handshake (as the client)
-    void handshake_start();
-};
+static nodenum_t tot_nodes, my_node_num;
+static class CompletedHandshakeCounter {
+    // Mutex around completed_handshakes
+    pthread_mutex_t mutex;
+    // The number of completed handshakes
+    nodenum_t completed_handshakes;
+    // The callback pointer to use when all handshakes complete
+    void *complete_handshake_cbpointer;
+
+public:
+    CompletedHandshakeCounter() {
+        pthread_mutex_init(&mutex, NULL);
+        completed_handshakes = 0;
+        complete_handshake_cbpointer = NULL;
+        if (tot_nodes == 1) {
+            // There's no one to handshake with, so we're already done
+            pthread_mutex_lock(&mutex);
+            void *cbpointer = complete_handshake_cbpointer;
+            complete_handshake_cbpointer = NULL;
+            pthread_mutex_unlock(&mutex);
+            ocall_comms_ready(cbpointer);
+        }
+    }
+
+    void reset(void *cbpointer) {
+        pthread_mutex_lock(&mutex);
+        completed_handshakes = 0;
+        complete_handshake_cbpointer = cbpointer;
+        if (tot_nodes == 1) {
+            // There's no one to handshake with, so we're already done
+            complete_handshake_cbpointer = NULL;
+            pthread_mutex_unlock(&mutex);
+            ocall_comms_ready(cbpointer);
+        } else {
+            pthread_mutex_unlock(&mutex);
+        }
+    }
+
+    void inc() {
+        pthread_mutex_lock(&mutex);
+        ++completed_handshakes;
+        nodenum_t num_completed = completed_handshakes;
+        pthread_mutex_unlock(&mutex);
+        if (num_completed == tot_nodes - 1) {
+            pthread_mutex_lock(&mutex);
+            void *cbpointer = complete_handshake_cbpointer;
+            complete_handshake_cbpointer = NULL;
+            completed_handshakes = 0;
+            pthread_mutex_unlock(&mutex);
+            ocall_comms_ready(cbpointer);
+        }
+    }
+} completed_handshake_counter;
 
 // A typical default in_msg_get_buf handler.  It computes the maximum
 // possible size of the decrypted data, allocates that much memory, and
 // returns a pointer to it.
-static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
+uint8_t* default_in_msg_get_buf(NodeCommState &commst,
     uint32_t tot_enc_chunk_size)
 {
     uint32_t max_plaintext_bytes = tot_enc_chunk_size;
@@ -133,10 +97,12 @@ static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
     return new uint8_t[max_plaintext_bytes];
 }
 
-static void default_in_msg_received(NodeCommState &nodest,
+// An in_msg_received handler when we don't actually expect a message
+// from a given node at a given time.
+void unknown_in_msg_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
 {
-    printf("Received message of %u bytes from node %lu:\n",
+    printf("Received unknown message of %u bytes from node %lu:\n",
         plaintext_len, nodest.node_num);
     for (uint32_t i=0;i<plaintext_len;++i) {
         printf("%02x", data[i]);
@@ -263,13 +229,13 @@ static void handshake_1_msg_received(NodeCommState &nodest,
     // Send handshake message 2
     nodest.message_start(sizeof(our_dh_pubkey) + sizeof(srv_cli_sig),
         false);
-    nodest.message_data((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
+    nodest.message_data((const uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
         false);
-    nodest.message_data((uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
+    nodest.message_data((const uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
         false);
 }
 
-// Receive (at the client) the secong handshake message
+// Receive (at the client) the second handshake message
 static void handshake_2_msg_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
 {
@@ -382,18 +348,18 @@ static void handshake_2_msg_received(NodeCommState &nodest,
     memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
     nodest.handshake_step = HANDSHAKE_COMPLETE;
     nodest.in_msg_get_buf = default_in_msg_get_buf;
-    nodest.in_msg_received = default_in_msg_received;
+    nodest.in_msg_received = unknown_in_msg_received;
 
     // Send handshake message 3
     nodest.message_start(sizeof(cli_srv_sig), false);
-    nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
+    nodest.message_data((const uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
         false);
 
-    // Send a test message
-    nodest.message_start(12);
-    unsigned char buf[13];
-    memmove(buf, "Hello, world", 13);
-    nodest.message_data(buf, 12);
+    // Set the received message handler for routing
+    route_init_msg_handler(nodest.node_num);
+
+    // Mark the handshake as complete
+    completed_handshake_counter.inc();
 }
 
 static void handshake_3_msg_received(NodeCommState &nodest,
@@ -436,13 +402,13 @@ static void handshake_3_msg_received(NodeCommState &nodest,
     memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
     nodest.handshake_step = HANDSHAKE_COMPLETE;
     nodest.in_msg_get_buf = default_in_msg_get_buf;
-    nodest.in_msg_received = default_in_msg_received;
+    nodest.in_msg_received = unknown_in_msg_received;
+
+    // Set the received message handler for routing
+    route_init_msg_handler(nodest.node_num);
 
-    // Send a test message
-    nodest.message_start(12);
-    unsigned char buf[13];
-    memmove(buf, "Hello, world", 13);
-    nodest.message_data(buf, 12);
+    // Mark the handshake as complete
+    completed_handshake_counter.inc();
 }
 
 // Start a new outgoing message.  Pass the number of _plaintext_ bytes
@@ -484,7 +450,7 @@ void NodeCommState::message_start(uint32_t plaintext_len, bool encrypt)
 }
 
 // Process len bytes of plaintext data into the current message.
-void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
+void NodeCommState::message_data(const uint8_t *data, uint32_t len, bool encrypt)
 {
     while (len > 0) {
         if (msg_plaintext_chunk_remain == 0) {
@@ -501,7 +467,7 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
         }
         if (encrypt) {
             // Encrypt the data
-            sgx_aes_gcm128_enc_update(data, bytes_to_process,
+            sgx_aes_gcm128_enc_update((uint8_t*)data, bytes_to_process,
                 frame+frame_offset, out_aes_gcm_state);
         } else {
             // Just copy the plaintext data during the handshake
@@ -543,11 +509,6 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
     }
 }
 
-// The communication states for all the nodes.  There's an entry for
-// ourselves in here, but it is unused.
-static std::vector<NodeCommState> commstates;
-static nodenum_t tot_nodes, my_node_num;
-
 // Generate a new identity signature key.  Output the public key and the
 // sealed private key.  outsealedpriv must point to SEALEDPRIVKEY_SIZE =
 // sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of
@@ -617,9 +578,9 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     sgx_ecc_state_handle_t ecc_handle;
     sgx_ecc256_open_context(&ecc_handle);
 
-    commstates.clear();
+    g_commstates.clear();
     tot_nodes = 0;
-    commstates.reserve(num_nodes);
+    g_commstates.reserve(num_nodes);
     for (nodenum_t i=0; i<num_nodes; ++i) {
         // Check that the pubkey is valid
         int valid;
@@ -627,11 +588,11 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
                 ecc_handle, &valid) ||
                 !valid) {
             printf("Pubkey for node %hu invalid\n", i);
-            commstates.clear();
+            g_commstates.clear();
             sgx_ecc256_close_context(ecc_handle);
             return false;
         }
-        commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
+        g_commstates.emplace_back(&apinodeconfigs[i].pubkey, i);
     }
     sgx_ecc256_close_context(ecc_handle);
 
@@ -641,12 +602,12 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     // reflection attacks)
     for (nodenum_t i=0; i<num_nodes; ++i) {
         if (i == my_node_num) continue;
-        if (!memcmp(&commstates[i].pubkey,
-                &commstates[my_node_num].pubkey,
-                sizeof(commstates[i].pubkey))) {
+        if (!memcmp(&g_commstates[i].pubkey,
+                &g_commstates[my_node_num].pubkey,
+                sizeof(g_commstates[i].pubkey))) {
             printf("Pubkey %hu matches our own; possible reflection attack?\n",
                 i);
-            commstates.clear();
+            g_commstates.clear();
             return false;
         }
     }
@@ -659,8 +620,8 @@ bool comms_init_nodestate(const EnclaveAPINodeConfig *apinodeconfigs,
     // to receive the first handshake message from those nodes'
     // enclaves.
     for (nodenum_t i=0; i<my_node_num; ++i) {
-        commstates[i].in_msg_get_buf = default_in_msg_get_buf;
-        commstates[i].in_msg_received = handshake_1_msg_received;
+        g_commstates[i].in_msg_get_buf = default_in_msg_get_buf;
+        g_commstates[i].in_msg_received = handshake_1_msg_received;
     }
     return true;
 }
@@ -672,7 +633,7 @@ bool ecall_message(nodenum_t node_num, uint32_t message_len)
             node_num);
         return false;
     }
-    NodeCommState &nodest = commstates[node_num];
+    NodeCommState &nodest = g_commstates[node_num];
 
     if (nodest.in_msg_size != nodest.in_msg_offset) {
         printf("Received ecall_message without completing previous message\n");
@@ -702,7 +663,7 @@ bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
             node_num);
         return false;
     }
-    NodeCommState &nodest = commstates[node_num];
+    NodeCommState &nodest = g_commstates[node_num];
 
     if (nodest.in_msg_size == nodest.in_msg_offset) {
         printf("Received ecall_chunk after completing message\n");
@@ -776,15 +737,19 @@ void NodeCommState::handshake_start()
     // Send the public key as the first message
     message_start(sizeof(handshake_dh_pubkey), false);
 
-    message_data((uint8_t*)&handshake_dh_pubkey,
+    message_data((const uint8_t*)&handshake_dh_pubkey,
         sizeof(handshake_dh_pubkey), false);
 }
 
-// Start all handshakes for which we are the client
-bool ecall_comms_start()
+// Start all handshakes for which we are the client.  Call
+// ocall_comms_ready(cbpointer) when the handshakes with all other nodes
+// (for which we are client or server) are complete.
+bool ecall_comms_start(void *cbpointer)
 {
+    completed_handshake_counter.reset(cbpointer);
+
     for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
-        commstates[t].handshake_start();
+        g_commstates[t].handshake_start();
     }
     return true;
 }

+ 114 - 0
Enclave/comms.hpp

@@ -1,8 +1,122 @@
 #ifndef __COMMS_HPP__
 #define __COMMS_HPP__
 
+#include <vector>
+
 #include "enclave_api.h"
 
+// What step of the handshake are we on?
+enum HandshakeStep {
+    HANDSHAKE_NONE,
+    HANDSHAKE_C_SENT_1,
+    HANDSHAKE_S_SENT_2,
+    HANDSHAKE_COMPLETE
+};
+
+// Communication state for a node
+struct NodeCommState {
+    sgx_ec256_public_t pubkey;
+    nodenum_t node_num;
+    HandshakeStep handshake_step;
+
+    // Our DH keypair during the handshake
+    sgx_ec256_private_t handshake_dh_privkey;
+    sgx_ec256_public_t handshake_dh_pubkey;
+
+    // The server keeps this state between handshake messages 1 and 3
+    uint8_t handshake_cli_srv_mac[16];
+
+    // The outgoing and incoming AES keys after the handshake
+    sgx_aes_gcm_128bit_key_t out_aes_key, in_aes_key;
+
+    // The outgoing and incoming IV counters
+    uint8_t out_aes_iv[SGX_AESGCM_IV_SIZE];
+    uint8_t in_aes_iv[SGX_AESGCM_IV_SIZE];
+
+    // The GCM state for incrementally building each outgoing chunk
+    sgx_aes_state_handle_t out_aes_gcm_state;
+
+    // The current outgoing frame and the current offset into it
+    uint8_t *frame;
+    uint32_t frame_offset;
+
+    // The current outgoing message ciphertext size and the offset into
+    // it of the start of the current frame
+    uint32_t msg_size;
+    uint32_t msg_frame_offset;
+
+    // The current outgoing message plaintext size, how many plaintext
+    // bytes we've already processed with message_data, and how many
+    // plaintext bytes remain for the current chunk
+    uint32_t msg_plaintext_size;
+    uint32_t msg_plaintext_processed;
+    uint32_t msg_plaintext_chunk_remain;
+
+    // The current incoming message ciphertext size and the offset into
+    // it of all previous chunks of this message
+    uint32_t in_msg_size;
+    uint32_t in_msg_offset;
+    // The current incoming message number of plaintext bytes processed
+    uint32_t in_msg_plaintext_processed;
+    // The internal buffer where we're storing the (decrypted) message
+    uint8_t *in_msg_buf;
+
+    // The function to call when a new incoming message header arrives.
+    // This function should return a pointer to enough memory to hold
+    // the (decrypted) chunks of the message.  Remember that the length
+    // passed here is the total size of the _encrypted_ chunks.  This
+    // function should not itself modify the in_msg_size, in_msg_offset,
+    // or in_msg_buf members.  This function will usually allocate an
+    // appropriate amount of memory and return the pointer to it, but
+    // may do other things, like return a pointer to the middle of a
+    // previously allocated region of memory.
+    std::function<uint8_t*(NodeCommState&,uint32_t)> in_msg_get_buf;
+
+    // The function to call after the last chunk of a message has been
+    // received.  If in_msg_get_buf allocated memory, this function
+    // should deallocate it.  in_msg_size, in_msg_offset,
+    // in_msg_plaintext_processed, and in_msg_buf will already have been
+    // reset when this function is called.  The uint32_t that is passed
+    // are the total size of the _decrypted_ data and the original total
+    // size of the _encrypted_ chunks that was passed to in_msg_get_buf.
+    std::function<void(NodeCommState&,uint8_t*,uint32_t,uint32_t)>
+        in_msg_received;
+
+    NodeCommState(const sgx_ec256_public_t* conf_pubkey, nodenum_t i) :
+            node_num(i), handshake_step(HANDSHAKE_NONE),
+            out_aes_gcm_state(NULL), frame(NULL),
+            frame_offset(0), msg_size(0), msg_frame_offset(0),
+            msg_plaintext_size(0), msg_plaintext_processed(0),
+            msg_plaintext_chunk_remain(0),
+            in_msg_size(0), in_msg_offset(0),
+            in_msg_plaintext_processed(0), in_msg_buf(NULL),
+            in_msg_get_buf(NULL), in_msg_received(NULL) {
+        memmove(&pubkey, conf_pubkey, sizeof(pubkey));
+    }
+
+    void message_start(uint32_t plaintext_len, bool encrypt=true);
+
+    void message_data(const uint8_t *data, uint32_t len, bool encrypt=true);
+
+    // Start the handshake (as the client)
+    void handshake_start();
+};
+
+// The communication states for all the nodes.  There's an entry for
+// ourselves in here, but it is unused.
+extern std::vector<NodeCommState> g_commstates;
+
+// A typical default in_msg_get_buf handler.  It computes the maximum
+// possible size of the decrypted data, allocates that much memory, and
+// returns a pointer to it.
+uint8_t* default_in_msg_get_buf(NodeCommState &commst,
+    uint32_t tot_enc_chunk_size);
+
+// An in_msg_received handler when we don't actually expect a message
+// from a given node at a given time.
+void unknown_in_msg_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t);
+
 // The enclave-to-enclave communication protocol is as follows.  It
 // probably could just be attested TLS in a production environment, but
 // we're not implementing remote attestation at this time.  This means

+ 6 - 3
Enclave/config.cpp

@@ -81,10 +81,12 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
     g_teems_config.ingestion_nodes.clear();
     g_teems_config.routing_nodes.clear();
     g_teems_config.storage_nodes.clear();
+    g_teems_config.storage_map.clear();
     for (nodenum_t i=0; i<num_nodes; ++i) {
         NodeWeight nw;
         nw.startweight = cumul_weight;
-        nw.weight = apinodeconfigs[i].weight;
+        // Weights only matter for routing nodes
+        nw.weight = 0;
         if (apinodeconfigs[i].roles & ROLE_INGESTION) {
             g_teems_config.num_ingestion_nodes += 1;
             if (i < my_node_num) {
@@ -94,6 +96,7 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
             }
         }
         if (apinodeconfigs[i].roles & ROLE_ROUTING) {
+            nw.weight = apinodeconfigs[i].weight;
             g_teems_config.num_routing_nodes += 1;
             if (i < my_node_num) {
                 rte_smaller.push_back(i);
@@ -108,9 +111,11 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
             } else {
                 g_teems_config.storage_nodes.push_back(i);
             }
+            g_teems_config.storage_map.push_back(i);
         }
         cumul_weight += nw.weight;
         g_teems_config.weights.push_back(nw);
+        g_teems_config.roles.push_back(apinodeconfigs[i].roles);
         if (i == my_node_num) {
             g_teems_config.my_weight = nw.weight;
         }
@@ -136,7 +141,6 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
 
     // Initialize the threadpool and the pseudorandom bytes pools
     threadpool_init(nthreads);
-    PRB_pool_init(nthreads);
 
     if(apinodeconfigs[my_node_num].roles & ROLE_INGESTION) {
         sgx_aes_gcm_128bit_key_t ESK, TSK;
@@ -184,6 +188,5 @@ bool ecall_config_load(threadid_t nthreads, bool private_routing,
 
 void ecall_close()
 {
-    PRB_pool_shutdown();
     threadpool_shutdown();
 }

+ 4 - 0
Enclave/config.hpp

@@ -30,11 +30,15 @@ struct Config {
     uint8_t m_pub_in;
     uint8_t my_weight;
     bool private_routing;
+    std::vector<uint8_t> roles;
     std::vector<NodeWeight> weights;
     std::vector<nodenum_t> ingestion_nodes;
     std::vector<nodenum_t> routing_nodes;
     std::vector<nodenum_t> storage_nodes;
     sgx_aes_gcm_128bit_key_t master_secret;
+    // storage_map[i] is the node number of the storage node responsible
+    // for the destination adddresses with storage node field i.
+    std::vector<nodenum_t> storage_map;
 };
 
 extern Config g_teems_config;

+ 71 - 0
Enclave/obliv.cpp

@@ -0,0 +1,71 @@
+#include "oasm_lib.h"
+#include "enclave_api.h"
+#include "obliv.hpp"
+
+// Routines for processing private data obliviously
+
+// Obliviously tally the number of messages in the given buffer destined
+// for each storage node.  Each message is of size msg_size bytes.
+// There are num_msgs messages in the buffer.  There are
+// num_storage_nodes storage nodes in total.  The destination storage
+// node of each message is determined by looking at the top
+// DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
+// beginning of the message; this will be a number between 0 and
+// num_storage_nodes-1, which is not necessarily the node number of the
+// storage node, which may be larger if, for example, there are a bunch
+// of routing or ingestion nodes that are not also storage nodes.  The
+// return value is a vector of length num_storage_nodes containing the
+// tally.
+std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
+    uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes)
+{
+    // The _contents_ of buf are private, but everything else in the
+    // input is public.  The contents of the output tally (but not its
+    // length) are also private.
+    std::vector<uint32_t> tally(num_storage_nodes, 0);
+
+    // This part must all be oblivious except for the length checks on
+    // num_msgs and num_storage_nodes
+    while (num_msgs) {
+        uint32_t storage_node_id = (*(const uint32_t*)buf) >> DEST_UID_BITS;
+        for (uint32_t i=0; i<num_storage_nodes; ++i) {
+            tally[i] += (storage_node_id == i);
+        }
+        buf += msg_size;
+        --num_msgs;
+    }
+
+    return tally;
+}
+
+// Obliviously create padding messages destined for the various storage
+// nodes, using the (private) counts in the tally vector.  The tally
+// vector may be modified by this function.  tot_padding must be the sum
+// of the elements in tally, which need _not_ be private.
+void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
+    std::vector<uint32_t> &tally, uint32_t tot_padding)
+{
+    // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in
+    // the bottom DEST_UID_BITS.
+    uint32_t pad_user = (1<<DEST_UID_BITS)-1;
+
+    // This value is not oblivious
+    const uint32_t num_storage_nodes = uint32_t(tally.size());
+
+    // This part must all be oblivious except for the length checks on
+    // tot_padding and num_storage_nodes
+    while (tot_padding) {
+        bool found = false;
+        uint32_t found_node = 0;
+        for (uint32_t i=0; i<num_storage_nodes; ++i) {
+            bool found_here = (!found) & (!!tally[i]);
+            found_node = oselect_uint32_t(found_node, i, found_here);
+            found = found | found_here;
+            tally[i] -= found_here;
+        }
+        *(uint32_t*)buf = ((found_node<<DEST_UID_BITS) | pad_user);
+
+        buf += msg_size;
+        --tot_padding;
+    }
+}

+ 30 - 0
Enclave/obliv.hpp

@@ -0,0 +1,30 @@
+#ifndef __OBLIV_HPP__
+#define __OBLIV_HPP__
+
+#include <vector>
+
+// Routines for processing private data obliviously
+
+// Obliviously tally the number of messages in the given buffer destined
+// for each storage node.  Each message is of size msg_size bytes.
+// There are num_msgs messages in the buffer.  There are
+// num_storage_nodes storage nodes in total.  The destination storage
+// node of each message is determined by looking at the top
+// DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
+// beginning of the message; this will be a number between 0 and
+// num_storage_nodes-1, which is not necessarily the node number of the
+// storage node, which may be larger if, for example, there are a bunch
+// of routing or ingestion nodes that are not also storage nodes.  The
+// return value is a vector of length num_storage_nodes containing the
+// tally.
+std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
+    uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes);
+
+// Obliviously create padding messages destined for the various storage
+// nodes, using the (private) counts in the tally vector.  The tally
+// vector may be modified by this function.  tot_padding must be the sum
+// of the elements in tally, which need _not_ be private.
+void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
+    std::vector<uint32_t> &tally, uint32_t tot_padding);
+
+#endif

+ 509 - 49
Enclave/route.cpp

@@ -2,6 +2,9 @@
 #include "config.hpp"
 #include "utils.hpp"
 #include "sort.hpp"
+#include "comms.hpp"
+#include "obliv.hpp"
+#include "storage.hpp"
 #include "route.hpp"
 
 #define PROFILE_ROUTING
@@ -12,18 +15,22 @@ enum RouteStep {
     ROUTE_ROUND_2
 };
 
-// The round1 MsgBuffer stores messages we ingest while waiting for
-// round 1 to start, which will be sorted and sent out in round 1.  The
-// round2 MsgBuffer stores messages we receive in round 1, which will be
-// padded, sorted, and sent out in round 2.
+// The ingbuf MsgBuffer stores messages an ingestion node ingests while
+// waiting for round 1 to start, which will be sorted and sent out in
+// round 1.  The round1 MsgBuffer stores messages a routing node
+// receives in round 1, which will be padded, sorted, and sent out in
+// round 2.  The round2 MsgBuffer stores messages a storage node
+// receives in round 2.
 
 static struct RouteState {
+    MsgBuffer ingbuf;
     MsgBuffer round1;
     MsgBuffer round2;
     RouteStep step;
     uint32_t tot_msg_per_ing;
-    uint32_t max_msg_to_each_str;
+    uint32_t max_msg_to_each_stg;
     uint32_t max_round2_msgs;
+    void *cbpointer;
 } route_state;
 
 // Computes ceil(x/y) where x and y are integers, x>=0, y>0.
@@ -58,25 +65,25 @@ bool route_init()
     // Compute the maximum number of messages we could send in round 2
 
     // Each storage node has at most this many users
-    uint32_t users_per_str = CEILDIV(g_teems_config.user_count,
+    uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
         g_teems_config.num_storage_nodes);
 
     // And so can receive at most this many messages
-    uint32_t tot_msg_per_str = users_per_str *
+    uint32_t tot_msg_per_stg = users_per_stg *
         g_teems_config.m_priv_in;
 
     // Which will be at most this many from us
-    uint32_t max_msg_to_each_str = CEILDIV(tot_msg_per_str,
+    uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
         g_teems_config.tot_weight) * g_teems_config.my_weight;
 
     // But we can't send more messages to each storage server than we
     // could receive in total
-    if (max_msg_to_each_str > max_round1_msgs) {
-        max_msg_to_each_str = max_round1_msgs;
+    if (max_msg_to_each_stg > max_round1_msgs) {
+        max_msg_to_each_stg = max_round1_msgs;
     }
 
     // And the max total number of outgoing messages in round 2 is then
-    uint32_t max_round2_msgs = max_msg_to_each_str *
+    uint32_t max_round2_msgs = max_msg_to_each_stg *
         g_teems_config.num_storage_nodes;
 
     // In case we have a weird configuration where users can send more
@@ -92,17 +99,27 @@ bool route_init()
     */
 
     // Create the route state
+    uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
     try {
-        route_state.round1.alloc(tot_msg_per_ing);
-        route_state.round2.alloc(max_round2_msgs);
+        if (my_roles & ROLE_INGESTION) {
+            route_state.ingbuf.alloc(tot_msg_per_ing);
+        }
+        if (my_roles & ROLE_ROUTING) {
+            route_state.round1.alloc(max_round2_msgs);
+        }
+        if (my_roles & ROLE_STORAGE) {
+            route_state.round2.alloc(tot_msg_per_stg +
+                g_teems_config.tot_weight);
+        }
     } catch (std::bad_alloc&) {
         printf("Memory allocation failed in route_init\n");
         return false;
     }
     route_state.step = ROUTE_NOT_STARTED;
     route_state.tot_msg_per_ing = tot_msg_per_ing;
-    route_state.max_msg_to_each_str = max_msg_to_each_str;
+    route_state.max_msg_to_each_stg = max_msg_to_each_stg;
     route_state.max_round2_msgs = max_round2_msgs;
+    route_state.cbpointer = NULL;
 
     threadid_t nthreads = g_teems_config.nthreads;
 #ifdef PROFILE_ROUTING
@@ -152,80 +169,523 @@ size_t ecall_precompute_sort(int sizeidx)
     return ret;
 }
 
-// Directly ingest a buffer of num_msgs messages into the round1 buffer.
+static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
+        NodeCommState &, uint32_t tot_enc_chunk_size)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+
+    // Chunks will be encrypted and have a MAC tag attached which will
+    // not correspond to plaintext bytes, so we can trim them.
+
+    // The minimum number of chunks needed to transmit this message
+    uint32_t min_num_chunks =
+        (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
+    // The number of plaintext bytes this message could contain
+    uint32_t plaintext_bytes = tot_enc_chunk_size -
+        SGX_AESGCM_MAC_SIZE * min_num_chunks;
+
+    assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
+
+    uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
+
+    pthread_mutex_lock(&msgbuf.mutex);
+    uint32_t start = msgbuf.reserved;
+    if (start + num_msgs > msgbuf.bufsize) {
+        pthread_mutex_unlock(&msgbuf.mutex);
+        printf("Max %u messages exceeded\n", msgbuf.bufsize);
+        return NULL;
+    }
+    msgbuf.reserved += num_msgs;
+    pthread_mutex_unlock(&msgbuf.mutex);
+
+    return msgbuf.buf + start * msg_size;
+}
+
+static void round2_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t);
+
+// A round 1 message was received by a routing node from an ingestion
+// node; we put it into the round 2 buffer for processing in round 2
+static void round1_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
+
+    pthread_mutex_lock(&route_state.round1.mutex);
+    route_state.round1.inserted += num_msgs;
+    route_state.round1.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round1.nodes_received;
+    bool completed_prev_round = route_state.round1.completed_prev_round;
+    pthread_mutex_unlock(&route_state.round1.mutex);
+
+    // What is the next message we expect from this node?
+    if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round2, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round2_received;
+    }
+    // Otherwise, it's just the next round 1 message, so don't change
+    // the handlers.
+
+    if (nodes_received == g_teems_config.num_ingestion_nodes &&
+            completed_prev_round) {
+        route_state.step = ROUTE_ROUND_1;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, 1);
+    }
+}
+
+// A round 2 message was received by a storage node from a routing node
+static void round2_received(NodeCommState &nodest,
+    uint8_t *data, uint32_t plaintext_len, uint32_t)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[nodest.node_num];
+
+    pthread_mutex_lock(&route_state.round2.mutex);
+    route_state.round2.inserted += num_msgs;
+    route_state.round2.nodes_received += 1;
+    nodenum_t nodes_received = route_state.round2.nodes_received;
+    bool completed_prev_round = route_state.round2.completed_prev_round;
+    pthread_mutex_unlock(&route_state.round2.mutex);
+
+    // What is the next message we expect from this node?
+    if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1_received;
+    }
+    // Otherwise, it's just the next round 2 message, so don't change
+    // the handlers.
+
+    if (nodes_received == g_teems_config.num_routing_nodes &&
+            completed_prev_round) {
+        route_state.step = ROUTE_ROUND_2;
+        void *cbpointer = route_state.cbpointer;
+        route_state.cbpointer = NULL;
+        ocall_routing_round_complete(cbpointer, 2);
+    }
+}
+
+// For a given other node, set the received message handler to the first
+// message we would expect from them, given their roles and our roles.
+void route_init_msg_handler(nodenum_t node_num)
+{
+    // Our roles and their roles
+    uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
+    uint8_t their_roles = g_teems_config.roles[node_num];
+
+    // The node communication state
+    NodeCommState &nodest = g_commstates[node_num];
+
+    // If we are a routing node (possibly among other roles) and they
+    // are an ingestion node (possibly among other roles), a round 1
+    // routing message is the first thing we expect from them.  We put
+    // these messages into the round1 buffer for processing.
+    if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round1, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round1_received;
+    }
+    // Otherwise, if we are a storage node (possibly among other roles)
+    // and they are a routing node (possibly among other roles), a round
+    // 2 routing message is the first thing we expect from them
+    else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
+        nodest.in_msg_get_buf = [&](NodeCommState &commst,
+                uint32_t tot_enc_chunk_size) {
+            return msgbuffer_get_buf(route_state.round2, commst,
+                tot_enc_chunk_size);
+        };
+        nodest.in_msg_received = round2_received;
+    }
+    // Otherwise, we don't expect a message from this node. Set the
+    // unknown message handler.
+    else {
+        nodest.in_msg_get_buf = default_in_msg_get_buf;
+        nodest.in_msg_received = unknown_in_msg_received;
+    }
+}
+
+// Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
 // Return true on success, false on failure.
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 {
     uint16_t msg_size = g_teems_config.msg_size;
-    MsgBuffer &round1 = route_state.round1;
+    MsgBuffer &ingbuf = route_state.ingbuf;
 
-    pthread_mutex_lock(&round1.mutex);
-    uint32_t start = round1.reserved;
+    pthread_mutex_lock(&ingbuf.mutex);
+    uint32_t start = ingbuf.reserved;
     if (start + num_msgs > route_state.tot_msg_per_ing) {
-        pthread_mutex_unlock(&round1.mutex);
+        pthread_mutex_unlock(&ingbuf.mutex);
         printf("Max %u messages exceeded\n",
             route_state.tot_msg_per_ing);
         return false;
     }
-    round1.reserved += num_msgs;
-    pthread_mutex_unlock(&round1.mutex);
+    ingbuf.reserved += num_msgs;
+    pthread_mutex_unlock(&ingbuf.mutex);
 
-    memmove(round1.buf + start * msg_size,
+    memmove(ingbuf.buf + start * msg_size,
         msgs, num_msgs * msg_size);
 
-    pthread_mutex_lock(&round1.mutex);
-    round1.inserted += num_msgs;
-    pthread_mutex_unlock(&round1.mutex);
+    pthread_mutex_lock(&ingbuf.mutex);
+    ingbuf.inserted += num_msgs;
+    pthread_mutex_unlock(&ingbuf.mutex);
 
     return true;
 }
 
-// Send the round 1 messages
+// Send the round 1 messages.  Note that N here is not private.
 static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
     uint32_t N)
 {
     uint16_t msg_size = g_teems_config.msg_size;
     uint16_t tot_weight = g_teems_config.tot_weight;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
 
-    /*
-    for (uint32_t i=0;i<N;++i) {
-        const uint8_t *msg = msgs + indices[i]*msg_size;
-        for (uint16_t j=0;j<msg_size/4;++j) {
-            printf("%08x ", ((const uint32_t*)msg)[j]);
+    uint32_t full_rows = N / uint32_t(tot_weight);
+    uint32_t last_row = N % uint32_t(tot_weight);
+
+    for (auto &routing_node: g_teems_config.routing_nodes) {
+        uint8_t weight =
+            g_teems_config.weights[routing_node].weight;
+        if (weight == 0) {
+            // This shouldn't happen, but just in case
+            continue;
+        }
+        uint16_t start_weight =
+            g_teems_config.weights[routing_node].startweight;
+
+        // The number of messages headed for this routing node from the
+        // full rows
+        uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
+        // The number of messages headed for this routing node from the
+        // incomplete last row is:
+        // 0 if last_row < start_weight
+        // last_row-start_weight if start_weight <= last_row < start_weight + weight
+        // weight if start_weight + weight <= last_row
+        uint32_t num_msgs_last_row = 0;
+        if (start_weight <= last_row && last_row < start_weight + weight) {
+            num_msgs_last_row = last_row-start_weight;
+        } else if (start_weight + weight <= last_row) {
+            num_msgs_last_row = weight;
+        }
+        // The total number of messages headed for this routing node
+        uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
+
+        if (routing_node == my_node_num) {
+            // Special case: we're sending to ourselves; just put the
+            // messages in our own round1 buffer
+            MsgBuffer &round1 = route_state.round1;
+
+            pthread_mutex_lock(&round1.mutex);
+            uint32_t start = round1.reserved;
+            if (start + num_msgs > round1.bufsize) {
+                pthread_mutex_unlock(&round1.mutex);
+                printf("Max %u messages exceeded\n", round1.bufsize);
+                return;
+            }
+            round1.reserved += num_msgs;
+            pthread_mutex_unlock(&round1.mutex);
+            uint8_t *buf = round1.buf + start * msg_size;
+
+            for (uint32_t i=0; i<full_rows; ++i) {
+                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                for (uint32_t j=0; j<weight; ++j) {
+                    memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                    buf += msg_size;
+                }
+            }
+            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                memmove(buf, msgs + idxp[j]*msg_size, msg_size);
+                buf += msg_size;
+            }
+
+            pthread_mutex_lock(&round1.mutex);
+            round1.inserted += num_msgs;
+            round1.nodes_received += 1;
+            pthread_mutex_unlock(&round1.mutex);
+
+        } else {
+            NodeCommState &nodecom = g_commstates[routing_node];
+            nodecom.message_start(num_msgs * msg_size);
+            for (uint32_t i=0; i<full_rows; ++i) {
+                const uint64_t *idxp = indices + i*tot_weight + start_weight;
+                for (uint32_t j=0; j<weight; ++j) {
+                    nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+                }
+            }
+            const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
+            for (uint32_t j=0; j<num_msgs_last_row; ++j) {
+                nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
+            }
         }
-        printf("\n");
     }
-    */
+}
+
+// Send the round 2 messages from the round 1 buffer, which are already
+// padded and shuffled, so this can be done non-obliviously.  tot_msgs
+// is the total number of messages in the input buffer, which may
+// include padding messages added by the shuffle.  Those messages are
+// not sent anywhere.  There are num_msgs_per_stg messages for each
+// storage node labelled for that node.
+static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
+{
+    uint16_t msg_size = g_teems_config.msg_size;
+    MsgBuffer &round1 = route_state.round1;
+    const uint8_t* buf = round1.buf;
+    nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    uint8_t *myself_buf = NULL;
+
+    for (nodenum_t i=0; i<num_storage_nodes; ++i) {
+        nodenum_t node = g_teems_config.storage_nodes[i];
+        if (node != my_node_num) {
+            g_commstates[node].message_start(msg_size * num_msgs_per_stg);
+        } else {
+            MsgBuffer &round2 = route_state.round2;
+            pthread_mutex_lock(&round2.mutex);
+            uint32_t start = round2.reserved;
+            if (start + num_msgs_per_stg > round2.bufsize) {
+                pthread_mutex_unlock(&round2.mutex);
+                printf("Max %u messages exceeded\n", round2.bufsize);
+                return;
+            }
+            round2.reserved += num_msgs_per_stg;
+            pthread_mutex_unlock(&round2.mutex);
+            myself_buf = round2.buf + start * msg_size;
+        }
+    }
+
+    while (tot_msgs) {
+        nodenum_t storage_node_id =
+            nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
+        if (storage_node_id < num_storage_nodes) {
+            nodenum_t node = g_teems_config.storage_map[storage_node_id];
+            if (node == my_node_num) {
+                memmove(myself_buf, buf, msg_size);
+                myself_buf += msg_size;
+            } else {
+                g_commstates[node].message_data(buf, msg_size);
+            }
+        }
+
+        buf += msg_size;
+        --tot_msgs;
+    }
+
+    if (myself_buf) {
+        MsgBuffer &round2 = route_state.round2;
+        pthread_mutex_lock(&round2.mutex);
+        round2.inserted += num_msgs_per_stg;
+        round2.nodes_received += 1;
+        pthread_mutex_unlock(&round2.mutex);
+    }
 }
 
 // Perform the next round of routing.  The callback pointer will be
 // passed to ocall_routing_round_complete when the round is complete.
 void ecall_routing_proceed(void *cbpointer)
 {
-    if (route_state.step == ROUTE_NOT_STARTED) {
+    uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
 
-        MsgBuffer &round1 = route_state.round1;
+    if (route_state.step == ROUTE_NOT_STARTED) {
+        if (my_roles & ROLE_INGESTION) {
+            route_state.cbpointer = cbpointer;
+            MsgBuffer &ingbuf = route_state.ingbuf;
+            MsgBuffer &round1 = route_state.round1;
+
+            pthread_mutex_lock(&ingbuf.mutex);
+            // Ensure there are no pending messages currently being inserted
+            // into the buffer
+            while (ingbuf.reserved != ingbuf.inserted) {
+                pthread_mutex_unlock(&ingbuf.mutex);
+                pthread_mutex_lock(&ingbuf.mutex);
+            }
+            // Sort the messages we've received
+#ifdef PROFILE_ROUTING
+            uint32_t inserted = ingbuf.inserted;
+            unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
+            unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
+#endif
+            sort_mtobliv(g_teems_config.nthreads, ingbuf.buf,
+                g_teems_config.msg_size, ingbuf.inserted,
+                route_state.tot_msg_per_ing, send_round1_msgs);
+#ifdef PROFILE_ROUTING
+            printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
+            printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
+#endif
+            ingbuf.reset();
+            pthread_mutex_unlock(&ingbuf.mutex);
 
-        pthread_mutex_lock(&round1.mutex);
-        // Ensure there are no pending messages currently being inserted
-        // into the buffer
-        while (round1.reserved != round1.inserted) {
+            pthread_mutex_lock(&round1.mutex);
+            round1.completed_prev_round = true;
+            nodenum_t nodes_received = round1.nodes_received;
             pthread_mutex_unlock(&round1.mutex);
+
+            if (nodes_received == g_teems_config.num_ingestion_nodes) {
+                route_state.step = ROUTE_ROUND_1;
+                route_state.cbpointer = NULL;
+                ocall_routing_round_complete(cbpointer, 1);
+            }
+        } else {
+            route_state.step = ROUTE_ROUND_1;
+            ocall_routing_round_complete(cbpointer, 1);
+        }
+    } else if (route_state.step == ROUTE_ROUND_1) {
+        if (my_roles & ROLE_ROUTING) {
+            route_state.cbpointer = cbpointer;
+            MsgBuffer &round1 = route_state.round1;
+            MsgBuffer &round2 = route_state.round2;
+
             pthread_mutex_lock(&round1.mutex);
+            // Ensure there are no pending messages currently being inserted
+            // into the buffer
+            while (round1.reserved != round1.inserted) {
+                pthread_mutex_unlock(&round1.mutex);
+                pthread_mutex_lock(&round1.mutex);
+            }
+
+            // If the _total_ number of messages we received in round 1
+            // is less than the max number of messages we could send to
+            // _each_ storage node, then cap the number of messages we
+            // will send to each storage node to that number.
+            uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
+            if (round1.inserted < msgs_per_stg) {
+                msgs_per_stg = round1.inserted;
+            }
+
+            // Note: at this point, it is required that each message in
+            // the round1 buffer have a _valid_ storage node id field.
+
+            // Obliviously tally the number of messages we received in
+            // round1 destined for each storage node
+#ifdef PROFILE_ROUTING
+            uint32_t inserted = round1.inserted;
+            unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
+            unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
+#endif
+            uint16_t msg_size = g_teems_config.msg_size;
+            nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
+            std::vector<uint32_t> tally = obliv_tally_stg(
+                round1.buf, msg_size, round1.inserted, num_storage_nodes);
+#ifdef PROFILE_ROUTING
+            printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
+#endif
+
+            // Note: tally contains private values!  It's OK to
+            // non-obliviously check for an error condition, though.
+            // While we're at it, obliviously change the tally of
+            // messages received to a tally of padding messages
+            // required.
+            uint32_t tot_padding = 0;
+            for (nodenum_t i=0; i<num_storage_nodes; ++i) {
+                if (tally[i] > msgs_per_stg) {
+                    printf("Received too many messages for storage node %u\n", i);
+                    assert(tally[i] <= msgs_per_stg);
+                }
+                tally[i] = msgs_per_stg - tally[i];
+                tot_padding += tally[i];
+            }
+
+            round1.reserved += tot_padding;
+            assert(round1.reserved <= round1.bufsize);
+
+            // Obliviously add padding for each storage node according
+            // to the (private) padding tally.
+#ifdef PROFILE_ROUTING
+            unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
+#endif
+            obliv_pad_stg(round1.buf + round1.inserted * msg_size,
+                msg_size, tally, tot_padding);
+#ifdef PROFILE_ROUTING
+            printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
+#endif
+
+            round1.inserted += tot_padding;
+
+            // Obliviously shuffle the messages
+#ifdef PROFILE_ROUTING
+            unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
+#endif
+            uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
+                round1.buf, msg_size, round1.inserted, round1.bufsize);
+#ifdef PROFILE_ROUTING
+            printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
+            printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
+#endif
+
+            // Now we can handle the messages non-obliviously, since we
+            // know there will be exactly msgs_per_stg messages to each
+            // storage node, and the oblivious shuffle broke the
+            // connection between where each message came from and where
+            // it's going.
+            send_round2_msgs(num_shuffled, msgs_per_stg);
+
+            round1.reset();
+            pthread_mutex_unlock(&round1.mutex);
+
+            pthread_mutex_lock(&round2.mutex);
+            round2.completed_prev_round = true;
+            nodenum_t nodes_received = round2.nodes_received;
+            pthread_mutex_unlock(&round2.mutex);
+
+            if (nodes_received == g_teems_config.num_routing_nodes) {
+                route_state.step = ROUTE_ROUND_2;
+                route_state.cbpointer = NULL;
+                ocall_routing_round_complete(cbpointer, 2);
+            }
+        } else {
+            route_state.step = ROUTE_ROUND_2;
+            ocall_routing_round_complete(cbpointer, 2);
         }
-        // Sort the messages we've received
+    } else if (route_state.step == ROUTE_ROUND_2) {
+        if (my_roles & ROLE_STORAGE) {
+            MsgBuffer &round2 = route_state.round2;
+
+            pthread_mutex_lock(&round2.mutex);
+            // Ensure there are no pending messages currently being inserted
+            // into the buffer
+            while (round2.reserved != round2.inserted) {
+                pthread_mutex_unlock(&round2.mutex);
+                pthread_mutex_lock(&round2.mutex);
+            }
+
 #ifdef PROFILE_ROUTING
-        uint32_t inserted = round1.inserted;
-        unsigned long start = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
+            unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
 #endif
-        sort_mtobliv(g_teems_config.nthreads, round1.buf,
-            g_teems_config.msg_size, round1.inserted,
-            route_state.tot_msg_per_ing, send_round1_msgs);
-        round1.reset();
+            storage_received(round2.buf, round2.inserted);
 #ifdef PROFILE_ROUTING
-        printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
+            printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
 #endif
-        route_state.step = ROUTE_ROUND_1;
-        ocall_routing_round_complete(cbpointer, 1);
+
+            round2.reset();
+            pthread_mutex_unlock(&round2.mutex);
+
+            // We're done
+            route_state.step = ROUTE_NOT_STARTED;
+            ocall_routing_round_complete(cbpointer, 0);
+        } else {
+            // We're done
+            route_state.step = ROUTE_NOT_STARTED;
+            ocall_routing_round_complete(cbpointer, 0);
+        }
     }
 }

+ 18 - 1
Enclave/route.hpp

@@ -11,8 +11,15 @@ struct MsgBuffer {
     uint32_t reserved;
     // The number of messages definitely in the buffer
     uint32_t inserted;
+    // The number of messages that can fit in buf
+    uint32_t bufsize;
+    // The number of nodes we've heard from
+    nodenum_t nodes_received;
+    // Have we completed the previous round yet?
+    bool completed_prev_round;
 
-    MsgBuffer() : buf(NULL), reserved(0), inserted(0) {
+    MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0),
+            nodes_received(0), completed_prev_round(false) {
         pthread_mutex_init(&mutex, NULL);
     }
 
@@ -28,6 +35,10 @@ struct MsgBuffer {
         inserted = 0;
         // This may throw bad_alloc, but we'll catch it higher up
         buf = new uint8_t[size_t(msgs) * g_teems_config.msg_size];
+        memset(buf, 0, size_t(msgs) * g_teems_config.msg_size);
+        bufsize = msgs;
+        nodes_received = 0;
+        completed_prev_round = false;
     }
 
     // Reset the contents of the buffer
@@ -35,6 +46,8 @@ struct MsgBuffer {
         memset(buf, 0, inserted * g_teems_config.msg_size);
         reserved = 0;
         inserted = 0;
+        nodes_received = 0;
+        completed_prev_round = false;
     }
 
     // You can't copy a MsgBuffer
@@ -46,4 +59,8 @@ struct MsgBuffer {
 // comms_init_nodestate. Returns true on success, false on failure.
 bool route_init();
 
+// For a given other node, set the received message handler to the first
+// message we would expect from them, given their roles and our roles.
+void route_init_msg_handler(nodenum_t node_num);
+
 #endif

+ 22 - 9
Enclave/sort.cpp

@@ -69,20 +69,18 @@ void sort_precompute_evalplan(uint32_t N, threadid_t nthreads)
     pthread_mutex_unlock(&precomp_eps.mutex);
 }
 
-// Perform the sort using up to nthreads threads.  The items to sort are
-// byte arrays of size msg_size.  The key is the first 4 bytes of each
-// item.
-void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
-    uint32_t Nr, uint32_t Na,
-    // the arguments to the callback are items, the sorted indices, and
-    // the number of non-padding items
-    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)> cb)
+// Shuffle Nr items at the beginning of an allocated array of Na items
+// using up to nthreads threads.  The items to shuffle are byte arrays
+// of size msg_size.  Return Nw, the size of the Waksman network we
+// used, which must satisfy Nr <= Nw <= Na.
+uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na)
 {
     // Find the smallest Nw for which we have a precomputed
     // WaksmanNetwork with Nr <= Nw <= Na
     pthread_mutex_lock(&precomp_wns.mutex);
     std::optional<WaksmanNetwork> wn;
-    uint32_t Nw;
+    uint32_t Nw = 0;
     for (auto& N : precomp_wns.sized_wns) {
         if (N.first > Na) {
             printf("No precomputed WaksmanNetworks of size at most %u\n", Na);
@@ -127,6 +125,21 @@ void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     wn.value().applyInversePermutation<OSWAP_16X>(
         items, msg_size, eval_plan);
 
+    return Nw;
+}
+
+// Perform the sort using up to nthreads threads.  The items to sort are
+// byte arrays of size msg_size.  The key is the 10-bit storage server
+// id concatenated with the 22-bit uid at the storage server.
+void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na,
+    // the arguments to the callback are items, the sorted indices, and
+    // the number of non-padding items
+    std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)> cb)
+{
+    // Shuffle the items
+    uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
+
     // Create the indices
     uint64_t *idx = new uint64_t[Nr];
     uint64_t *nextidx = idx;

+ 8 - 1
Enclave/sort.hpp

@@ -34,9 +34,16 @@ size_t sort_precompute(uint32_t N);
 // background thread.
 void sort_precompute_evalplan(uint32_t N, threadid_t nthreads);
 
+// Shuffle Nr items at the beginning of an allocated array of Na items
+// using up to nthreads threads.  The items to shuffle are byte arrays
+// of size msg_size.  Return Nw, the size of the Waksman network we
+// used, which must satisfy Nr <= Nw <= Na.
+uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
+    uint32_t Nr, uint32_t Na);
+
 // Perform the sort using up to nthreads threads.  The items to sort are
 // byte arrays of size msg_size.  The key is the 10-bit storage server
-// id contatenated with the 22-bit uid at the storage server.
+// id concatenated with the 22-bit uid at the storage server.
 void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
     uint32_t Nr, uint32_t Na,
     // the arguments to the callback are items, the sorted indices, and

+ 34 - 0
Enclave/storage.cpp

@@ -0,0 +1,34 @@
+#include "utils.hpp"
+#include "config.hpp"
+#include "storage.hpp"
+
+// Handle the messages received by a storage node
+void storage_received(const uint8_t *msgs, uint32_t num_msgs)
+{
+    // A dummy function for now that just counts how many real and
+    // padding messages arrived
+    uint16_t msg_size = g_teems_config.msg_size;
+    nodenum_t my_node_num = g_teems_config.my_node_num;
+    uint32_t real = 0, padding = 0;
+    uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
+
+    printf("Storage server received %u messages:\n", num_msgs);
+    for (uint32_t i=0; i<num_msgs; ++i) {
+        uint32_t dest_addr = *(const uint32_t*)msgs;
+        nodenum_t dest_node =
+            g_teems_config.storage_map[dest_addr >> DEST_UID_BITS];
+        if (dest_node != my_node_num) {
+            char hexbuf[2*msg_size + 1];
+            for (uint32_t j=0;j<msg_size;++j) {
+                snprintf(hexbuf+2*j, 3, "%02x", msgs[j]);
+            }
+            printf("Misrouted message: %s\n", hexbuf);
+        } else if ((dest_addr & uid_mask) == uid_mask) {
+            ++padding;
+        } else {
+            ++real;
+        }
+        msgs += msg_size;
+    }
+    printf("%u real, %u padding\n", real, padding);
+}

+ 9 - 0
Enclave/storage.hpp

@@ -0,0 +1,9 @@
+#ifndef __STORAGE_HPP__
+#define __STORAGE_HPP__
+
+#include <cstdint>
+
+// Handle the messages received by a storage node
+void storage_received(const uint8_t *msgs, uint32_t num_msgs);
+
+#endif

+ 13 - 4
Makefile

@@ -270,6 +270,11 @@ Enclave/%.o: Enclave/%.cpp
 	@echo "CXX  <=  $<"
 	@$(CXX) $(SGX_COMMON_CXXFLAGS) $(Enclave_Cpp_Flags) -c $< -o $@
 
+Enclave/asm/%.s: Enclave/%.cpp
+	@echo "CXXASM  <=  $<"
+	@mkdir -p $$(dirname $@)
+	@$(CXX) $(SGX_COMMON_CXXFLAGS) $(Enclave_Cpp_Flags) -S $< -o $@
+
 $(Enclave_Cpp_Objects): Enclave/Enclave_t.h
 
 $(Enclave_Name): Enclave/Enclave_t.o $(Enclave_Cpp_Objects)
@@ -313,22 +318,26 @@ App/appconfig.o: Untrusted/Untrusted.hpp Enclave/enclave_api.h
 App/appconfig.o: App/appconfig.hpp
 App/net.o: Untrusted/Enclave_u.h Enclave/enclave_api.h
 App/net.o: Untrusted/Untrusted.hpp App/net.hpp App/appconfig.hpp
-App/start.o: App/start.hpp App/net.hpp App/appconfig.hpp
-App/start.o: Enclave/enclave_api.h
+App/start.o: Untrusted/Untrusted.hpp Enclave/enclave_api.h App/start.hpp
+App/start.o: App/net.hpp App/appconfig.hpp
 App/teems.o: Untrusted/Untrusted.hpp Enclave/enclave_api.h App/appconfig.hpp
 App/teems.o: App/net.hpp App/start.hpp
 Untrusted/Untrusted.o: Untrusted/Untrusted.hpp Enclave/enclave_api.h
 Untrusted/Untrusted.o: Untrusted/Enclave_u.h
 
 Enclave/comms.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
-Enclave/comms.o: Enclave/enclave_api.h
+Enclave/comms.o: Enclave/enclave_api.h Enclave/route.hpp Enclave/comms.hpp
 Enclave/config.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/comms.hpp
 Enclave/config.o: Enclave/enclave_api.h Enclave/config.hpp Enclave/route.hpp Enclave/ingest.hpp
+Enclave/obliv.o: Enclave/enclave_api.h Enclave/obliv.hpp
 Enclave/route.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
-Enclave/route.o: Enclave/enclave_api.h Enclave/route.hpp
+Enclave/route.o: Enclave/enclave_api.h Enclave/sort.hpp Enclave/comms.hpp
+Enclave/route.o: Enclave/obliv.hpp Enclave/storage.hpp Enclave/route.hpp
 Enclave/ingest.o: Enclave/Enclave_t.h Enclave/enclave_api.h Enclave/config.hpp
 Enclave/ingest.o: Enclave/route.hpp
 Enclave/sort.o: Enclave/sort.hpp
+Enclave/storage.o: Enclave/config.hpp Enclave/enclave_api.h
+Enclave/storage.o: Enclave/storage.hpp
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/CONFIG.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.tcc

+ 13 - 2
Untrusted/Untrusted.cpp

@@ -242,10 +242,12 @@ void ecall_close()
     ecall_close(global_eid);
 }
 
-bool ecall_comms_start()
+bool ecall_comms_start(std::function<void(void)> cb)
 {
+    std::function<void(void)> *p = new std::function<void(void)>;
+    *p = cb;
     bool ret;
-    ecall_comms_start(global_eid, &ret);
+    ecall_comms_start(global_eid, &ret, p);
     return ret;
 }
 
@@ -285,6 +287,15 @@ void ecall_routing_proceed(std::function<void(uint32_t)> cb)
     ecall_routing_proceed(global_eid, p);
 }
 
+void ocall_comms_ready(void *cbpointer)
+{
+    std::function<void(void)> *p =
+        (std::function<void(void)> *)cbpointer;
+    std::function<void(void)> f = *p;
+    delete p;
+    f();
+}
+
 void ocall_routing_round_complete(void *cbpointer, uint32_t round_num)
 {
     std::function<void(uint32_t)> *p =

+ 2 - 2
Untrusted/Untrusted.hpp

@@ -27,7 +27,7 @@ bool ecall_config_load(threadid_t nthreads,
 
 void ecall_close();
 
-bool ecall_comms_start();
+bool ecall_comms_start(std::function<void(void)> cb);
 
 bool ecall_message(nodenum_t node_num, uint32_t message_len);
 
@@ -38,7 +38,7 @@ size_t ecall_precompute_sort(int size);
 
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs);
 
-void ecall_routing_proceed(std::function<void(uint32_t)>);
+void ecall_routing_proceed(std::function<void(uint32_t)> cb);
 
 bool ecall_ingest_msgbundle(clientid_t cid, unsigned char *msgbundle,
     uint32_t num_msgs);