Bladeren bron

add support for p2p clients (built as "peer")

Justin Tracey 1 jaar geleden
bovenliggende
commit
a77581b348
8 gewijzigde bestanden met toevoegingen van 410 en 222 verwijderingen
  1. 35 4
      README.md
  2. 189 93
      src/bin/client.rs
  3. 7 0
      src/bin/messenger/dists.rs
  4. 23 2
      src/bin/messenger/error.rs
  5. 8 1
      src/bin/messenger/message.rs
  6. 123 108
      src/bin/messenger/state.rs
  7. 3 12
      src/bin/server.rs
  8. 22 2
      src/lib.rs

+ 35 - 4
README.md

@@ -7,6 +7,8 @@ Notably, this allows for studying network traffic properties of messenger apps i
 Like TGen, MGen can create message flows built around Markov models.
 Unlike TGen, these models are expressly designed with user activity in messenger clients in mind.
 These messages can be relayed through a central server, which can handle group messages (i.e., traffic that originates from one sender, but gets forwarded to multiple recipients).
+Alternatively, a peer-to-peer client can be used.
+
 Clients also generate received receipts (small messages used to indicate to someone who sent a message that the recipient device has received it).
 These receipts can make up to half of all traffic.
 (Read receipts, however, are not supported.)
@@ -16,20 +18,20 @@ These receipts can make up to half of all traffic.
 MGen is written entirely in Rust, and is built like most pure Rust projects.
 If you have a working Rust install with Cargo, you can build the client and server with `cargo build`.
 Normal cargo features apply---e.g., use the `--release` flag to enable a larger set of compiler optimizations.
-The server can be built and executed with `cargo run --bin server`, and the client with `cargo run --bin client [config.toml]...`.
+The server can be built and executed with `cargo run --bin server`, the client (for use with client-server mode) with `cargo run --bin client [config.toml]...`, and the peer (for use with peer-to-peer mode) with `cargo run --bin peer [config.toml]...`.
 Alternatively, you can run the executables directly from the respective target directory (e.g., `./target/release/server`).
 
-# Client Configuration
+### Client Configuration
 
 Clients are designed to simulate one conversation per configuration file.
 Part of this configuration is the user sending messages in this conversation---similar to techniques used in TGen, a single client instance can simulate traffic of many individual users.
 The following example configuration with explanatory comments should be enough to understand almost everything you need:
 
 ```TOML
-# conversation.toml
+# client-conversation.toml
 
 # A name used for logs and to create unique circuits for each user on a client.
-sender = "Alice"
+user = "Alice"
 
 # A name used for logs and to create unique circuits for each conversation,
 # even when two chats share the same participants.
@@ -85,3 +87,32 @@ The particular distributions and parameters used in the example are for demonstr
 When sampling, values below zero are clamped to 0---e.g., the `i` distribution above will have an outsize probability of yielding 0.0 seconds, instead of redistributing weight.
 Any distribution in the [rand_distr](https://docs.rs/rand_distr/latest/rand_distr/index.html) crate would be simple to add support for.
 Distributions not in that crate can also be supported, but would require implementing.
+
+### Peer configuration
+
+Running in peer-to-peer mode is very similar to running a client.
+The only differences are that users and recipients consist of a name and address each, and there is no server.
+Here is an example peer conversation configuration (again, all values are for demonstration purposes only):
+
+```TOML
+# peer-conversation.toml
+
+user = {name = "Alice", address = "127.0.0.1:6397"}
+group = "group1"
+recipients = [{name = "Bob", address = "insert.ip.or.onion:6397"}]
+socks = "127.0.0.1:9050"
+bootstrap = 5.0
+retry = 5.0
+
+[distributions]
+s = 0.5
+r = 0.1
+m = {distribution = "Poisson", lambda = 1.0}
+i = {distribution = "Normal", mean = 30.0, std_dev = 100.0}
+w = {distribution = "Normal", mean = 30.0, std_dev = 30.0}
+a_s = {distribution = "Normal", mean = 10.0, std_dev = 5.0}
+a_r = {distribution = "Normal", mean = 10.0, std_dev = 5.0}
+```
+
+In the likely case that these peers are connecting via onion addresses, you must configure each torrc file to match with each peer configuration (in the above example, Alice's HiddenService lines in the torrc must have a `HiddenServicePort` line that forwards to `127.0.0.1:6397`, and Bob's torrc must have a `HiddenServicePort` line that listens on `6397`).
+Multiple users can share an onion address by using different ports (different cirtuits will be used).

+ 189 - 93
src/bin/client.rs

@@ -1,10 +1,14 @@
 // Code specific to the client in the client-server mode.
 
+use mgen::{MessageHeader, SerializedMessage};
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
 use std::env;
 use std::result::Result;
+use std::sync::Arc;
 use tokio::io::AsyncWriteExt;
+use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
+use tokio::sync::mpsc;
 use tokio::task;
 use tokio::time::Duration;
 use tokio_socks::tcp::Socks5Stream;
@@ -12,128 +16,220 @@ use tokio_socks::tcp::Socks5Stream;
 mod messenger;
 
 use crate::messenger::dists::{ConfigDistributions, Distributions};
-use crate::messenger::error::MessengerError;
+use crate::messenger::error::{FatalError, MessengerError};
 use crate::messenger::state::{manage_active_conversation, manage_idle_conversation, StateMachine};
 
-async fn manage_conversation(config: Config) -> Result<(), MessengerError> {
-    let mut rng = Xoshiro256PlusPlus::from_entropy();
-    let distributions: Distributions = config.distributions.try_into()?;
+/// Type for sending messages from the reader thread to the state thread.
+type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
+// implicit
+//type StateFromReader = mpsc::UnboundedReceiver<MessageHeader>;
+/// Type for getting messages from the state thread in the writer thread.
+type WriterFromState = mpsc::UnboundedReceiver<Arc<SerializedMessage>>;
+// implicit
+//type StateToWriter = mpsc::UnboundedSender<Arc<SerializedMessage>>;
+/// Type for sending the updated read half of the socket.
+type ReadSocketUpdaterIn = mpsc::UnboundedSender<OwnedReadHalf>;
+/// Type for getting the updated read half of the socket.
+type ReadSocketUpdaterOut = mpsc::UnboundedReceiver<OwnedReadHalf>;
+/// Type for sending the updated write half of the socket.
+type WriteSocketUpdaterIn = mpsc::UnboundedSender<OwnedWriteHalf>;
+/// Type for getting the updated write half of the socket.
+type WriteSocketUpdaterOut = mpsc::UnboundedReceiver<OwnedWriteHalf>;
+/// Type for sending errors to other threads.
+type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
+/// Type for getting errors from other threads.
+type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
+
+/// The thread responsible for getting incoming messages,
+/// checking for any network errors while doing so,
+/// and giving messages to the state thread.
+async fn reader(
+    message_channel: ReaderToState,
+    mut socket_updater: ReadSocketUpdaterOut,
+    error_channel: ErrorChannelIn,
+) {
+    loop {
+        let mut stream = socket_updater
+            .recv()
+            .await
+            .expect("Reader socket updater closed");
+        loop {
+            let (msg, _) = match mgen::get_message(&mut stream).await {
+                Ok(msg) => msg,
+                Err(e) => {
+                    error_channel.send(e.into()).expect("Error channel closed");
+                    break;
+                }
+            };
 
-    struct StrParams<'a> {
-        socks: &'a str,
-        server: &'a str,
-        sender: &'a str,
-        group: &'a str,
+            if msg.body != mgen::MessageBody::Receipt {
+                message_channel
+                    .send(msg)
+                    .expect("Reader message channel closed");
+            }
+        }
     }
-    let str_params = StrParams {
-        socks: &config.socks,
-        server: &config.server,
-        sender: &config.sender,
-        group: &config.group,
-    };
+}
 
-    let mut state_machine = StateMachine::start(distributions, &mut rng);
-    let recipients: Vec<&str> = config.recipients.iter().map(String::as_str).collect();
+/// The thread responsible for sending messages from the state thread,
+/// and checking for any network errors while doing so.
+async fn writer(
+    mut message_channel: WriterFromState,
+    mut socket_updater: WriteSocketUpdaterOut,
+    error_channel: ErrorChannelIn,
+) {
+    loop {
+        let mut stream = socket_updater
+            .recv()
+            .await
+            .expect("Writer socket updater closed");
+        loop {
+            let msg = message_channel
+                .recv()
+                .await
+                .expect("Writer message channel closed");
+            if let Err(e) = msg.write_all_to(&mut stream).await {
+                error_channel.send(e.into()).expect("Error channel closed");
+                break;
+            }
+        }
+    }
+}
 
-    async fn error_collector<'a>(
-        bootstrap: Option<Duration>,
-        str_params: &'a StrParams<'a>,
-        rng: &mut Xoshiro256PlusPlus,
-        mut state_machine: StateMachine,
-        recipients: Vec<&str>,
-    ) -> Result<(), (StateMachine, MessengerError)> {
-        let mut stream = match Socks5Stream::connect_with_password(
-            str_params.socks,
-            str_params.server,
-            str_params.sender,
-            str_params.group,
-        )
-        .await
-        {
+/// Parameters used in establishing the connection through the socks proxy.
+/// (Members may be useful elsewhere as well, but that's the primary purpose.)
+struct SocksParams {
+    socks: String,
+    server: String,
+    user: String,
+    group: String,
+}
+
+/// The thread responsible for (re-)establishing connections to the server,
+/// and determining how to handle errors this or other threads receive.
+async fn socket_updater(
+    str_params: SocksParams,
+    bootstrap: f64,
+    retry: f64,
+    mut error_channel: ErrorChannelOut,
+    reader_channel: ReadSocketUpdaterIn,
+    writer_channel: WriteSocketUpdaterIn,
+) -> FatalError {
+    let mut first_run: bool = true;
+    let retry = Duration::from_secs_f64(retry);
+    loop {
+        let socks_connection: Result<Socks5Stream<_>, MessengerError> =
+            Socks5Stream::connect_with_password(
+                str_params.socks.as_str(),
+                str_params.server.as_str(),
+                &str_params.user,
+                &str_params.group,
+            )
+            .await
+            .map_err(|e| e.into());
+        let mut stream = match socks_connection {
             Ok(stream) => stream,
-            Err(e) => return Err((state_machine, e.into())),
+            Err(MessengerError::Recoverable(_)) => {
+                tokio::time::sleep(retry).await;
+                continue;
+            }
+            Err(MessengerError::Fatal(e)) => return e,
         };
 
-        if let Err(e) = stream
-            .write_all(&mgen::serialize_str(str_params.sender))
+        if stream
+            .write_all(&mgen::serialize_str(&str_params.user))
             .await
+            .is_err()
         {
-            return Err((state_machine, e.into()));
+            continue;
         }
 
-        if let Some(bootstrap) = bootstrap {
-            tokio::time::sleep(bootstrap).await;
+        if first_run {
+            tokio::time::sleep(Duration::from_secs_f64(bootstrap)).await;
+            first_run = false;
         }
 
-        loop {
-            state_machine = match state_machine {
-                StateMachine::Idle(conversation) => {
-                    manage_idle_conversation(
-                        conversation,
-                        &mut stream,
-                        str_params.sender,
-                        recipients.clone(),
-                        rng,
-                    )
-                    .await?
-                }
-                StateMachine::Active(conversation) => {
-                    manage_active_conversation(
-                        conversation,
-                        &mut stream,
-                        str_params.sender,
-                        recipients.clone(),
-                        rng,
-                    )
-                    .await?
-                }
-            };
+        let (rd, wr) = stream.into_inner().into_split();
+        reader_channel.send(rd).expect("Reader channel closed");
+        writer_channel.send(wr).expect("Writer channel closed");
+
+        let res = error_channel.recv().await.expect("Error channel closed");
+        if let MessengerError::Fatal(e) = res {
+            return e;
         }
     }
+}
 
-    let retry = config.retry;
-    let retry = Duration::from_secs_f64(retry);
+/// The thread responsible for handling the conversation state
+/// (i.e., whether the user is active or idle, and when to send messages).
+/// Spawns all other threads for this conversation.
+async fn manage_conversation(config: Config) -> Result<(), MessengerError> {
+    let mut rng = Xoshiro256PlusPlus::from_entropy();
+    let distributions: Distributions = config.distributions.try_into()?;
 
-    match error_collector(
-        Some(Duration::from_secs_f64(config.bootstrap)),
-        &str_params,
-        &mut rng,
-        state_machine,
-        recipients.clone(),
-    )
-    .await
-    .expect_err("Inner loop returned Ok?")
-    {
-        (sm, MessengerError::Recoverable(_)) => {
-            state_machine = sm;
-            tokio::time::sleep(retry).await;
-        }
-        (_, e) => return Err(e),
+    let str_params = SocksParams {
+        socks: config.socks,
+        server: config.server,
+        user: config.user.clone(),
+        group: config.group.clone(),
     };
 
+    let mut state_machine = StateMachine::start(distributions, &mut rng);
+    let recipients: Vec<&str> = config.recipients.iter().map(String::as_str).collect();
+
+    let (reader_to_state, mut state_from_reader) = mpsc::unbounded_channel();
+    let (mut state_to_writer, writer_from_state) = mpsc::unbounded_channel();
+    let (read_socket_updater_in, read_socket_updater_out) = mpsc::unbounded_channel();
+    let (write_socket_updater_in, write_socket_updater_out) = mpsc::unbounded_channel();
+    let (errs_in, errs_out) = mpsc::unbounded_channel();
+    tokio::spawn(reader(
+        reader_to_state,
+        read_socket_updater_out,
+        errs_in.clone(),
+    ));
+    tokio::spawn(writer(writer_from_state, write_socket_updater_out, errs_in));
+    tokio::spawn(socket_updater(
+        str_params,
+        config.bootstrap,
+        config.retry,
+        errs_out,
+        read_socket_updater_in,
+        write_socket_updater_in,
+    ));
+
     loop {
-        match error_collector(
-            None,
-            &str_params,
-            &mut rng,
-            state_machine,
-            recipients.clone(),
-        )
-        .await
-        .expect_err("Inner loop returned Ok?")
-        {
-            (sm, MessengerError::Recoverable(_)) => {
-                state_machine = sm;
-                tokio::time::sleep(retry).await;
+        state_machine = match state_machine {
+            StateMachine::Idle(conversation) => {
+                manage_idle_conversation(
+                    conversation,
+                    &mut state_from_reader,
+                    &mut state_to_writer,
+                    &config.user,
+                    &config.group,
+                    recipients.clone(),
+                    &mut rng,
+                )
+                .await
+            }
+            StateMachine::Active(conversation) => {
+                manage_active_conversation(
+                    conversation,
+                    &mut state_from_reader,
+                    &mut state_to_writer,
+                    &config.user,
+                    &config.group,
+                    recipients.clone(),
+                    &mut rng,
+                )
+                .await
             }
-            (_, e) => return Err(e),
         };
     }
 }
 
 #[derive(Debug, Deserialize)]
 struct Config {
-    sender: String,
+    user: String,
     group: String,
     recipients: Vec<String>,
     socks: String,

+ 7 - 0
src/bin/messenger/dists.rs

@@ -134,6 +134,13 @@ pub enum DistParameterError {
     Exp(ExpError),
     Pareto(ParetoError),
 }
+impl std::fmt::Display for DistParameterError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
+impl std::error::Error for DistParameterError {}
 
 impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
     type Error = DistParameterError;

+ 23 - 2
src/bin/messenger/error.rs

@@ -2,6 +2,8 @@
 
 use crate::messenger::dists::DistParameterError;
 
+/// Errors encountered by the client.
+/// Note that I/O errors are Recoverable by default.
 #[derive(Debug)]
 pub enum MessengerError {
     Recoverable(RecoverableError),
@@ -14,8 +16,6 @@ pub enum RecoverableError {
     /// Recoverable errors from the socks connection.
     Socks(tokio_socks::Error),
     /// Network I/O errors.
-    // Note that all I/O handled by MessengerError should be recoverable;
-    // if you need fatal I/O errors, use a different error type.
     Io(std::io::Error),
 }
 
@@ -30,6 +30,8 @@ pub enum FatalError {
     Utf8Error(std::str::Utf8Error),
     /// A message failed to deserialize.
     MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
+    /// Fatal network I/O errors.
+    Io(std::io::Error),
 }
 
 impl std::fmt::Display for MessengerError {
@@ -38,7 +40,14 @@ impl std::fmt::Display for MessengerError {
     }
 }
 
+impl std::fmt::Display for FatalError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
 impl std::error::Error for MessengerError {}
+impl std::error::Error for FatalError {}
 
 impl From<mgen::Error> for MessengerError {
     fn from(e: mgen::Error) -> Self {
@@ -58,6 +67,12 @@ impl From<DistParameterError> for MessengerError {
     }
 }
 
+impl From<DistParameterError> for FatalError {
+    fn from(e: DistParameterError) -> Self {
+        Self::Parameter(e)
+    }
+}
+
 impl From<tokio_socks::Error> for MessengerError {
     fn from(e: tokio_socks::Error) -> Self {
         match e {
@@ -76,3 +91,9 @@ impl From<std::io::Error> for MessengerError {
         Self::Recoverable(RecoverableError::Io(e))
     }
 }
+
+impl From<std::io::Error> for FatalError {
+    fn from(e: std::io::Error) -> Self {
+        Self::Io(e)
+    }
+}

+ 8 - 1
src/bin/messenger/message.rs

@@ -4,21 +4,28 @@
 /// Construct and serialize a message from the sender to the recipients with the given number of blocks.
 pub fn construct_message(
     sender: String,
+    group: String,
     recipients: Vec<String>,
     blocks: u32,
 ) -> mgen::SerializedMessage {
     let size = std::cmp::max(blocks, 1) * mgen::PADDING_BLOCK_SIZE;
     let m = mgen::MessageHeader {
         sender,
+        group,
         recipients,
         body: mgen::MessageBody::Size(std::num::NonZeroU32::new(size).unwrap()),
     };
     m.serialize()
 }
 
-pub fn construct_receipt(sender: String, recipient: String) -> mgen::SerializedMessage {
+pub fn construct_receipt(
+    sender: String,
+    group: String,
+    recipient: String,
+) -> mgen::SerializedMessage {
     let m = mgen::MessageHeader {
         sender,
+        group,
         recipients: vec![recipient],
         body: mgen::MessageBody::Receipt,
     };

+ 123 - 108
src/bin/messenger/state.rs

@@ -2,15 +2,15 @@
 // This includes inducing transitions and actions taken during transitions,
 // so a lot of the messenger network code is here.
 
-use mgen::log;
+use mgen::{log, MessageHeader, SerializedMessage};
 use rand_distr::Distribution;
 use rand_xoshiro::Xoshiro256PlusPlus;
-use tokio::io::{AsyncReadExt, AsyncWriteExt};
-use tokio::net::TcpStream;
+use std::collections::HashMap;
+use std::sync::Arc;
+use tokio::sync::mpsc;
 use tokio::time::Instant;
 
 use crate::messenger::dists::Distributions;
-use crate::messenger::error::MessengerError;
 use crate::messenger::message::{construct_message, construct_receipt};
 
 /// All possible Conversation state machine states
@@ -147,48 +147,59 @@ impl Conversation<Active> {
         }
     }
 
-    async fn sleep(delay: Instant, wait: Instant) -> ActiveActions {
+    async fn sleep(delay: Instant, wait: Instant) -> ActiveGroupActions {
         if delay < wait {
             log!("delaying for {:?}", delay - Instant::now());
             tokio::time::sleep_until(delay).await;
-            ActiveActions::Send
+            ActiveGroupActions::Send
         } else {
             log!("waiting for {:?}", wait - Instant::now());
             tokio::time::sleep_until(wait).await;
-            ActiveActions::Idle
+            ActiveGroupActions::Idle
         }
     }
 }
 
-/// Attempt to read some portion of the header size from the stream.
-/// The number of bytes written is returned in the Ok case.
-/// The caller must read any remaining bytes less than 4.
-// N.B.: This must be written cancellation safe!
-// https://docs.rs/tokio/1.26.0/tokio/macro.select.html#cancellation-safety
-async fn read_header_size(
-    stream: &mut TcpStream,
-    header_size: &mut [u8; 4],
-) -> Result<usize, MessengerError> {
-    let read = stream.read(header_size).await?;
-
-    if read == 0 {
-        Err(tokio::io::Error::new(
-            tokio::io::ErrorKind::WriteZero,
-            "failed to read any bytes from message with bytes remaining",
-        )
-        .into())
-    } else {
-        Ok(read)
+/// Type for getting messages from the reader thread in the state thread.
+type StateFromReader = mpsc::UnboundedReceiver<MessageHeader>;
+/// Type for sending messages from the state thread to the writer thread.
+type StateToWriter = mpsc::UnboundedSender<Arc<SerializedMessage>>;
+
+pub trait StreamMap<'a, I: Iterator<Item = &'a mut StateToWriter>> {
+    fn channel_for(&self, name: &str) -> &StateToWriter;
+    fn values(&'a mut self) -> I;
+}
+
+impl<'a> StreamMap<'a, std::collections::hash_map::ValuesMut<'a, String, StateToWriter>>
+    for HashMap<String, StateToWriter>
+{
+    fn channel_for(&self, name: &str) -> &StateToWriter {
+        &self[name]
+    }
+
+    fn values(&'a mut self) -> std::collections::hash_map::ValuesMut<'a, String, StateToWriter> {
+        self.values_mut()
     }
 }
 
-async fn send_action<T: State>(
+impl<'a> StreamMap<'a, std::iter::Once<&'a mut StateToWriter>> for StateToWriter {
+    fn channel_for(&self, _name: &str) -> &StateToWriter {
+        self
+    }
+
+    fn values(&'a mut self) -> std::iter::Once<&'a mut StateToWriter> {
+        std::iter::once(self)
+    }
+}
+
+async fn send_action<'a, T: State, I: Iterator<Item = &'a mut StateToWriter>>(
     conversation: Conversation<T>,
-    stream: &mut TcpStream,
+    streams: I,
     our_id: &str,
+    group: &str,
     recipients: Vec<&str>,
     rng: &mut Xoshiro256PlusPlus,
-) -> Result<StateMachine, (StateMachine, MessengerError)> {
+) -> StateMachine {
     let size = conversation.dists.m.sample(rng);
     log!(
         "sending message from {} to {:?} of size {}",
@@ -196,42 +207,35 @@ async fn send_action<T: State>(
         recipients,
         size
     );
-    let m = construct_message(
+    let m = Arc::new(construct_message(
         our_id.to_string(),
+        group.to_string(),
         recipients.iter().map(|s| s.to_string()).collect(),
         size,
-    );
+    ));
 
-    if let Err(e) = m.write_all_to(stream).await {
-        return Err((T::to_machine(conversation), e.into()));
-    }
-    if let Err(e) = stream.flush().await {
-        return Err((T::to_machine(conversation), e.into()));
+    for stream in streams {
+        stream
+            .send(m.clone())
+            .expect("Internal stream closed with messages still being sent");
     }
 
-    Ok(T::sent(conversation, rng))
+    T::sent(conversation, rng)
 }
 
-async fn receive_action<T: State>(
-    n: usize,
-    mut header_size: [u8; 4],
+async fn receive_action<
+    'a,
+    T: State,
+    I: std::iter::Iterator<Item = &'a mut StateToWriter>,
+    M: StreamMap<'a, I>,
+>(
+    msg: MessageHeader,
     conversation: Conversation<T>,
-    stream: &mut TcpStream,
+    stream_map: &mut M,
     our_id: &str,
+    group: &str,
     rng: &mut Xoshiro256PlusPlus,
-) -> Result<StateMachine, (StateMachine, MessengerError)> {
-    if n < 4 {
-        // we didn't get the whole size, but we can use read_exact now
-        if let Err(e) = stream.read_exact(&mut header_size[n..]).await {
-            return Err((T::to_machine(conversation), e.into()));
-        }
-    }
-
-    let msg = match mgen::get_message_with_header_size(stream, header_size).await {
-        Ok((msg, _)) => msg,
-        Err(e) => return Err((T::to_machine(conversation), e.into())),
-    };
-
+) -> StateMachine {
     match msg.body {
         mgen::MessageBody::Size(size) => {
             log!(
@@ -240,94 +244,105 @@ async fn receive_action<T: State>(
                 msg.sender,
                 size
             );
-            let m = construct_receipt(our_id.to_string(), msg.sender);
-            if let Err(e) = m.write_all_to(stream).await {
-                return Err((T::to_machine(conversation), e.into()));
-            }
-            if let Err(e) = stream.flush().await {
-                return Err((T::to_machine(conversation), e.into()));
-            }
-            Ok(T::received(conversation, rng))
+            let stream = stream_map.channel_for(&msg.sender);
+            let m = construct_receipt(our_id.to_string(), group.to_string(), msg.sender);
+            stream
+                .send(Arc::new(m))
+                .expect("channel from receive_action to sender closed");
+            T::received(conversation, rng)
         }
-        mgen::MessageBody::Receipt => Ok(T::to_machine(conversation)),
+        mgen::MessageBody::Receipt => T::to_machine(conversation),
     }
 }
 
-enum IdleActions {
+enum IdleGroupActions {
     Send,
-    Receive(usize),
+    Receive(MessageHeader),
 }
 
-pub async fn manage_idle_conversation(
+/// Handle a state transition from Idle, including I/O, for a multi-connection conversation.
+/// Used for Idle group p2p conversations.
+pub async fn manage_idle_conversation<
+    'a,
+    I: std::iter::Iterator<Item = &'a mut StateToWriter>,
+    M: StreamMap<'a, I> + 'a,
+>(
     conversation: Conversation<Idle>,
-    stream: &mut TcpStream,
+    inbound: &mut StateFromReader,
+    stream_map: &'a mut M,
     our_id: &str,
+    group: &str,
     recipients: Vec<&str>,
     rng: &mut Xoshiro256PlusPlus,
-) -> Result<StateMachine, (StateMachine, MessengerError)> {
+) -> StateMachine {
     log!("delaying for {:?}", conversation.delay - Instant::now());
-    let mut header_size = [0; 4];
     let action = tokio::select! {
-        () = tokio::time::sleep_until(conversation.delay) => {
-            Ok(IdleActions::Send)
-        }
+        () = tokio::time::sleep_until(conversation.delay) => IdleGroupActions::Send,
 
-        res = read_header_size(stream, &mut header_size) => {
-            match res {
-                Ok(n) => Ok(IdleActions::Receive(n)),
-                Err(e) => Err(e),
-            }
-        }
-    };
-    let action = match action {
-        Ok(action) => action,
-        Err(e) => return Err((StateMachine::Idle(conversation), e)),
+        res = inbound.recv() =>
+            IdleGroupActions::Receive(res.expect("inbound channel closed")),
     };
 
     match action {
-        IdleActions::Send => send_action(conversation, stream, our_id, recipients, rng).await,
-        IdleActions::Receive(n) => {
-            receive_action(n, header_size, conversation, stream, our_id, rng).await
+        IdleGroupActions::Send => {
+            send_action(
+                conversation,
+                stream_map.values(),
+                our_id,
+                group,
+                recipients,
+                rng,
+            )
+            .await
+        }
+        IdleGroupActions::Receive(msg) => {
+            receive_action(msg, conversation, stream_map, our_id, group, rng).await
         }
     }
 }
 
-enum ActiveActions {
+enum ActiveGroupActions {
     Send,
-    Receive(usize),
+    Receive(MessageHeader),
     Idle,
 }
 
-pub async fn manage_active_conversation(
+/// Handle a state transition from Active.
+pub async fn manage_active_conversation<
+    'a,
+    I: std::iter::Iterator<Item = &'a mut StateToWriter>,
+    M: StreamMap<'a, I> + 'a,
+>(
     conversation: Conversation<Active>,
-    stream: &mut TcpStream,
+    inbound: &mut StateFromReader,
+    stream_map: &'a mut M,
     our_id: &str,
+    group: &str,
     recipients: Vec<&str>,
     rng: &mut Xoshiro256PlusPlus,
-) -> Result<StateMachine, (StateMachine, MessengerError)> {
-    let mut header_size = [0; 4];
+) -> StateMachine {
     let action = tokio::select! {
-        action = Conversation::<Active>::sleep(conversation.delay, conversation.state.wait) => {
-            Ok(action)
-        }
+        action = Conversation::<Active>::sleep(conversation.delay, conversation.state.wait) => action,
 
-        res = read_header_size(stream, &mut header_size) => {
-            match res {
-                Ok(n) => Ok(ActiveActions::Receive(n)),
-                Err(e) => Err(e),
-            }
-        }
-    };
-    let action = match action {
-        Ok(action) => action,
-        Err(e) => return Err((StateMachine::Active(conversation), e)),
+        res = inbound.recv() =>
+            ActiveGroupActions::Receive(res.expect("inbound channel closed")),
     };
 
     match action {
-        ActiveActions::Send => send_action(conversation, stream, our_id, recipients, rng).await,
-        ActiveActions::Receive(n) => {
-            receive_action(n, header_size, conversation, stream, our_id, rng).await
+        ActiveGroupActions::Send => {
+            send_action(
+                conversation,
+                stream_map.values(),
+                our_id,
+                group,
+                recipients,
+                rng,
+            )
+            .await
+        }
+        ActiveGroupActions::Receive(msg) => {
+            receive_action(msg, conversation, stream_map, our_id, group, rng).await
         }
-        ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
+        ActiveGroupActions::Idle => StateMachine::Idle(conversation.waited(rng)),
     }
 }

+ 3 - 12
src/bin/server.rs

@@ -1,9 +1,9 @@
-use mgen::{log, SerializedMessage};
+use mgen::{log, parse_identifier, SerializedMessage};
 use std::collections::HashMap;
 use std::error::Error;
 use std::result::Result;
 use std::sync::{Arc, Mutex};
-use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::io::AsyncWriteExt;
 use tokio::net::{
     tcp::{OwnedReadHalf, OwnedWriteHalf},
     TcpListener,
@@ -24,6 +24,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
         ID,
         mpsc::UnboundedSender<Arc<SerializedMessage>>,
     >::new()));
+    // FIXME: this should probably be a Notify + Mutex
     let mut writer_db = HashMap::<ID, mpsc::Sender<(OwnedWriteHalf, watch::Sender<bool>)>>::new();
 
     loop {
@@ -92,16 +93,6 @@ fn spawn_message_receiver(
     });
 }
 
-/// Parse the identifier from the start of the TcpStream.
-async fn parse_identifier(stream: &mut OwnedReadHalf) -> Result<ID, Box<dyn Error>> {
-    // this should maybe be buffered
-    let strlen = stream.read_u32().await?;
-    let mut buf = vec![0u8; strlen as usize];
-    stream.read_exact(&mut buf).await?;
-    let s = std::str::from_utf8(&buf)?;
-    Ok(s.to_string())
-}
-
 /// Loop for receiving messages on the socket, figuring out who to deliver them to,
 /// and forwarding them locally to the respective channel.
 async fn get_messages(

+ 22 - 2
src/lib.rs

@@ -70,16 +70,18 @@ impl MessageBody {
 #[derive(Debug)]
 pub struct MessageHeader {
     pub sender: String,
+    pub group: String,
     pub recipients: Vec<String>,
     pub body: MessageBody,
 }
 
 impl MessageHeader {
-    /// Generate a consise serialization of the Message.
+    /// Generate a concise serialization of the Message.
     pub fn serialize(&self) -> SerializedMessage {
         // serialized message header: {
         //   header_len: u32,
         //   sender: {u32, utf-8}
+        //   group: {u32, utf-8}
         //   num_recipients: u32,
         //   recipients: [{u32, utf-8}],
         //   body_type: MessageBody (i.e., u32)
@@ -91,8 +93,9 @@ impl MessageHeader {
             MessageBody::Size(s) => s.get(),
         };
 
-        let header_len = (1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
+        let header_len = (1 + 1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
             + self.sender.len()
+            + self.group.len()
             + self.recipients.iter().map(String::len).sum::<usize>();
 
         let mut header: Vec<u8> = Vec::with_capacity(header_len);
@@ -103,6 +106,7 @@ impl MessageHeader {
         header.extend(header_len.to_be_bytes());
 
         serialize_str_to(&self.sender, &mut header);
+        serialize_str_to(&self.group, &mut header);
 
         header.extend(num_recipients.to_be_bytes());
         for recipient in self.recipients.iter() {
@@ -124,6 +128,9 @@ impl MessageHeader {
         let (sender, buf) = deserialize_str(buf)?;
         let sender = sender.to_string();
 
+        let (group, buf) = deserialize_str(buf)?;
+        let group = group.to_string();
+
         let (num_recipients, buf) = deserialize_u32(buf)?;
         debug_assert!(num_recipients != 0);
 
@@ -144,12 +151,25 @@ impl MessageHeader {
         };
         Ok(Self {
             sender,
+            group,
             recipients,
             body,
         })
     }
 }
 
+/// Parse the identifier from the start of the TcpStream.
+pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
+    stream: &mut T,
+) -> Result<String, Error> {
+    // this should maybe be buffered
+    let strlen = stream.read_u32().await?;
+    let mut buf = vec![0u8; strlen as usize];
+    stream.read_exact(&mut buf).await?;
+    let s = std::str::from_utf8(&buf)?;
+    Ok(s.to_string())
+}
+
 /// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer
 pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,