|
|
@@ -8,7 +8,7 @@ use tokio::net::{
|
|
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
|
TcpListener,
|
|
|
};
|
|
|
-use tokio::sync::{mpsc, watch};
|
|
|
+use tokio::sync::{mpsc, Notify};
|
|
|
|
|
|
// FIXME: identifiers should be interned
|
|
|
type ID = String;
|
|
|
@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
mpsc::UnboundedSender<Arc<SerializedMessage>>,
|
|
|
>::new()));
|
|
|
|
|
|
- let mut writer_db = HashMap::<ID, Updater<(OwnedWriteHalf, watch::Sender<bool>)>>::new();
|
|
|
+ let mut writer_db = HashMap::<ID, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
|
|
|
|
|
|
loop {
|
|
|
let (socket, _) = listener.accept().await?;
|
|
|
@@ -36,12 +36,12 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
if let Some(socket_updater) = writer_db.get(&id) {
|
|
|
// we've seen this client before
|
|
|
|
|
|
- // start the new reader thread with a new watch
|
|
|
- let (watch_snd, watch_rcv) = watch::channel(false);
|
|
|
- spawn_message_receiver(rd, snd_db.clone(), watch_rcv);
|
|
|
+ // start the new reader thread with a new notify
|
|
|
+ let notify = Arc::new(Notify::new());
|
|
|
+ spawn_message_receiver(rd, snd_db.clone(), notify.clone());
|
|
|
|
|
|
- // give the writer thread the new write half of the socket and watch
|
|
|
- socket_updater.send((wr, watch_snd)).await;
|
|
|
+ // give the writer thread the new write half of the socket and notify
|
|
|
+ socket_updater.send((wr, notify)).await;
|
|
|
} else {
|
|
|
// newly-registered client
|
|
|
log!("New client");
|
|
|
@@ -55,16 +55,19 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
locked_db.insert(id.clone(), msg_snd);
|
|
|
}
|
|
|
|
|
|
- // socket watch, used to terminate the socket if the sender encounters an error
|
|
|
- let (watch_snd, watch_rcv) = watch::channel(false);
|
|
|
+ // socket notify, used to terminate the socket if the sender encounters an error
|
|
|
+ let notify = Arc::new(Notify::new());
|
|
|
|
|
|
- // socket updater, used to give the sender thread a new socket and watch
|
|
|
+ // socket updater, used to give the sender thread a new socket + notify channel
|
|
|
+ // (we can't use the existing notify channel, else we get race conditions where
|
|
|
+ // the reader thread terminates and spawns again before the sender thread
|
|
|
+ // notices and activates its existing notify channel)
|
|
|
let (socket_updater_snd, socket_updater_rcv) = Updater::channel();
|
|
|
- socket_updater_snd.send((wr, watch_snd)).await;
|
|
|
+ socket_updater_snd.send((wr, notify.clone())).await;
|
|
|
|
|
|
writer_db.insert(id.clone(), socket_updater_snd);
|
|
|
|
|
|
- spawn_message_receiver(rd, snd_db, watch_rcv);
|
|
|
+ spawn_message_receiver(rd, snd_db, notify);
|
|
|
tokio::spawn(async move {
|
|
|
send_messages(msg_rcv, socket_updater_rcv).await;
|
|
|
});
|
|
|
@@ -75,9 +78,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
fn spawn_message_receiver(
|
|
|
rd: OwnedReadHalf,
|
|
|
db: Arc<Mutex<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
|
|
|
- mut watch_rcv: watch::Receiver<bool>,
|
|
|
+ notify: Arc<Notify>,
|
|
|
) {
|
|
|
- watch_rcv.borrow_and_update();
|
|
|
tokio::spawn(async move {
|
|
|
tokio::select! {
|
|
|
ret = get_messages(rd, db) => {
|
|
|
@@ -85,9 +87,9 @@ fn spawn_message_receiver(
|
|
|
log!("message receiver failed: {:?}", e);
|
|
|
}
|
|
|
}
|
|
|
- _ = watch_rcv.changed() => {
|
|
|
+ _ = notify.notified() => {
|
|
|
log!("receiver terminated");
|
|
|
- // should cause the other thread to terminate, dropping the socket
|
|
|
+ // should cause get_messages to terminate, dropping the socket
|
|
|
}
|
|
|
}
|
|
|
});
|
|
|
@@ -147,7 +149,7 @@ async fn get_messages(
|
|
|
/// and sending them out on the associated socket.
|
|
|
async fn send_messages(
|
|
|
mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
|
|
|
- mut socket_updater: Updater<(OwnedWriteHalf, watch::Sender<bool>)>,
|
|
|
+ mut socket_updater: Updater<(OwnedWriteHalf, Arc<Notify>)>,
|
|
|
) {
|
|
|
let (mut current_socket, mut current_watch) =
|
|
|
socket_updater.recv().await.expect("socket updater closed");
|
|
|
@@ -166,7 +168,7 @@ async fn send_messages(
|
|
|
log!("terminating connection");
|
|
|
// socket is presumably closed, clean up and notify the listening end to close
|
|
|
// (all best-effort, we can ignore errors because it presumably means it's done)
|
|
|
- let _ = current_watch.send(true);
|
|
|
+ current_watch.notify_one();
|
|
|
let _ = current_socket.shutdown().await;
|
|
|
|
|
|
// wait for the new socket
|