Pārlūkot izejas kodu

In round1b, send an extra 5 bytes at the end of the message

These 5 bytes are the receiver id of the first message we _didn't_ send,
and the count of how many messages we have (in the first pub_in we
didn't send) with that receiver id.

TODO: we don't currently actually (obliviously) count the messages, so
when pub_in > 1, the count may be incorrect.
Ian Goldberg 1 gadu atpakaļ
vecāks
revīzija
66ed3d04d4
1 mainītis faili ar 42 papildinājumiem un 6 dzēšanām
  1. 42 6
      Enclave/route.cpp

+ 42 - 6
Enclave/route.cpp

@@ -294,9 +294,7 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
     uint32_t plaintext_bytes = tot_enc_chunk_size -
         SGX_AESGCM_MAC_SIZE * min_num_chunks;
 
-    assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
-
-    uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
+    uint32_t num_msgs = CEILDIV(plaintext_bytes, uint32_t(msg_size));
 
     pthread_mutex_lock(&msgbuf.mutex);
     uint32_t start = msgbuf.reserved;
@@ -428,12 +426,17 @@ static void round1b_next_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
 {
     uint16_t msg_size = g_teems_config.msg_size;
-    assert((plaintext_len % uint32_t(msg_size)) == 0);
+    // There are an extra 5 bytes at the end of this message, containing
+    // the next receiver id (4 bytes) and the count of messages with
+    // that receiver id (1 byte)
+    assert((plaintext_len % uint32_t(msg_size)) == 5);
     uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
     uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
     uint8_t their_roles = g_teems_config.roles[nodest.node_num];
     pthread_mutex_lock(&route_state.round1b_next.mutex);
-    route_state.round1b_next.inserted += num_msgs;
+    // Add an extra 1 for the message space taken up by the above 5
+    // bytes
+    route_state.round1b_next.inserted += num_msgs + 1;
     route_state.round1b_next.nodes_received += 1;
     nodenum_t nodes_received = route_state.round1b_next.nodes_received;
     bool completed_prev_round = route_state.round1b_next.completed_prev_round;
@@ -815,8 +818,28 @@ static uint32_t send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
         nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1];
         NodeCommState &nodecom = g_commstates[prev_node];
         uint32_t num_msgs = min(N, route_state.max_round1b_msgs_to_adj_rtr);
-        nodecom.message_start(num_msgs * msg_size);
+        // There are an extra 5 bytes at the end of this message: 4
+        // bytes for the receiver id in the next message we _didn't_
+        // send, and 1 byte for the number of messages we have at the
+        // beginning of the buffer of messages we didn't send (max
+        // pub_in) with the same receiver id
+        nodecom.message_start(num_msgs * msg_size + 5);
         nodecom.message_data(msgs, num_msgs * msg_size);
+        uint32_t next_receiver_id = 0xffffffff;
+        uint8_t next_rid_count = 0;
+        // num_msgs and N are not private, but the contents of the
+        // buffer are.
+        if (num_msgs < N) {
+            next_receiver_id = *(const uint32_t *)(msgs +
+                num_msgs * msg_size);
+            next_rid_count = 1;
+        }
+        // TODO: If pub_in > 1, obliviously scan messages num_msgs+1 ..
+        // num_msgs+(pub_in-1) and as long as they have the same
+        // receiver id as next_receiver_id, add 1 to next_rid_count (but
+        // don't go past message N of course)
+        nodecom.message_data((const unsigned char *)&next_receiver_id, 4);
+        nodecom.message_data(&next_rid_count, 1);
         return num_msgs;
     }
     return 0;
@@ -1072,6 +1095,19 @@ static void round1c_processing(void *cbpointer) {
             pthread_mutex_unlock(&round1b_next.mutex);
             pthread_mutex_lock(&round1b_next.mutex);
         }
+
+        uint32_t next_receiver_id = 0xffffffff;
+        uint8_t next_rid_count = 0;
+
+        // Extract the trailing 5 bytes if we received a round1b message
+        if (round1b_next.inserted >= 1) {
+            next_receiver_id = *(uint32_t *)(round1b_next.buf +
+                (round1b_next.inserted-1)*msg_size);
+            next_rid_count = *(round1b_next.buf +
+                (round1b_next.inserted-1)*msg_size + 4);
+            round1b_next.inserted -= 1;
+            round1b_next.reserved -= 1;
+        }
         pthread_mutex_lock(&round1a.mutex);
         pthread_mutex_lock(&round1a_sorted.mutex);