client.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. use enum_dispatch::enum_dispatch;
  2. use mgen::{log, SerializedMessage};
  3. use rand_distr::{
  4. Bernoulli, BernoulliError, Distribution, Exp, ExpError, LogNormal, Normal, NormalError, Pareto,
  5. ParetoError, Uniform,
  6. };
  7. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  8. use serde::Deserialize;
  9. use std::env;
  10. use std::num::NonZeroU32;
  11. use std::result::Result;
  12. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  13. use tokio::net::TcpStream;
  14. use tokio::task;
  15. use tokio::time::{Duration, Instant};
  16. #[derive(Debug)]
  17. enum ClientError {
  18. // errors from the library
  19. Mgen(mgen::Error),
  20. // errors from parsing the conversation files
  21. Parameter(DistParameterError),
  22. // errors from the socks connection
  23. Socks(tokio_socks::Error),
  24. // general I/O errors in this file
  25. Io(std::io::Error),
  26. }
  27. impl std::fmt::Display for ClientError {
  28. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  29. write!(f, "{:?}", self)
  30. }
  31. }
  32. impl std::error::Error for ClientError {}
  33. impl From<mgen::Error> for ClientError {
  34. fn from(e: mgen::Error) -> Self {
  35. Self::Mgen(e)
  36. }
  37. }
  38. impl From<DistParameterError> for ClientError {
  39. fn from(e: DistParameterError) -> Self {
  40. Self::Parameter(e)
  41. }
  42. }
  43. impl From<tokio_socks::Error> for ClientError {
  44. fn from(e: tokio_socks::Error) -> Self {
  45. Self::Socks(e)
  46. }
  47. }
  48. impl From<std::io::Error> for ClientError {
  49. fn from(e: std::io::Error) -> Self {
  50. Self::Io(e)
  51. }
  52. }
  53. /// All possible Conversation state machine states
  54. enum StateMachine {
  55. Idle(Conversation<Idle>),
  56. Active(Conversation<Active>),
  57. }
  58. /// The state machine representing a conversation state and its transitions.
  59. struct Conversation<S: State> {
  60. dists: Distributions,
  61. delay: Instant,
  62. state: S,
  63. }
  64. #[derive(Debug)]
  65. #[enum_dispatch(Distribution)]
  66. /// The set of Distributions we currently support.
  67. /// To modify the code to add support for more, one approach is to first add them here,
  68. /// then fix all the compiler errors that arise as a result.
  69. enum SupportedDistribution {
  70. Normal(Normal<f64>),
  71. LogNormal(LogNormal<f64>),
  72. Uniform(Uniform<f64>),
  73. Exp(Exp<f64>),
  74. Pareto(Pareto<f64>),
  75. }
  76. /// The set of distributions necessary to represent the actions of the state machine.
  77. #[derive(Debug)]
  78. struct Distributions {
  79. i: SupportedDistribution,
  80. w: SupportedDistribution,
  81. a_s: SupportedDistribution,
  82. a_r: SupportedDistribution,
  83. s: Bernoulli,
  84. r: Bernoulli,
  85. }
  86. trait State {}
  87. struct Idle {}
  88. struct Active {
  89. wait: Instant,
  90. }
  91. impl State for Idle {}
  92. impl State for Active {}
  93. impl Conversation<Idle> {
  94. fn start(dists: Distributions, rng: &mut Xoshiro256PlusPlus) -> Self {
  95. let delay = Instant::now() + dists.i.sample_secs(rng);
  96. log!("[start]");
  97. Self {
  98. dists,
  99. delay,
  100. state: Idle {},
  101. }
  102. }
  103. fn sent(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
  104. if self.dists.s.sample(rng) {
  105. log!("Idle: [sent] tranisition to [Active]");
  106. let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
  107. let wait = Instant::now() + self.dists.w.sample_secs(rng);
  108. StateMachine::Active({
  109. Conversation::<Active> {
  110. dists: self.dists,
  111. delay,
  112. state: Active { wait },
  113. }
  114. })
  115. } else {
  116. log!("Idle: [sent] tranisition to [Idle]");
  117. let delay = Instant::now() + self.dists.i.sample_secs(rng);
  118. StateMachine::Idle({
  119. Conversation::<Idle> {
  120. dists: self.dists,
  121. delay,
  122. state: Idle {},
  123. }
  124. })
  125. }
  126. }
  127. fn received(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
  128. if self.dists.r.sample(rng) {
  129. log!("Idle: [recv'd] tranisition to [Active]");
  130. let wait = Instant::now() + self.dists.w.sample_secs(rng);
  131. let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
  132. StateMachine::Active({
  133. Conversation::<Active> {
  134. dists: self.dists,
  135. delay,
  136. state: Active { wait },
  137. }
  138. })
  139. } else {
  140. log!("Idle: [recv'd] tranisition to [Idle]");
  141. StateMachine::Idle(self)
  142. }
  143. }
  144. }
  145. impl Conversation<Active> {
  146. fn sent(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
  147. log!("Active: [sent] transition to [Active]");
  148. let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
  149. Conversation::<Active> {
  150. dists: self.dists,
  151. delay,
  152. state: self.state,
  153. }
  154. }
  155. fn received(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
  156. log!("Active: [recv'd] transition to [Active]");
  157. let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
  158. Conversation::<Active> {
  159. dists: self.dists,
  160. delay,
  161. state: self.state,
  162. }
  163. }
  164. fn waited(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Idle> {
  165. log!("Active: [waited] tranision to [Idle]");
  166. let delay = Instant::now() + self.dists.i.sample_secs(rng);
  167. Conversation::<Idle> {
  168. dists: self.dists,
  169. delay,
  170. state: Idle {},
  171. }
  172. }
  173. async fn sleep(delay: Instant, wait: Instant) -> ActiveActions {
  174. if delay < wait {
  175. log!("delaying for {:?}", delay - Instant::now());
  176. tokio::time::sleep_until(delay).await;
  177. ActiveActions::Send
  178. } else {
  179. log!("waiting for {:?}", wait - Instant::now());
  180. tokio::time::sleep_until(wait).await;
  181. ActiveActions::Idle
  182. }
  183. }
  184. }
  185. /// Attempt to read some portion of the size of the reast of the header from the stream.
  186. /// The number of bytes written is returned in the Ok case.
  187. /// The caller must read any remaining bytes less than 4.
  188. // N.B.: This must be written cancellation safe!
  189. // https://docs.rs/tokio/1.26.0/tokio/macro.select.html#cancellation-safety
  190. async fn read_header_size(
  191. stream: &mut TcpStream,
  192. header_size: &mut [u8; 4],
  193. ) -> Result<usize, ClientError> {
  194. let read = stream.read(header_size).await?;
  195. if read == 0 {
  196. Err(tokio::io::Error::new(
  197. tokio::io::ErrorKind::WriteZero,
  198. "failed to read any bytes from message with bytes remaining",
  199. )
  200. .into())
  201. } else {
  202. Ok(read)
  203. }
  204. }
  205. enum IdleActions {
  206. Send,
  207. Receive(usize),
  208. }
  209. async fn manage_idle_conversation(
  210. conversation: Conversation<Idle>,
  211. stream: &mut TcpStream,
  212. our_id: &str,
  213. recipients: Vec<&str>,
  214. rng: &mut Xoshiro256PlusPlus,
  215. ) -> Result<StateMachine, ClientError> {
  216. log!("delaying for {:?}", conversation.delay - Instant::now());
  217. let mut header_size = [0; 4];
  218. let action = tokio::select! {
  219. () = tokio::time::sleep_until(conversation.delay) => {
  220. Ok(IdleActions::Send)
  221. }
  222. res = read_header_size(stream, &mut header_size) => {
  223. match res {
  224. Ok(n) => Ok(IdleActions::Receive(n)),
  225. Err(e) => Err(e),
  226. }
  227. }
  228. }?;
  229. match action {
  230. IdleActions::Send => {
  231. log!("sending message from {} to {:?}", our_id, recipients);
  232. let m = construct_message(
  233. our_id.to_string(),
  234. recipients.iter().map(|s| s.to_string()).collect(),
  235. );
  236. m.write_all_to(stream).await?;
  237. stream.flush().await?;
  238. Ok(conversation.sent(rng))
  239. }
  240. IdleActions::Receive(n) => {
  241. if n < 4 {
  242. // we didn't get the whole size, but we can use read_exact now
  243. stream.read_exact(&mut header_size[n..]).await?;
  244. }
  245. let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
  246. if msg.body != mgen::MessageBody::Receipt {
  247. log!("{:?} got message from {}", msg.recipients, msg.sender);
  248. let m = construct_receipt(our_id.to_string(), msg.sender);
  249. m.write_all_to(stream).await?;
  250. stream.flush().await?;
  251. Ok(conversation.received(rng))
  252. } else {
  253. Ok(StateMachine::Idle(conversation))
  254. }
  255. }
  256. }
  257. }
  258. enum ActiveActions {
  259. Send,
  260. Receive(usize),
  261. Idle,
  262. }
  263. async fn manage_active_conversation(
  264. conversation: Conversation<Active>,
  265. stream: &mut TcpStream,
  266. our_id: &str,
  267. recipients: Vec<&str>,
  268. rng: &mut Xoshiro256PlusPlus,
  269. ) -> Result<StateMachine, ClientError> {
  270. let mut header_size = [0; 4];
  271. let action = tokio::select! {
  272. action = Conversation::<Active>::sleep(conversation.delay, conversation.state.wait) => {
  273. Ok(action)
  274. }
  275. res = read_header_size(stream, &mut header_size) => {
  276. match res {
  277. Ok(n) => Ok(ActiveActions::Receive(n)),
  278. Err(e) => Err(e),
  279. }
  280. }
  281. }?;
  282. match action {
  283. ActiveActions::Send => {
  284. log!("sending message from {} to {:?}", our_id, recipients);
  285. let m = construct_message(
  286. our_id.to_string(),
  287. recipients.iter().map(|s| s.to_string()).collect(),
  288. );
  289. m.write_all_to(stream).await?;
  290. stream.flush().await?;
  291. Ok(StateMachine::Active(conversation.sent(rng)))
  292. }
  293. ActiveActions::Receive(n) => {
  294. if n < 4 {
  295. // we didn't get the whole size, but we can use read_exact now
  296. stream.read_exact(&mut header_size[n..]).await?;
  297. }
  298. let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
  299. if msg.body != mgen::MessageBody::Receipt {
  300. log!("{:?} got message from {}", msg.recipients, msg.sender);
  301. let m = construct_receipt(our_id.to_string(), msg.sender);
  302. m.write_all_to(stream).await?;
  303. stream.flush().await?;
  304. Ok(StateMachine::Active(conversation.received(rng)))
  305. } else {
  306. Ok(StateMachine::Active(conversation))
  307. }
  308. }
  309. ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
  310. }
  311. }
  312. async fn manage_conversation(config: Config) -> Result<(), ClientError> {
  313. let mut rng = Xoshiro256PlusPlus::from_entropy();
  314. let distributions: Distributions = config.distributions.try_into()?;
  315. let mut state_machine =
  316. StateMachine::Idle(Conversation::<Idle>::start(distributions, &mut rng));
  317. let recipients: Vec<&str> = config.recipients.iter().map(String::as_str).collect();
  318. let mut stream = tokio_socks::tcp::Socks5Stream::connect_with_password(
  319. config.socks.as_str(),
  320. config.server.as_str(),
  321. &config.sender,
  322. &config.group,
  323. )
  324. .await?;
  325. stream
  326. .write_all(&mgen::serialize_str(&config.sender))
  327. .await?;
  328. tokio::time::sleep(Duration::from_secs(5)).await;
  329. loop {
  330. state_machine = match state_machine {
  331. StateMachine::Idle(conversation) => {
  332. manage_idle_conversation(
  333. conversation,
  334. &mut stream,
  335. &config.sender,
  336. recipients.clone(),
  337. &mut rng,
  338. )
  339. .await?
  340. }
  341. StateMachine::Active(conversation) => {
  342. manage_active_conversation(
  343. conversation,
  344. &mut stream,
  345. &config.sender,
  346. recipients.clone(),
  347. &mut rng,
  348. )
  349. .await?
  350. }
  351. };
  352. }
  353. }
  354. /// A wrapper for the Distribution trait that specifies the RNG to allow (fake) dynamic dispatch.
  355. #[enum_dispatch(SupportedDistribution)]
  356. trait Dist {
  357. fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64;
  358. }
  359. /*
  360. // This would be easier, but we run into https://github.com/rust-lang/rust/issues/48869
  361. impl<T, D> Dist<T> for D
  362. where
  363. D: Distribution<T> + Send + Sync,
  364. {
  365. fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> T {
  366. self.sample(rng)
  367. }
  368. }
  369. */
  370. macro_rules! dist_impl {
  371. ($dist:ident) => {
  372. impl Dist for $dist<f64> {
  373. fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
  374. Distribution::sample(self, rng)
  375. }
  376. }
  377. };
  378. }
  379. dist_impl!(Exp);
  380. dist_impl!(Normal);
  381. dist_impl!(LogNormal);
  382. dist_impl!(Pareto);
  383. dist_impl!(Uniform);
  384. impl SupportedDistribution {
  385. // FIXME: there's probably a better way to do this integrated with the crate
  386. fn clamped_sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
  387. let sample = self.sample(rng);
  388. if sample >= 0.0 {
  389. sample
  390. } else {
  391. 0.0
  392. }
  393. }
  394. fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
  395. Duration::from_secs_f64(self.clamped_sample(rng))
  396. }
  397. }
  398. fn construct_message(sender: String, recipients: Vec<String>) -> SerializedMessage {
  399. // FIXME: sample size from distribution
  400. let m = mgen::MessageHeader {
  401. sender,
  402. recipients,
  403. body: mgen::MessageBody::Size(NonZeroU32::new(1024).unwrap()),
  404. };
  405. m.serialize()
  406. }
  407. fn construct_receipt(sender: String, recipient: String) -> SerializedMessage {
  408. let m = mgen::MessageHeader {
  409. sender,
  410. recipients: vec![recipient],
  411. body: mgen::MessageBody::Receipt,
  412. };
  413. m.serialize()
  414. }
  415. /// The same as Distributions, but designed for easier deserialization.
  416. #[derive(Debug, Deserialize)]
  417. struct ConfigDistributions {
  418. i: ConfigSupportedDistribution,
  419. w: ConfigSupportedDistribution,
  420. a_s: ConfigSupportedDistribution,
  421. a_r: ConfigSupportedDistribution,
  422. s: f64,
  423. r: f64,
  424. }
  425. /// The same as SupportedDistributions, but designed for easier deserialization.
  426. #[derive(Debug, Deserialize)]
  427. #[serde(tag = "distribution")]
  428. enum ConfigSupportedDistribution {
  429. Normal { mean: f64, std_dev: f64 },
  430. LogNormal { mean: f64, std_dev: f64 },
  431. Uniform { low: f64, high: f64 },
  432. Exp { lambda: f64 },
  433. Pareto { scale: f64, shape: f64 },
  434. }
  435. #[derive(Debug)]
  436. enum DistParameterError {
  437. Bernoulli(BernoulliError),
  438. Normal(NormalError),
  439. LogNormal(NormalError),
  440. Uniform, // Uniform::new doesn't return an error, it just panics
  441. Exp(ExpError),
  442. Pareto(ParetoError),
  443. }
  444. impl TryFrom<ConfigSupportedDistribution> for SupportedDistribution {
  445. type Error = DistParameterError;
  446. fn try_from(dist: ConfigSupportedDistribution) -> Result<Self, DistParameterError> {
  447. let dist = match dist {
  448. ConfigSupportedDistribution::Normal { mean, std_dev } => SupportedDistribution::Normal(
  449. Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
  450. ),
  451. ConfigSupportedDistribution::LogNormal { mean, std_dev } => {
  452. SupportedDistribution::LogNormal(
  453. LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
  454. )
  455. }
  456. ConfigSupportedDistribution::Uniform { low, high } => {
  457. if low >= high {
  458. return Err(DistParameterError::Uniform);
  459. }
  460. SupportedDistribution::Uniform(Uniform::new(low, high))
  461. }
  462. ConfigSupportedDistribution::Exp { lambda } => {
  463. SupportedDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
  464. }
  465. ConfigSupportedDistribution::Pareto { scale, shape } => SupportedDistribution::Pareto(
  466. Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
  467. ),
  468. };
  469. Ok(dist)
  470. }
  471. }
  472. impl TryFrom<ConfigDistributions> for Distributions {
  473. type Error = DistParameterError;
  474. fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
  475. Ok(Distributions {
  476. i: config.i.try_into()?,
  477. w: config.w.try_into()?,
  478. a_s: config.a_s.try_into()?,
  479. a_r: config.a_r.try_into()?,
  480. s: Bernoulli::new(config.s).map_err(DistParameterError::Bernoulli)?,
  481. r: Bernoulli::new(config.r).map_err(DistParameterError::Bernoulli)?,
  482. })
  483. }
  484. }
  485. #[derive(Debug, Deserialize)]
  486. struct Config {
  487. sender: String,
  488. group: String,
  489. recipients: Vec<String>,
  490. socks: String,
  491. server: String,
  492. distributions: ConfigDistributions,
  493. }
  494. #[tokio::main]
  495. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  496. let mut args = env::args();
  497. let _ = args.next();
  498. let mut handles = vec![];
  499. for config_file in args {
  500. let toml_s = std::fs::read_to_string(config_file)?;
  501. let config = toml::from_str(&toml_s)?;
  502. let handle: task::JoinHandle<Result<(), ClientError>> =
  503. tokio::spawn(manage_conversation(config));
  504. handles.push(handle);
  505. }
  506. for handle in handles {
  507. handle.await??;
  508. }
  509. Ok(())
  510. }