Browse Source

make server use zero-copy strings when handling messages

Justin Tracey 1 year ago
parent
commit
67225eb8b5
4 changed files with 85 additions and 26 deletions
  1. 1 1
      src/bin/client.rs
  2. 1 1
      src/bin/peer.rs
  3. 19 16
      src/bin/server.rs
  4. 64 8
      src/lib.rs

+ 1 - 1
src/bin/client.rs

@@ -53,7 +53,7 @@ async fn reader(
             .await
             .expect("Reader socket updater closed");
         loop {
-            let (msg, _) = match mgen::get_message(&mut stream).await {
+            let msg = match mgen::get_message(&mut stream).await {
                 Ok(msg) => msg,
                 Err(e) => {
                     error_channel.send(e.into()).expect("Error channel closed");

+ 1 - 1
src/bin/peer.rs

@@ -147,7 +147,7 @@ async fn reader(
             .await
             .expect("reader: Channel to reader closed");
         loop {
-            let (msg, _) = if let Ok(msg) = mgen::get_message(&mut stream).await {
+            let msg = if let Ok(msg) = mgen::get_message(&mut stream).await {
                 msg
             } else {
                 // Unlike the client-server case, we can assume that if there

+ 19 - 16
src/bin/server.rs

@@ -1,3 +1,4 @@
+use mgen::MessageHeaderRef;
 use mgen::{log, parse_identifier, updater::Updater, SerializedMessage};
 use std::collections::HashMap;
 use std::error::Error;
@@ -50,7 +51,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
             // start the new reader thread with a new notify
             let notify = Arc::new(Notify::new());
-            spawn_message_receiver(rd, snd_db.clone(), notify.clone());
+            spawn_message_receiver(id, rd, snd_db.clone(), notify.clone());
 
             // give the writer thread the new write half of the socket and notify
             socket_updater.send((wr, notify)).await;
@@ -79,7 +80,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
             writer_db.insert(id.clone(), socket_updater_snd);
 
-            spawn_message_receiver(rd, snd_db, notify);
+            spawn_message_receiver(id, rd, snd_db, notify);
             tokio::spawn(async move {
                 send_messages(msg_rcv, socket_updater_rcv).await;
             });
@@ -88,13 +89,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
 }
 
 fn spawn_message_receiver(
+    sender: String,
     rd: OwnedReadHalf,
     db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
     notify: Arc<Notify>,
 ) {
     tokio::spawn(async move {
         tokio::select! {
-            ret = get_messages(rd, db) => {
+            ret = get_messages(&sender, rd, db) => {
                 if let Err(e) = ret {
                     log!("message receiver failed: {:?}", e);
                 }
@@ -110,46 +112,47 @@ fn spawn_message_receiver(
 /// Loop for receiving messages on the socket, figuring out who to deliver them to,
 /// and forwarding them locally to the respective channel.
 async fn get_messages(
+    sender: &str,
     mut socket: OwnedReadHalf,
     db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
 ) -> Result<(), Box<dyn Error>> {
     // stores snd's for contacts this client has already sent messages to, to reduce contention on the main db
     // if memory ends up being more of a constraint, could be worth getting rid of this
     let mut localdb = HashMap::<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>::new();
-    let mut sender = "".to_string();
     loop {
         log!("waiting for message from {}", sender);
-        let (message, buf) = mgen::get_message(&mut socket).await?;
+        let buf = mgen::get_message_bytes(&mut socket).await?;
+        let message = MessageHeaderRef::deserialize(&buf[4..])?;
         log!(
             "got message from {} for {:?}",
             message.sender,
             message.recipients
         );
-        sender = message.sender;
+
         let missing = message
             .recipients
             .iter()
-            .filter(|&r| !localdb.contains_key(r))
-            .collect::<Vec<&ID>>();
+            .filter(|&&r| !localdb.contains_key(r))
+            .collect::<Vec<&&str>>();
 
         {
             let locked_db = db.read().unwrap();
             for m in missing {
-                if let Some(snd) = locked_db.get(m) {
+                if let Some(snd) = locked_db.get(*m) {
                     localdb.insert(m.to_string(), snd.clone());
                 }
             }
         }
 
-        let m = Arc::new(SerializedMessage {
-            header: buf,
-            body: message.body,
-        });
+        let body = message.body;
+
+        let m = Arc::new(SerializedMessage { header: buf, body });
+
+        // FIXME: we could avoid this second deserialization if we stored recipients by group instead of in the message itself
+        let message = MessageHeaderRef::deserialize(&m.header[4..])?;
 
         for recipient in message.recipients.iter() {
-            let recipient_sender = localdb
-                .get(recipient)
-                .unwrap_or_else(|| panic!("Unknown sender: {}", recipient));
+            let recipient_sender = &localdb[*recipient];
             recipient_sender
                 .send(m.clone())
                 .expect("Recipient closed channel with messages still being sent");

+ 64 - 8
src/lib.rs

@@ -68,7 +68,7 @@ impl MessageBody {
 /// Message metadata.
 ///
 /// This has everything needed to reconstruct a message.
-// FIXME: every String should be &str
+// FIXME: we should try to replace MessageHeader with MessageHeaderRef
 #[derive(Debug)]
 pub struct MessageHeader {
     pub sender: String,
@@ -160,6 +160,54 @@ impl MessageHeader {
     }
 }
 
+/// Message metadata.
+///
+/// This has everything needed to reconstruct a message.
+#[derive(Debug)]
+pub struct MessageHeaderRef<'a> {
+    pub sender: &'a str,
+    pub group: &'a str,
+    pub recipients: Vec<&'a str>,
+    pub body: MessageBody,
+}
+
+impl<'a> MessageHeaderRef<'a> {
+    /// Creates a MessageHeader from bytes created via serialization,
+    /// but with the size already parsed out.
+    pub fn deserialize(buf: &'a [u8]) -> Result<Self, Error> {
+        let (sender, buf) = deserialize_str(buf)?;
+        let sender = sender;
+
+        let (group, buf) = deserialize_str(buf)?;
+        let group = group;
+
+        let (num_recipients, buf) = deserialize_u32(buf)?;
+        debug_assert!(num_recipients != 0);
+
+        let mut recipients = Vec::with_capacity(num_recipients as usize);
+        let mut recipient;
+        let mut buf = buf;
+        for _ in 0..num_recipients {
+            (recipient, buf) = deserialize_str(buf)?;
+            let recipient = recipient;
+            recipients.push(recipient);
+        }
+
+        let (body, _) = deserialize_u32(buf)?;
+        let body = if let Some(size) = NonZeroU32::new(body) {
+            MessageBody::Size(size)
+        } else {
+            MessageBody::Receipt
+        };
+        Ok(Self {
+            sender,
+            group,
+            recipients,
+            body,
+        })
+    }
+}
+
 /// Parse the identifier from the start of the TcpStream.
 pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,
@@ -172,20 +220,28 @@ pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
     Ok(s.to_string())
 }
 
-/// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer
-pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
+/// Gets a message from the stream, returning the raw byte buffer
+pub async fn get_message_bytes<T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,
-) -> Result<(MessageHeader, Vec<u8>), Error> {
+) -> Result<Vec<u8>, Error> {
     let mut header_size_bytes = [0u8; 4];
     stream.read_exact(&mut header_size_bytes).await?;
-    log!("got header size");
     get_message_with_header_size(stream, header_size_bytes).await
 }
 
-pub async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
+/// Gets a message from the stream and constructs a MessageHeader object
+pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
+    stream: &mut T,
+) -> Result<MessageHeader, Error> {
+    let buf = get_message_bytes(stream).await?;
+    let msg = MessageHeader::deserialize(&buf)?;
+    Ok(msg)
+}
+
+async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,
     header_size_bytes: [u8; 4],
-) -> Result<(MessageHeader, Vec<u8>), Error> {
+) -> Result<Vec<u8>, Error> {
     let header_size = u32::from_be_bytes(header_size_bytes);
     let mut header_buf = vec![0; header_size as usize];
     stream.read_exact(&mut header_buf[4..]).await?;
@@ -199,7 +255,7 @@ pub async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
     let header_size_buf = &mut header_buf[..4];
     header_size_buf.copy_from_slice(&header_size_bytes);
     copy(&mut stream.take(header.body.size() as u64), &mut sink()).await?;
-    Ok((header, header_buf))
+    Ok(header_buf)
 }
 
 pub fn serialize_str(s: &str) -> Vec<u8> {