Browse Source

use Updater in more places

Justin Tracey 1 year ago
parent
commit
91268cdd44
4 changed files with 54 additions and 68 deletions
  1. 11 16
      src/bin/client.rs
  2. 15 24
      src/bin/peer.rs
  3. 5 8
      src/bin/server.rs
  4. 23 20
      src/updater.rs

+ 11 - 16
src/bin/client.rs

@@ -1,5 +1,6 @@
 // Code specific to the client in the client-server mode.
 
+use mgen::updater::Updater;
 use mgen::{MessageHeader, SerializedMessage};
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
@@ -26,13 +27,13 @@ type MessageHolder = Box<SerializedMessage>;
 /// Type for getting messages from the state thread in the writer thread.
 type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
 /// Type for sending the updated read half of the socket.
-type ReadSocketUpdaterIn = mpsc::UnboundedSender<OwnedReadHalf>;
+type ReadSocketUpdaterIn = Updater<OwnedReadHalf>;
 /// Type for getting the updated read half of the socket.
-type ReadSocketUpdaterOut = mpsc::UnboundedReceiver<OwnedReadHalf>;
+type ReadSocketUpdaterOut = Updater<OwnedReadHalf>;
 /// Type for sending the updated write half of the socket.
-type WriteSocketUpdaterIn = mpsc::UnboundedSender<OwnedWriteHalf>;
+type WriteSocketUpdaterIn = Updater<OwnedWriteHalf>;
 /// Type for getting the updated write half of the socket.
-type WriteSocketUpdaterOut = mpsc::UnboundedReceiver<OwnedWriteHalf>;
+type WriteSocketUpdaterOut = Updater<OwnedWriteHalf>;
 /// Type for sending errors to other threads.
 type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
 /// Type for getting errors from other threads.
@@ -47,10 +48,7 @@ async fn reader(
     error_channel: ErrorChannelIn,
 ) {
     loop {
-        let mut stream = socket_updater
-            .recv()
-            .await
-            .expect("Reader socket updater closed");
+        let mut stream = socket_updater.recv().await;
         loop {
             let msg = match mgen::get_message(&mut stream).await {
                 Ok(msg) => msg,
@@ -77,10 +75,7 @@ async fn writer(
     error_channel: ErrorChannelIn,
 ) {
     loop {
-        let mut stream = socket_updater
-            .recv()
-            .await
-            .expect("Writer socket updater closed");
+        let mut stream = socket_updater.recv().await;
         loop {
             let msg = message_channel
                 .recv()
@@ -148,8 +143,8 @@ async fn socket_updater(
         }
 
         let (rd, wr) = stream.into_inner().into_split();
-        reader_channel.send(rd).expect("Reader channel closed");
-        writer_channel.send(wr).expect("Writer channel closed");
+        reader_channel.send(rd);
+        writer_channel.send(wr);
 
         let res = error_channel.recv().await.expect("Error channel closed");
         if let MessengerError::Fatal(e) = res {
@@ -177,8 +172,8 @@ async fn manage_conversation(config: Config) -> Result<(), MessengerError> {
 
     let (reader_to_state, mut state_from_reader) = mpsc::unbounded_channel();
     let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
-    let (read_socket_updater_in, read_socket_updater_out) = mpsc::unbounded_channel();
-    let (write_socket_updater_in, write_socket_updater_out) = mpsc::unbounded_channel();
+    let (read_socket_updater_in, read_socket_updater_out) = Updater::channel();
+    let (write_socket_updater_in, write_socket_updater_out) = Updater::channel();
     let (errs_in, errs_out) = mpsc::unbounded_channel();
     tokio::spawn(reader(
         reader_to_state,

+ 15 - 24
src/bin/peer.rs

@@ -1,5 +1,5 @@
 // Code specific to the peer in the p2p mode.
-use mgen::{log, MessageHeader, SerializedMessage};
+use mgen::{log, updater::Updater, MessageHeader, SerializedMessage};
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
 use std::collections::HashMap;
@@ -31,13 +31,13 @@ type WriterFromState = mpsc::UnboundedReceiver<Arc<SerializedMessage>>;
 /// Type for sending messages from the state thread to the writer thread.
 type MessageHolder = Arc<SerializedMessage>;
 /// Type for sending the updated read half of the socket.
-type ReadSocketUpdaterIn = mpsc::UnboundedSender<OwnedReadHalf>;
+type ReadSocketUpdaterIn = Arc<Updater<OwnedReadHalf>>;
 /// Type for getting the updated read half of the socket.
-type ReadSocketUpdaterOut = mpsc::UnboundedReceiver<OwnedReadHalf>;
+type ReadSocketUpdaterOut = Updater<OwnedReadHalf>;
 /// Type for sending the updated write half of the socket.
-type WriteSocketUpdaterIn = mpsc::UnboundedSender<OwnedWriteHalf>;
+type WriteSocketUpdaterIn = Updater<OwnedWriteHalf>;
 /// Type for getting the updated write half of the socket.
-type WriteSocketUpdaterOut = mpsc::UnboundedReceiver<OwnedWriteHalf>;
+type WriteSocketUpdaterOut = Updater<OwnedWriteHalf>;
 
 /// The conversation (state) thread tracks the conversation state
 /// (i.e., whether the user is active or idle, and when to send messages).
@@ -114,12 +114,8 @@ async fn listener(
         let (channel_to_reader, channel_to_writer) = name_to_io_threads
             .get(&from)
             .unwrap_or_else(|| panic!("{} got connection from unknown contact: {}", address, from));
-        channel_to_reader
-            .send(rd)
-            .expect("listener: Channel to reader closed");
-        channel_to_writer
-            .send(wr)
-            .expect("listener: Channel to writer closed");
+        channel_to_reader.send(rd);
+        channel_to_writer.send(wr);
         Ok(())
     }
 
@@ -141,10 +137,7 @@ async fn reader(
 ) {
     loop {
         // wait for listener or writer thread to give us a stream to read from
-        let mut stream = connection_channel
-            .recv()
-            .await
-            .expect("reader: Channel to reader closed");
+        let mut stream = connection_channel.recv().await;
         loop {
             let msg = if let Ok(msg) = mgen::get_message(&mut stream).await {
                 msg
@@ -228,7 +221,7 @@ async fn writer<'a>(
         retry: Duration,
     ) -> Result<OwnedWriteHalf, FatalError> {
         // first check if the listener thread already has a socket
-        if let Ok(wr) = write_socket_updater.try_recv() {
+        if let Some(wr) = write_socket_updater.maybe_recv() {
             return Ok(wr);
         }
 
@@ -251,9 +244,7 @@ async fn writer<'a>(
                 .write_all(&mgen::serialize_str(&socks_params.user))
                 .await?;
             let (rd, wr) = stream.into_inner().into_split();
-            read_socket_updater
-                .send(rd)
-                .expect("writer: Channel to reader closed");
+            read_socket_updater.send(rd);
             return Ok(wr);
         } else if let Err(e) = connection_attempt {
             let e: MessengerError = e.into();
@@ -300,10 +291,10 @@ async fn writer<'a>(
                     stream.write_all(&mgen::serialize_str(&socks_params.user)).await?;
 
                     let (rd, wr) = stream.into_inner().into_split();
-                    read_socket_updater.send(rd).expect("writer: Channel to reader closed");
+                    read_socket_updater.send(rd);
                     Ok(wr)
                 },
-                stream = write_socket_updater.recv() => Ok(stream.expect("writer: Channel from listener closed")),
+                stream = write_socket_updater.recv() => Ok(stream),
             }
         }
     }
@@ -421,9 +412,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
             HashMap::new();
 
         for (recipient, for_io) in recipient_map.drain() {
-            let (listener_writer_to_reader, reader_from_listener_writer) =
-                mpsc::unbounded_channel();
-            let (listener_to_writer, writer_from_listener) = mpsc::unbounded_channel();
+            let (listener_writer_to_reader, reader_from_listener_writer) = Updater::channel();
+            let listener_writer_to_reader = Arc::new(listener_writer_to_reader);
+            let (listener_to_writer, writer_from_listener) = Updater::channel();
             name_to_io_threads.insert(
                 recipient.to_string(),
                 (listener_writer_to_reader.clone(), listener_to_writer),

+ 5 - 8
src/bin/server.rs

@@ -1,5 +1,4 @@
-use mgen::MessageHeaderRef;
-use mgen::{log, parse_identifier, updater::Updater, SerializedMessage};
+use mgen::{log, parse_identifier, updater::Updater, MessageHeaderRef, SerializedMessage};
 use std::collections::HashMap;
 use std::error::Error;
 use std::result::Result;
@@ -60,7 +59,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
             spawn_message_receiver(id, rd, snd_db.clone(), notify.clone());
 
             // give the writer thread the new write half of the socket and notify
-            socket_updater.send((wr, notify)).await;
+            socket_updater.send((wr, notify));
         } else {
             // newly-registered client
             log!("New client");
@@ -82,7 +81,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
             // the reader thread terminates and spawns again before the sender thread
             // notices and activates its existing notify channel)
             let (socket_updater_snd, socket_updater_rcv) = Updater::channel();
-            socket_updater_snd.send((wr, notify.clone())).await;
+            socket_updater_snd.send((wr, notify.clone()));
 
             writer_db.insert(id.clone(), socket_updater_snd);
 
@@ -172,8 +171,7 @@ async fn send_messages(
     mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
     mut socket_updater: Updater<(OwnedWriteHalf, Arc<Notify>)>,
 ) {
-    let (mut current_socket, mut current_watch) =
-        socket_updater.recv().await.expect("socket updater closed");
+    let (mut current_socket, mut current_watch) = socket_updater.recv().await;
     let mut message_cache = None;
     loop {
         let message = if message_cache.is_none() {
@@ -193,8 +191,7 @@ async fn send_messages(
             let _ = current_socket.shutdown().await;
 
             // wait for the new socket
-            (current_socket, current_watch) =
-                socket_updater.recv().await.expect("socket updater closed");
+            (current_socket, current_watch) = socket_updater.recv().await;
             log!("socket updated");
         } else {
             message_cache = None;

+ 23 - 20
src/updater.rs

@@ -1,39 +1,42 @@
 use std::sync::{Arc, Mutex};
 use tokio::sync::Notify;
 
+/// A multi/single-producer, single-consumer channel for updating an object.
+/// Unlike a mpsc, there is no queue of objects, only the most recent can be obtained.
+/// Unlike a watch, the receiver owns the object received.
 pub struct Updater<T>(Arc<(Mutex<Option<T>>, Notify)>);
 
 impl<T> Updater<T> {
-    pub async fn send(&self, value: T) {
+    /// Send an object T to the receiver end, repacing any currently queued object.
+    pub fn send(&self, value: T) {
         let mut locked_object = self.0 .0.lock().expect("send failed to lock mutex");
         *locked_object = Some(value);
         self.0 .1.notify_one();
     }
 
-    pub async fn recv(&mut self) -> Option<T> {
-        self.0 .1.notified().await;
-        {
-            let mut locked_object = self.0 .0.lock().unwrap();
-            if locked_object.is_some() {
-                return locked_object.take();
+    /// Get the object most recently sent by the sender end.
+    pub async fn recv(&mut self) -> T {
+        // According to a dev on GH, tokio's Notify is allowed false notifications.
+        loop {
+            self.0 .1.notified().await;
+            {
+                let mut locked_object = self.0 .0.lock().unwrap();
+                if locked_object.is_some() {
+                    return locked_object.take().unwrap();
+                }
             }
         }
+    }
 
-        // We must have gotten the last value from a stale notification:
-        //       ...
-        //                send.object.update
-        //                send.notify
-        // recv.notified
-        //                send.object.update
-        //                send.notify
-        // recv.object.update
-        //       ...
-        // recv.notified <- notified but no new object
-        // Waiting one more time should do the trick.
-        self.0 .1.notified().await;
-        self.0 .0.lock().unwrap().take()
+    /// Get the object most recently sent by the sender end, if one is already available.
+    pub fn maybe_recv(&mut self) -> Option<T> {
+        let mut locked_object = self.0 .0.lock().unwrap();
+        locked_object.take()
     }
 
+    /// Create an Updater channel.
+    /// Currently there is no distinction between sender and receiver updaters,
+    /// it's the caller's job to decide which is which.
     pub fn channel() -> (Self, Self) {
         let body = Arc::new((Mutex::new(None), Notify::new()));
         (Updater(body.clone()), Updater(body))