Browse Source

don't send recipients in each message

With this change, clients don't keep track of who is in each group. Instead,
clients self-register with the server, and the server maintains tables of
which socket corresponds to which user in the group. Peers still need to keep
track of who is in the group, since they are essentially their own servers,
but this is still done at config-time (i.e., sending a message to a peer that
hasn't been configured as being part of that group, or knowing the sender is,
is liable to cause problems).

In the process, I also found and fixed some serious bugs in the server. The
way it keeps track of and associates reader/writer ends of sockets with groups
has been overhauled (the old way allowed messages to be sent on connections to
the correct client, but on the wrong TCP socket).
Justin Tracey 1 year ago
parent
commit
24c5aa5219
5 changed files with 152 additions and 144 deletions
  1. 7 6
      src/bin/client.rs
  2. 3 14
      src/bin/messenger/message.rs
  3. 5 7
      src/bin/messenger/state.rs
  4. 82 73
      src/bin/server.rs
  5. 55 44
      src/lib.rs

+ 7 - 6
src/bin/client.rs

@@ -1,7 +1,7 @@
 // Code specific to the client in the client-server mode.
 
 use mgen::updater::Updater;
-use mgen::{MessageHeader, SerializedMessage};
+use mgen::{HandshakeRef, MessageHeader, SerializedMessage};
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
 use std::result::Result;
@@ -129,11 +129,12 @@ async fn socket_updater(
             Err(MessengerError::Fatal(e)) => return e,
         };
 
-        if stream
-            .write_all(&mgen::serialize_str(&str_params.user))
-            .await
-            .is_err()
-        {
+        let handshake = HandshakeRef {
+            sender: &str_params.user,
+            group: &str_params.group,
+        };
+
+        if stream.write_all(&handshake.serialize()).await.is_err() {
             continue;
         }
 

+ 3 - 14
src/bin/messenger/message.rs

@@ -2,31 +2,20 @@
 // (Most message functionality is in the library, not this module.)
 
 /// Construct and serialize a message from the sender to the recipients with the given number of blocks.
-pub fn construct_message(
-    sender: String,
-    group: String,
-    recipients: Vec<String>,
-    blocks: u32,
-) -> mgen::SerializedMessage {
+pub fn construct_message(sender: String, group: String, blocks: u32) -> mgen::SerializedMessage {
     let size = std::cmp::max(blocks, 1) * mgen::PADDING_BLOCK_SIZE;
     let m = mgen::MessageHeader {
         sender,
         group,
-        recipients,
         body: mgen::MessageBody::Size(std::num::NonZeroU32::new(size).unwrap()),
     };
     m.serialize()
 }
 
-pub fn construct_receipt(
-    sender: String,
-    group: String,
-    recipient: String,
-) -> mgen::SerializedMessage {
+pub fn construct_receipt(sender: String, recipient: String) -> mgen::SerializedMessage {
     let m = mgen::MessageHeader {
         sender,
-        group,
-        recipients: vec![recipient],
+        group: recipient,
         body: mgen::MessageBody::Receipt,
     };
     m.serialize()

+ 5 - 7
src/bin/messenger/state.rs

@@ -247,7 +247,6 @@ async fn send_action<
     let m = S::new(construct_message(
         our_id.to_string(),
         group.to_string(),
-        recipients.iter().map(|s| s.to_string()).collect(),
         size,
     ));
 
@@ -281,19 +280,18 @@ async fn receive_action<
     conversation: Conversation<T>,
     stream_map: &mut M,
     our_id: &str,
-    group: &str,
     rng: &mut Xoshiro256PlusPlus,
 ) -> StateMachine {
     match msg.body {
         mgen::MessageBody::Size(size) => {
             log!(
-                "{:?} got message from {} of size {}",
-                msg.recipients,
+                "{} got message from {} of size {}",
+                msg.group,
                 msg.sender,
                 size
             );
             let stream = stream_map.channel_for(&msg.sender);
-            let m = construct_receipt(our_id.to_string(), group.to_string(), msg.sender);
+            let m = construct_receipt(our_id.to_string(), msg.sender);
             stream
                 .channel
                 .send(S::new(m))
@@ -346,7 +344,7 @@ pub async fn manage_idle_conversation<
             .await
         }
         IdleGroupActions::Receive(msg) => {
-            receive_action(msg, conversation, stream_map, our_id, group, rng).await
+            receive_action(msg, conversation, stream_map, our_id, rng).await
         }
     }
 }
@@ -392,7 +390,7 @@ pub async fn manage_active_conversation<
             .await
         }
         ActiveGroupActions::Receive(msg) => {
-            receive_action(msg, conversation, stream_map, our_id, group, rng).await
+            receive_action(msg, conversation, stream_map, our_id, rng).await
         }
         ActiveGroupActions::Idle => StateMachine::Idle(conversation.waited(rng)),
     }

+ 82 - 73
src/bin/server.rs

@@ -1,18 +1,20 @@
-use mgen::{log, parse_identifier, updater::Updater, MessageHeaderRef, SerializedMessage};
+use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
 use std::collections::HashMap;
 use std::error::Error;
 use std::result::Result;
-use std::sync::{Arc, RwLock};
+use std::sync::Arc;
 use tokio::io::AsyncWriteExt;
 use tokio::net::{
     tcp::{OwnedReadHalf, OwnedWriteHalf},
     TcpListener,
 };
-use tokio::sync::{mpsc, Notify};
+use tokio::sync::{mpsc, Notify, RwLock};
 
 // FIXME: identifiers should be interned
 type ID = String;
 
+type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
+
 #[tokio::main]
 async fn main() -> Result<(), Box<dyn Error>> {
     let args: Vec<String> = std::env::args().collect();
@@ -25,12 +27,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
     log!("Listening on {}", listen_addr);
 
-    let snd_db = Arc::new(RwLock::new(HashMap::<
-        ID,
-        mpsc::UnboundedSender<Arc<SerializedMessage>>,
-    >::new()));
-
-    let mut writer_db = HashMap::<ID, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
+    // Maps group name to the table of message channels.
+    let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
+    // Maps the (sender, group) pair to the socket updater.
+    let mut writer_db = HashMap::<Handshake, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
 
     loop {
         let socket = match listener.accept().await {
@@ -42,21 +42,23 @@ async fn main() -> Result<(), Box<dyn Error>> {
         };
         let (mut rd, wr) = socket.into_split();
 
-        let id = match parse_identifier(&mut rd).await {
-            Ok(id) => id,
-            Err(mgen::Error::Utf8Error(e)) => {
-                let err: Box<dyn Error> = Box::new(e);
-                return Err(err);
-            }
-            Err(_) => continue,
-        };
-        log!("Accepting \"{id}\"");
-        if let Some(socket_updater) = writer_db.get(&id) {
+        let handshake = mgen::get_handshake(&mut rd).await?;
+        log!(
+            "Accepting channel from {} to {}",
+            handshake.sender,
+            handshake.group
+        );
+
+        if let Some(socket_updater) = writer_db.get(&handshake) {
             // we've seen this client before
 
             // start the new reader thread with a new notify
+            // (we can't use the existing notify channel, else we get race conditions where
+            // the reader thread terminates and spawns again before the sender thread
+            // notices and activates its existing notify channel)
             let notify = Arc::new(Notify::new());
-            spawn_message_receiver(id, rd, snd_db.clone(), notify.clone());
+            let db = snd_db[&handshake.group].clone();
+            spawn_message_receiver(handshake.sender, handshake.group, rd, db, notify.clone());
 
             // give the writer thread the new write half of the socket and notify
             socket_updater.send((wr, notify));
@@ -64,29 +66,38 @@ async fn main() -> Result<(), Box<dyn Error>> {
             // newly-registered client
             log!("New client");
 
-            // message channel, used to send messages between threads
+            // message channel, for sending messages between threads
             let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
 
-            let snd_db = snd_db.clone();
-            {
-                let mut locked_db = snd_db.write().unwrap();
-                locked_db.insert(id.clone(), msg_snd);
-            }
-
-            // socket notify, used to terminate the socket if the sender encounters an error
+            let group_snds = if let Some(table) = snd_db.get(&handshake.group) {
+                table
+            } else {
+                let table = Arc::new(RwLock::new(HashMap::new()));
+                snd_db.insert(handshake.group.clone(), table);
+                &snd_db[&handshake.group]
+            };
+            group_snds
+                .write()
+                .await
+                .insert(handshake.sender.clone(), msg_snd);
+
+            // socket notify, for terminating the socket if the sender encounters an error
             let notify = Arc::new(Notify::new());
 
-            // socket updater, used to give the sender thread a new socket + notify channel
-            // (we can't use the existing notify channel, else we get race conditions where
-            // the reader thread terminates and spawns again before the sender thread
-            // notices and activates its existing notify channel)
+            // socket updater, for giving the sender thread a new socket + notify channel
             let socket_updater_snd = Updater::new();
             let socket_updater_rcv = socket_updater_snd.clone();
             socket_updater_snd.send((wr, notify.clone()));
 
-            writer_db.insert(id.clone(), socket_updater_snd);
+            spawn_message_receiver(
+                handshake.sender.clone(),
+                handshake.group.clone(),
+                rd,
+                group_snds.clone(),
+                notify,
+            );
+            writer_db.insert(handshake, socket_updater_snd);
 
-            spawn_message_receiver(id, rd, snd_db, notify);
             tokio::spawn(async move {
                 send_messages(msg_rcv, socket_updater_rcv).await;
             });
@@ -96,13 +107,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
 fn spawn_message_receiver(
     sender: String,
+    group: String,
     rd: OwnedReadHalf,
-    db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
+    db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
     notify: Arc<Notify>,
 ) {
     tokio::spawn(async move {
         tokio::select! {
-            ret = get_messages(&sender, rd, db) => {
+            // n.b.: get_message is not cancellation safe,
+            // but this is one of the cases where that's expected
+            // (we only cancel when something is wrong with the stream anyway)
+            ret = get_messages(sender, group, rd, db) => {
                 if let Err(e) = ret {
                     log!("message receiver failed: {:?}", e);
                 }
@@ -118,50 +133,44 @@ fn spawn_message_receiver(
 /// Loop for receiving messages on the socket, figuring out who to deliver them to,
 /// and forwarding them locally to the respective channel.
 async fn get_messages(
-    sender: &str,
+    sender: String,
+    group: String,
     mut socket: OwnedReadHalf,
-    db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
+    global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
 ) -> Result<(), Box<dyn Error>> {
-    // stores snd's for contacts this client has already sent messages to, to reduce contention on the main db
-    // if memory ends up being more of a constraint, could be worth getting rid of this
-    let mut localdb = HashMap::<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>::new();
+    // Wait for the next message to be received before populating our local copy of the db,
+    // that way other clients have time to register.
+    let mut discard = vec![0u8];
+    while socket.peek(&mut discard).await? == 0 {}
+
+    let db = global_db.read().await.clone();
+    let message_channels: Vec<_> = db
+        .iter()
+        .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
+        .collect();
+
     loop {
-        log!("waiting for message from {}", sender);
+        log!("waiting for message from {} to {}", sender, group);
         let buf = mgen::get_message_bytes(&mut socket).await?;
         let message = MessageHeaderRef::deserialize(&buf[4..])?;
-        log!(
-            "got message from {} for {:?}",
-            message.sender,
-            message.recipients
-        );
-
-        let missing = message
-            .recipients
-            .iter()
-            .filter(|&&r| !localdb.contains_key(r))
-            .collect::<Vec<&&str>>();
-
-        {
-            let locked_db = db.read().unwrap();
-            for m in missing {
-                if let Some(snd) = locked_db.get(*m) {
-                    localdb.insert(m.to_string(), snd.clone());
+        assert!(message.sender == sender);
+
+        match message.body {
+            MessageBody::Size(_) => {
+                assert!(message.group == group);
+                log!("got message from {} for {}", sender, group);
+                let body = message.body;
+                let m = Arc::new(SerializedMessage { header: buf, body });
+                for recipient in message_channels.iter() {
+                    recipient.send(m.clone()).unwrap();
                 }
             }
-        }
-
-        let body = message.body;
-
-        let m = Arc::new(SerializedMessage { header: buf, body });
-
-        // FIXME: we could avoid this second deserialization if we stored recipients by group instead of in the message itself
-        let message = MessageHeaderRef::deserialize(&m.header[4..])?;
-
-        for recipient in message.recipients.iter() {
-            let recipient_sender = &localdb[*recipient];
-            recipient_sender
-                .send(m.clone())
-                .expect("Recipient closed channel with messages still being sent");
+            MessageBody::Receipt => {
+                let recipient = &db[message.group];
+                let body = message.body;
+                let m = Arc::new(SerializedMessage { header: buf, body });
+                recipient.send(m.clone()).unwrap();
+            }
         }
     }
 }

+ 55 - 44
src/lib.rs

@@ -73,7 +73,6 @@ impl MessageBody {
 pub struct MessageHeader {
     pub sender: String,
     pub group: String,
-    pub recipients: Vec<String>,
     pub body: MessageBody,
 }
 
@@ -84,37 +83,24 @@ impl MessageHeader {
         //   header_len: u32,
         //   sender: {u32, utf-8}
         //   group: {u32, utf-8}
-        //   num_recipients: u32,
-        //   recipients: [{u32, utf-8}],
         //   body_type: MessageBody (i.e., u32)
         // }
 
-        let num_recipients = self.recipients.len();
         let body_type = match self.body {
             MessageBody::Receipt => 0,
             MessageBody::Size(s) => s.get(),
         };
 
-        let header_len = (1 + 1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
-            + self.sender.len()
-            + self.group.len()
-            + self.recipients.iter().map(String::len).sum::<usize>();
+        let header_len = (1 + 1 + 1 + 1) * size_of::<u32>() + self.sender.len() + self.group.len();
 
         let mut header: Vec<u8> = Vec::with_capacity(header_len);
 
         let header_len = header_len as u32;
-        let num_recipients = num_recipients as u32;
-
         header.extend(header_len.to_be_bytes());
 
         serialize_str_to(&self.sender, &mut header);
         serialize_str_to(&self.group, &mut header);
 
-        header.extend(num_recipients.to_be_bytes());
-        for recipient in self.recipients.iter() {
-            serialize_str_to(recipient, &mut header);
-        }
-
         header.extend(body_type.to_be_bytes());
 
         assert!(header.len() == header_len as usize);
@@ -133,18 +119,6 @@ impl MessageHeader {
         let (group, buf) = deserialize_str(buf)?;
         let group = group.to_string();
 
-        let (num_recipients, buf) = deserialize_u32(buf)?;
-        debug_assert!(num_recipients != 0);
-
-        let mut recipients = Vec::with_capacity(num_recipients as usize);
-        let mut recipient;
-        let mut buf = buf;
-        for _ in 0..num_recipients {
-            (recipient, buf) = deserialize_str(buf)?;
-            let recipient = recipient.to_string();
-            recipients.push(recipient);
-        }
-
         let (body, _) = deserialize_u32(buf)?;
         let body = if let Some(size) = NonZeroU32::new(body) {
             MessageBody::Size(size)
@@ -154,7 +128,6 @@ impl MessageHeader {
         Ok(Self {
             sender,
             group,
-            recipients,
             body,
         })
     }
@@ -167,7 +140,6 @@ impl MessageHeader {
 pub struct MessageHeaderRef<'a> {
     pub sender: &'a str,
     pub group: &'a str,
-    pub recipients: Vec<&'a str>,
     pub body: MessageBody,
 }
 
@@ -181,18 +153,6 @@ impl<'a> MessageHeaderRef<'a> {
         let (group, buf) = deserialize_str(buf)?;
         let group = group;
 
-        let (num_recipients, buf) = deserialize_u32(buf)?;
-        debug_assert!(num_recipients != 0);
-
-        let mut recipients = Vec::with_capacity(num_recipients as usize);
-        let mut recipient;
-        let mut buf = buf;
-        for _ in 0..num_recipients {
-            (recipient, buf) = deserialize_str(buf)?;
-            let recipient = recipient;
-            recipients.push(recipient);
-        }
-
         let (body, _) = deserialize_u32(buf)?;
         let body = if let Some(size) = NonZeroU32::new(body) {
             MessageBody::Size(size)
@@ -202,7 +162,6 @@ impl<'a> MessageHeaderRef<'a> {
         Ok(Self {
             sender,
             group,
-            recipients,
             body,
         })
     }
@@ -247,9 +206,9 @@ async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
     stream.read_exact(&mut header_buf[4..]).await?;
     let header = MessageHeader::deserialize(&header_buf[4..])?;
     log!(
-        "got header from {} to {:?}, about to read {} bytes",
+        "got header from {} to {}, about to read {} bytes",
         header.sender,
-        header.recipients,
+        header.group,
         header.body.size()
     );
     let header_size_buf = &mut header_buf[..4];
@@ -336,3 +295,55 @@ impl SerializedMessage {
         writer.write_all(body).await
     }
 }
+
+/// Handshake between client and server (peers do not use).
+#[derive(Eq, Hash, PartialEq)]
+pub struct Handshake {
+    pub sender: String,
+    pub group: String,
+}
+
+impl Handshake {
+    /// Generate a serialized handshake message.
+    pub fn serialize(&self) -> Vec<u8> {
+        serialize_handshake(&self.sender, &self.group)
+    }
+}
+
+/// Gets a handshake from the stream and constructs a Handshake object
+pub async fn get_handshake<T: AsyncReadExt + std::marker::Unpin>(
+    stream: &mut T,
+) -> Result<Handshake, Error> {
+    let sender = parse_identifier(stream).await?;
+    let group = parse_identifier(stream).await?;
+    Ok(Handshake { sender, group })
+}
+
+pub struct HandshakeRef<'a> {
+    pub sender: &'a str,
+    pub group: &'a str,
+}
+
+impl HandshakeRef<'_> {
+    /// Generate a serialized handshake message.
+    pub fn serialize(&self) -> Vec<u8> {
+        serialize_handshake(self.sender, self.group)
+    }
+}
+
+fn serialize_handshake(sender: &str, group: &str) -> Vec<u8> {
+    // serialized handshake: {
+    //   sender: {u32, utf-8}
+    //   group: {u32, utf-8}
+    // }
+
+    let handshake_len = (1 + 1) * size_of::<u32>() + sender.len() + group.len();
+
+    let mut handshake: Vec<u8> = Vec::with_capacity(handshake_len);
+
+    serialize_str_to(sender, &mut handshake);
+    serialize_str_to(group, &mut handshake);
+
+    debug_assert!(handshake.len() == handshake_len);
+    handshake
+}