123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
- use std::collections::HashMap;
- use std::error::Error;
- use std::io::BufReader;
- use std::result::Result;
- use std::sync::Arc;
- use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
- use tokio::net::{TcpListener, TcpStream};
- use tokio::sync::{mpsc, Notify, RwLock};
- use tokio_rustls::{rustls::PrivateKey, server::TlsStream};
- // FIXME: identifiers should be interned
- type ID = String;
- type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
- #[tokio::main]
- async fn main() -> Result<(), Box<dyn Error>> {
- let mut args = std::env::args();
- let _arg0 = args.next().unwrap();
- let cert_filename = args
- .next()
- .unwrap_or_else(|| panic!("no cert file provided"));
- let key_filename = args
- .next()
- .unwrap_or_else(|| panic!("no key file provided"));
- let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
- let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
- let mut reader = BufReader::new(certfile);
- let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
- .unwrap()
- .iter()
- .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
- .collect();
- let key = load_private_key(&key_filename);
- let config = tokio_rustls::rustls::ServerConfig::builder()
- .with_safe_default_cipher_suites()
- .with_safe_default_kx_groups()
- .with_safe_default_protocol_versions()
- .unwrap()
- .with_no_client_auth()
- .with_single_cert(certs, key)?;
- let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
- let listener = TcpListener::bind(&listen_addr).await?;
- log!("listening,{}", listen_addr);
- // Maps group name to the table of message channels.
- let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
- // Maps the (sender, group) pair to the socket updater.
- let mut writer_db =
- HashMap::<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>::new();
- loop {
- let stream = match listener.accept().await {
- Ok((stream, _)) => stream,
- Err(e) => {
- log!("failed,accept,{}", e.kind());
- continue;
- }
- };
- let acceptor = acceptor.clone();
- let stream = match acceptor.accept(stream).await {
- Ok(stream) => stream,
- Err(e) => {
- log!("failed,tls,{}", e.kind());
- continue;
- }
- };
- let (mut rd, wr) = split(stream);
- let handshake = match mgen::get_handshake(&mut rd).await {
- Ok(handshake) => handshake,
- Err(mgen::Error::Io(e)) => {
- log!("failed,handshake,{}", e.kind());
- continue;
- }
- Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
- Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
- };
- log!("accept,{},{}", handshake.sender, handshake.group);
- if let Some(socket_updater) = writer_db.get(&handshake) {
- // we've seen this client before
- // start the new reader thread with a new notify
- // (we can't use the existing notify channel, else we get race conditions where
- // the reader thread terminates and spawns again before the sender thread
- // notices and activates its existing notify channel)
- let notify = Arc::new(Notify::new());
- let db = snd_db[&handshake.group].clone();
- spawn_message_receiver(handshake.sender, handshake.group, rd, db, notify.clone());
- // give the writer thread the new write half of the socket and notify
- socket_updater.send((wr, notify));
- } else {
- // newly-registered client
- log!("register,{},{}", handshake.sender, handshake.group);
- // message channel, for sending messages between threads
- let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
- let group_snds = snd_db
- .entry(handshake.group.clone())
- .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
- group_snds
- .write()
- .await
- .insert(handshake.sender.clone(), msg_snd);
- // socket notify, for terminating the socket if the sender encounters an error
- let notify = Arc::new(Notify::new());
- // socket updater, for giving the sender thread a new socket + notify channel
- let socket_updater_snd = Updater::new();
- let socket_updater_rcv = socket_updater_snd.clone();
- socket_updater_snd.send((wr, notify.clone()));
- spawn_message_receiver(
- handshake.sender.clone(),
- handshake.group.clone(),
- rd,
- group_snds.clone(),
- notify,
- );
- let sender = handshake.sender.clone();
- let group = handshake.group.clone();
- tokio::spawn(async move {
- send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
- });
- writer_db.insert(handshake, socket_updater_snd);
- }
- }
- }
- fn spawn_message_receiver(
- sender: String,
- group: String,
- rd: ReadHalf<TlsStream<TcpStream>>,
- db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
- notify: Arc<Notify>,
- ) {
- tokio::spawn(async move {
- tokio::select! {
- // n.b.: get_message is not cancellation safe,
- // but this is one of the cases where that's expected
- // (we only cancel when something is wrong with the stream anyway)
- ret = get_messages(&sender, &group, rd, db) => {
- match ret {
- Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
- Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
- Err(mgen::Error::MalformedSerialization(v, b)) => panic!(
- "Malformed Serialization: {:?}\n{:?})", v, b),
- Ok(()) => panic!("Message receiver returned OK"),
- }
- }
- _ = notify.notified() => {
- log!("terminated,{},{}", sender, group);
- // should cause get_messages to terminate, dropping the socket
- }
- }
- });
- }
- /// Loop for receiving messages on the socket, figuring out who to deliver them to,
- /// and forwarding them locally to the respective channel.
- async fn get_messages<T: tokio::io::AsyncRead>(
- sender: &str,
- group: &str,
- mut socket: ReadHalf<T>,
- global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
- ) -> Result<(), mgen::Error> {
- // Wait for the next message to be received before populating our local copy of the db,
- // that way other clients have time to register.
- let buf = mgen::get_message_bytes(&mut socket).await?;
- let db = global_db.read().await.clone();
- let message_channels: Vec<_> = db
- .iter()
- .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
- .collect();
- let message = MessageHeaderRef::deserialize(&buf[4..])?;
- assert!(message.sender == sender);
- match message.body {
- MessageBody::Size(_) => {
- assert!(message.group == group);
- log!("received,{},{},{}", sender, group, message.id);
- let body = message.body;
- let m = Arc::new(SerializedMessage { header: buf, body });
- for recipient in message_channels.iter() {
- recipient.send(m.clone()).unwrap();
- }
- }
- MessageBody::Receipt => {
- log!(
- "receipt,{},{},{},{}",
- sender,
- group,
- message.group,
- message.id
- );
- let recipient = &db[message.group];
- let body = message.body;
- let m = Arc::new(SerializedMessage { header: buf, body });
- recipient.send(m).unwrap();
- }
- }
- // we never have to update the DB again, so repeat the above, skipping that step
- loop {
- let buf = mgen::get_message_bytes(&mut socket).await?;
- let message = MessageHeaderRef::deserialize(&buf[4..])?;
- assert!(message.sender == sender);
- match message.body {
- MessageBody::Size(_) => {
- assert!(message.group == group);
- log!("received,{},{},{}", sender, group, message.id);
- let body = message.body;
- let m = Arc::new(SerializedMessage { header: buf, body });
- for recipient in message_channels.iter() {
- recipient.send(m.clone()).unwrap();
- }
- }
- MessageBody::Receipt => {
- log!(
- "receipt,{},{},{},{}",
- sender,
- group,
- message.group,
- message.id
- );
- let recipient = &db[message.group];
- let body = message.body;
- let m = Arc::new(SerializedMessage { header: buf, body });
- recipient.send(m).unwrap();
- }
- }
- }
- }
- /// Loop for receiving messages on the mpsc channel for this recipient,
- /// and sending them out on the associated socket.
- async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
- recipient: ID,
- group: ID,
- mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
- mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
- ) {
- let (mut current_socket, mut current_watch) = socket_updater.recv().await;
- let mut message_cache = None;
- loop {
- let message = if message_cache.is_none() {
- msg_rcv.recv().await.expect("message channel closed")
- } else {
- message_cache.unwrap()
- };
- if message.write_all_to(&mut current_socket).await.is_err()
- || current_socket.flush().await.is_err()
- {
- message_cache = Some(message);
- log!("terminating,{},{}", recipient, group);
- // socket is presumably closed, clean up and notify the listening end to close
- // (all best-effort, we can ignore errors because it presumably means it's done)
- current_watch.notify_one();
- let _ = current_socket.shutdown().await;
- // wait for the new socket
- (current_socket, current_watch) = socket_updater.recv().await;
- } else {
- log!("sent,{},{}", recipient, group);
- message_cache = None;
- }
- }
- }
- fn load_private_key(filename: &str) -> PrivateKey {
- let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
- let mut reader = BufReader::new(keyfile);
- loop {
- match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
- Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
- Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
- Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
- None => break,
- _ => {}
- }
- }
- panic!(
- "no keys found in {:?} (encrypted keys not supported)",
- filename
- );
- }
|