mgen-client.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. // Code specific to the client in the client-server mode.
  2. use mgen::{log, updater::Updater, HandshakeRef, MessageHeader, SerializedMessage};
  3. use futures::future::try_join_all;
  4. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  5. use serde::Deserialize;
  6. use std::result::Result;
  7. use std::sync::Arc;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::TcpStream;
  10. use tokio::spawn;
  11. use tokio::sync::mpsc;
  12. use tokio::time::Duration;
  13. use tokio_rustls::{client::TlsStream, TlsConnector};
  14. mod messenger;
  15. use crate::messenger::dists::{ConfigDistributions, Distributions};
  16. use crate::messenger::error::{FatalError, MessengerError};
  17. use crate::messenger::state::{
  18. manage_active_conversation, manage_idle_conversation, StateFromReader, StateMachine,
  19. StateToWriter,
  20. };
  21. use crate::messenger::tcp::{connect, SocksParams};
  22. /// Type for sending messages from the reader thread to the state thread.
  23. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  24. /// Type of messages sent to the writer thread.
  25. type MessageHolder = Box<SerializedMessage>;
  26. /// Type for getting messages from the state thread in the writer thread.
  27. type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
  28. /// Type for sending the updated read half of the socket.
  29. type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
  30. /// Type for getting the updated read half of the socket.
  31. type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
  32. /// Type for sending the updated write half of the socket.
  33. type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
  34. /// Type for getting the updated write half of the socket.
  35. type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
  36. /// Type for sending errors to other threads.
  37. type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
  38. /// Type for getting errors from other threads.
  39. type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
  40. /// Type for sending sizes to the attachment sender thread.
  41. type SizeChannelIn = mpsc::UnboundedSender<usize>;
  42. /// Type for getting sizes from other threads.
  43. type SizeChannelOut = mpsc::UnboundedReceiver<usize>;
  44. // we gain a (very) tiny performance win by not bothering to validate the cert
  45. struct NoCertificateVerification {}
  46. impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
  47. fn verify_server_cert(
  48. &self,
  49. _end_entity: &tokio_rustls::rustls::Certificate,
  50. _intermediates: &[tokio_rustls::rustls::Certificate],
  51. _server_name: &tokio_rustls::rustls::ServerName,
  52. _scts: &mut dyn Iterator<Item = &[u8]>,
  53. _ocsp: &[u8],
  54. _now: std::time::SystemTime,
  55. ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
  56. Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
  57. }
  58. }
  59. /// Create a URL the web server can use to accept or produce traffic.
  60. /// `target` is the IP or host name of the web server,
  61. /// `size` is the number of bytes to download or upload,
  62. /// `user` is to let the server log the user making the request.
  63. /// Panics if the arguments do not produce a valid URI.
  64. fn web_url(target: &str, size: usize, user: &str) -> hyper::Uri {
  65. let formatted = format!("https://{}/?size={}&user={}", target, size, user);
  66. formatted
  67. .parse()
  68. .unwrap_or_else(|_| panic!("Invalid URI: {}", formatted))
  69. }
  70. /// The thread responsible for getting incoming messages,
  71. /// checking for any network errors while doing so,
  72. /// and giving messages to the state thread.
  73. async fn reader(
  74. web_params: SocksParams,
  75. retry: Duration,
  76. tls_config: tokio_rustls::rustls::ClientConfig,
  77. message_channel: ReaderToState,
  78. socket_updater: ReadSocketUpdaterOut,
  79. error_channel: ErrorChannelIn,
  80. ) {
  81. let https = hyper_rustls::HttpsConnectorBuilder::new()
  82. .with_tls_config(tls_config)
  83. .https_only()
  84. .enable_http1()
  85. .build();
  86. match web_params.socks {
  87. Some(proxy) => {
  88. let auth = hyper_socks2::Auth {
  89. username: web_params.user.clone(),
  90. password: web_params.recipient,
  91. };
  92. let socks = hyper_socks2::SocksConnector {
  93. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  94. auth: Some(auth),
  95. connector: https,
  96. };
  97. let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(socks);
  98. worker(
  99. web_params.target,
  100. web_params.user,
  101. retry,
  102. client,
  103. message_channel,
  104. socket_updater,
  105. error_channel,
  106. )
  107. .await
  108. }
  109. None => {
  110. let client = hyper::Client::builder().build(https);
  111. worker(
  112. web_params.target,
  113. web_params.user,
  114. retry,
  115. client,
  116. message_channel,
  117. socket_updater,
  118. error_channel,
  119. )
  120. .await
  121. }
  122. }
  123. async fn worker<C>(
  124. target: String,
  125. user: String,
  126. retry: Duration,
  127. client: hyper::Client<C, hyper::Body>,
  128. message_channel: ReaderToState,
  129. mut socket_updater: ReadSocketUpdaterOut,
  130. error_channel: ErrorChannelIn,
  131. ) where
  132. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  133. {
  134. loop {
  135. let mut message_stream = socket_updater.recv().await;
  136. loop {
  137. let msg = match mgen::get_message(&mut message_stream).await {
  138. Ok(msg) => msg,
  139. Err(e) => {
  140. error_channel.send(e.into()).expect("Error channel closed");
  141. break;
  142. }
  143. };
  144. if msg.body.has_attachment() {
  145. let url = web_url(&target, msg.body.total_size(), &user);
  146. let client = client.clone();
  147. spawn(async move {
  148. let mut res = client.get(url.clone()).await;
  149. while res.is_err() {
  150. log!("Error fetching: {}", res.unwrap_err());
  151. tokio::time::sleep(retry).await;
  152. res = client.get(url.clone()).await;
  153. }
  154. });
  155. }
  156. message_channel
  157. .send(msg)
  158. .expect("Reader message channel closed");
  159. }
  160. }
  161. }
  162. }
  163. async fn uploader(
  164. web_params: SocksParams,
  165. retry: Duration,
  166. tls_config: tokio_rustls::rustls::ClientConfig,
  167. size_channel: SizeChannelOut,
  168. ) {
  169. let https = hyper_rustls::HttpsConnectorBuilder::new()
  170. .with_tls_config(tls_config)
  171. .https_only()
  172. .enable_http1()
  173. .build();
  174. match web_params.socks {
  175. Some(proxy) => {
  176. let auth = hyper_socks2::Auth {
  177. username: web_params.user.clone(),
  178. password: web_params.recipient,
  179. };
  180. let socks = hyper_socks2::SocksConnector {
  181. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  182. auth: Some(auth),
  183. connector: https,
  184. };
  185. let client = hyper::Client::builder().build(socks);
  186. worker(
  187. web_params.target,
  188. web_params.user,
  189. retry,
  190. client,
  191. size_channel,
  192. )
  193. .await
  194. }
  195. None => {
  196. let client = hyper::Client::builder().build(https);
  197. worker(
  198. web_params.target,
  199. web_params.user,
  200. retry,
  201. client,
  202. size_channel,
  203. )
  204. .await
  205. }
  206. }
  207. async fn worker<C>(
  208. target: String,
  209. user: String,
  210. retry: Duration,
  211. client: hyper::Client<C, hyper::Body>,
  212. mut size_channel: SizeChannelOut,
  213. ) where
  214. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  215. {
  216. loop {
  217. let size = size_channel.recv().await.expect("Size channel closed");
  218. let client = client.clone();
  219. let url = web_url(&target, size, &user);
  220. let request = hyper::Request::put(url.clone())
  221. .body(hyper::Body::empty())
  222. .expect("Invalid HTTP request attempted to construct");
  223. let mut res = client.request(request).await;
  224. while res.is_err() {
  225. log!("Error uploading: {}", res.unwrap_err());
  226. tokio::time::sleep(retry).await;
  227. res = client.get(url.clone()).await;
  228. }
  229. }
  230. }
  231. }
  232. /// The thread responsible for sending messages from the state thread,
  233. /// and checking for any network errors while doing so.
  234. async fn writer(
  235. mut message_channel: WriterFromState,
  236. attachment_channel: SizeChannelIn,
  237. mut socket_updater: WriteSocketUpdaterOut,
  238. error_channel: ErrorChannelIn,
  239. ) {
  240. loop {
  241. let mut stream = socket_updater.recv().await;
  242. loop {
  243. let msg = message_channel
  244. .recv()
  245. .await
  246. .expect("Writer message channel closed");
  247. if msg.body.has_attachment() {
  248. attachment_channel
  249. .send(msg.body.total_size())
  250. .expect("Attachment channel closed");
  251. }
  252. if let Err(e) = msg.write_all_to(&mut stream).await {
  253. error_channel.send(e.into()).expect("Error channel closed");
  254. break;
  255. }
  256. }
  257. }
  258. }
  259. /// The thread responsible for (re-)establishing connections to the server,
  260. /// and determining how to handle errors this or other threads receive.
  261. async fn socket_updater(
  262. str_params: SocksParams,
  263. retry: Duration,
  264. tls_config: tokio_rustls::rustls::ClientConfig,
  265. mut error_channel: ErrorChannelOut,
  266. reader_channel: ReadSocketUpdaterIn,
  267. writer_channel: WriteSocketUpdaterIn,
  268. ) -> FatalError {
  269. let connector = TlsConnector::from(Arc::new(tls_config));
  270. // unwrap is safe, split always returns at least one element
  271. let tls_server_str = str_params.target.split(':').next().unwrap();
  272. let tls_server_name =
  273. tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
  274. loop {
  275. let stream: TcpStream = match connect(&str_params).await {
  276. Ok(stream) => stream,
  277. Err(MessengerError::Recoverable(e)) => {
  278. log!(
  279. "{},{},error,TCP,{:?}",
  280. str_params.user,
  281. str_params.recipient,
  282. e
  283. );
  284. tokio::time::sleep(retry).await;
  285. continue;
  286. }
  287. Err(MessengerError::Fatal(e)) => return e,
  288. };
  289. let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
  290. Ok(stream) => stream,
  291. Err(e) => {
  292. log!(
  293. "{},{},error,TLS,{:?}",
  294. str_params.user,
  295. str_params.recipient,
  296. e
  297. );
  298. tokio::time::sleep(retry).await;
  299. continue;
  300. }
  301. };
  302. let handshake = HandshakeRef {
  303. sender: &str_params.user,
  304. group: &str_params.recipient,
  305. };
  306. if stream.write_all(&handshake.serialize()).await.is_err() {
  307. continue;
  308. }
  309. log!("{},{},handshake", str_params.user, str_params.recipient);
  310. let (rd, wr) = split(stream);
  311. reader_channel.send(rd);
  312. writer_channel.send(wr);
  313. let res = error_channel.recv().await.expect("Error channel closed");
  314. if let MessengerError::Fatal(e) = res {
  315. return e;
  316. } else {
  317. log!(
  318. "{},{},error,{:?}",
  319. str_params.user,
  320. str_params.recipient,
  321. res
  322. );
  323. }
  324. }
  325. }
  326. /// The thread responsible for handling the conversation state
  327. /// (i.e., whether the user is active or idle, and when to send messages).
  328. async fn manage_conversation(
  329. config: FullConfig,
  330. mut state_from_reader: StateFromReader,
  331. mut state_to_writer: StateToWriter<MessageHolder>,
  332. ) {
  333. tokio::time::sleep(Duration::from_secs_f64(config.bootstrap)).await;
  334. log!("{},{},awake", &config.user, &config.group);
  335. let mut rng = Xoshiro256PlusPlus::from_entropy();
  336. let mut state_machine = StateMachine::start(config.distributions, &mut rng);
  337. loop {
  338. state_machine = match state_machine {
  339. StateMachine::Idle(conversation) => {
  340. manage_idle_conversation::<false, _, _, _>(
  341. conversation,
  342. &mut state_from_reader,
  343. &mut state_to_writer,
  344. &config.user,
  345. &config.group,
  346. &mut rng,
  347. )
  348. .await
  349. }
  350. StateMachine::Active(conversation) => {
  351. manage_active_conversation(
  352. conversation,
  353. &mut state_from_reader,
  354. &mut state_to_writer,
  355. &config.user,
  356. &config.group,
  357. false,
  358. &mut rng,
  359. )
  360. .await
  361. }
  362. };
  363. }
  364. }
  365. /// Spawns all other threads for this conversation.
  366. async fn spawn_threads(config: FullConfig) -> Result<(), MessengerError> {
  367. let message_server_params = SocksParams {
  368. socks: config.socks.clone(),
  369. target: config.message_server.clone(),
  370. user: config.user.clone(),
  371. recipient: config.group.clone(),
  372. };
  373. let web_server_params = SocksParams {
  374. socks: config.socks.clone(),
  375. target: config.web_server.clone(),
  376. user: config.user.clone(),
  377. recipient: config.group.clone(),
  378. };
  379. let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
  380. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  381. let read_socket_updater_in = Updater::new();
  382. let read_socket_updater_out = read_socket_updater_in.clone();
  383. let write_socket_updater_in = Updater::new();
  384. let write_socket_updater_out = write_socket_updater_in.clone();
  385. let (errs_in, errs_out) = mpsc::unbounded_channel();
  386. let (writer_to_uploader, uploader_from_writer) = mpsc::unbounded_channel();
  387. let state_to_writer = StateToWriter {
  388. channel: state_to_writer,
  389. };
  390. let retry = Duration::from_secs_f64(config.retry);
  391. let tls_config = tokio_rustls::rustls::ClientConfig::builder()
  392. .with_safe_defaults()
  393. .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
  394. .with_no_client_auth();
  395. spawn(reader(
  396. web_server_params.clone(),
  397. retry,
  398. tls_config.clone(),
  399. reader_to_state,
  400. read_socket_updater_out,
  401. errs_in.clone(),
  402. ));
  403. spawn(writer(
  404. writer_from_state,
  405. writer_to_uploader,
  406. write_socket_updater_out,
  407. errs_in,
  408. ));
  409. spawn(uploader(
  410. web_server_params,
  411. retry,
  412. tls_config.clone(),
  413. uploader_from_writer,
  414. ));
  415. spawn(manage_conversation(
  416. config,
  417. state_from_reader,
  418. state_to_writer,
  419. ));
  420. Err(MessengerError::Fatal(
  421. socket_updater(
  422. message_server_params,
  423. retry,
  424. tls_config,
  425. errs_out,
  426. read_socket_updater_in,
  427. write_socket_updater_in,
  428. )
  429. .await,
  430. ))
  431. }
  432. struct FullConfig {
  433. user: String,
  434. group: String,
  435. socks: Option<String>,
  436. message_server: String,
  437. web_server: String,
  438. bootstrap: f64,
  439. retry: f64,
  440. distributions: Distributions,
  441. }
  442. #[derive(Debug, Deserialize)]
  443. struct ConversationConfig {
  444. group: String,
  445. message_server: Option<String>,
  446. web_server: Option<String>,
  447. bootstrap: Option<f64>,
  448. retry: Option<f64>,
  449. distributions: Option<ConfigDistributions>,
  450. }
  451. #[derive(Debug, Deserialize)]
  452. struct Config {
  453. user: String,
  454. socks: Option<String>,
  455. message_server: String,
  456. web_server: String,
  457. bootstrap: f64,
  458. retry: f64,
  459. distributions: ConfigDistributions,
  460. conversations: Vec<ConversationConfig>,
  461. }
  462. async fn main_worker() -> Result<(), Box<dyn std::error::Error>> {
  463. #[cfg(feature = "tracing")]
  464. console_subscriber::init();
  465. let mut args = std::env::args();
  466. let _ = args.next();
  467. let mut handles = vec![];
  468. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  469. let yaml_s = std::fs::read_to_string(config_file?)?;
  470. let config: Config = serde_yaml::from_str(&yaml_s)?;
  471. let default_dists: Distributions = config.distributions.try_into()?;
  472. for conversation in config.conversations.into_iter() {
  473. let distributions: Distributions = match conversation.distributions {
  474. Some(dists) => dists.try_into()?,
  475. None => default_dists.clone(),
  476. };
  477. let filled_conversation = FullConfig {
  478. user: config.user.clone(),
  479. group: conversation.group,
  480. socks: config.socks.clone(),
  481. message_server: conversation
  482. .message_server
  483. .unwrap_or_else(|| config.message_server.clone()),
  484. web_server: conversation
  485. .web_server
  486. .unwrap_or_else(|| config.web_server.clone()),
  487. bootstrap: conversation.bootstrap.unwrap_or(config.bootstrap),
  488. retry: conversation.retry.unwrap_or(config.retry),
  489. distributions,
  490. };
  491. let handle = spawn_threads(filled_conversation);
  492. handles.push(handle);
  493. }
  494. }
  495. try_join_all(handles).await?;
  496. Ok(())
  497. }
  498. fn main() -> Result<(), Box<dyn std::error::Error>> {
  499. tokio::runtime::Builder::new_multi_thread()
  500. .worker_threads(2)
  501. .enable_all()
  502. .disable_lifo_slot()
  503. .build()
  504. .unwrap()
  505. .block_on(main_worker())
  506. }
  507. #[cfg(test)]
  508. mod tests {
  509. use super::*;
  510. #[test]
  511. fn test_uri_generation() {
  512. // should panic if any of these are invalid
  513. web_url("192.0.2.1", 0, "Alice");
  514. web_url("hostname", 65536, "Bob");
  515. web_url("web0", 4294967295, "Carol");
  516. web_url("web1", 1, "");
  517. web_url("foo.bar.baz", 1, "Dave");
  518. // IPv6 is not a valid in a URI
  519. //web_url("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 1, "1");
  520. // hyper does not automatically convert to punycode
  521. //web_url("web2", 1, "🦀");
  522. //web_url("🦀", 1, "Ferris");
  523. }
  524. }