mgen-client.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. // Code specific to the client in the client-server mode.
  2. use futures::future::try_join_all;
  3. use mgen::updater::Updater;
  4. use mgen::{log, HandshakeRef, MessageHeader, SerializedMessage};
  5. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  6. use serde::Deserialize;
  7. use std::result::Result;
  8. use std::sync::Arc;
  9. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  10. use tokio::join;
  11. use tokio::net::TcpStream;
  12. use tokio::sync::mpsc;
  13. use tokio::time::Duration;
  14. use tokio_rustls::{client::TlsStream, TlsConnector};
  15. mod messenger;
  16. use crate::messenger::dists::{ConfigDistributions, Distributions};
  17. use crate::messenger::error::{FatalError, MessengerError};
  18. use crate::messenger::state::{
  19. manage_active_conversation, manage_idle_conversation, StateFromReader, StateMachine,
  20. StateToWriter,
  21. };
  22. use crate::messenger::tcp::{connect, SocksParams};
  23. /// Type for sending messages from the reader thread to the state thread.
  24. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  25. /// Type of messages sent to the writer thread.
  26. type MessageHolder = Box<SerializedMessage>;
  27. /// Type for getting messages from the state thread in the writer thread.
  28. type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
  29. /// Type for sending the updated read half of the socket.
  30. type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
  31. /// Type for getting the updated read half of the socket.
  32. type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
  33. /// Type for sending the updated write half of the socket.
  34. type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
  35. /// Type for getting the updated write half of the socket.
  36. type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
  37. /// Type for sending errors to other threads.
  38. type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
  39. /// Type for getting errors from other threads.
  40. type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
  41. /// Type for sending sizes to the attachment sender thread.
  42. type SizeChannelIn = mpsc::UnboundedSender<usize>;
  43. /// Type for getting sizes from other threads.
  44. type SizeChannelOut = mpsc::UnboundedReceiver<usize>;
  45. // we gain a (very) tiny performance win by not bothering to validate the cert
  46. struct NoCertificateVerification {}
  47. impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
  48. fn verify_server_cert(
  49. &self,
  50. _end_entity: &tokio_rustls::rustls::Certificate,
  51. _intermediates: &[tokio_rustls::rustls::Certificate],
  52. _server_name: &tokio_rustls::rustls::ServerName,
  53. _scts: &mut dyn Iterator<Item = &[u8]>,
  54. _ocsp: &[u8],
  55. _now: std::time::SystemTime,
  56. ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
  57. Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
  58. }
  59. }
  60. /// Create a URL the web server can use to accept or produce traffic.
  61. /// `target` is the IP or host name of the web server,
  62. /// `size` is the number of bytes to download or upload,
  63. /// `user` is to let the server log the user making the request.
  64. /// Panics if the arguments do not produce a valid URI.
  65. fn web_url(target: &str, size: usize, user: &str) -> hyper::Uri {
  66. let formatted = format!("https://{}/?size={}&user={}", target, size, user);
  67. formatted
  68. .parse()
  69. .unwrap_or_else(|_| panic!("Invalid URI: {}", formatted))
  70. }
  71. /// The thread responsible for getting incoming messages,
  72. /// checking for any network errors while doing so,
  73. /// and giving messages to the state thread.
  74. async fn reader(
  75. web_params: SocksParams,
  76. retry: Duration,
  77. tls_config: tokio_rustls::rustls::ClientConfig,
  78. message_channel: ReaderToState,
  79. socket_updater: ReadSocketUpdaterOut,
  80. error_channel: ErrorChannelIn,
  81. ) {
  82. let https = hyper_rustls::HttpsConnectorBuilder::new()
  83. .with_tls_config(tls_config)
  84. .https_only()
  85. .enable_http1()
  86. .build();
  87. match web_params.socks {
  88. Some(proxy) => {
  89. let auth = hyper_socks2::Auth {
  90. username: web_params.user.clone(),
  91. password: web_params.recipient,
  92. };
  93. let socks = hyper_socks2::SocksConnector {
  94. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  95. auth: Some(auth),
  96. connector: https,
  97. };
  98. let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(socks);
  99. worker(
  100. web_params.target,
  101. web_params.user,
  102. retry,
  103. client,
  104. message_channel,
  105. socket_updater,
  106. error_channel,
  107. )
  108. .await
  109. }
  110. None => {
  111. let client = hyper::Client::builder().build(https);
  112. worker(
  113. web_params.target,
  114. web_params.user,
  115. retry,
  116. client,
  117. message_channel,
  118. socket_updater,
  119. error_channel,
  120. )
  121. .await
  122. }
  123. }
  124. async fn worker<C>(
  125. target: String,
  126. user: String,
  127. retry: Duration,
  128. client: hyper::Client<C, hyper::Body>,
  129. message_channel: ReaderToState,
  130. mut socket_updater: ReadSocketUpdaterOut,
  131. error_channel: ErrorChannelIn,
  132. ) where
  133. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  134. {
  135. loop {
  136. let mut message_stream = socket_updater.recv().await;
  137. loop {
  138. let msg = match mgen::get_message(&mut message_stream).await {
  139. Ok(msg) => msg,
  140. Err(e) => {
  141. error_channel.send(e.into()).expect("Error channel closed");
  142. break;
  143. }
  144. };
  145. if msg.body.has_attachment() {
  146. let url = web_url(&target, msg.body.total_size(), &user);
  147. let client = client.clone();
  148. tokio::spawn(async move {
  149. let mut res = client.get(url.clone()).await;
  150. while res.is_err() {
  151. log!("Error fetching: {}", res.unwrap_err());
  152. tokio::time::sleep(retry).await;
  153. res = client.get(url.clone()).await;
  154. }
  155. });
  156. }
  157. message_channel
  158. .send(msg)
  159. .expect("Reader message channel closed");
  160. }
  161. }
  162. }
  163. }
  164. async fn uploader(
  165. web_params: SocksParams,
  166. retry: Duration,
  167. tls_config: tokio_rustls::rustls::ClientConfig,
  168. size_channel: SizeChannelOut,
  169. ) {
  170. let https = hyper_rustls::HttpsConnectorBuilder::new()
  171. .with_tls_config(tls_config)
  172. .https_only()
  173. .enable_http1()
  174. .build();
  175. match web_params.socks {
  176. Some(proxy) => {
  177. let auth = hyper_socks2::Auth {
  178. username: web_params.user.clone(),
  179. password: web_params.recipient,
  180. };
  181. let socks = hyper_socks2::SocksConnector {
  182. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  183. auth: Some(auth),
  184. connector: https,
  185. };
  186. let client = hyper::Client::builder().build(socks);
  187. worker(
  188. web_params.target,
  189. web_params.user,
  190. retry,
  191. client,
  192. size_channel,
  193. )
  194. .await
  195. }
  196. None => {
  197. let client = hyper::Client::builder().build(https);
  198. worker(
  199. web_params.target,
  200. web_params.user,
  201. retry,
  202. client,
  203. size_channel,
  204. )
  205. .await
  206. }
  207. }
  208. async fn worker<C>(
  209. target: String,
  210. user: String,
  211. retry: Duration,
  212. client: hyper::Client<C, hyper::Body>,
  213. mut size_channel: SizeChannelOut,
  214. ) where
  215. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  216. {
  217. loop {
  218. let size = size_channel.recv().await.expect("Size channel closed");
  219. let client = client.clone();
  220. let url = web_url(&target, size, &user);
  221. let request = hyper::Request::put(url.clone())
  222. .body(hyper::Body::empty())
  223. .expect("Invalid HTTP request attempted to construct");
  224. let mut res = client.request(request).await;
  225. while res.is_err() {
  226. log!("Error uploading: {}", res.unwrap_err());
  227. tokio::time::sleep(retry).await;
  228. res = client.get(url.clone()).await;
  229. }
  230. }
  231. }
  232. }
  233. /// The thread responsible for sending messages from the state thread,
  234. /// and checking for any network errors while doing so.
  235. async fn writer(
  236. mut message_channel: WriterFromState,
  237. attachment_channel: SizeChannelIn,
  238. mut socket_updater: WriteSocketUpdaterOut,
  239. error_channel: ErrorChannelIn,
  240. ) {
  241. loop {
  242. let mut stream = socket_updater.recv().await;
  243. loop {
  244. let msg = message_channel
  245. .recv()
  246. .await
  247. .expect("Writer message channel closed");
  248. if msg.body.has_attachment() {
  249. attachment_channel
  250. .send(msg.body.total_size())
  251. .expect("Attachment channel closed");
  252. }
  253. if let Err(e) = msg.write_all_to(&mut stream).await {
  254. error_channel.send(e.into()).expect("Error channel closed");
  255. break;
  256. }
  257. }
  258. }
  259. }
  260. /// The thread responsible for (re-)establishing connections to the server,
  261. /// and determining how to handle errors this or other threads receive.
  262. async fn socket_updater(
  263. str_params: SocksParams,
  264. retry: Duration,
  265. tls_config: tokio_rustls::rustls::ClientConfig,
  266. mut error_channel: ErrorChannelOut,
  267. reader_channel: ReadSocketUpdaterIn,
  268. writer_channel: WriteSocketUpdaterIn,
  269. ) -> FatalError {
  270. let connector = TlsConnector::from(Arc::new(tls_config));
  271. // unwrap is safe, split always returns at least one element
  272. let tls_server_str = str_params.target.split(':').next().unwrap();
  273. let tls_server_name =
  274. tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
  275. loop {
  276. let stream: TcpStream = match connect(&str_params).await {
  277. Ok(stream) => stream,
  278. Err(MessengerError::Recoverable(_)) => {
  279. tokio::time::sleep(retry).await;
  280. continue;
  281. }
  282. Err(MessengerError::Fatal(e)) => return e,
  283. };
  284. let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
  285. Ok(stream) => stream,
  286. Err(_) => {
  287. tokio::time::sleep(retry).await;
  288. continue;
  289. }
  290. };
  291. let handshake = HandshakeRef {
  292. sender: &str_params.user,
  293. group: &str_params.recipient,
  294. };
  295. if stream.write_all(&handshake.serialize()).await.is_err() {
  296. continue;
  297. }
  298. log!("{},{},handshake", str_params.user, str_params.recipient);
  299. let (rd, wr) = split(stream);
  300. reader_channel.send(rd);
  301. writer_channel.send(wr);
  302. let res = error_channel.recv().await.expect("Error channel closed");
  303. if let MessengerError::Fatal(e) = res {
  304. return e;
  305. }
  306. }
  307. }
  308. /// The thread responsible for handling the conversation state
  309. /// (i.e., whether the user is active or idle, and when to send messages).
  310. async fn manage_conversation(
  311. config: FullConfig,
  312. mut state_from_reader: StateFromReader,
  313. mut state_to_writer: StateToWriter<MessageHolder>,
  314. ) {
  315. tokio::time::sleep(Duration::from_secs_f64(config.bootstrap)).await;
  316. log!("{},{},awake", &config.user, &config.group);
  317. let mut rng = Xoshiro256PlusPlus::from_entropy();
  318. let mut state_machine = StateMachine::start(config.distributions, &mut rng);
  319. loop {
  320. state_machine = match state_machine {
  321. StateMachine::Idle(conversation) => {
  322. manage_idle_conversation::<false, _, _, _>(
  323. conversation,
  324. &mut state_from_reader,
  325. &mut state_to_writer,
  326. &config.user,
  327. &config.group,
  328. &mut rng,
  329. )
  330. .await
  331. }
  332. StateMachine::Active(conversation) => {
  333. manage_active_conversation(
  334. conversation,
  335. &mut state_from_reader,
  336. &mut state_to_writer,
  337. &config.user,
  338. &config.group,
  339. false,
  340. &mut rng,
  341. )
  342. .await
  343. }
  344. };
  345. }
  346. }
  347. /// Spawns all other threads for this conversation.
  348. async fn spawn_threads(config: FullConfig) -> Result<(), MessengerError> {
  349. let message_server_params = SocksParams {
  350. socks: config.socks.clone(),
  351. target: config.message_server.clone(),
  352. user: config.user.clone(),
  353. recipient: config.group.clone(),
  354. };
  355. let web_server_params = SocksParams {
  356. socks: config.socks.clone(),
  357. target: config.web_server.clone(),
  358. user: config.user.clone(),
  359. recipient: config.group.clone(),
  360. };
  361. let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
  362. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  363. let read_socket_updater_in = Updater::new();
  364. let read_socket_updater_out = read_socket_updater_in.clone();
  365. let write_socket_updater_in = Updater::new();
  366. let write_socket_updater_out = write_socket_updater_in.clone();
  367. let (errs_in, errs_out) = mpsc::unbounded_channel();
  368. let (writer_to_uploader, uploader_from_writer) = mpsc::unbounded_channel();
  369. let state_to_writer = StateToWriter {
  370. channel: state_to_writer,
  371. };
  372. let retry = Duration::from_secs_f64(config.retry);
  373. let tls_config = tokio_rustls::rustls::ClientConfig::builder()
  374. .with_safe_defaults()
  375. .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
  376. .with_no_client_auth();
  377. let reader = reader(
  378. web_server_params.clone(),
  379. retry,
  380. tls_config.clone(),
  381. reader_to_state,
  382. read_socket_updater_out,
  383. errs_in.clone(),
  384. );
  385. let writer = writer(
  386. writer_from_state,
  387. writer_to_uploader,
  388. write_socket_updater_out,
  389. errs_in,
  390. );
  391. let uploader = uploader(
  392. web_server_params,
  393. retry,
  394. tls_config.clone(),
  395. uploader_from_writer,
  396. );
  397. let updater = socket_updater(
  398. message_server_params,
  399. retry,
  400. tls_config,
  401. errs_out,
  402. read_socket_updater_in,
  403. write_socket_updater_in,
  404. );
  405. let manager = manage_conversation(config, state_from_reader, state_to_writer);
  406. join!(reader, writer, uploader, updater, manager);
  407. Ok(())
  408. }
  409. struct FullConfig {
  410. user: String,
  411. group: String,
  412. socks: Option<String>,
  413. message_server: String,
  414. web_server: String,
  415. bootstrap: f64,
  416. retry: f64,
  417. distributions: Distributions,
  418. }
  419. #[derive(Debug, Deserialize)]
  420. struct ConversationConfig {
  421. group: String,
  422. message_server: Option<String>,
  423. web_server: Option<String>,
  424. bootstrap: Option<f64>,
  425. retry: Option<f64>,
  426. distributions: Option<ConfigDistributions>,
  427. }
  428. #[derive(Debug, Deserialize)]
  429. struct Config {
  430. user: String,
  431. socks: Option<String>,
  432. message_server: String,
  433. web_server: String,
  434. bootstrap: f64,
  435. retry: f64,
  436. distributions: ConfigDistributions,
  437. conversations: Vec<ConversationConfig>,
  438. }
  439. #[tokio::main(flavor = "current_thread")]
  440. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  441. #[cfg(feature = "tracing")]
  442. console_subscriber::init();
  443. let mut args = std::env::args();
  444. let _ = args.next();
  445. let mut handles = vec![];
  446. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  447. let yaml_s = std::fs::read_to_string(config_file?)?;
  448. let config: Config = serde_yaml::from_str(&yaml_s)?;
  449. let default_dists: Distributions = config.distributions.try_into()?;
  450. for conversation in config.conversations.into_iter() {
  451. let distributions: Distributions = match conversation.distributions {
  452. Some(dists) => dists.try_into()?,
  453. None => default_dists.clone(),
  454. };
  455. let filled_conversation = FullConfig {
  456. user: config.user.clone(),
  457. group: conversation.group,
  458. socks: config.socks.clone(),
  459. message_server: conversation
  460. .message_server
  461. .unwrap_or_else(|| config.message_server.clone()),
  462. web_server: conversation
  463. .web_server
  464. .unwrap_or_else(|| config.web_server.clone()),
  465. bootstrap: conversation.bootstrap.unwrap_or(config.bootstrap),
  466. retry: conversation.retry.unwrap_or(config.retry),
  467. distributions,
  468. };
  469. let handle = spawn_threads(filled_conversation);
  470. handles.push(handle);
  471. }
  472. }
  473. try_join_all(handles.into_iter()).await?;
  474. Ok(())
  475. }
  476. #[cfg(test)]
  477. mod tests {
  478. use super::*;
  479. #[test]
  480. fn test_uri_generation() {
  481. // should panic if any of these are invalid
  482. web_url("192.0.2.1", 0, "Alice");
  483. web_url("hostname", 65536, "Bob");
  484. web_url("web0", 4294967295, "Carol");
  485. web_url("web1", 1, "");
  486. web_url("foo.bar.baz", 1, "Dave");
  487. // IPv6 is not a valid in a URI
  488. //web_url("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 1, "1");
  489. // hyper does not automatically convert to punycode
  490. //web_url("web2", 1, "🦀");
  491. //web_url("🦀", 1, "Ferris");
  492. }
  493. }