mgen-client.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. // Code specific to the client in the client-server mode.
  2. use mgen::updater::Updater;
  3. use mgen::{log, HandshakeRef, MessageHeader, SerializedMessage};
  4. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  5. use serde::Deserialize;
  6. use std::result::Result;
  7. use std::sync::Arc;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::TcpStream;
  10. use tokio::sync::mpsc;
  11. use tokio::task;
  12. use tokio::time::Duration;
  13. use tokio_rustls::{client::TlsStream, TlsConnector};
  14. mod messenger;
  15. use crate::messenger::dists::{ConfigDistributions, Distributions};
  16. use crate::messenger::error::{FatalError, MessengerError};
  17. use crate::messenger::state::{
  18. manage_active_conversation, manage_idle_conversation, StateMachine, StateToWriter,
  19. };
  20. use crate::messenger::tcp::{connect, SocksParams};
  21. /// Type for sending messages from the reader thread to the state thread.
  22. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  23. /// Type of messages sent to the writer thread.
  24. type MessageHolder = Box<SerializedMessage>;
  25. /// Type for getting messages from the state thread in the writer thread.
  26. type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
  27. /// Type for sending the updated read half of the socket.
  28. type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
  29. /// Type for getting the updated read half of the socket.
  30. type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
  31. /// Type for sending the updated write half of the socket.
  32. type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
  33. /// Type for getting the updated write half of the socket.
  34. type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
  35. /// Type for sending errors to other threads.
  36. type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
  37. /// Type for getting errors from other threads.
  38. type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
  39. /// Type for sending sizes to the attachment sender thread.
  40. type SizeChannelIn = mpsc::UnboundedSender<usize>;
  41. /// Type for getting sizes from other threads.
  42. type SizeChannelOut = mpsc::UnboundedReceiver<usize>;
  43. // we gain a (very) tiny performance win by not bothering to validate the cert
  44. struct NoCertificateVerification {}
  45. impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
  46. fn verify_server_cert(
  47. &self,
  48. _end_entity: &tokio_rustls::rustls::Certificate,
  49. _intermediates: &[tokio_rustls::rustls::Certificate],
  50. _server_name: &tokio_rustls::rustls::ServerName,
  51. _scts: &mut dyn Iterator<Item = &[u8]>,
  52. _ocsp: &[u8],
  53. _now: std::time::SystemTime,
  54. ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
  55. Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
  56. }
  57. }
  58. /// The thread responsible for getting incoming messages,
  59. /// checking for any network errors while doing so,
  60. /// and giving messages to the state thread.
  61. async fn reader(
  62. web_params: SocksParams,
  63. retry: Duration,
  64. tls_config: tokio_rustls::rustls::ClientConfig,
  65. message_channel: ReaderToState,
  66. socket_updater: ReadSocketUpdaterOut,
  67. error_channel: ErrorChannelIn,
  68. ) {
  69. let https = hyper_rustls::HttpsConnectorBuilder::new()
  70. .with_tls_config(tls_config)
  71. .https_only()
  72. .enable_http1()
  73. .build();
  74. match web_params.socks {
  75. Some(proxy) => {
  76. let auth = hyper_socks2::Auth {
  77. username: web_params.user.clone(),
  78. password: web_params.recipient,
  79. };
  80. let socks = hyper_socks2::SocksConnector {
  81. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  82. auth: Some(auth),
  83. connector: https,
  84. };
  85. let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(socks);
  86. worker(
  87. web_params.target,
  88. web_params.user,
  89. retry,
  90. client,
  91. message_channel,
  92. socket_updater,
  93. error_channel,
  94. )
  95. .await
  96. }
  97. None => {
  98. let client = hyper::Client::builder().build(https);
  99. worker(
  100. web_params.target,
  101. web_params.user,
  102. retry,
  103. client,
  104. message_channel,
  105. socket_updater,
  106. error_channel,
  107. )
  108. .await
  109. }
  110. }
  111. async fn worker<C>(
  112. target: String,
  113. user: String,
  114. retry: Duration,
  115. client: hyper::Client<C, hyper::Body>,
  116. message_channel: ReaderToState,
  117. mut socket_updater: ReadSocketUpdaterOut,
  118. error_channel: ErrorChannelIn,
  119. ) where
  120. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  121. {
  122. loop {
  123. let mut message_stream = socket_updater.recv().await;
  124. loop {
  125. let msg = match mgen::get_message(&mut message_stream).await {
  126. Ok(msg) => msg,
  127. Err(e) => {
  128. error_channel.send(e.into()).expect("Error channel closed");
  129. break;
  130. }
  131. };
  132. if msg.body.has_attachment() {
  133. let url: hyper::Uri =
  134. format!("{}/?size={}&user={}", target, msg.body.total_size(), user)
  135. .parse()
  136. .expect("Invalid URI");
  137. let client = client.clone();
  138. tokio::spawn(async move {
  139. let mut res = client.get(url.clone()).await;
  140. while res.is_err() {
  141. log!("Error fetching: {}", res.unwrap_err());
  142. tokio::time::sleep(retry).await;
  143. res = client.get(url.clone()).await;
  144. }
  145. });
  146. }
  147. message_channel
  148. .send(msg)
  149. .expect("Reader message channel closed");
  150. }
  151. }
  152. }
  153. }
  154. async fn uploader(
  155. web_params: SocksParams,
  156. retry: Duration,
  157. tls_config: tokio_rustls::rustls::ClientConfig,
  158. size_channel: SizeChannelOut,
  159. ) {
  160. let https = hyper_rustls::HttpsConnectorBuilder::new()
  161. .with_tls_config(tls_config)
  162. .https_only()
  163. .enable_http1()
  164. .build();
  165. match web_params.socks {
  166. Some(proxy) => {
  167. let auth = hyper_socks2::Auth {
  168. username: web_params.user.clone(),
  169. password: web_params.recipient,
  170. };
  171. let socks = hyper_socks2::SocksConnector {
  172. proxy_addr: proxy.parse().expect("Invalid proxy URI"),
  173. auth: Some(auth),
  174. connector: https,
  175. };
  176. let client = hyper::Client::builder().build(socks);
  177. worker(
  178. web_params.target,
  179. web_params.user,
  180. retry,
  181. client,
  182. size_channel,
  183. )
  184. .await
  185. }
  186. None => {
  187. let client = hyper::Client::builder().build(https);
  188. worker(
  189. web_params.target,
  190. web_params.user,
  191. retry,
  192. client,
  193. size_channel,
  194. )
  195. .await
  196. }
  197. }
  198. async fn worker<C>(
  199. target: String,
  200. user: String,
  201. retry: Duration,
  202. client: hyper::Client<C, hyper::Body>,
  203. mut size_channel: SizeChannelOut,
  204. ) where
  205. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  206. {
  207. loop {
  208. let size = size_channel.recv().await.expect("Size channel closed");
  209. let client = client.clone();
  210. let url: hyper::Uri = format!("{}/?size={}&user={}", target, size, user)
  211. .parse()
  212. .expect("Invalid URI");
  213. let request = hyper::Request::put(url.clone())
  214. .body(hyper::Body::empty())
  215. .expect("Invalid HTTP request attempted to construct");
  216. let mut res = client.request(request).await;
  217. while res.is_err() {
  218. log!("Error uploading: {}", res.unwrap_err());
  219. tokio::time::sleep(retry).await;
  220. res = client.get(url.clone()).await;
  221. }
  222. }
  223. }
  224. }
  225. /// The thread responsible for sending messages from the state thread,
  226. /// and checking for any network errors while doing so.
  227. async fn writer(
  228. mut message_channel: WriterFromState,
  229. attachment_channel: SizeChannelIn,
  230. mut socket_updater: WriteSocketUpdaterOut,
  231. error_channel: ErrorChannelIn,
  232. ) {
  233. loop {
  234. let mut stream = socket_updater.recv().await;
  235. loop {
  236. let msg = message_channel
  237. .recv()
  238. .await
  239. .expect("Writer message channel closed");
  240. if msg.body.has_attachment() {
  241. attachment_channel
  242. .send(msg.body.total_size())
  243. .expect("Attachment channel closed");
  244. }
  245. if let Err(e) = msg.write_all_to(&mut stream).await {
  246. error_channel.send(e.into()).expect("Error channel closed");
  247. break;
  248. }
  249. }
  250. }
  251. }
  252. /// The thread responsible for (re-)establishing connections to the server,
  253. /// and determining how to handle errors this or other threads receive.
  254. async fn socket_updater(
  255. str_params: SocksParams,
  256. retry: Duration,
  257. tls_config: tokio_rustls::rustls::ClientConfig,
  258. mut error_channel: ErrorChannelOut,
  259. reader_channel: ReadSocketUpdaterIn,
  260. writer_channel: WriteSocketUpdaterIn,
  261. ) -> FatalError {
  262. let connector = TlsConnector::from(Arc::new(tls_config));
  263. // unwrap is safe, split always returns at least one element
  264. let tls_server_str = str_params.target.split(':').next().unwrap();
  265. let tls_server_name =
  266. tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
  267. loop {
  268. let stream: TcpStream = match connect(&str_params).await {
  269. Ok(stream) => stream,
  270. Err(MessengerError::Recoverable(_)) => {
  271. tokio::time::sleep(retry).await;
  272. continue;
  273. }
  274. Err(MessengerError::Fatal(e)) => return e,
  275. };
  276. let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
  277. Ok(stream) => stream,
  278. Err(_) => {
  279. tokio::time::sleep(retry).await;
  280. continue;
  281. }
  282. };
  283. let handshake = HandshakeRef {
  284. sender: &str_params.user,
  285. group: &str_params.recipient,
  286. };
  287. if stream.write_all(&handshake.serialize()).await.is_err() {
  288. continue;
  289. }
  290. let (rd, wr) = split(stream);
  291. reader_channel.send(rd);
  292. writer_channel.send(wr);
  293. let res = error_channel.recv().await.expect("Error channel closed");
  294. if let MessengerError::Fatal(e) = res {
  295. return e;
  296. }
  297. }
  298. }
  299. /// The thread responsible for handling the conversation state
  300. /// (i.e., whether the user is active or idle, and when to send messages).
  301. /// Spawns all other threads for this conversation.
  302. async fn manage_conversation(config: FullConfig) -> Result<(), MessengerError> {
  303. let mut rng = Xoshiro256PlusPlus::from_entropy();
  304. let distributions: Distributions = config.distributions.try_into()?;
  305. let message_server_params = SocksParams {
  306. socks: config.socks.clone(),
  307. target: config.message_server,
  308. user: config.user.clone(),
  309. recipient: config.group.clone(),
  310. };
  311. let web_server_params = SocksParams {
  312. socks: config.socks,
  313. target: config.web_server,
  314. user: config.user.clone(),
  315. recipient: config.group.clone(),
  316. };
  317. let mut state_machine = StateMachine::start(distributions, &mut rng);
  318. let (reader_to_state, mut state_from_reader) = mpsc::unbounded_channel();
  319. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  320. let read_socket_updater_in = Updater::new();
  321. let read_socket_updater_out = read_socket_updater_in.clone();
  322. let write_socket_updater_in = Updater::new();
  323. let write_socket_updater_out = write_socket_updater_in.clone();
  324. let (errs_in, errs_out) = mpsc::unbounded_channel();
  325. let (writer_to_uploader, uploader_from_writer) = mpsc::unbounded_channel();
  326. let retry = Duration::from_secs_f64(config.retry);
  327. let tls_config = tokio_rustls::rustls::ClientConfig::builder()
  328. .with_safe_defaults()
  329. .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
  330. .with_no_client_auth();
  331. tokio::spawn(reader(
  332. web_server_params.clone(),
  333. retry,
  334. tls_config.clone(),
  335. reader_to_state,
  336. read_socket_updater_out,
  337. errs_in.clone(),
  338. ));
  339. tokio::spawn(writer(
  340. writer_from_state,
  341. writer_to_uploader,
  342. write_socket_updater_out,
  343. errs_in,
  344. ));
  345. tokio::spawn(uploader(
  346. web_server_params,
  347. retry,
  348. tls_config.clone(),
  349. uploader_from_writer,
  350. ));
  351. tokio::spawn(socket_updater(
  352. message_server_params,
  353. retry,
  354. tls_config,
  355. errs_out,
  356. read_socket_updater_in,
  357. write_socket_updater_in,
  358. ));
  359. tokio::time::sleep(Duration::from_secs_f64(config.bootstrap)).await;
  360. let mut state_to_writer = StateToWriter {
  361. channel: state_to_writer,
  362. };
  363. loop {
  364. state_machine = match state_machine {
  365. StateMachine::Idle(conversation) => {
  366. manage_idle_conversation::<false, _, _, _>(
  367. conversation,
  368. &mut state_from_reader,
  369. &mut state_to_writer,
  370. &config.user,
  371. &config.group,
  372. &mut rng,
  373. )
  374. .await
  375. }
  376. StateMachine::Active(conversation) => {
  377. manage_active_conversation(
  378. conversation,
  379. &mut state_from_reader,
  380. &mut state_to_writer,
  381. &config.user,
  382. &config.group,
  383. false,
  384. &mut rng,
  385. )
  386. .await
  387. }
  388. };
  389. }
  390. }
  391. struct FullConfig {
  392. user: String,
  393. group: String,
  394. socks: Option<String>,
  395. message_server: String,
  396. web_server: String,
  397. bootstrap: f64,
  398. retry: f64,
  399. distributions: ConfigDistributions,
  400. }
  401. #[derive(Debug, Deserialize)]
  402. struct ConversationConfig {
  403. group: String,
  404. message_server: Option<String>,
  405. web_server: Option<String>,
  406. bootstrap: Option<f64>,
  407. retry: Option<f64>,
  408. distributions: Option<ConfigDistributions>,
  409. }
  410. #[derive(Debug, Deserialize)]
  411. struct Config {
  412. user: String,
  413. socks: Option<String>,
  414. message_server: String,
  415. web_server: String,
  416. bootstrap: f64,
  417. retry: f64,
  418. distributions: ConfigDistributions,
  419. conversations: Vec<ConversationConfig>,
  420. }
  421. #[tokio::main]
  422. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  423. let mut args = std::env::args();
  424. let _ = args.next();
  425. let mut handles = vec![];
  426. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  427. let yaml_s = std::fs::read_to_string(config_file?)?;
  428. let config: Config = serde_yaml::from_str(&yaml_s)?;
  429. for conversation in config.conversations.into_iter() {
  430. let filled_conversation = FullConfig {
  431. user: config.user.clone(),
  432. group: conversation.group,
  433. socks: config.socks.clone(),
  434. message_server: conversation
  435. .message_server
  436. .unwrap_or_else(|| config.message_server.clone()),
  437. web_server: conversation
  438. .web_server
  439. .unwrap_or_else(|| config.web_server.clone()),
  440. bootstrap: conversation.bootstrap.unwrap_or(config.bootstrap),
  441. retry: conversation.retry.unwrap_or(config.retry),
  442. distributions: conversation
  443. .distributions
  444. .unwrap_or_else(|| config.distributions.clone()),
  445. };
  446. let handle: task::JoinHandle<Result<(), MessengerError>> =
  447. tokio::spawn(manage_conversation(filled_conversation));
  448. handles.push(handle);
  449. }
  450. }
  451. let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
  452. for handle in handles {
  453. handle.await??;
  454. }
  455. Ok(())
  456. }