Просмотр исходного кода

use Notify instead of watch for terminating sever reader thread

Justin Tracey 2 лет назад
Родитель
Сommit
4874416e92
1 измененных файлов с 20 добавлено и 18 удалено
  1. 20 18
      src/bin/server.rs

+ 20 - 18
src/bin/server.rs

@@ -8,7 +8,7 @@ use tokio::net::{
     tcp::{OwnedReadHalf, OwnedWriteHalf},
     TcpListener,
 };
-use tokio::sync::{mpsc, watch};
+use tokio::sync::{mpsc, Notify};
 
 // FIXME: identifiers should be interned
 type ID = String;
@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
         mpsc::UnboundedSender<Arc<SerializedMessage>>,
     >::new()));
 
-    let mut writer_db = HashMap::<ID, Updater<(OwnedWriteHalf, watch::Sender<bool>)>>::new();
+    let mut writer_db = HashMap::<ID, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
 
     loop {
         let (socket, _) = listener.accept().await?;
@@ -36,12 +36,12 @@ async fn main() -> Result<(), Box<dyn Error>> {
         if let Some(socket_updater) = writer_db.get(&id) {
             // we've seen this client before
 
-            // start the new reader thread with a new watch
-            let (watch_snd, watch_rcv) = watch::channel(false);
-            spawn_message_receiver(rd, snd_db.clone(), watch_rcv);
+            // start the new reader thread with a new notify
+            let notify = Arc::new(Notify::new());
+            spawn_message_receiver(rd, snd_db.clone(), notify.clone());
 
-            // give the writer thread the new write half of the socket and watch
-            socket_updater.send((wr, watch_snd)).await;
+            // give the writer thread the new write half of the socket and notify
+            socket_updater.send((wr, notify)).await;
         } else {
             // newly-registered client
             log!("New client");
@@ -55,16 +55,19 @@ async fn main() -> Result<(), Box<dyn Error>> {
                 locked_db.insert(id.clone(), msg_snd);
             }
 
-            // socket watch, used to terminate the socket if the sender encounters an error
-            let (watch_snd, watch_rcv) = watch::channel(false);
+            // socket notify, used to terminate the socket if the sender encounters an error
+            let notify = Arc::new(Notify::new());
 
-            // socket updater, used to give the sender thread a new socket and watch
+            // socket updater, used to give the sender thread a new socket + notify channel
+            // (we can't use the existing notify channel, else we get race conditions where
+            // the reader thread terminates and spawns again before the sender thread
+            // notices and activates its existing notify channel)
             let (socket_updater_snd, socket_updater_rcv) = Updater::channel();
-            socket_updater_snd.send((wr, watch_snd)).await;
+            socket_updater_snd.send((wr, notify.clone())).await;
 
             writer_db.insert(id.clone(), socket_updater_snd);
 
-            spawn_message_receiver(rd, snd_db, watch_rcv);
+            spawn_message_receiver(rd, snd_db, notify);
             tokio::spawn(async move {
                 send_messages(msg_rcv, socket_updater_rcv).await;
             });
@@ -75,9 +78,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
 fn spawn_message_receiver(
     rd: OwnedReadHalf,
     db: Arc<Mutex<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
-    mut watch_rcv: watch::Receiver<bool>,
+    notify: Arc<Notify>,
 ) {
-    watch_rcv.borrow_and_update();
     tokio::spawn(async move {
         tokio::select! {
             ret = get_messages(rd, db) => {
@@ -85,9 +87,9 @@ fn spawn_message_receiver(
                     log!("message receiver failed: {:?}", e);
                 }
             }
-            _ = watch_rcv.changed() => {
+            _ = notify.notified() => {
                 log!("receiver terminated");
-                // should cause the other thread to terminate, dropping the socket
+                // should cause get_messages to terminate, dropping the socket
             }
         }
     });
@@ -147,7 +149,7 @@ async fn get_messages(
 /// and sending them out on the associated socket.
 async fn send_messages(
     mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
-    mut socket_updater: Updater<(OwnedWriteHalf, watch::Sender<bool>)>,
+    mut socket_updater: Updater<(OwnedWriteHalf, Arc<Notify>)>,
 ) {
     let (mut current_socket, mut current_watch) =
         socket_updater.recv().await.expect("socket updater closed");
@@ -166,7 +168,7 @@ async fn send_messages(
             log!("terminating connection");
             // socket is presumably closed, clean up and notify the listening end to close
             // (all best-effort, we can ignore errors because it presumably means it's done)
-            let _ = current_watch.send(true);
+            current_watch.notify_one();
             let _ = current_socket.shutdown().await;
 
             // wait for the new socket