mgen-client.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. // Code specific to the client in the client-server mode.
  2. use mgen::updater::Updater;
  3. use mgen::{log, HandshakeRef, MessageHeader, SerializedMessage};
  4. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  5. use serde::Deserialize;
  6. use std::result::Result;
  7. use std::sync::Arc;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::TcpStream;
  10. use tokio::sync::mpsc;
  11. use tokio::task;
  12. use tokio::time::Duration;
  13. use tokio_rustls::{client::TlsStream, TlsConnector};
  14. mod messenger;
  15. use crate::messenger::dists::{ConfigDistributions, Distributions};
  16. use crate::messenger::error::{FatalError, MessengerError};
  17. use crate::messenger::state::{
  18. manage_active_conversation, manage_idle_conversation, StateMachine, StateToWriter,
  19. };
  20. use crate::messenger::tcp::{connect, SocksParams};
  21. /// Type for sending messages from the reader thread to the state thread.
  22. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  23. /// Type of messages sent to the writer thread.
  24. type MessageHolder = Box<SerializedMessage>;
  25. /// Type for getting messages from the state thread in the writer thread.
  26. type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
  27. /// Type for sending the updated read half of the socket.
  28. type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
  29. /// Type for getting the updated read half of the socket.
  30. type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
  31. /// Type for sending the updated write half of the socket.
  32. type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
  33. /// Type for getting the updated write half of the socket.
  34. type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
  35. /// Type for sending errors to other threads.
  36. type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
  37. /// Type for getting errors from other threads.
  38. type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
  39. /// Type for sending sizes to the attachment sender thread.
  40. type SizeChannelIn = mpsc::UnboundedSender<usize>;
  41. /// Type for getting sizes from other threads.
  42. type SizeChannelOut = mpsc::UnboundedReceiver<usize>;
  43. // we gain a (very) tiny performance win by not bothering to validate the cert
  44. struct NoCertificateVerification {}
  45. impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
  46. fn verify_server_cert(
  47. &self,
  48. _end_entity: &tokio_rustls::rustls::Certificate,
  49. _intermediates: &[tokio_rustls::rustls::Certificate],
  50. _server_name: &tokio_rustls::rustls::ServerName,
  51. _scts: &mut dyn Iterator<Item = &[u8]>,
  52. _ocsp: &[u8],
  53. _now: std::time::SystemTime,
  54. ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
  55. Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
  56. }
  57. }
  58. /// Create a URL the web server can use to accept or produce traffic.
  59. /// `target` is the IP or host name of the web server,
  60. /// `size` is the number of bytes to download or upload,
  61. /// `user` is to let the server log the user making the request.
  62. /// Panics if the arguments do not produce a valid URI.
  63. fn web_url(target: &str, size: usize, user: &str) -> hyper::Uri {
  64. let formatted = format!("https://{}/?size={}&user={}", target, size, user);
  65. formatted
  66. .parse()
  67. .unwrap_or_else(|_| panic!("Invalid URI: {}", formatted))
  68. }
  69. /// The thread responsible for getting incoming messages,
  70. /// checking for any network errors while doing so,
  71. /// and giving messages to the state thread.
  72. async fn reader(
  73. web_params: SocksParams,
  74. retry: Duration,
  75. tls_config: tokio_rustls::rustls::ClientConfig,
  76. message_channel: ReaderToState,
  77. socket_updater: ReadSocketUpdaterOut,
  78. error_channel: ErrorChannelIn,
  79. ) {
  80. let https = hyper_rustls::HttpsConnectorBuilder::new()
  81. .with_tls_config(tls_config)
  82. .https_only()
  83. .enable_http1()
  84. .build();
  85. match web_params.socks {
  86. Some(proxy) => {
  87. let auth = hyper_socks2::Auth {
  88. username: web_params.user.clone(),
  89. password: web_params.recipient,
  90. };
  91. let socks = hyper_socks2::SocksConnector {
  92. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  93. auth: Some(auth),
  94. connector: https,
  95. };
  96. let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(socks);
  97. worker(
  98. web_params.target,
  99. web_params.user,
  100. retry,
  101. client,
  102. message_channel,
  103. socket_updater,
  104. error_channel,
  105. )
  106. .await
  107. }
  108. None => {
  109. let client = hyper::Client::builder().build(https);
  110. worker(
  111. web_params.target,
  112. web_params.user,
  113. retry,
  114. client,
  115. message_channel,
  116. socket_updater,
  117. error_channel,
  118. )
  119. .await
  120. }
  121. }
  122. async fn worker<C>(
  123. target: String,
  124. user: String,
  125. retry: Duration,
  126. client: hyper::Client<C, hyper::Body>,
  127. message_channel: ReaderToState,
  128. mut socket_updater: ReadSocketUpdaterOut,
  129. error_channel: ErrorChannelIn,
  130. ) where
  131. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  132. {
  133. loop {
  134. let mut message_stream = socket_updater.recv().await;
  135. loop {
  136. let msg = match mgen::get_message(&mut message_stream).await {
  137. Ok(msg) => msg,
  138. Err(e) => {
  139. error_channel.send(e.into()).expect("Error channel closed");
  140. break;
  141. }
  142. };
  143. if msg.body.has_attachment() {
  144. let url = web_url(&target, msg.body.total_size(), &user);
  145. let client = client.clone();
  146. tokio::spawn(async move {
  147. let mut res = client.get(url.clone()).await;
  148. while res.is_err() {
  149. log!("Error fetching: {}", res.unwrap_err());
  150. tokio::time::sleep(retry).await;
  151. res = client.get(url.clone()).await;
  152. }
  153. });
  154. }
  155. message_channel
  156. .send(msg)
  157. .expect("Reader message channel closed");
  158. }
  159. }
  160. }
  161. }
  162. async fn uploader(
  163. web_params: SocksParams,
  164. retry: Duration,
  165. tls_config: tokio_rustls::rustls::ClientConfig,
  166. size_channel: SizeChannelOut,
  167. ) {
  168. let https = hyper_rustls::HttpsConnectorBuilder::new()
  169. .with_tls_config(tls_config)
  170. .https_only()
  171. .enable_http1()
  172. .build();
  173. match web_params.socks {
  174. Some(proxy) => {
  175. let auth = hyper_socks2::Auth {
  176. username: web_params.user.clone(),
  177. password: web_params.recipient,
  178. };
  179. let socks = hyper_socks2::SocksConnector {
  180. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  181. auth: Some(auth),
  182. connector: https,
  183. };
  184. let client = hyper::Client::builder().build(socks);
  185. worker(
  186. web_params.target,
  187. web_params.user,
  188. retry,
  189. client,
  190. size_channel,
  191. )
  192. .await
  193. }
  194. None => {
  195. let client = hyper::Client::builder().build(https);
  196. worker(
  197. web_params.target,
  198. web_params.user,
  199. retry,
  200. client,
  201. size_channel,
  202. )
  203. .await
  204. }
  205. }
  206. async fn worker<C>(
  207. target: String,
  208. user: String,
  209. retry: Duration,
  210. client: hyper::Client<C, hyper::Body>,
  211. mut size_channel: SizeChannelOut,
  212. ) where
  213. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  214. {
  215. loop {
  216. let size = size_channel.recv().await.expect("Size channel closed");
  217. let client = client.clone();
  218. let url = web_url(&target, size, &user);
  219. let request = hyper::Request::put(url.clone())
  220. .body(hyper::Body::empty())
  221. .expect("Invalid HTTP request attempted to construct");
  222. let mut res = client.request(request).await;
  223. while res.is_err() {
  224. log!("Error uploading: {}", res.unwrap_err());
  225. tokio::time::sleep(retry).await;
  226. res = client.get(url.clone()).await;
  227. }
  228. }
  229. }
  230. }
  231. /// The thread responsible for sending messages from the state thread,
  232. /// and checking for any network errors while doing so.
  233. async fn writer(
  234. mut message_channel: WriterFromState,
  235. attachment_channel: SizeChannelIn,
  236. mut socket_updater: WriteSocketUpdaterOut,
  237. error_channel: ErrorChannelIn,
  238. ) {
  239. loop {
  240. let mut stream = socket_updater.recv().await;
  241. loop {
  242. let msg = message_channel
  243. .recv()
  244. .await
  245. .expect("Writer message channel closed");
  246. if msg.body.has_attachment() {
  247. attachment_channel
  248. .send(msg.body.total_size())
  249. .expect("Attachment channel closed");
  250. }
  251. if let Err(e) = msg.write_all_to(&mut stream).await {
  252. error_channel.send(e.into()).expect("Error channel closed");
  253. break;
  254. }
  255. }
  256. }
  257. }
  258. /// The thread responsible for (re-)establishing connections to the server,
  259. /// and determining how to handle errors this or other threads receive.
  260. async fn socket_updater(
  261. str_params: SocksParams,
  262. retry: Duration,
  263. tls_config: tokio_rustls::rustls::ClientConfig,
  264. mut error_channel: ErrorChannelOut,
  265. reader_channel: ReadSocketUpdaterIn,
  266. writer_channel: WriteSocketUpdaterIn,
  267. ) -> FatalError {
  268. let connector = TlsConnector::from(Arc::new(tls_config));
  269. // unwrap is safe, split always returns at least one element
  270. let tls_server_str = str_params.target.split(':').next().unwrap();
  271. let tls_server_name =
  272. tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
  273. loop {
  274. let stream: TcpStream = match connect(&str_params).await {
  275. Ok(stream) => stream,
  276. Err(MessengerError::Recoverable(_)) => {
  277. tokio::time::sleep(retry).await;
  278. continue;
  279. }
  280. Err(MessengerError::Fatal(e)) => return e,
  281. };
  282. let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
  283. Ok(stream) => stream,
  284. Err(_) => {
  285. tokio::time::sleep(retry).await;
  286. continue;
  287. }
  288. };
  289. let handshake = HandshakeRef {
  290. sender: &str_params.user,
  291. group: &str_params.recipient,
  292. };
  293. if stream.write_all(&handshake.serialize()).await.is_err() {
  294. continue;
  295. }
  296. let (rd, wr) = split(stream);
  297. reader_channel.send(rd);
  298. writer_channel.send(wr);
  299. let res = error_channel.recv().await.expect("Error channel closed");
  300. if let MessengerError::Fatal(e) = res {
  301. return e;
  302. }
  303. }
  304. }
  305. /// The thread responsible for handling the conversation state
  306. /// (i.e., whether the user is active or idle, and when to send messages).
  307. /// Spawns all other threads for this conversation.
  308. async fn manage_conversation(config: FullConfig) -> Result<(), MessengerError> {
  309. let mut rng = Xoshiro256PlusPlus::from_entropy();
  310. let message_server_params = SocksParams {
  311. socks: config.socks.clone(),
  312. target: config.message_server,
  313. user: config.user.clone(),
  314. recipient: config.group.clone(),
  315. };
  316. let web_server_params = SocksParams {
  317. socks: config.socks,
  318. target: config.web_server,
  319. user: config.user.clone(),
  320. recipient: config.group.clone(),
  321. };
  322. let mut state_machine = StateMachine::start(config.distributions, &mut rng);
  323. let (reader_to_state, mut state_from_reader) = mpsc::unbounded_channel();
  324. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  325. let read_socket_updater_in = Updater::new();
  326. let read_socket_updater_out = read_socket_updater_in.clone();
  327. let write_socket_updater_in = Updater::new();
  328. let write_socket_updater_out = write_socket_updater_in.clone();
  329. let (errs_in, errs_out) = mpsc::unbounded_channel();
  330. let (writer_to_uploader, uploader_from_writer) = mpsc::unbounded_channel();
  331. let retry = Duration::from_secs_f64(config.retry);
  332. let tls_config = tokio_rustls::rustls::ClientConfig::builder()
  333. .with_safe_defaults()
  334. .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
  335. .with_no_client_auth();
  336. tokio::spawn(reader(
  337. web_server_params.clone(),
  338. retry,
  339. tls_config.clone(),
  340. reader_to_state,
  341. read_socket_updater_out,
  342. errs_in.clone(),
  343. ));
  344. tokio::spawn(writer(
  345. writer_from_state,
  346. writer_to_uploader,
  347. write_socket_updater_out,
  348. errs_in,
  349. ));
  350. tokio::spawn(uploader(
  351. web_server_params,
  352. retry,
  353. tls_config.clone(),
  354. uploader_from_writer,
  355. ));
  356. tokio::spawn(socket_updater(
  357. message_server_params,
  358. retry,
  359. tls_config,
  360. errs_out,
  361. read_socket_updater_in,
  362. write_socket_updater_in,
  363. ));
  364. tokio::time::sleep(Duration::from_secs_f64(config.bootstrap)).await;
  365. let mut state_to_writer = StateToWriter {
  366. channel: state_to_writer,
  367. };
  368. loop {
  369. state_machine = match state_machine {
  370. StateMachine::Idle(conversation) => {
  371. manage_idle_conversation::<false, _, _, _>(
  372. conversation,
  373. &mut state_from_reader,
  374. &mut state_to_writer,
  375. &config.user,
  376. &config.group,
  377. &mut rng,
  378. )
  379. .await
  380. }
  381. StateMachine::Active(conversation) => {
  382. manage_active_conversation(
  383. conversation,
  384. &mut state_from_reader,
  385. &mut state_to_writer,
  386. &config.user,
  387. &config.group,
  388. false,
  389. &mut rng,
  390. )
  391. .await
  392. }
  393. };
  394. }
  395. }
  396. struct FullConfig {
  397. user: String,
  398. group: String,
  399. socks: Option<String>,
  400. message_server: String,
  401. web_server: String,
  402. bootstrap: f64,
  403. retry: f64,
  404. distributions: Distributions,
  405. }
  406. #[derive(Debug, Deserialize)]
  407. struct ConversationConfig {
  408. group: String,
  409. message_server: Option<String>,
  410. web_server: Option<String>,
  411. bootstrap: Option<f64>,
  412. retry: Option<f64>,
  413. distributions: Option<ConfigDistributions>,
  414. }
  415. #[derive(Debug, Deserialize)]
  416. struct Config {
  417. user: String,
  418. socks: Option<String>,
  419. message_server: String,
  420. web_server: String,
  421. bootstrap: f64,
  422. retry: f64,
  423. distributions: ConfigDistributions,
  424. conversations: Vec<ConversationConfig>,
  425. }
  426. #[tokio::main]
  427. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  428. let mut args = std::env::args();
  429. let _ = args.next();
  430. let mut handles = vec![];
  431. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  432. let yaml_s = std::fs::read_to_string(config_file?)?;
  433. let config: Config = serde_yaml::from_str(&yaml_s)?;
  434. let default_dists: Distributions = config.distributions.try_into()?;
  435. for conversation in config.conversations.into_iter() {
  436. let distributions: Distributions = match conversation.distributions {
  437. Some(dists) => dists.try_into()?,
  438. None => default_dists.clone(),
  439. };
  440. let filled_conversation = FullConfig {
  441. user: config.user.clone(),
  442. group: conversation.group,
  443. socks: config.socks.clone(),
  444. message_server: conversation
  445. .message_server
  446. .unwrap_or_else(|| config.message_server.clone()),
  447. web_server: conversation
  448. .web_server
  449. .unwrap_or_else(|| config.web_server.clone()),
  450. bootstrap: conversation.bootstrap.unwrap_or(config.bootstrap),
  451. retry: conversation.retry.unwrap_or(config.retry),
  452. distributions,
  453. };
  454. let handle: task::JoinHandle<Result<(), MessengerError>> =
  455. tokio::spawn(manage_conversation(filled_conversation));
  456. handles.push(handle);
  457. }
  458. }
  459. let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
  460. for handle in handles {
  461. handle.await??;
  462. }
  463. Ok(())
  464. }
  465. #[cfg(test)]
  466. mod tests {
  467. use super::*;
  468. #[test]
  469. fn test_uri_generation() {
  470. // should panic if any of these are invalid
  471. web_url("192.0.2.1", 0, "Alice");
  472. web_url("hostname", 65536, "Bob");
  473. web_url("web0", 4294967295, "Carol");
  474. web_url("web1", 1, "");
  475. web_url("foo.bar.baz", 1, "Dave");
  476. // IPv6 is not a valid in a URI
  477. //web_url("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 1, "1");
  478. // hyper does not automatically convert to punycode
  479. //web_url("web2", 1, "🦀");
  480. //web_url("🦀", 1, "Ferris");
  481. }
  482. }