|
@@ -60,10 +60,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
let listener = socket.listen(4096)?;
|
|
|
log!("listening,{}", listen_addr);
|
|
|
|
|
|
- // Maps group name to the table of message channels.
|
|
|
- let mut snd_db = SndDb::new();
|
|
|
// Maps the (sender, group) pair to the socket updater.
|
|
|
- let mut writer_db = WriterDb::new();
|
|
|
+ let writer_db = Arc::new(RwLock::new(WriterDb::new()));
|
|
|
+ // Maps group name to the table of message channels.
|
|
|
+ let snd_db = Arc::new(RwLock::new(SndDb::new()));
|
|
|
// Notifies listener threads when registration phase is over.
|
|
|
let phase_notify = Arc::new(Notify::new());
|
|
|
|
|
@@ -77,14 +77,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
}
|
|
|
};
|
|
|
log!("accepted {}", stream.peer_addr()?);
|
|
|
- handle_handshake::</*REGISTRATION_PHASE=*/ true>(
|
|
|
- stream,
|
|
|
- acceptor.clone(),
|
|
|
- &mut writer_db,
|
|
|
- &mut snd_db,
|
|
|
- phase_notify.clone(),
|
|
|
- )
|
|
|
- .await;
|
|
|
+ let acceptor = acceptor.clone();
|
|
|
+ let writer_db = writer_db.clone();
|
|
|
+ let snd_db = snd_db.clone();
|
|
|
+ let phase_notify = phase_notify.clone();
|
|
|
+ tokio::spawn(async move {
|
|
|
+ handle_handshake::</*REGISTRATION_PHASE=*/ true>(
|
|
|
+ stream,
|
|
|
+ acceptor,
|
|
|
+ writer_db,
|
|
|
+ snd_db,
|
|
|
+ phase_notify,
|
|
|
+ )
|
|
|
+ .await
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
log!("registration phase complete");
|
|
@@ -100,22 +106,28 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
continue;
|
|
|
}
|
|
|
};
|
|
|
- handle_handshake::</*REGISTRATION_PHASE=*/ false>(
|
|
|
- stream,
|
|
|
- acceptor.clone(),
|
|
|
- &mut writer_db,
|
|
|
- &mut snd_db,
|
|
|
- phase_notify.clone(),
|
|
|
- )
|
|
|
- .await;
|
|
|
+ let acceptor = acceptor.clone();
|
|
|
+ let writer_db = writer_db.clone();
|
|
|
+ let snd_db = snd_db.clone();
|
|
|
+ let phase_notify = phase_notify.clone();
|
|
|
+ tokio::spawn(async move {
|
|
|
+ handle_handshake::</*REGISTRATION_PHASE=*/ false>(
|
|
|
+ stream,
|
|
|
+ acceptor,
|
|
|
+ writer_db,
|
|
|
+ snd_db,
|
|
|
+ phase_notify,
|
|
|
+ )
|
|
|
+ .await
|
|
|
+ });
|
|
|
}
|
|
|
}
|
|
|
|
|
|
async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
stream: TcpStream,
|
|
|
acceptor: TlsAcceptor,
|
|
|
- writer_db: &mut WriterDb,
|
|
|
- snd_db: &mut SndDb,
|
|
|
+ writer_db: Arc<RwLock<WriterDb>>,
|
|
|
+ snd_db: Arc<RwLock<SndDb>>,
|
|
|
phase_notify: Arc<Notify>,
|
|
|
) {
|
|
|
let stream = match acceptor.accept(stream).await {
|
|
@@ -139,7 +151,8 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
};
|
|
|
log!("accept,{},{}", handshake.sender, handshake.group);
|
|
|
|
|
|
- if let Some(socket_updater) = writer_db.get(&handshake) {
|
|
|
+ let read_writer_db = writer_db.read().await;
|
|
|
+ if let Some(socket_updater) = read_writer_db.get(&handshake) {
|
|
|
// we've seen this client before
|
|
|
|
|
|
// start the new reader thread with a new notify
|
|
@@ -147,7 +160,7 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
// the reader thread terminates and spawns again before the sender thread
|
|
|
// notices and activates its existing notify channel)
|
|
|
let socket_notify = Arc::new(Notify::new());
|
|
|
- let db = snd_db[&handshake.group].clone();
|
|
|
+ let db = snd_db.read().await[&handshake.group].clone();
|
|
|
spawn_message_receiver(
|
|
|
handshake.sender,
|
|
|
handshake.group,
|
|
@@ -160,6 +173,7 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
// give the writer thread the new write half of the socket and notify
|
|
|
socket_updater.send((wr, socket_notify));
|
|
|
} else {
|
|
|
+ drop(read_writer_db);
|
|
|
// newly-registered client
|
|
|
log!("register,{},{}", handshake.sender, handshake.group);
|
|
|
|
|
@@ -167,13 +181,30 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
// message channel, for sending messages between threads
|
|
|
let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
|
|
|
|
|
|
- let group_snds = snd_db
|
|
|
- .entry(handshake.group.clone())
|
|
|
- .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
|
|
|
- group_snds
|
|
|
- .write()
|
|
|
- .await
|
|
|
- .insert(handshake.sender.clone(), msg_snd);
|
|
|
+ let group_snds = {
|
|
|
+ let read_snd_db = snd_db.read().await;
|
|
|
+ let group_snds = read_snd_db.get(&handshake.group);
|
|
|
+ if let Some(group_snds) = group_snds {
|
|
|
+ let group_snds = group_snds.clone();
|
|
|
+ drop(read_snd_db);
|
|
|
+ group_snds
|
|
|
+ .write()
|
|
|
+ .await
|
|
|
+ .insert(handshake.sender.clone(), msg_snd);
|
|
|
+ group_snds
|
|
|
+ } else {
|
|
|
+ drop(read_snd_db);
|
|
|
+ let mut write_snd_db = snd_db.write().await;
|
|
|
+ let group_snds = write_snd_db
|
|
|
+ .entry(handshake.group.clone())
|
|
|
+ .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
|
|
|
+ group_snds
|
|
|
+ .write()
|
|
|
+ .await
|
|
|
+ .insert(handshake.sender.clone(), msg_snd);
|
|
|
+ group_snds.clone()
|
|
|
+ }
|
|
|
+ };
|
|
|
|
|
|
// socket notify, for terminating the socket if the sender encounters an error
|
|
|
let socket_notify = Arc::new(Notify::new());
|
|
@@ -187,7 +218,7 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
handshake.sender.clone(),
|
|
|
handshake.group.clone(),
|
|
|
rd,
|
|
|
- group_snds.clone(),
|
|
|
+ group_snds,
|
|
|
phase_notify,
|
|
|
socket_notify,
|
|
|
);
|
|
@@ -198,13 +229,16 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
|
|
|
});
|
|
|
|
|
|
- writer_db.insert(handshake, socket_updater_snd);
|
|
|
+ writer_db
|
|
|
+ .write()
|
|
|
+ .await
|
|
|
+ .insert(handshake, socket_updater_snd);
|
|
|
} else {
|
|
|
panic!(
|
|
|
"late registration: {},{}",
|
|
|
handshake.sender, handshake.group
|
|
|
);
|
|
|
- }
|
|
|
+ };
|
|
|
}
|
|
|
}
|
|
|
|