123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- use enum_dispatch::enum_dispatch;
- use mgen::{log, SerializedMessage};
- use rand_distr::{
- Bernoulli, BernoulliError, Distribution, Exp, ExpError, LogNormal, Normal, NormalError, Pareto,
- ParetoError, Uniform,
- };
- use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
- use serde::Deserialize;
- use std::env;
- use std::num::NonZeroU32;
- use std::result::Result;
- use tokio::io::{AsyncReadExt, AsyncWriteExt};
- use tokio::net::TcpStream;
- use tokio::task;
- use tokio::time::{Duration, Instant};
- #[derive(Debug)]
- enum ClientError {
- // errors from the library
- Mgen(mgen::Error),
- // errors from parsing the conversation files
- Parameter(DistParameterError),
- // errors from the socks connection
- Socks(tokio_socks::Error),
- // general I/O errors in this file
- Io(std::io::Error),
- }
- impl std::fmt::Display for ClientError {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self)
- }
- }
- impl std::error::Error for ClientError {}
- impl From<mgen::Error> for ClientError {
- fn from(e: mgen::Error) -> Self {
- Self::Mgen(e)
- }
- }
- impl From<DistParameterError> for ClientError {
- fn from(e: DistParameterError) -> Self {
- Self::Parameter(e)
- }
- }
- impl From<tokio_socks::Error> for ClientError {
- fn from(e: tokio_socks::Error) -> Self {
- Self::Socks(e)
- }
- }
- impl From<std::io::Error> for ClientError {
- fn from(e: std::io::Error) -> Self {
- Self::Io(e)
- }
- }
- /// All possible Conversation state machine states
- enum StateMachine {
- Idle(Conversation<Idle>),
- Active(Conversation<Active>),
- }
- /// The state machine representing a conversation state and its transitions.
- struct Conversation<S: State> {
- dists: Distributions,
- delay: Instant,
- state: S,
- }
- #[derive(Debug)]
- #[enum_dispatch(Distribution)]
- /// The set of Distributions we currently support.
- /// To modify the code to add support for more, one approach is to first add them here,
- /// then fix all the compiler errors that arise as a result.
- enum SupportedDistribution {
- Normal(Normal<f64>),
- LogNormal(LogNormal<f64>),
- Uniform(Uniform<f64>),
- Exp(Exp<f64>),
- Pareto(Pareto<f64>),
- }
- /// The set of distributions necessary to represent the actions of the state machine.
- #[derive(Debug)]
- struct Distributions {
- i: SupportedDistribution,
- w: SupportedDistribution,
- a_s: SupportedDistribution,
- a_r: SupportedDistribution,
- s: Bernoulli,
- r: Bernoulli,
- }
- trait State {}
- struct Idle {}
- struct Active {
- wait: Instant,
- }
- impl State for Idle {}
- impl State for Active {}
- impl Conversation<Idle> {
- fn start(dists: Distributions, rng: &mut Xoshiro256PlusPlus) -> Self {
- let delay = Instant::now() + dists.i.sample_secs(rng);
- log!("[start]");
- Self {
- dists,
- delay,
- state: Idle {},
- }
- }
- fn sent(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
- if self.dists.s.sample(rng) {
- log!("Idle: [sent] tranisition to [Active]");
- let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
- let wait = Instant::now() + self.dists.w.sample_secs(rng);
- StateMachine::Active({
- Conversation::<Active> {
- dists: self.dists,
- delay,
- state: Active { wait },
- }
- })
- } else {
- log!("Idle: [sent] tranisition to [Idle]");
- let delay = Instant::now() + self.dists.i.sample_secs(rng);
- StateMachine::Idle({
- Conversation::<Idle> {
- dists: self.dists,
- delay,
- state: Idle {},
- }
- })
- }
- }
- fn received(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
- if self.dists.r.sample(rng) {
- log!("Idle: [recv'd] tranisition to [Active]");
- let wait = Instant::now() + self.dists.w.sample_secs(rng);
- let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
- StateMachine::Active({
- Conversation::<Active> {
- dists: self.dists,
- delay,
- state: Active { wait },
- }
- })
- } else {
- log!("Idle: [recv'd] tranisition to [Idle]");
- StateMachine::Idle(self)
- }
- }
- }
- impl Conversation<Active> {
- fn sent(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
- log!("Active: [sent] transition to [Active]");
- let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
- Conversation::<Active> {
- dists: self.dists,
- delay,
- state: self.state,
- }
- }
- fn received(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
- log!("Active: [recv'd] transition to [Active]");
- let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
- Conversation::<Active> {
- dists: self.dists,
- delay,
- state: self.state,
- }
- }
- fn waited(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Idle> {
- log!("Active: [waited] tranision to [Idle]");
- let delay = Instant::now() + self.dists.i.sample_secs(rng);
- Conversation::<Idle> {
- dists: self.dists,
- delay,
- state: Idle {},
- }
- }
- async fn sleep(delay: Instant, wait: Instant) -> ActiveActions {
- if delay < wait {
- log!("delaying for {:?}", delay - Instant::now());
- tokio::time::sleep_until(delay).await;
- ActiveActions::Send
- } else {
- log!("waiting for {:?}", wait - Instant::now());
- tokio::time::sleep_until(wait).await;
- ActiveActions::Idle
- }
- }
- }
- /// Attempt to read some portion of the size of the reast of the header from the stream.
- /// The number of bytes written is returned in the Ok case.
- /// The caller must read any remaining bytes less than 4.
- // N.B.: This must be written cancellation safe!
- // https://docs.rs/tokio/1.26.0/tokio/macro.select.html#cancellation-safety
- async fn read_header_size(
- stream: &mut TcpStream,
- header_size: &mut [u8; 4],
- ) -> Result<usize, ClientError> {
- let read = stream.read(header_size).await?;
- if read == 0 {
- Err(tokio::io::Error::new(
- tokio::io::ErrorKind::WriteZero,
- "failed to read any bytes from message with bytes remaining",
- )
- .into())
- } else {
- Ok(read)
- }
- }
- enum IdleActions {
- Send,
- Receive(usize),
- }
- async fn manage_idle_conversation(
- conversation: Conversation<Idle>,
- stream: &mut TcpStream,
- our_id: &str,
- recipients: Vec<&str>,
- rng: &mut Xoshiro256PlusPlus,
- ) -> Result<StateMachine, ClientError> {
- log!("delaying for {:?}", conversation.delay - Instant::now());
- let mut header_size = [0; 4];
- let action = tokio::select! {
- () = tokio::time::sleep_until(conversation.delay) => {
- Ok(IdleActions::Send)
- }
- res = read_header_size(stream, &mut header_size) => {
- match res {
- Ok(n) => Ok(IdleActions::Receive(n)),
- Err(e) => Err(e),
- }
- }
- }?;
- match action {
- IdleActions::Send => {
- log!("sending message from {} to {:?}", our_id, recipients);
- let m = construct_message(
- our_id.to_string(),
- recipients.iter().map(|s| s.to_string()).collect(),
- );
- m.write_all_to(stream).await?;
- stream.flush().await?;
- Ok(conversation.sent(rng))
- }
- IdleActions::Receive(n) => {
- if n < 4 {
- // we didn't get the whole size, but we can use read_exact now
- stream.read_exact(&mut header_size[n..]).await?;
- }
- let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
- if msg.body != mgen::MessageBody::Receipt {
- log!("{:?} got message from {}", msg.recipients, msg.sender);
- let m = construct_receipt(our_id.to_string(), msg.sender);
- m.write_all_to(stream).await?;
- stream.flush().await?;
- Ok(conversation.received(rng))
- } else {
- Ok(StateMachine::Idle(conversation))
- }
- }
- }
- }
- enum ActiveActions {
- Send,
- Receive(usize),
- Idle,
- }
- async fn manage_active_conversation(
- conversation: Conversation<Active>,
- stream: &mut TcpStream,
- our_id: &str,
- recipients: Vec<&str>,
- rng: &mut Xoshiro256PlusPlus,
- ) -> Result<StateMachine, ClientError> {
- let mut header_size = [0; 4];
- let action = tokio::select! {
- action = Conversation::<Active>::sleep(conversation.delay, conversation.state.wait) => {
- Ok(action)
- }
- res = read_header_size(stream, &mut header_size) => {
- match res {
- Ok(n) => Ok(ActiveActions::Receive(n)),
- Err(e) => Err(e),
- }
- }
- }?;
- match action {
- ActiveActions::Send => {
- log!("sending message from {} to {:?}", our_id, recipients);
- let m = construct_message(
- our_id.to_string(),
- recipients.iter().map(|s| s.to_string()).collect(),
- );
- m.write_all_to(stream).await?;
- stream.flush().await?;
- Ok(StateMachine::Active(conversation.sent(rng)))
- }
- ActiveActions::Receive(n) => {
- if n < 4 {
- // we didn't get the whole size, but we can use read_exact now
- stream.read_exact(&mut header_size[n..]).await?;
- }
- let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
- if msg.body != mgen::MessageBody::Receipt {
- log!("{:?} got message from {}", msg.recipients, msg.sender);
- let m = construct_receipt(our_id.to_string(), msg.sender);
- m.write_all_to(stream).await?;
- stream.flush().await?;
- Ok(StateMachine::Active(conversation.received(rng)))
- } else {
- Ok(StateMachine::Active(conversation))
- }
- }
- ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
- }
- }
- async fn manage_conversation(config: Config) -> Result<(), ClientError> {
- let mut rng = Xoshiro256PlusPlus::from_entropy();
- let distributions: Distributions = config.distributions.try_into()?;
- let mut state_machine =
- StateMachine::Idle(Conversation::<Idle>::start(distributions, &mut rng));
- let recipients: Vec<&str> = config.recipients.iter().map(String::as_str).collect();
- let mut stream = tokio_socks::tcp::Socks5Stream::connect_with_password(
- config.socks.as_str(),
- config.server.as_str(),
- &config.sender,
- &config.group,
- )
- .await?;
- stream
- .write_all(&mgen::serialize_str(&config.sender))
- .await?;
- tokio::time::sleep(Duration::from_secs(5)).await;
- loop {
- state_machine = match state_machine {
- StateMachine::Idle(conversation) => {
- manage_idle_conversation(
- conversation,
- &mut stream,
- &config.sender,
- recipients.clone(),
- &mut rng,
- )
- .await?
- }
- StateMachine::Active(conversation) => {
- manage_active_conversation(
- conversation,
- &mut stream,
- &config.sender,
- recipients.clone(),
- &mut rng,
- )
- .await?
- }
- };
- }
- }
- /// A wrapper for the Distribution trait that specifies the RNG to allow (fake) dynamic dispatch.
- #[enum_dispatch(SupportedDistribution)]
- trait Dist {
- fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64;
- }
- /*
- // This would be easier, but we run into https://github.com/rust-lang/rust/issues/48869
- impl<T, D> Dist<T> for D
- where
- D: Distribution<T> + Send + Sync,
- {
- fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> T {
- self.sample(rng)
- }
- }
- */
- macro_rules! dist_impl {
- ($dist:ident) => {
- impl Dist for $dist<f64> {
- fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
- Distribution::sample(self, rng)
- }
- }
- };
- }
- dist_impl!(Exp);
- dist_impl!(Normal);
- dist_impl!(LogNormal);
- dist_impl!(Pareto);
- dist_impl!(Uniform);
- impl SupportedDistribution {
- // FIXME: there's probably a better way to do this integrated with the crate
- fn clamped_sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
- let sample = self.sample(rng);
- if sample >= 0.0 {
- sample
- } else {
- 0.0
- }
- }
- fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
- Duration::from_secs_f64(self.clamped_sample(rng))
- }
- }
- fn construct_message(sender: String, recipients: Vec<String>) -> SerializedMessage {
- // FIXME: sample size from distribution
- let m = mgen::MessageHeader {
- sender,
- recipients,
- body: mgen::MessageBody::Size(NonZeroU32::new(1024).unwrap()),
- };
- m.serialize()
- }
- fn construct_receipt(sender: String, recipient: String) -> SerializedMessage {
- let m = mgen::MessageHeader {
- sender,
- recipients: vec![recipient],
- body: mgen::MessageBody::Receipt,
- };
- m.serialize()
- }
- /// The same as Distributions, but designed for easier deserialization.
- #[derive(Debug, Deserialize)]
- struct ConfigDistributions {
- i: ConfigSupportedDistribution,
- w: ConfigSupportedDistribution,
- a_s: ConfigSupportedDistribution,
- a_r: ConfigSupportedDistribution,
- s: f64,
- r: f64,
- }
- /// The same as SupportedDistributions, but designed for easier deserialization.
- #[derive(Debug, Deserialize)]
- #[serde(tag = "distribution")]
- enum ConfigSupportedDistribution {
- Normal { mean: f64, std_dev: f64 },
- LogNormal { mean: f64, std_dev: f64 },
- Uniform { low: f64, high: f64 },
- Exp { lambda: f64 },
- Pareto { scale: f64, shape: f64 },
- }
- #[derive(Debug)]
- enum DistParameterError {
- Bernoulli(BernoulliError),
- Normal(NormalError),
- LogNormal(NormalError),
- Uniform, // Uniform::new doesn't return an error, it just panics
- Exp(ExpError),
- Pareto(ParetoError),
- }
- impl TryFrom<ConfigSupportedDistribution> for SupportedDistribution {
- type Error = DistParameterError;
- fn try_from(dist: ConfigSupportedDistribution) -> Result<Self, DistParameterError> {
- let dist = match dist {
- ConfigSupportedDistribution::Normal { mean, std_dev } => SupportedDistribution::Normal(
- Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
- ),
- ConfigSupportedDistribution::LogNormal { mean, std_dev } => {
- SupportedDistribution::LogNormal(
- LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
- )
- }
- ConfigSupportedDistribution::Uniform { low, high } => {
- if low >= high {
- return Err(DistParameterError::Uniform);
- }
- SupportedDistribution::Uniform(Uniform::new(low, high))
- }
- ConfigSupportedDistribution::Exp { lambda } => {
- SupportedDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
- }
- ConfigSupportedDistribution::Pareto { scale, shape } => SupportedDistribution::Pareto(
- Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
- ),
- };
- Ok(dist)
- }
- }
- impl TryFrom<ConfigDistributions> for Distributions {
- type Error = DistParameterError;
- fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
- Ok(Distributions {
- i: config.i.try_into()?,
- w: config.w.try_into()?,
- a_s: config.a_s.try_into()?,
- a_r: config.a_r.try_into()?,
- s: Bernoulli::new(config.s).map_err(DistParameterError::Bernoulli)?,
- r: Bernoulli::new(config.r).map_err(DistParameterError::Bernoulli)?,
- })
- }
- }
- #[derive(Debug, Deserialize)]
- struct Config {
- sender: String,
- group: String,
- recipients: Vec<String>,
- socks: String,
- server: String,
- distributions: ConfigDistributions,
- }
- #[tokio::main]
- async fn main() -> Result<(), Box<dyn std::error::Error>> {
- let mut args = env::args();
- let _ = args.next();
- let mut handles = vec![];
- for config_file in args {
- let toml_s = std::fs::read_to_string(config_file)?;
- let config = toml::from_str(&toml_s)?;
- let handle: task::JoinHandle<Result<(), ClientError>> =
- tokio::spawn(manage_conversation(config));
- handles.push(handle);
- }
- for handle in handles {
- handle.await??;
- }
- Ok(())
- }
|