Explorar o código

inital commit

Justin Tracey hai 1 ano
achega
4212301fce
Modificáronse 5 ficheiros con 1097 adicións e 0 borrados
  1. 15 0
      Cargo.toml
  2. 75 0
      README.md
  3. 561 0
      src/bin/client.rs
  4. 181 0
      src/bin/server.rs
  5. 265 0
      src/lib.rs

+ 15 - 0
Cargo.toml

@@ -0,0 +1,15 @@
+[package]
+name = "mgen"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+chrono = "0.4.24"
+enum_dispatch = "0.3.11"
+rand = "0.8.5"
+rand_distr = { version = "0.4.3", features = ["serde1"] }
+rand_xoshiro = "0.6.0"
+serde = { version = "1.0.158", features = ["derive"] }
+tokio = { version = "1", features = ["full"] }
+tokio-socks = "0.5.1"
+toml = "0.7.3"

+ 75 - 0
README.md

@@ -0,0 +1,75 @@
+## MGen
+
+MGen is a client, server, and library for generating simulated messenger traffic.
+It is designed for use analogous to (and likely in conjunction with) [TGen](https://github.com/shadow/tgen), but for simulating traffic generated from communications in messenger apps, such as Signal or WhatsApp, rather than web traffic or file downloads.
+Notably, this allows for studying network traffic properties of messenger apps in [Shadow](https://github.com/shadow/shadow).
+
+Like TGen, MGen can create message flows built around Markov models.
+Unlike TGen, these models are expressly designed with user activity in messenger clients in mind.
+These messages can be relayed through a central server, which can handle group messages (i.e., traffic that originates from one sender, but gets forwarded to multiple recipients).
+Clients also generate received receipts (small messages used to indicate to someone who sent a message that the recipient device has received it).
+These receipts can make up to half of all traffic.
+(Read receipts, however, are not supported.)
+
+## Usage
+
+MGen is written entirely in Rust, and is built like most pure Rust projects.
+If you have a working Rust install with Cargo, you can build the client and server with `cargo build`.
+Normal cargo features apply---e.g., use the `--release` flag to enable a larger set of compiler optimizations.
+The server can be built and executed with `cargo run --bin server`, and the client with `cargo run --bin client [config.toml]...`.
+Alternatively, you can run the executables directly from the respective target directory (e.g., `./target/release/server`).
+
+# Client Configuration
+
+Clients are designed to simulate one conversation per configuration file.
+Part of this configuration is the user sending messages in this conversation---similar to techniques used in TGen, a single client instance can simulate traffic of many individual users.
+The following example configuration with explanatory comments should be enough to understand almost everything you need:
+
+```TOML
+# conversation.toml
+
+# A name used for logs and to create unique circuits for each user on a client.
+sender = "Alice"
+
+# A name used for logs and to create unique circuits for each conversation,
+# even when two chats share the same participants.
+group = "group1"
+
+# The list of participants, except the sender.
+recipients = ["Bob", "Carol", "Dave"]
+
+# The <ip>:<port> of the socks5 proxy to connect through.
+socks = "127.0.0.1:9050"
+
+# The <address>:<port> of the message server, where <address> is an IP or onion address.
+server = "insert.ip.or.onion:6397"
+
+
+# Parameters for distributions used by the Markov model.
+[distributions]
+
+# Probabilities of Idle to Active transition after sending/receiving messages.
+s = 0.5
+r = 0.1
+
+# Distribution I, the amount of time Idle before sending a message.
+i = {distribution = "Normal", mean = 30.0, std_dev = 100.0}
+
+# Distribution W, the amount of time Active without sending or receiving
+# messages to transition to Idle.
+w = {distribution = "Uniform", low = 0.0, high = 90.0}
+
+# Distribution A_{s/r}, the amount of time Active since last sent/received
+# message until the client sends a message.
+a_s = {distribution = "Exp", lambda = 2.0}
+a_r = {distribution = "Pareto", scale = 1.0, shape = 3.0}
+
+```
+
+The client currently supports five probability distributions: Normal and LogNormal, Uniform, Exp(onential), and Pareto.
+The parameter names can be found in the example above.
+The distributions are sampled to return a double-precision floating point number of seconds.
+The particular distributions and parameters used in the example are for demonstration purposes only, they have no relationship to empirical conversation behaviors.
+When sampling, values below zero are clamped to 0---e.g., the `i` distribution above will have an outsize probability of yielding 0.0 seconds, instead of redistributing weight.
+Any distribution in the [rand_distr](https://docs.rs/rand_distr/latest/rand_distr/index.html) crate would be simple to add support for.
+Distributions not in that crate can also be supported, but would require implementing.

+ 561 - 0
src/bin/client.rs

@@ -0,0 +1,561 @@
+use enum_dispatch::enum_dispatch;
+use mgen::{log, SerializedMessage};
+use rand_distr::{
+    Bernoulli, BernoulliError, Distribution, Exp, ExpError, LogNormal, Normal, NormalError, Pareto,
+    ParetoError, Uniform,
+};
+use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
+use serde::Deserialize;
+use std::env;
+use std::num::NonZeroU32;
+use std::result::Result;
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::net::TcpStream;
+use tokio::task;
+use tokio::time::{Duration, Instant};
+
+#[derive(Debug)]
+enum ClientError {
+    // errors from the library
+    Mgen(mgen::Error),
+    // errors from parsing the conversation files
+    Parameter(DistParameterError),
+    // errors from the socks connection
+    Socks(tokio_socks::Error),
+    // general I/O errors in this file
+    Io(std::io::Error),
+}
+
+impl std::fmt::Display for ClientError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
+impl std::error::Error for ClientError {}
+
+impl From<mgen::Error> for ClientError {
+    fn from(e: mgen::Error) -> Self {
+        Self::Mgen(e)
+    }
+}
+
+impl From<DistParameterError> for ClientError {
+    fn from(e: DistParameterError) -> Self {
+        Self::Parameter(e)
+    }
+}
+
+impl From<tokio_socks::Error> for ClientError {
+    fn from(e: tokio_socks::Error) -> Self {
+        Self::Socks(e)
+    }
+}
+
+impl From<std::io::Error> for ClientError {
+    fn from(e: std::io::Error) -> Self {
+        Self::Io(e)
+    }
+}
+
+/// All possible Conversation state machine states
+enum StateMachine {
+    Idle(Conversation<Idle>),
+    Active(Conversation<Active>),
+}
+
+/// The state machine representing a conversation state and its transitions.
+struct Conversation<S: State> {
+    dists: Distributions,
+    delay: Instant,
+    state: S,
+}
+
+#[derive(Debug)]
+#[enum_dispatch(Distribution)]
+/// The set of Distributions we currently support.
+/// To modify the code to add support for more, one approach is to first add them here,
+/// then fix all the compiler errors that arise as a result.
+enum SupportedDistribution {
+    Normal(Normal<f64>),
+    LogNormal(LogNormal<f64>),
+    Uniform(Uniform<f64>),
+    Exp(Exp<f64>),
+    Pareto(Pareto<f64>),
+}
+
+/// The set of distributions necessary to represent the actions of the state machine.
+#[derive(Debug)]
+struct Distributions {
+    i: SupportedDistribution,
+    w: SupportedDistribution,
+    a_s: SupportedDistribution,
+    a_r: SupportedDistribution,
+    s: Bernoulli,
+    r: Bernoulli,
+}
+
+trait State {}
+
+struct Idle {}
+struct Active {
+    wait: Instant,
+}
+
+impl State for Idle {}
+impl State for Active {}
+
+impl Conversation<Idle> {
+    fn start(dists: Distributions, rng: &mut Xoshiro256PlusPlus) -> Self {
+        let delay = Instant::now() + dists.i.sample_secs(rng);
+        log!("[start]");
+        Self {
+            dists,
+            delay,
+            state: Idle {},
+        }
+    }
+
+    fn sent(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
+        if self.dists.s.sample(rng) {
+            log!("Idle: [sent] tranisition to [Active]");
+            let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
+            let wait = Instant::now() + self.dists.w.sample_secs(rng);
+            StateMachine::Active({
+                Conversation::<Active> {
+                    dists: self.dists,
+                    delay,
+                    state: Active { wait },
+                }
+            })
+        } else {
+            log!("Idle: [sent] tranisition to [Idle]");
+            let delay = Instant::now() + self.dists.i.sample_secs(rng);
+            StateMachine::Idle({
+                Conversation::<Idle> {
+                    dists: self.dists,
+                    delay,
+                    state: Idle {},
+                }
+            })
+        }
+    }
+
+    fn received(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
+        if self.dists.r.sample(rng) {
+            log!("Idle: [recv'd] tranisition to [Active]");
+            let wait = Instant::now() + self.dists.w.sample_secs(rng);
+            let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
+            StateMachine::Active({
+                Conversation::<Active> {
+                    dists: self.dists,
+                    delay,
+                    state: Active { wait },
+                }
+            })
+        } else {
+            log!("Idle: [recv'd] tranisition to [Idle]");
+            StateMachine::Idle(self)
+        }
+    }
+}
+
+impl Conversation<Active> {
+    fn sent(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
+        log!("Active: [sent] transition to [Active]");
+        let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
+        Conversation::<Active> {
+            dists: self.dists,
+            delay,
+            state: self.state,
+        }
+    }
+
+    fn received(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
+        log!("Active: [recv'd] transition to [Active]");
+        let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
+        Conversation::<Active> {
+            dists: self.dists,
+            delay,
+            state: self.state,
+        }
+    }
+
+    fn waited(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Idle> {
+        log!("Active: [waited] tranision to [Idle]");
+        let delay = Instant::now() + self.dists.i.sample_secs(rng);
+        Conversation::<Idle> {
+            dists: self.dists,
+            delay,
+            state: Idle {},
+        }
+    }
+
+    async fn sleep(delay: Instant, wait: Instant) -> ActiveActions {
+        if delay < wait {
+            log!("delaying for {:?}", delay - Instant::now());
+            tokio::time::sleep_until(delay).await;
+            ActiveActions::Send
+        } else {
+            log!("waiting for {:?}", wait - Instant::now());
+            tokio::time::sleep_until(wait).await;
+            ActiveActions::Idle
+        }
+    }
+}
+
+/// Attempt to read some portion of the size of the reast of the header from the stream.
+/// The number of bytes written is returned in the Ok case.
+/// The caller must read any remaining bytes less than 4.
+// N.B.: This must be written cancellation safe!
+// https://docs.rs/tokio/1.26.0/tokio/macro.select.html#cancellation-safety
+async fn read_header_size(
+    stream: &mut TcpStream,
+    header_size: &mut [u8; 4],
+) -> Result<usize, ClientError> {
+    let read = stream.read(header_size).await?;
+
+    if read == 0 {
+        Err(tokio::io::Error::new(
+            tokio::io::ErrorKind::WriteZero,
+            "failed to read any bytes from message with bytes remaining",
+        )
+        .into())
+    } else {
+        Ok(read)
+    }
+}
+
+enum IdleActions {
+    Send,
+    Receive(usize),
+}
+
+async fn manage_idle_conversation(
+    conversation: Conversation<Idle>,
+    stream: &mut TcpStream,
+    our_id: &str,
+    recipients: Vec<&str>,
+    rng: &mut Xoshiro256PlusPlus,
+) -> Result<StateMachine, ClientError> {
+    log!("delaying for {:?}", conversation.delay - Instant::now());
+    let mut header_size = [0; 4];
+    let action = tokio::select! {
+        () = tokio::time::sleep_until(conversation.delay) => {
+            Ok(IdleActions::Send)
+        }
+
+        res = read_header_size(stream, &mut header_size) => {
+            match res {
+                Ok(n) => Ok(IdleActions::Receive(n)),
+                Err(e) => Err(e),
+            }
+        }
+    }?;
+
+    match action {
+        IdleActions::Send => {
+            log!("sending message from {} to {:?}", our_id, recipients);
+            let m = construct_message(
+                our_id.to_string(),
+                recipients.iter().map(|s| s.to_string()).collect(),
+            );
+            m.write_all_to(stream).await?;
+            stream.flush().await?;
+            Ok(conversation.sent(rng))
+        }
+        IdleActions::Receive(n) => {
+            if n < 4 {
+                // we didn't get the whole size, but we can use read_exact now
+                stream.read_exact(&mut header_size[n..]).await?;
+            }
+            let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
+            if msg.body != mgen::MessageBody::Receipt {
+                log!("{:?} got message from {}", msg.recipients, msg.sender);
+                let m = construct_receipt(our_id.to_string(), msg.sender);
+                m.write_all_to(stream).await?;
+                stream.flush().await?;
+                Ok(conversation.received(rng))
+            } else {
+                Ok(StateMachine::Idle(conversation))
+            }
+        }
+    }
+}
+
+enum ActiveActions {
+    Send,
+    Receive(usize),
+    Idle,
+}
+
+async fn manage_active_conversation(
+    conversation: Conversation<Active>,
+    stream: &mut TcpStream,
+    our_id: &str,
+    recipients: Vec<&str>,
+    rng: &mut Xoshiro256PlusPlus,
+) -> Result<StateMachine, ClientError> {
+    let mut header_size = [0; 4];
+    let action = tokio::select! {
+        action = Conversation::<Active>::sleep(conversation.delay, conversation.state.wait) => {
+            Ok(action)
+        }
+
+        res = read_header_size(stream, &mut header_size) => {
+            match res {
+                Ok(n) => Ok(ActiveActions::Receive(n)),
+                Err(e) => Err(e),
+            }
+        }
+    }?;
+
+    match action {
+        ActiveActions::Send => {
+            log!("sending message from {} to {:?}", our_id, recipients);
+            let m = construct_message(
+                our_id.to_string(),
+                recipients.iter().map(|s| s.to_string()).collect(),
+            );
+            m.write_all_to(stream).await?;
+            stream.flush().await?;
+            Ok(StateMachine::Active(conversation.sent(rng)))
+        }
+        ActiveActions::Receive(n) => {
+            if n < 4 {
+                // we didn't get the whole size, but we can use read_exact now
+                stream.read_exact(&mut header_size[n..]).await?;
+            }
+            let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
+            if msg.body != mgen::MessageBody::Receipt {
+                log!("{:?} got message from {}", msg.recipients, msg.sender);
+                let m = construct_receipt(our_id.to_string(), msg.sender);
+                m.write_all_to(stream).await?;
+                stream.flush().await?;
+                Ok(StateMachine::Active(conversation.received(rng)))
+            } else {
+                Ok(StateMachine::Active(conversation))
+            }
+        }
+        ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
+    }
+}
+
+async fn manage_conversation(config: Config) -> Result<(), ClientError> {
+    let mut rng = Xoshiro256PlusPlus::from_entropy();
+    let distributions: Distributions = config.distributions.try_into()?;
+    let mut state_machine =
+        StateMachine::Idle(Conversation::<Idle>::start(distributions, &mut rng));
+    let recipients: Vec<&str> = config.recipients.iter().map(String::as_str).collect();
+
+    let mut stream = tokio_socks::tcp::Socks5Stream::connect_with_password(
+        config.socks.as_str(),
+        config.server.as_str(),
+        &config.sender,
+        &config.group,
+    )
+    .await?;
+    stream
+        .write_all(&mgen::serialize_str(&config.sender))
+        .await?;
+
+    tokio::time::sleep(Duration::from_secs(5)).await;
+    loop {
+        state_machine = match state_machine {
+            StateMachine::Idle(conversation) => {
+                manage_idle_conversation(
+                    conversation,
+                    &mut stream,
+                    &config.sender,
+                    recipients.clone(),
+                    &mut rng,
+                )
+                .await?
+            }
+            StateMachine::Active(conversation) => {
+                manage_active_conversation(
+                    conversation,
+                    &mut stream,
+                    &config.sender,
+                    recipients.clone(),
+                    &mut rng,
+                )
+                .await?
+            }
+        };
+    }
+}
+
+/// A wrapper for the Distribution trait that specifies the RNG to allow (fake) dynamic dispatch.
+#[enum_dispatch(SupportedDistribution)]
+trait Dist {
+    fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64;
+}
+
+/*
+// This would be easier, but we run into https://github.com/rust-lang/rust/issues/48869
+impl<T, D> Dist<T> for D
+where
+    D: Distribution<T> + Send + Sync,
+{
+    fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> T {
+        self.sample(rng)
+    }
+}
+ */
+
+macro_rules! dist_impl {
+    ($dist:ident) => {
+        impl Dist for $dist<f64> {
+            fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
+                Distribution::sample(self, rng)
+            }
+        }
+    };
+}
+
+dist_impl!(Exp);
+dist_impl!(Normal);
+dist_impl!(LogNormal);
+dist_impl!(Pareto);
+dist_impl!(Uniform);
+
+impl SupportedDistribution {
+    // FIXME: there's probably a better way to do this integrated with the crate
+    fn clamped_sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
+        let sample = self.sample(rng);
+        if sample >= 0.0 {
+            sample
+        } else {
+            0.0
+        }
+    }
+
+    fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
+        Duration::from_secs_f64(self.clamped_sample(rng))
+    }
+}
+
+fn construct_message(sender: String, recipients: Vec<String>) -> SerializedMessage {
+    // FIXME: sample size from distribution
+    let m = mgen::MessageHeader {
+        sender,
+        recipients,
+        body: mgen::MessageBody::Size(NonZeroU32::new(1024).unwrap()),
+    };
+    m.serialize()
+}
+
+fn construct_receipt(sender: String, recipient: String) -> SerializedMessage {
+    let m = mgen::MessageHeader {
+        sender,
+        recipients: vec![recipient],
+        body: mgen::MessageBody::Receipt,
+    };
+    m.serialize()
+}
+
+/// The same as Distributions, but designed for easier deserialization.
+#[derive(Debug, Deserialize)]
+struct ConfigDistributions {
+    i: ConfigSupportedDistribution,
+    w: ConfigSupportedDistribution,
+    a_s: ConfigSupportedDistribution,
+    a_r: ConfigSupportedDistribution,
+    s: f64,
+    r: f64,
+}
+
+/// The same as SupportedDistributions, but designed for easier deserialization.
+#[derive(Debug, Deserialize)]
+#[serde(tag = "distribution")]
+enum ConfigSupportedDistribution {
+    Normal { mean: f64, std_dev: f64 },
+    LogNormal { mean: f64, std_dev: f64 },
+    Uniform { low: f64, high: f64 },
+    Exp { lambda: f64 },
+    Pareto { scale: f64, shape: f64 },
+}
+
+#[derive(Debug)]
+enum DistParameterError {
+    Bernoulli(BernoulliError),
+    Normal(NormalError),
+    LogNormal(NormalError),
+    Uniform, // Uniform::new doesn't return an error, it just panics
+    Exp(ExpError),
+    Pareto(ParetoError),
+}
+
+impl TryFrom<ConfigSupportedDistribution> for SupportedDistribution {
+    type Error = DistParameterError;
+
+    fn try_from(dist: ConfigSupportedDistribution) -> Result<Self, DistParameterError> {
+        let dist = match dist {
+            ConfigSupportedDistribution::Normal { mean, std_dev } => SupportedDistribution::Normal(
+                Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
+            ),
+            ConfigSupportedDistribution::LogNormal { mean, std_dev } => {
+                SupportedDistribution::LogNormal(
+                    LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
+                )
+            }
+            ConfigSupportedDistribution::Uniform { low, high } => {
+                if low >= high {
+                    return Err(DistParameterError::Uniform);
+                }
+                SupportedDistribution::Uniform(Uniform::new(low, high))
+            }
+            ConfigSupportedDistribution::Exp { lambda } => {
+                SupportedDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
+            }
+            ConfigSupportedDistribution::Pareto { scale, shape } => SupportedDistribution::Pareto(
+                Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
+            ),
+        };
+        Ok(dist)
+    }
+}
+
+impl TryFrom<ConfigDistributions> for Distributions {
+    type Error = DistParameterError;
+
+    fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
+        Ok(Distributions {
+            i: config.i.try_into()?,
+            w: config.w.try_into()?,
+            a_s: config.a_s.try_into()?,
+            a_r: config.a_r.try_into()?,
+            s: Bernoulli::new(config.s).map_err(DistParameterError::Bernoulli)?,
+            r: Bernoulli::new(config.r).map_err(DistParameterError::Bernoulli)?,
+        })
+    }
+}
+
+#[derive(Debug, Deserialize)]
+struct Config {
+    sender: String,
+    group: String,
+    recipients: Vec<String>,
+    socks: String,
+    server: String,
+    distributions: ConfigDistributions,
+}
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn std::error::Error>> {
+    let mut args = env::args();
+    let _ = args.next();
+    let mut handles = vec![];
+    for config_file in args {
+        let toml_s = std::fs::read_to_string(config_file)?;
+        let config = toml::from_str(&toml_s)?;
+        let handle: task::JoinHandle<Result<(), ClientError>> =
+            tokio::spawn(manage_conversation(config));
+        handles.push(handle);
+    }
+    for handle in handles {
+        handle.await??;
+    }
+    Ok(())
+}

+ 181 - 0
src/bin/server.rs

@@ -0,0 +1,181 @@
+use mgen::{log, SerializedMessage};
+use std::collections::HashMap;
+use std::error::Error;
+use std::result::Result;
+use std::sync::{Arc, Mutex};
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::net::{
+    tcp::{OwnedReadHalf, OwnedWriteHalf},
+    TcpListener,
+};
+use tokio::sync::{mpsc, watch};
+
+// FIXME: identifiers should be interned
+type ID = String;
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn Error>> {
+    let listener = TcpListener::bind("127.0.0.1:6397").await?;
+
+    log!("Listening");
+
+    // FIXME: should probably be a readers-writer lock
+    let snd_db = Arc::new(Mutex::new(HashMap::<
+        ID,
+        mpsc::UnboundedSender<Arc<SerializedMessage>>,
+    >::new()));
+    let mut writer_db = HashMap::<ID, mpsc::Sender<(OwnedWriteHalf, watch::Sender<bool>)>>::new();
+
+    loop {
+        let (socket, _) = listener.accept().await?;
+        let (mut rd, wr) = socket.into_split();
+
+        let id = parse_identifier(&mut rd).await?;
+        log!("Accepting \"{id}\"");
+        if let Some(socket_updater) = writer_db.get(&id) {
+            // we've seen this client before
+
+            // start the new reader thread with a new watch
+            let (watch_snd, watch_rcv) = watch::channel(false);
+            spawn_message_receiver(rd, snd_db.clone(), watch_rcv);
+
+            // give the writer thread the new write half of the socket and watch
+            socket_updater.send((wr, watch_snd)).await?;
+        } else {
+            // newly-registered client
+            log!("New client");
+
+            // message channel, used to send messages between threads
+            let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
+
+            let snd_db = snd_db.clone();
+            {
+                let mut locked_db = snd_db.lock().unwrap();
+                locked_db.insert(id.clone(), msg_snd);
+            }
+
+            // socket watch, used to terminate the socket if the sender encounters an error
+            let (watch_snd, watch_rcv) = watch::channel(false);
+
+            // socket updater, used to give the sender thread a new socket and watch
+            let (socket_updater_snd, socket_updater_rcv) = mpsc::channel(8);
+            socket_updater_snd.send((wr, watch_snd)).await?;
+
+            writer_db.insert(id.clone(), socket_updater_snd);
+
+            spawn_message_receiver(rd, snd_db, watch_rcv);
+            tokio::spawn(async move {
+                send_messages(msg_rcv, socket_updater_rcv).await;
+            });
+        }
+    }
+}
+
+fn spawn_message_receiver(
+    rd: OwnedReadHalf,
+    db: Arc<Mutex<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
+    mut watch_rcv: watch::Receiver<bool>,
+) {
+    watch_rcv.borrow_and_update();
+    tokio::spawn(async move {
+        tokio::select! {
+            ret = get_messages(rd, db) => {
+                if let Err(e) = ret {
+                    log!("message receiver failed: {:?}", e);
+                }
+            }
+            _ = watch_rcv.changed() => {
+                log!("receiver terminated");
+                // should cause the other thread to terminate, dropping the socket
+            }
+        }
+    });
+}
+
+/// Parse the identifier from the start of the TcpStream.
+async fn parse_identifier(stream: &mut OwnedReadHalf) -> Result<ID, Box<dyn Error>> {
+    // this should maybe be buffered
+    let strlen = stream.read_u32().await?;
+    let mut buf = vec![0u8; strlen as usize];
+    stream.read_exact(&mut buf).await?;
+    let s = std::str::from_utf8(&buf)?;
+    Ok(s.to_string())
+}
+
+/// Loop for receiving messages on the socket, figuring out who to deliver them to,
+/// and forwarding them locally to the respective channel.
+async fn get_messages(
+    mut socket: OwnedReadHalf,
+    db: Arc<Mutex<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
+) -> Result<(), Box<dyn Error>> {
+    // stores snd's for contacts this client has already sent messages to, to reduce contention on the main db
+    // if memory ends up being more of a constraint, could be worth getting rid of this
+    let mut localdb = HashMap::<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>::new();
+    let mut sender = "".to_string();
+    loop {
+        log!("waiting for message from {}", sender);
+        let (message, buf) = mgen::get_message(&mut socket).await?;
+        log!(
+            "got message from {} for {:?}",
+            message.sender,
+            message.recipients
+        );
+        sender = message.sender;
+        let missing = message
+            .recipients
+            .iter()
+            .filter(|&r| !localdb.contains_key(r))
+            .collect::<Vec<&ID>>();
+
+        {
+            let locked_db = db.lock().unwrap();
+            for m in missing {
+                if let Some(snd) = locked_db.get(m) {
+                    localdb.insert(m.to_string(), snd.clone());
+                }
+            }
+        }
+
+        let m = Arc::new(SerializedMessage {
+            header: buf,
+            body: message.body,
+        });
+
+        for recipient in message.recipients.iter() {
+            let recipient_sender = localdb
+                .get(recipient)
+                .unwrap_or_else(|| panic!("Unknown sender: {}", recipient));
+            recipient_sender
+                .send(m.clone())
+                .expect("Recipient closed channel with messages still being sent");
+        }
+    }
+}
+
+/// Loop for receiving messages on the mpsc channel for this recipient,
+/// and sending them out on the associated socket.
+async fn send_messages(
+    mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
+    mut socket_updater: mpsc::Receiver<(OwnedWriteHalf, watch::Sender<bool>)>,
+) {
+    let (mut current_socket, mut current_watch) =
+        socket_updater.recv().await.expect("socket updater closed");
+    loop {
+        let message = msg_rcv.recv().await.expect("message channel closed");
+        log!("sending message");
+        if message.write_all_to(&mut current_socket).await.is_err()
+            || current_socket.flush().await.is_err()
+        {
+            log!("terminating connection");
+            // socket is presumably closed, clean up and notify the listening end to close
+            // (all best-effort, we can ignore errors because it presumably means it's done)
+            let _ = current_watch.send(true);
+            let _ = current_socket.shutdown().await;
+
+            // wait for the new socket
+            (current_socket, current_watch) =
+                socket_updater.recv().await.expect("socket updater closed");
+            log!("socket updated");
+        }
+    }
+}

+ 265 - 0
src/lib.rs

@@ -0,0 +1,265 @@
+use std::mem::size_of;
+use std::num::NonZeroU32;
+use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
+
+/// The minimum message size.
+/// All messages bodies less than this size (notably, receipts) will be padded to this length.
+const MIN_MESSAGE_SIZE: u32 = 256; // FIXME: double check what this should be
+
+#[macro_export]
+macro_rules! log {
+    ( $( $x:expr ),* ) => {
+        print!("{}", chrono::offset::Utc::now().format("%F %T: "));
+        println!($( $x ),*)
+    }
+}
+
+#[derive(Debug)]
+pub enum Error {
+    Io(std::io::Error),
+    Utf8Error(std::str::Utf8Error),
+    TryFromSliceError(std::array::TryFromSliceError),
+    MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
+}
+
+impl std::fmt::Display for Error {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
+impl std::error::Error for Error {}
+
+impl From<std::io::Error> for Error {
+    fn from(e: std::io::Error) -> Self {
+        Self::Io(e)
+    }
+}
+
+impl From<std::str::Utf8Error> for Error {
+    fn from(e: std::str::Utf8Error) -> Self {
+        Self::Utf8Error(e)
+    }
+}
+
+impl From<std::array::TryFromSliceError> for Error {
+    fn from(e: std::array::TryFromSliceError) -> Self {
+        Self::TryFromSliceError(e)
+    }
+}
+
+/// Metadata for the body of the message.
+///
+/// Message contents are always 0-filled buffers, so never represented.
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum MessageBody {
+    Receipt,
+    Size(NonZeroU32),
+}
+
+impl MessageBody {
+    fn size(&self) -> u32 {
+        match self {
+            MessageBody::Receipt => MIN_MESSAGE_SIZE,
+            MessageBody::Size(size) => size.get(),
+        }
+    }
+}
+
+/// Message metadata.
+///
+/// This has everything needed to reconstruct a message.
+// FIXME: every String should be &str
+#[derive(Debug)]
+pub struct MessageHeader {
+    pub sender: String,
+    pub recipients: Vec<String>,
+    pub body: MessageBody,
+}
+
+impl MessageHeader {
+    /// Generate a consise serialization of the Message.
+    pub fn serialize(&self) -> SerializedMessage {
+        // serialized message header: {
+        //   header_len: u32,
+        //   sender: {u32, utf-8}
+        //   num_recipients: u32,
+        //   recipients: [{u32, utf-8}],
+        //   body_type: MessageBody (i.e., u32)
+        // }
+
+        let num_recipients = self.recipients.len();
+        let body_type = match self.body {
+            MessageBody::Receipt => 0,
+            MessageBody::Size(s) => s.get(),
+        };
+
+        let header_len = (1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
+            + self.sender.len()
+            + self.recipients.iter().map(String::len).sum::<usize>();
+
+        let mut header: Vec<u8> = Vec::with_capacity(header_len);
+
+        let header_len = header_len as u32;
+        let num_recipients = num_recipients as u32;
+
+        header.extend(header_len.to_be_bytes());
+
+        serialize_str_to(&self.sender, &mut header);
+
+        header.extend(num_recipients.to_be_bytes());
+        for recipient in self.recipients.iter() {
+            serialize_str_to(recipient, &mut header);
+        }
+
+        header.extend(body_type.to_be_bytes());
+
+        assert!(header.len() == header_len as usize);
+        SerializedMessage {
+            header,
+            body: self.body,
+        }
+    }
+
+    /// Creates a MessageHeader from bytes created via serialization,
+    /// but with the size already parsed out.
+    fn deserialize(buf: &[u8]) -> Result<Self, Error> {
+        let (sender, buf) = deserialize_str(buf)?;
+        let sender = sender.to_string();
+
+        let (num_recipients, buf) = deserialize_u32(buf)?;
+        debug_assert!(num_recipients != 0);
+
+        let mut recipients = Vec::with_capacity(num_recipients as usize);
+        let mut recipient;
+        let mut buf = buf;
+        for _ in 0..num_recipients {
+            (recipient, buf) = deserialize_str(buf)?;
+            let recipient = recipient.to_string();
+            recipients.push(recipient);
+        }
+
+        let (body, _) = deserialize_u32(buf)?;
+        let body = if let Some(size) = NonZeroU32::new(body) {
+            MessageBody::Size(size)
+        } else {
+            MessageBody::Receipt
+        };
+        Ok(Self {
+            sender,
+            recipients,
+            body,
+        })
+    }
+}
+
+/// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer
+pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
+    stream: &mut T,
+) -> Result<(MessageHeader, Vec<u8>), Error> {
+    let mut header_size_bytes = [0u8; 4];
+    stream.read_exact(&mut header_size_bytes).await?;
+    log!("got header size");
+    get_message_with_header_size(stream, header_size_bytes).await
+}
+
+pub async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
+    stream: &mut T,
+    header_size_bytes: [u8; 4],
+) -> Result<(MessageHeader, Vec<u8>), Error> {
+    let header_size = u32::from_be_bytes(header_size_bytes);
+    let mut header_buf = vec![0; header_size as usize];
+    stream.read_exact(&mut header_buf[4..]).await?;
+    let header = MessageHeader::deserialize(&header_buf[4..])?;
+    log!(
+        "got header from {} to {:?}, about to read {} bytes",
+        header.sender,
+        header.recipients,
+        header.body.size()
+    );
+    let header_size_buf = &mut header_buf[..4];
+    header_size_buf.copy_from_slice(&header_size_bytes);
+    copy(&mut stream.take(header.body.size() as u64), &mut sink()).await?;
+    Ok((header, header_buf))
+}
+
+pub fn serialize_str(s: &str) -> Vec<u8> {
+    let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
+    serialize_str_to(s, &mut buf);
+    buf
+}
+
+pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
+    let strlen = s.len() as u32;
+    buf.extend(strlen.to_be_bytes());
+    buf.extend(s.as_bytes());
+}
+
+fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
+    let bytes = buf.get(0..4).ok_or(Error::MalformedSerialization(
+        buf.to_vec(),
+        std::backtrace::Backtrace::capture(),
+    ))?;
+    Ok((u32::from_be_bytes(bytes.try_into()?), &buf[4..]))
+}
+
+fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
+    let (strlen, buf) = deserialize_u32(buf)?;
+    let strlen = strlen as usize;
+    let strbytes = buf.get(..strlen).ok_or(Error::MalformedSerialization(
+        buf.to_vec(),
+        std::backtrace::Backtrace::capture(),
+    ))?;
+    Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
+}
+
+/// A message almost ready for sending.
+///
+/// We represent each message in two halves: the header, and the body.
+/// This way, the server can parse out the header in its own buf,
+/// and just pass that around intact, without keeping a (possibly large)
+/// 0-filled body around.
+#[derive(Debug)]
+pub struct SerializedMessage {
+    pub header: Vec<u8>,
+    pub body: MessageBody,
+}
+
+impl SerializedMessage {
+    pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
+        &self,
+        writer: &mut T,
+    ) -> std::io::Result<()> {
+        let body_buf = vec![0; self.body.size() as usize];
+
+        // write_all_vectored is not yet stable x_x
+        // https://github.com/rust-lang/rust/issues/70436
+        let mut header: &[u8] = &self.header;
+        let mut body: &[u8] = &body_buf;
+        loop {
+            let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
+            match writer.write_vectored(&bufs).await {
+                Ok(written) => {
+                    if written == header.len() + body.len() {
+                        return Ok(());
+                    }
+
+                    if written >= header.len() {
+                        body = &body[written - header.len()..];
+                        break;
+                    } else if written == 0 {
+                        return Err(std::io::Error::new(
+                            std::io::ErrorKind::WriteZero,
+                            "failed to write any bytes from message with bytes remaining",
+                        ));
+                    } else {
+                        header = &header[written..];
+                    }
+                }
+                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
+                Err(e) => return Err(e),
+            }
+        }
+        writer.write_all(body).await
+    }
+}