Quellcode durchsuchen

server: spawn new threads to handle handshakes

Previously, the server was doing a tiny bit of work to process the handshakes
(e.g., TLS) on the main thread. While the real workhorses were spawned
quickly, this little bit of traffic could quickly add up with a lot of
concurrent incoming connections. Hopefully this alleviates some of that.
Justin Tracey vor 3 Monaten
Ursprung
Commit
aa6412b4d3
1 geänderte Dateien mit 67 neuen und 33 gelöschten Zeilen
  1. 67 33
      src/bin/mgen-server.rs

+ 67 - 33
src/bin/mgen-server.rs

@@ -60,10 +60,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
     let listener = socket.listen(4096)?;
     log!("listening,{}", listen_addr);
 
-    // Maps group name to the table of message channels.
-    let mut snd_db = SndDb::new();
     // Maps the (sender, group) pair to the socket updater.
-    let mut writer_db = WriterDb::new();
+    let writer_db = Arc::new(RwLock::new(WriterDb::new()));
+    // Maps group name to the table of message channels.
+    let snd_db = Arc::new(RwLock::new(SndDb::new()));
     // Notifies listener threads when registration phase is over.
     let phase_notify = Arc::new(Notify::new());
 
@@ -77,14 +77,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
             }
         };
         log!("accepted {}", stream.peer_addr()?);
-        handle_handshake::</*REGISTRATION_PHASE=*/ true>(
-            stream,
-            acceptor.clone(),
-            &mut writer_db,
-            &mut snd_db,
-            phase_notify.clone(),
-        )
-        .await;
+        let acceptor = acceptor.clone();
+        let writer_db = writer_db.clone();
+        let snd_db = snd_db.clone();
+        let phase_notify = phase_notify.clone();
+        tokio::spawn(async move {
+            handle_handshake::</*REGISTRATION_PHASE=*/ true>(
+                stream,
+                acceptor,
+                writer_db,
+                snd_db,
+                phase_notify,
+            )
+            .await
+        });
     }
 
     log!("registration phase complete");
@@ -100,22 +106,28 @@ async fn main() -> Result<(), Box<dyn Error>> {
                 continue;
             }
         };
-        handle_handshake::</*REGISTRATION_PHASE=*/ false>(
-            stream,
-            acceptor.clone(),
-            &mut writer_db,
-            &mut snd_db,
-            phase_notify.clone(),
-        )
-        .await;
+        let acceptor = acceptor.clone();
+        let writer_db = writer_db.clone();
+        let snd_db = snd_db.clone();
+        let phase_notify = phase_notify.clone();
+        tokio::spawn(async move {
+            handle_handshake::</*REGISTRATION_PHASE=*/ false>(
+                stream,
+                acceptor,
+                writer_db,
+                snd_db,
+                phase_notify,
+            )
+            .await
+        });
     }
 }
 
 async fn handle_handshake<const REGISTRATION_PHASE: bool>(
     stream: TcpStream,
     acceptor: TlsAcceptor,
-    writer_db: &mut WriterDb,
-    snd_db: &mut SndDb,
+    writer_db: Arc<RwLock<WriterDb>>,
+    snd_db: Arc<RwLock<SndDb>>,
     phase_notify: Arc<Notify>,
 ) {
     let stream = match acceptor.accept(stream).await {
@@ -139,7 +151,8 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
     };
     log!("accept,{},{}", handshake.sender, handshake.group);
 
-    if let Some(socket_updater) = writer_db.get(&handshake) {
+    let read_writer_db = writer_db.read().await;
+    if let Some(socket_updater) = read_writer_db.get(&handshake) {
         // we've seen this client before
 
         // start the new reader thread with a new notify
@@ -147,7 +160,7 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
         // the reader thread terminates and spawns again before the sender thread
         // notices and activates its existing notify channel)
         let socket_notify = Arc::new(Notify::new());
-        let db = snd_db[&handshake.group].clone();
+        let db = snd_db.read().await[&handshake.group].clone();
         spawn_message_receiver(
             handshake.sender,
             handshake.group,
@@ -160,6 +173,7 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
         // give the writer thread the new write half of the socket and notify
         socket_updater.send((wr, socket_notify));
     } else {
+        drop(read_writer_db);
         // newly-registered client
         log!("register,{},{}", handshake.sender, handshake.group);
 
@@ -167,13 +181,30 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
             // message channel, for sending messages between threads
             let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
 
-            let group_snds = snd_db
-                .entry(handshake.group.clone())
-                .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
-            group_snds
-                .write()
-                .await
-                .insert(handshake.sender.clone(), msg_snd);
+            let group_snds = {
+                let read_snd_db = snd_db.read().await;
+                let group_snds = read_snd_db.get(&handshake.group);
+                if let Some(group_snds) = group_snds {
+                    let group_snds = group_snds.clone();
+                    drop(read_snd_db);
+                    group_snds
+                        .write()
+                        .await
+                        .insert(handshake.sender.clone(), msg_snd);
+                    group_snds
+                } else {
+                    drop(read_snd_db);
+                    let mut write_snd_db = snd_db.write().await;
+                    let group_snds = write_snd_db
+                        .entry(handshake.group.clone())
+                        .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
+                    group_snds
+                        .write()
+                        .await
+                        .insert(handshake.sender.clone(), msg_snd);
+                    group_snds.clone()
+                }
+            };
 
             // socket notify, for terminating the socket if the sender encounters an error
             let socket_notify = Arc::new(Notify::new());
@@ -187,7 +218,7 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
                 handshake.sender.clone(),
                 handshake.group.clone(),
                 rd,
-                group_snds.clone(),
+                group_snds,
                 phase_notify,
                 socket_notify,
             );
@@ -198,13 +229,16 @@ async fn handle_handshake<const REGISTRATION_PHASE: bool>(
                 send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
             });
 
-            writer_db.insert(handshake, socket_updater_snd);
+            writer_db
+                .write()
+                .await
+                .insert(handshake, socket_updater_snd);
         } else {
             panic!(
                 "late registration: {},{}",
                 handshake.sender, handshake.group
             );
-        }
+        };
     }
 }