|
@@ -4,17 +4,21 @@ use std::error::Error;
|
|
|
use std::io::BufReader;
|
|
|
use std::result::Result;
|
|
|
use std::sync::Arc;
|
|
|
+use std::time::Duration;
|
|
|
use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
use tokio::sync::{mpsc, Notify, RwLock};
|
|
|
-use tokio_rustls::{rustls::PrivateKey, server::TlsStream};
|
|
|
+use tokio::time::{timeout_at, Instant};
|
|
|
+use tokio_rustls::{rustls::PrivateKey, server::TlsStream, TlsAcceptor};
|
|
|
|
|
|
// FIXME: identifiers should be interned
|
|
|
type ID = String;
|
|
|
|
|
|
type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
|
|
|
+type WriterDb = HashMap<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>;
|
|
|
+type SndDb = HashMap<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>;
|
|
|
|
|
|
-#[tokio::main]
|
|
|
+#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
|
|
|
async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
let mut args = std::env::args();
|
|
|
let _arg0 = args.next().unwrap();
|
|
@@ -28,6 +32,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
|
|
|
let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
|
|
|
|
|
|
+ let reg_time = args.next().unwrap_or("5".to_string()).parse()?;
|
|
|
+ let reg_time = Instant::now() + Duration::from_secs(reg_time);
|
|
|
+
|
|
|
let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
|
|
|
let mut reader = BufReader::new(certfile);
|
|
|
let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
|
|
@@ -44,64 +51,115 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
.unwrap()
|
|
|
.with_no_client_auth()
|
|
|
.with_single_cert(certs, key)?;
|
|
|
- let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
|
|
|
+ let acceptor = TlsAcceptor::from(Arc::new(config));
|
|
|
|
|
|
let listener = TcpListener::bind(&listen_addr).await?;
|
|
|
log!("listening,{}", listen_addr);
|
|
|
|
|
|
// Maps group name to the table of message channels.
|
|
|
- let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
|
|
|
+ let mut snd_db = SndDb::new();
|
|
|
// Maps the (sender, group) pair to the socket updater.
|
|
|
- let mut writer_db =
|
|
|
- HashMap::<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>::new();
|
|
|
-
|
|
|
- loop {
|
|
|
- let stream = match listener.accept().await {
|
|
|
+ let mut writer_db = WriterDb::new();
|
|
|
+ // Notifies listener threads when registration phase is over.
|
|
|
+ let phase_notify = Arc::new(Notify::new());
|
|
|
+
|
|
|
+ // Allow registering or reconnecting during the registration time.
|
|
|
+ while let Ok(accepted) = timeout_at(reg_time, listener.accept()).await {
|
|
|
+ log!("foo");
|
|
|
+ let stream = match accepted {
|
|
|
Ok((stream, _)) => stream,
|
|
|
Err(e) => {
|
|
|
log!("failed,accept,{}", e.kind());
|
|
|
continue;
|
|
|
}
|
|
|
};
|
|
|
- let acceptor = acceptor.clone();
|
|
|
- let stream = match acceptor.accept(stream).await {
|
|
|
- Ok(stream) => stream,
|
|
|
- Err(e) => {
|
|
|
- log!("failed,tls,{}", e.kind());
|
|
|
- continue;
|
|
|
- }
|
|
|
- };
|
|
|
+ log!("accepted {}", stream.peer_addr()?);
|
|
|
+ handle_handshake::</*REGISTRATION_PHASE=*/ true>(
|
|
|
+ stream,
|
|
|
+ acceptor.clone(),
|
|
|
+ &mut writer_db,
|
|
|
+ &mut snd_db,
|
|
|
+ phase_notify.clone(),
|
|
|
+ )
|
|
|
+ .await;
|
|
|
+ }
|
|
|
|
|
|
- let (mut rd, wr) = split(stream);
|
|
|
+ // Notify all the listener threads that registration is over.
|
|
|
+ phase_notify.notify_waiters();
|
|
|
|
|
|
- let handshake = match mgen::get_handshake(&mut rd).await {
|
|
|
- Ok(handshake) => handshake,
|
|
|
- Err(mgen::Error::Io(e)) => {
|
|
|
- log!("failed,handshake,{}", e.kind());
|
|
|
+ // Now registration phase is over, only allow reconnecting.
|
|
|
+ loop {
|
|
|
+ let stream = match listener.accept().await {
|
|
|
+ Ok((stream, _)) => stream,
|
|
|
+ Err(e) => {
|
|
|
+ log!("failed,accept,{}", e.kind());
|
|
|
continue;
|
|
|
}
|
|
|
- Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
|
|
|
- Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
|
|
|
};
|
|
|
- log!("accept,{},{}", handshake.sender, handshake.group);
|
|
|
-
|
|
|
- if let Some(socket_updater) = writer_db.get(&handshake) {
|
|
|
- // we've seen this client before
|
|
|
+ handle_handshake::</*REGISTRATION_PHASE=*/ false>(
|
|
|
+ stream,
|
|
|
+ acceptor.clone(),
|
|
|
+ &mut writer_db,
|
|
|
+ &mut snd_db,
|
|
|
+ phase_notify.clone(),
|
|
|
+ )
|
|
|
+ .await;
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- // start the new reader thread with a new notify
|
|
|
- // (we can't use the existing notify channel, else we get race conditions where
|
|
|
- // the reader thread terminates and spawns again before the sender thread
|
|
|
- // notices and activates its existing notify channel)
|
|
|
- let notify = Arc::new(Notify::new());
|
|
|
- let db = snd_db[&handshake.group].clone();
|
|
|
- spawn_message_receiver(handshake.sender, handshake.group, rd, db, notify.clone());
|
|
|
+async fn handle_handshake<const REGISTRATION_PHASE: bool>(
|
|
|
+ stream: TcpStream,
|
|
|
+ acceptor: TlsAcceptor,
|
|
|
+ writer_db: &mut WriterDb,
|
|
|
+ snd_db: &mut SndDb,
|
|
|
+ phase_notify: Arc<Notify>,
|
|
|
+) {
|
|
|
+ let stream = match acceptor.accept(stream).await {
|
|
|
+ Ok(stream) => stream,
|
|
|
+ Err(e) => {
|
|
|
+ log!("failed,tls,{}", e.kind());
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ };
|
|
|
|
|
|
- // give the writer thread the new write half of the socket and notify
|
|
|
- socket_updater.send((wr, notify));
|
|
|
- } else {
|
|
|
- // newly-registered client
|
|
|
- log!("register,{},{}", handshake.sender, handshake.group);
|
|
|
+ let (mut rd, wr) = split(stream);
|
|
|
|
|
|
+ let handshake = match mgen::get_handshake(&mut rd).await {
|
|
|
+ Ok(handshake) => handshake,
|
|
|
+ Err(mgen::Error::Io(e)) => {
|
|
|
+ log!("failed,handshake,{}", e.kind());
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
|
|
|
+ Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
|
|
|
+ };
|
|
|
+ log!("accept,{},{}", handshake.sender, handshake.group);
|
|
|
+
|
|
|
+ if let Some(socket_updater) = writer_db.get(&handshake) {
|
|
|
+ // we've seen this client before
|
|
|
+
|
|
|
+ // start the new reader thread with a new notify
|
|
|
+ // (we can't use the existing notify channel, else we get race conditions where
|
|
|
+ // the reader thread terminates and spawns again before the sender thread
|
|
|
+ // notices and activates its existing notify channel)
|
|
|
+ let socket_notify = Arc::new(Notify::new());
|
|
|
+ let db = snd_db[&handshake.group].clone();
|
|
|
+ spawn_message_receiver(
|
|
|
+ handshake.sender,
|
|
|
+ handshake.group,
|
|
|
+ rd,
|
|
|
+ db,
|
|
|
+ phase_notify,
|
|
|
+ socket_notify.clone(),
|
|
|
+ );
|
|
|
+
|
|
|
+ // give the writer thread the new write half of the socket and notify
|
|
|
+ socket_updater.send((wr, socket_notify));
|
|
|
+ } else {
|
|
|
+ // newly-registered client
|
|
|
+ log!("register,{},{}", handshake.sender, handshake.group);
|
|
|
+
|
|
|
+ if REGISTRATION_PHASE {
|
|
|
// message channel, for sending messages between threads
|
|
|
let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
|
|
|
|
|
@@ -114,19 +172,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
.insert(handshake.sender.clone(), msg_snd);
|
|
|
|
|
|
// socket notify, for terminating the socket if the sender encounters an error
|
|
|
- let notify = Arc::new(Notify::new());
|
|
|
+ let socket_notify = Arc::new(Notify::new());
|
|
|
|
|
|
// socket updater, for giving the sender thread a new socket + notify channel
|
|
|
let socket_updater_snd = Updater::new();
|
|
|
let socket_updater_rcv = socket_updater_snd.clone();
|
|
|
- socket_updater_snd.send((wr, notify.clone()));
|
|
|
+ socket_updater_snd.send((wr, socket_notify.clone()));
|
|
|
|
|
|
spawn_message_receiver(
|
|
|
handshake.sender.clone(),
|
|
|
handshake.group.clone(),
|
|
|
rd,
|
|
|
group_snds.clone(),
|
|
|
- notify,
|
|
|
+ phase_notify,
|
|
|
+ socket_notify,
|
|
|
);
|
|
|
|
|
|
let sender = handshake.sender.clone();
|
|
@@ -136,6 +195,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
});
|
|
|
|
|
|
writer_db.insert(handshake, socket_updater_snd);
|
|
|
+ } else {
|
|
|
+ panic!(
|
|
|
+ "late registration: {},{}",
|
|
|
+ handshake.sender, handshake.group
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -145,14 +209,15 @@ fn spawn_message_receiver(
|
|
|
group: String,
|
|
|
rd: ReadHalf<TlsStream<TcpStream>>,
|
|
|
db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
|
|
|
- notify: Arc<Notify>,
|
|
|
+ phase_notify: Arc<Notify>,
|
|
|
+ socket_notify: Arc<Notify>,
|
|
|
) {
|
|
|
tokio::spawn(async move {
|
|
|
tokio::select! {
|
|
|
// n.b.: get_message is not cancellation safe,
|
|
|
// but this is one of the cases where that's expected
|
|
|
// (we only cancel when something is wrong with the stream anyway)
|
|
|
- ret = get_messages(&sender, &group, rd, db) => {
|
|
|
+ ret = get_messages(&sender, &group, rd, phase_notify, db) => {
|
|
|
match ret {
|
|
|
Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
|
|
|
Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
|
|
@@ -161,7 +226,7 @@ fn spawn_message_receiver(
|
|
|
Ok(()) => panic!("Message receiver returned OK"),
|
|
|
}
|
|
|
}
|
|
|
- _ = notify.notified() => {
|
|
|
+ _ = socket_notify.notified() => {
|
|
|
log!("terminated,{},{}", sender, group);
|
|
|
// should cause get_messages to terminate, dropping the socket
|
|
|
}
|
|
@@ -175,11 +240,11 @@ async fn get_messages<T: tokio::io::AsyncRead>(
|
|
|
sender: &str,
|
|
|
group: &str,
|
|
|
mut socket: ReadHalf<T>,
|
|
|
+ phase_notify: Arc<Notify>,
|
|
|
global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
|
|
|
) -> Result<(), mgen::Error> {
|
|
|
- // Wait for the next message to be received before populating our local copy of the db,
|
|
|
- // that way other clients have time to register.
|
|
|
- let buf = mgen::get_message_bytes(&mut socket).await?;
|
|
|
+ // Wait for the registration phase to end before updating our local copy of the DB
|
|
|
+ phase_notify.notified().await;
|
|
|
|
|
|
let db = global_db.read().await.clone();
|
|
|
let message_channels: Vec<_> = db
|
|
@@ -187,35 +252,6 @@ async fn get_messages<T: tokio::io::AsyncRead>(
|
|
|
.filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
|
|
|
.collect();
|
|
|
|
|
|
- let message = MessageHeaderRef::deserialize(&buf[4..])?;
|
|
|
- assert!(message.sender == sender);
|
|
|
-
|
|
|
- match message.body {
|
|
|
- MessageBody::Size(_) => {
|
|
|
- assert!(message.group == group);
|
|
|
- log!("received,{},{},{}", sender, group, message.id);
|
|
|
- let body = message.body;
|
|
|
- let m = Arc::new(SerializedMessage { header: buf, body });
|
|
|
- for recipient in message_channels.iter() {
|
|
|
- recipient.send(m.clone()).unwrap();
|
|
|
- }
|
|
|
- }
|
|
|
- MessageBody::Receipt => {
|
|
|
- log!(
|
|
|
- "receipt,{},{},{},{}",
|
|
|
- sender,
|
|
|
- group,
|
|
|
- message.group,
|
|
|
- message.id
|
|
|
- );
|
|
|
- let recipient = &db[message.group];
|
|
|
- let body = message.body;
|
|
|
- let m = Arc::new(SerializedMessage { header: buf, body });
|
|
|
- recipient.send(m).unwrap();
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // we never have to update the DB again, so repeat the above, skipping that step
|
|
|
loop {
|
|
|
let buf = mgen::get_message_bytes(&mut socket).await?;
|
|
|
let message = MessageHeaderRef::deserialize(&buf[4..])?;
|