Ian Goldberg 1 gadu atpakaļ
vecāks
revīzija
61fe90b6d6
5 mainītis faili ar 38 papildinājumiem un 7 dzēšanām
  1. 2 1
      App/net.cpp
  2. 2 0
      App/net.hpp
  3. 11 3
      App/start.cpp
  4. 1 1
      App/teems.cpp
  5. 22 2
      Enclave/route.cpp

+ 2 - 1
App/net.cpp

@@ -171,7 +171,8 @@ void NodeIO::recv_commands(
 }
 
 NetIO::NetIO(boost::asio::io_context &io_context, const Config &config)
-    : conf(config), myconf(config.nodes[config.my_node_num])
+    : context(io_context), conf(config),
+      myconf(config.nodes[config.my_node_num])
 {
     num_nodes = nodenum_t(conf.nodes.size());
     nodeios.resize(num_nodes);

+ 2 - 0
App/net.hpp

@@ -125,6 +125,7 @@ public:
 };
 
 class NetIO {
+    boost::asio::io_context &context;
     const Config &conf;
     const NodeConfig &myconf;
     std::deque<std::optional<NodeIO>> nodeios;
@@ -139,6 +140,7 @@ public:
         return nodeios[node_num].value();
     }
     const Config &config() { return conf; }
+    boost::asio::io_context &io_context() { return context; }
     // Call recv_commands with these arguments on each of the nodes (not
     // including ourselves)
     void recv_commands(

+ 11 - 3
App/start.cpp

@@ -35,6 +35,7 @@ static void route_test(NetIO &netio, char **args)
 
     const Config &config = netio.config();
     uint16_t msg_size = config.msg_size;
+    nodenum_t my_node_num = config.my_node_num;
 
     uint8_t *msgs = new uint8_t[tot_tokens * msg_size];
     uint8_t *nextmsg = msgs;
@@ -47,7 +48,8 @@ static void route_test(NetIO &netio, char **args)
             if (r < num_tokens[j]) {
                 // Use a token from node j
                 *((uint32_t*)nextmsg) =
-                    (j << DEST_UID_BITS) + (r & dest_uid_mask);
+                    (j << DEST_UID_BITS) +
+                        (((r<<8)+(my_node_num&0xff)) & dest_uid_mask);
                 // Put a bunch of copies of r as the message body
                 for (uint16_t i=1;i<msg_size/4;++i) {
                     ((uint32_t*)nextmsg)[i] = r;
@@ -80,9 +82,15 @@ static void route_test(NetIO &netio, char **args)
         return;
     }
 
-    ecall_routing_proceed([&](uint32_t round_num){
+    ecall_routing_proceed([&](uint32_t round_num) {
         printf("Round %u complete\n", round_num);
-        //netio.close();
+        if (round_num == 1) {
+            boost::asio::post(netio.io_context(), []{
+                ecall_routing_proceed([&](uint32_t round_num2) {
+                    printf("Round %u complete\n", round_num2);
+                });
+            });
+        }
     });
 }
 

+ 1 - 1
App/teems.cpp

@@ -260,7 +260,7 @@ int main(int argc, char **argv)
             node.recv_commands(
                 // error_cb
                 [](boost::system::error_code ec) {
-                    printf("Error %d\n", ec.value());
+                    printf("Error %s\n", ec.message().c_str());
                 },
                 // epoch_cb
                 [](uint32_t epoch) {

+ 22 - 2
Enclave/route.cpp

@@ -433,7 +433,6 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
 void ecall_routing_proceed(void *cbpointer)
 {
     if (route_state.step == ROUTE_NOT_STARTED) {
-
         route_state.cbpointer = cbpointer;
         MsgBuffer &round1 = route_state.round1;
 
@@ -452,9 +451,30 @@ void ecall_routing_proceed(void *cbpointer)
         sort_mtobliv(g_teems_config.nthreads, round1.buf,
             g_teems_config.msg_size, round1.inserted,
             route_state.tot_msg_per_ing, send_round1_msgs);
-        round1.reset();
 #ifdef PROFILE_ROUTING
         printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
 #endif
+        round1.reset();
+        pthread_mutex_unlock(&round1.mutex);
+    } else if (route_state.step == ROUTE_ROUND_1) {
+        route_state.cbpointer = cbpointer;
+        MsgBuffer &round2 = route_state.round2;
+
+        pthread_mutex_lock(&round2.mutex);
+        // Ensure there are no pending messages currently being inserted
+        // into the buffer
+        while (round2.reserved != round2.inserted) {
+            pthread_mutex_unlock(&round2.mutex);
+            pthread_mutex_lock(&round2.mutex);
+        }
+
+        uint32_t msg_size = g_teems_config.msg_size;
+        for(uint32_t i=0;i<round2.inserted;++i) {
+            uint32_t destaddr = *(uint32_t*)(round2.buf+i*msg_size);
+            printf("%08x\n", destaddr);
+        }
+
+        round2.reset();
+        pthread_mutex_unlock(&round2.mutex);
     }
 }