Browse Source

fix timing bug (sort of)

This problem seems to be fundamentally a problem with relying on an async
runtime in shadow. It boils down to "don't rely on async timing" extends
between hosts when run in shadow, so tokio can just decide not to run the
TLS handshake for arbitrarily long after the TCP handshake. For now, these
changes (notably, increasing the number of threads to 2) looks like it might
work around it.
Justin Tracey 4 months ago
parent
commit
493ce34383
2 changed files with 116 additions and 79 deletions
  1. 2 1
      src/bin/mgen-client.rs
  2. 114 78
      src/bin/mgen-server.rs

+ 2 - 1
src/bin/mgen-client.rs

@@ -317,6 +317,7 @@ async fn socket_updater(
         if stream.write_all(&handshake.serialize()).await.is_err() {
             continue;
         }
+        log!("{},{},handshake", str_params.user, str_params.recipient);
 
         let (rd, wr) = split(stream);
         reader_channel.send(rd);
@@ -462,7 +463,7 @@ struct Config {
     conversations: Vec<ConversationConfig>,
 }
 
-#[tokio::main]
+#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
 async fn main() -> Result<(), Box<dyn std::error::Error>> {
     let mut args = std::env::args();
     let _ = args.next();

+ 114 - 78
src/bin/mgen-server.rs

@@ -4,17 +4,21 @@ use std::error::Error;
 use std::io::BufReader;
 use std::result::Result;
 use std::sync::Arc;
+use std::time::Duration;
 use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
 use tokio::net::{TcpListener, TcpStream};
 use tokio::sync::{mpsc, Notify, RwLock};
-use tokio_rustls::{rustls::PrivateKey, server::TlsStream};
+use tokio::time::{timeout_at, Instant};
+use tokio_rustls::{rustls::PrivateKey, server::TlsStream, TlsAcceptor};
 
 // FIXME: identifiers should be interned
 type ID = String;
 
 type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
+type WriterDb = HashMap<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>;
+type SndDb = HashMap<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>;
 
-#[tokio::main]
+#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
 async fn main() -> Result<(), Box<dyn Error>> {
     let mut args = std::env::args();
     let _arg0 = args.next().unwrap();
@@ -28,6 +32,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
     let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
 
+    let reg_time = args.next().unwrap_or("5".to_string()).parse()?;
+    let reg_time = Instant::now() + Duration::from_secs(reg_time);
+
     let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
     let mut reader = BufReader::new(certfile);
     let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
@@ -44,64 +51,115 @@ async fn main() -> Result<(), Box<dyn Error>> {
         .unwrap()
         .with_no_client_auth()
         .with_single_cert(certs, key)?;
-    let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
+    let acceptor = TlsAcceptor::from(Arc::new(config));
 
     let listener = TcpListener::bind(&listen_addr).await?;
     log!("listening,{}", listen_addr);
 
     // Maps group name to the table of message channels.
-    let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
+    let mut snd_db = SndDb::new();
     // Maps the (sender, group) pair to the socket updater.
-    let mut writer_db =
-        HashMap::<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>::new();
-
-    loop {
-        let stream = match listener.accept().await {
+    let mut writer_db = WriterDb::new();
+    // Notifies listener threads when registration phase is over.
+    let phase_notify = Arc::new(Notify::new());
+
+    // Allow registering or reconnecting during the registration time.
+    while let Ok(accepted) = timeout_at(reg_time, listener.accept()).await {
+        log!("foo");
+        let stream = match accepted {
             Ok((stream, _)) => stream,
             Err(e) => {
                 log!("failed,accept,{}", e.kind());
                 continue;
             }
         };
-        let acceptor = acceptor.clone();
-        let stream = match acceptor.accept(stream).await {
-            Ok(stream) => stream,
-            Err(e) => {
-                log!("failed,tls,{}", e.kind());
-                continue;
-            }
-        };
+        log!("accepted {}", stream.peer_addr()?);
+        handle_handshake::</*REGISTRATION_PHASE=*/ true>(
+            stream,
+            acceptor.clone(),
+            &mut writer_db,
+            &mut snd_db,
+            phase_notify.clone(),
+        )
+        .await;
+    }
 
-        let (mut rd, wr) = split(stream);
+    // Notify all the listener threads that registration is over.
+    phase_notify.notify_waiters();
 
-        let handshake = match mgen::get_handshake(&mut rd).await {
-            Ok(handshake) => handshake,
-            Err(mgen::Error::Io(e)) => {
-                log!("failed,handshake,{}", e.kind());
+    // Now registration phase is over, only allow reconnecting.
+    loop {
+        let stream = match listener.accept().await {
+            Ok((stream, _)) => stream,
+            Err(e) => {
+                log!("failed,accept,{}", e.kind());
                 continue;
             }
-            Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
-            Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
         };
-        log!("accept,{},{}", handshake.sender, handshake.group);
-
-        if let Some(socket_updater) = writer_db.get(&handshake) {
-            // we've seen this client before
+        handle_handshake::</*REGISTRATION_PHASE=*/ false>(
+            stream,
+            acceptor.clone(),
+            &mut writer_db,
+            &mut snd_db,
+            phase_notify.clone(),
+        )
+        .await;
+    }
+}
 
-            // start the new reader thread with a new notify
-            // (we can't use the existing notify channel, else we get race conditions where
-            // the reader thread terminates and spawns again before the sender thread
-            // notices and activates its existing notify channel)
-            let notify = Arc::new(Notify::new());
-            let db = snd_db[&handshake.group].clone();
-            spawn_message_receiver(handshake.sender, handshake.group, rd, db, notify.clone());
+async fn handle_handshake<const REGISTRATION_PHASE: bool>(
+    stream: TcpStream,
+    acceptor: TlsAcceptor,
+    writer_db: &mut WriterDb,
+    snd_db: &mut SndDb,
+    phase_notify: Arc<Notify>,
+) {
+    let stream = match acceptor.accept(stream).await {
+        Ok(stream) => stream,
+        Err(e) => {
+            log!("failed,tls,{}", e.kind());
+            return;
+        }
+    };
 
-            // give the writer thread the new write half of the socket and notify
-            socket_updater.send((wr, notify));
-        } else {
-            // newly-registered client
-            log!("register,{},{}", handshake.sender, handshake.group);
+    let (mut rd, wr) = split(stream);
 
+    let handshake = match mgen::get_handshake(&mut rd).await {
+        Ok(handshake) => handshake,
+        Err(mgen::Error::Io(e)) => {
+            log!("failed,handshake,{}", e.kind());
+            return;
+        }
+        Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
+        Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
+    };
+    log!("accept,{},{}", handshake.sender, handshake.group);
+
+    if let Some(socket_updater) = writer_db.get(&handshake) {
+        // we've seen this client before
+
+        // start the new reader thread with a new notify
+        // (we can't use the existing notify channel, else we get race conditions where
+        // the reader thread terminates and spawns again before the sender thread
+        // notices and activates its existing notify channel)
+        let socket_notify = Arc::new(Notify::new());
+        let db = snd_db[&handshake.group].clone();
+        spawn_message_receiver(
+            handshake.sender,
+            handshake.group,
+            rd,
+            db,
+            phase_notify,
+            socket_notify.clone(),
+        );
+
+        // give the writer thread the new write half of the socket and notify
+        socket_updater.send((wr, socket_notify));
+    } else {
+        // newly-registered client
+        log!("register,{},{}", handshake.sender, handshake.group);
+
+        if REGISTRATION_PHASE {
             // message channel, for sending messages between threads
             let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
 
@@ -114,19 +172,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
                 .insert(handshake.sender.clone(), msg_snd);
 
             // socket notify, for terminating the socket if the sender encounters an error
-            let notify = Arc::new(Notify::new());
+            let socket_notify = Arc::new(Notify::new());
 
             // socket updater, for giving the sender thread a new socket + notify channel
             let socket_updater_snd = Updater::new();
             let socket_updater_rcv = socket_updater_snd.clone();
-            socket_updater_snd.send((wr, notify.clone()));
+            socket_updater_snd.send((wr, socket_notify.clone()));
 
             spawn_message_receiver(
                 handshake.sender.clone(),
                 handshake.group.clone(),
                 rd,
                 group_snds.clone(),
-                notify,
+                phase_notify,
+                socket_notify,
             );
 
             let sender = handshake.sender.clone();
@@ -136,6 +195,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
             });
 
             writer_db.insert(handshake, socket_updater_snd);
+        } else {
+            panic!(
+                "late registration: {},{}",
+                handshake.sender, handshake.group
+            );
         }
     }
 }
@@ -145,14 +209,15 @@ fn spawn_message_receiver(
     group: String,
     rd: ReadHalf<TlsStream<TcpStream>>,
     db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
-    notify: Arc<Notify>,
+    phase_notify: Arc<Notify>,
+    socket_notify: Arc<Notify>,
 ) {
     tokio::spawn(async move {
         tokio::select! {
             // n.b.: get_message is not cancellation safe,
             // but this is one of the cases where that's expected
             // (we only cancel when something is wrong with the stream anyway)
-            ret = get_messages(&sender, &group, rd, db) => {
+            ret = get_messages(&sender, &group, rd, phase_notify, db) => {
                 match ret {
                     Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
                     Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
@@ -161,7 +226,7 @@ fn spawn_message_receiver(
                     Ok(()) => panic!("Message receiver returned OK"),
                 }
             }
-            _ = notify.notified() => {
+            _ = socket_notify.notified() => {
                 log!("terminated,{},{}", sender, group);
                 // should cause get_messages to terminate, dropping the socket
             }
@@ -175,11 +240,11 @@ async fn get_messages<T: tokio::io::AsyncRead>(
     sender: &str,
     group: &str,
     mut socket: ReadHalf<T>,
+    phase_notify: Arc<Notify>,
     global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
 ) -> Result<(), mgen::Error> {
-    // Wait for the next message to be received before populating our local copy of the db,
-    // that way other clients have time to register.
-    let buf = mgen::get_message_bytes(&mut socket).await?;
+    // Wait for the registration phase to end before updating our local copy of the DB
+    phase_notify.notified().await;
 
     let db = global_db.read().await.clone();
     let message_channels: Vec<_> = db
@@ -187,35 +252,6 @@ async fn get_messages<T: tokio::io::AsyncRead>(
         .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
         .collect();
 
-    let message = MessageHeaderRef::deserialize(&buf[4..])?;
-    assert!(message.sender == sender);
-
-    match message.body {
-        MessageBody::Size(_) => {
-            assert!(message.group == group);
-            log!("received,{},{},{}", sender, group, message.id);
-            let body = message.body;
-            let m = Arc::new(SerializedMessage { header: buf, body });
-            for recipient in message_channels.iter() {
-                recipient.send(m.clone()).unwrap();
-            }
-        }
-        MessageBody::Receipt => {
-            log!(
-                "receipt,{},{},{},{}",
-                sender,
-                group,
-                message.group,
-                message.id
-            );
-            let recipient = &db[message.group];
-            let body = message.body;
-            let m = Arc::new(SerializedMessage { header: buf, body });
-            recipient.send(m).unwrap();
-        }
-    }
-
-    // we never have to update the DB again, so repeat the above, skipping that step
     loop {
         let buf = mgen::get_message_bytes(&mut socket).await?;
         let message = MessageHeaderRef::deserialize(&buf[4..])?;