Browse Source

Merge branch 'main' into Clients

Sajin Sasy 1 year ago
parent
commit
635e49fc71
4 changed files with 59 additions and 37 deletions
  1. 36 33
      App/start.cpp
  2. 2 2
      App/teems.cpp
  3. 11 0
      Enclave/comms.cpp
  4. 10 2
      Enclave/route.cpp

+ 36 - 33
App/start.cpp

@@ -61,44 +61,47 @@ static void epoch(NetIO &netio, char **args) {
     const Config &config = netio.config();
     uint16_t msg_size = config.msg_size;
     nodenum_t my_node_num = config.my_node_num;
-
-    uint8_t *msgs = new uint8_t[tot_tokens * msg_size];
-    uint8_t *nextmsg = msgs;
-    uint32_t dest_uid_mask = (1 << DEST_UID_BITS) - 1;
-    uint32_t rem_tokens = tot_tokens;
-    while (rem_tokens > 0) {
-        // Pick a random remaining token
-        uint32_t r = uint32_t(lrand48()) % rem_tokens;
-        for (nodenum_t j=0;j<num_nodes;++j) {
-            if (r < num_tokens[j]) {
-                // Use a token from node j
-                *((uint32_t*)nextmsg) =
-                    (j << DEST_UID_BITS) +
-                        (((r<<8)+(my_node_num&0xff)) & dest_uid_mask);
-                // Put a bunch of copies of r as the message body
-                for (uint16_t i=1;i<msg_size/4;++i) {
-                    ((uint32_t*)nextmsg)[i] = r;
+    uint8_t my_roles = config.nodes[my_node_num].roles;
+
+    if (my_roles & ROLE_INGESTION) {
+        uint8_t *msgs = new uint8_t[tot_tokens * msg_size];
+        uint8_t *nextmsg = msgs;
+        uint32_t dest_uid_mask = (1 << DEST_UID_BITS) - 1;
+        uint32_t rem_tokens = tot_tokens;
+        while (rem_tokens > 0) {
+            // Pick a random remaining token
+            uint32_t r = uint32_t(lrand48()) % rem_tokens;
+            for (nodenum_t j=0;j<num_nodes;++j) {
+                if (r < num_tokens[j]) {
+                    // Use a token from node j
+                    *((uint32_t*)nextmsg) =
+                        (j << DEST_UID_BITS) +
+                            (((r<<8)+(my_node_num&0xff)) & dest_uid_mask);
+                    // Put a bunch of copies of r as the message body
+                    for (uint16_t i=1;i<msg_size/4;++i) {
+                        ((uint32_t*)nextmsg)[i] = r;
+                    }
+                    num_tokens[j] -= 1;
+                    rem_tokens -= 1;
+                    nextmsg += msg_size;
+                } else {
+                    r -= num_tokens[j];
                 }
-                num_tokens[j] -= 1;
-                rem_tokens -= 1;
-                nextmsg += msg_size;
-            } else {
-                r -= num_tokens[j];
             }
         }
-    }
-    /*
-    for (uint32_t i=0;i<tot_tokens;++i) {
-        for(uint16_t j=0;j<msg_size/4;++j) {
-            printf("%08x ", ((uint32_t*)msgs)[i*msg_size/4+j]);
+        /*
+        for (uint32_t i=0;i<tot_tokens;++i) {
+            for(uint16_t j=0;j<msg_size/4;++j) {
+                printf("%08x ", ((uint32_t*)msgs)[i*msg_size/4+j]);
+            }
+            printf("\n");
         }
-        printf("\n");
-    }
-    */
+        */
 
-    if (!ecall_ingest_raw(msgs, tot_tokens)) {
-        printf("Ingestion failed\n");
-        return;
+        if (!ecall_ingest_raw(msgs, tot_tokens)) {
+            printf("Ingestion failed\n");
+            return;
+        }
     }
 
     Epoch epoch(netio.io_context(), epoch_num);

+ 2 - 2
App/teems.cpp

@@ -259,8 +259,8 @@ int main(int argc, char **argv)
             NodeIO &node = netio.node(node_num);
             node.recv_commands(
                 // error_cb
-                [](boost::system::error_code ec) {
-                    printf("Error %s\n", ec.message().c_str());
+                [node_num](boost::system::error_code ec) {
+                    printf("Error %s from %d\n", ec.message().c_str(), node_num);
                 },
                 // epoch_cb
                 [](uint32_t epoch) {

+ 11 - 0
Enclave/comms.cpp

@@ -652,6 +652,17 @@ bool ecall_message(nodenum_t node_num, uint32_t message_len)
     nodest.in_msg_offset = 0;
     nodest.in_msg_plaintext_processed = 0;
     nodest.in_msg_buf = buf;
+    // Just in case message_len == 0
+    if (nodest.in_msg_offset == nodest.in_msg_size) {
+        // This was the last chunk; handle the received message
+        uint32_t plaintext_processed = nodest.in_msg_plaintext_processed;
+        uint32_t msg_size = nodest.in_msg_size;
+        nodest.in_msg_buf = NULL;
+        nodest.in_msg_size = 0;
+        nodest.in_msg_offset = 0;
+        nodest.in_msg_plaintext_processed = 0;
+        nodest.in_msg_received(nodest, buf, plaintext_processed, msg_size);
+    }
     return true;
 }
 

+ 10 - 2
Enclave/route.cpp

@@ -510,7 +510,6 @@ void ecall_routing_proceed(void *cbpointer)
         if (my_roles & ROLE_INGESTION) {
             route_state.cbpointer = cbpointer;
             MsgBuffer &ingbuf = route_state.ingbuf;
-            MsgBuffer &round1 = route_state.round1;
 
             pthread_mutex_lock(&ingbuf.mutex);
             // Ensure there are no pending messages currently being inserted
@@ -534,6 +533,9 @@ void ecall_routing_proceed(void *cbpointer)
 #endif
             ingbuf.reset();
             pthread_mutex_unlock(&ingbuf.mutex);
+        }
+        if (my_roles & ROLE_ROUTING) {
+            MsgBuffer &round1 = route_state.round1;
 
             pthread_mutex_lock(&round1.mutex);
             round1.completed_prev_round = true;
@@ -547,13 +549,13 @@ void ecall_routing_proceed(void *cbpointer)
             }
         } else {
             route_state.step = ROUTE_ROUND_1;
+            route_state.round1.completed_prev_round = true;
             ocall_routing_round_complete(cbpointer, 1);
         }
     } else if (route_state.step == ROUTE_ROUND_1) {
         if (my_roles & ROLE_ROUTING) {
             route_state.cbpointer = cbpointer;
             MsgBuffer &round1 = route_state.round1;
-            MsgBuffer &round2 = route_state.round2;
 
             pthread_mutex_lock(&round1.mutex);
             // Ensure there are no pending messages currently being inserted
@@ -642,6 +644,11 @@ void ecall_routing_proceed(void *cbpointer)
             round1.reset();
             pthread_mutex_unlock(&round1.mutex);
 
+        }
+        if (my_roles & ROLE_STORAGE) {
+            route_state.cbpointer = cbpointer;
+            MsgBuffer &round2 = route_state.round2;
+
             pthread_mutex_lock(&round2.mutex);
             round2.completed_prev_round = true;
             nodenum_t nodes_received = round2.nodes_received;
@@ -654,6 +661,7 @@ void ecall_routing_proceed(void *cbpointer)
             }
         } else {
             route_state.step = ROUTE_ROUND_2;
+            route_state.round2.completed_prev_round = true;
             ocall_routing_round_complete(cbpointer, 2);
         }
     } else if (route_state.step == ROUTE_ROUND_2) {