Przeglądaj źródła

terminate client/peer on error immediately

Justin Tracey 1 rok temu
rodzic
commit
506bf0282f
3 zmienionych plików z 102 dodań i 93 usunięć
  1. 1 0
      Cargo.toml
  2. 2 1
      src/bin/mgen-client.rs
  3. 99 92
      src/bin/mgen-peer.rs

+ 1 - 0
Cargo.toml

@@ -5,6 +5,7 @@ edition = "2021"
 
 [dependencies]
 chrono = "0.4.24"
+futures = "0.3.28"
 glob = "0.3.1"
 rand = "0.8.5"
 rand_distr = { version = "0.4.3", features = ["serde1"] }

+ 2 - 1
src/bin/mgen-client.rs

@@ -238,7 +238,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
             handles.push(handle);
         }
     }
-    handles.shrink_to_fit();
+
+    let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
     for handle in handles {
         handle.await??;
     }

+ 99 - 92
src/bin/mgen-peer.rs

@@ -11,7 +11,7 @@ use tokio::net::{
     TcpListener,
 };
 use tokio::sync::mpsc;
-use tokio::task;
+use tokio::task::JoinHandle;
 use tokio::time::Duration;
 
 mod messenger;
@@ -309,11 +309,10 @@ struct Config {
     conversations: Vec<ConversationConfig>,
 }
 
-#[tokio::main]
-async fn main() -> Result<(), Box<dyn std::error::Error>> {
-    let mut args = std::env::args();
-    let _ = args.next();
-
+fn process_config(
+    config: Config,
+    handles: &mut Vec<JoinHandle<Result<(), FatalError>>>,
+) -> Result<(), Box<dyn std::error::Error>> {
     struct ForIoThreads {
         state_to_writer: mpsc::UnboundedSender<MessageHolder>,
         writer_from_state: WriterFromState,
@@ -322,103 +321,111 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
         retry: f64,
     }
 
-    let mut handles = vec![];
+    // map from `recipient` to things the (user, recipient) reader/writer threads will need
+    let mut recipient_map = HashMap::<String, ForIoThreads>::new();
+    for conversation in config.conversations.into_iter() {
+        let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
 
-    for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
-        let yaml_s = std::fs::read_to_string(config_file?)?;
-        let config: Config = serde_yaml::from_str(&yaml_s)?;
+        let mut conversation_recipient_map =
+            HashMap::<String, StateToWriter<MessageHolder>>::with_capacity(
+                conversation.recipients.len(),
+            );
 
-        // map from `recipient` to things the (user, recipient) reader/writer threads will need
-        let mut recipient_map = HashMap::<String, ForIoThreads>::new();
-        for conversation in config.conversations.into_iter() {
-            let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
-
-            let mut conversation_recipient_map =
-                HashMap::<String, StateToWriter<MessageHolder>>::with_capacity(
-                    conversation.recipients.len(),
-                );
-
-            for recipient in conversation.recipients.iter() {
-                let state_to_writer = if !recipient_map.contains_key(&recipient.name) {
-                    let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
-                    let mut reader_to_states = HashMap::new();
-                    reader_to_states.insert(conversation.group.clone(), reader_to_state.clone());
-                    let str_params = SocksParams {
-                        socks: config.socks.clone(),
-                        target: recipient.address.clone(),
-                        user: config.user.name.clone(),
-                        recipient: recipient.name.clone(),
-                    };
-                    let for_io = ForIoThreads {
-                        state_to_writer: state_to_writer.clone(),
-                        writer_from_state,
-                        reader_to_states,
-                        str_params,
-                        retry: conversation.retry,
-                    };
-                    recipient_map.insert(recipient.name.clone(), for_io);
-                    state_to_writer
-                } else {
-                    let for_io = recipient_map.get_mut(&recipient.name).unwrap();
-                    if !for_io.reader_to_states.contains_key(&conversation.group) {
-                        for_io
-                            .reader_to_states
-                            .insert(conversation.group.clone(), reader_to_state.clone());
-                    }
-                    for_io.state_to_writer.clone()
+        for recipient in conversation.recipients.iter() {
+            let state_to_writer = if !recipient_map.contains_key(&recipient.name) {
+                let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
+                let mut reader_to_states = HashMap::new();
+                reader_to_states.insert(conversation.group.clone(), reader_to_state.clone());
+                let str_params = SocksParams {
+                    socks: config.socks.clone(),
+                    target: recipient.address.clone(),
+                    user: config.user.name.clone(),
+                    recipient: recipient.name.clone(),
                 };
-                conversation_recipient_map.insert(
-                    recipient.name.clone(),
-                    StateToWriter {
-                        channel: state_to_writer,
-                    },
-                );
-            }
+                let for_io = ForIoThreads {
+                    state_to_writer: state_to_writer.clone(),
+                    writer_from_state,
+                    reader_to_states,
+                    str_params,
+                    retry: conversation.retry,
+                };
+                recipient_map.insert(recipient.name.clone(), for_io);
+                state_to_writer
+            } else {
+                let for_io = recipient_map.get_mut(&recipient.name).unwrap();
+                if !for_io.reader_to_states.contains_key(&conversation.group) {
+                    for_io
+                        .reader_to_states
+                        .insert(conversation.group.clone(), reader_to_state.clone());
+                }
+                for_io.state_to_writer.clone()
+            };
+            conversation_recipient_map.insert(
+                recipient.name.clone(),
+                StateToWriter {
+                    channel: state_to_writer,
+                },
+            );
+        }
 
-            let distributions: Distributions = conversation.distributions.try_into()?;
+        let distributions: Distributions = conversation.distributions.try_into()?;
 
-            tokio::spawn(manage_conversation(
-                config.user.name.clone(),
-                conversation.group,
-                distributions,
-                conversation.bootstrap,
-                state_from_reader,
-                conversation_recipient_map,
-            ));
-        }
+        tokio::spawn(manage_conversation(
+            config.user.name.clone(),
+            conversation.group,
+            distributions,
+            conversation.bootstrap,
+            state_from_reader,
+            conversation_recipient_map,
+        ));
+    }
 
-        let mut name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)> =
-            HashMap::new();
-
-        for (recipient, for_io) in recipient_map.drain() {
-            let listener_writer_to_reader = Updater::new();
-            let reader_from_listener_writer = listener_writer_to_reader.clone();
-            let listener_to_writer = Updater::new();
-            let writer_from_listener = listener_to_writer.clone();
-            name_to_io_threads.insert(
-                recipient.to_string(),
-                (listener_writer_to_reader.clone(), listener_to_writer),
-            );
+    let mut name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)> =
+        HashMap::new();
+
+    for (recipient, for_io) in recipient_map.drain() {
+        let listener_writer_to_reader = Updater::new();
+        let reader_from_listener_writer = listener_writer_to_reader.clone();
+        let listener_to_writer = Updater::new();
+        let writer_from_listener = listener_to_writer.clone();
+        name_to_io_threads.insert(
+            recipient.to_string(),
+            (listener_writer_to_reader.clone(), listener_to_writer),
+        );
+
+        tokio::spawn(reader(reader_from_listener_writer, for_io.reader_to_states));
+
+        let retry = Duration::from_secs_f64(for_io.retry);
+        let handle: JoinHandle<Result<(), FatalError>> = tokio::spawn(writer(
+            for_io.writer_from_state,
+            writer_from_listener,
+            listener_writer_to_reader,
+            for_io.str_params,
+            retry,
+        ));
+        handles.push(handle);
+    }
 
-            tokio::spawn(reader(reader_from_listener_writer, for_io.reader_to_states));
+    let handle: JoinHandle<Result<(), FatalError>> =
+        tokio::spawn(listener(config.user.address, name_to_io_threads));
+    handles.push(handle);
+    Ok(())
+}
 
-            let retry = Duration::from_secs_f64(for_io.retry);
-            let handle: task::JoinHandle<Result<(), FatalError>> = tokio::spawn(writer(
-                for_io.writer_from_state,
-                writer_from_listener,
-                listener_writer_to_reader,
-                for_io.str_params,
-                retry,
-            ));
-            handles.push(handle);
-        }
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn std::error::Error>> {
+    let mut args = std::env::args();
+    let _ = args.next();
 
-        let handle: task::JoinHandle<Result<(), FatalError>> =
-            tokio::spawn(listener(config.user.address, name_to_io_threads));
-        handles.push(handle);
+    let mut handles = vec![];
+
+    for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
+        let yaml_s = std::fs::read_to_string(config_file?)?;
+        let config: Config = serde_yaml::from_str(&yaml_s)?;
+        process_config(config, &mut handles)?;
     }
 
-    handles.shrink_to_fit();
+    let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
     for handle in handles {
         handle.await??;
     }