mgen-client.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. // Code specific to the client in the client-server mode.
  2. use mgen::{log, updater::Updater, HandshakeRef, MessageHeader, SerializedMessage};
  3. use futures::future::try_join_all;
  4. use hyper::client::HttpConnector;
  5. use hyper_rustls::HttpsConnector;
  6. use hyper_socks2::SocksConnector;
  7. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  8. use serde::Deserialize;
  9. use std::hash::{Hash, Hasher};
  10. use std::result::Result;
  11. use std::sync::Arc;
  12. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  13. use tokio::net::TcpStream;
  14. use tokio::spawn;
  15. use tokio::sync::mpsc;
  16. use tokio::time::{sleep, Duration};
  17. use tokio_rustls::{client::TlsStream, TlsConnector};
  18. mod messenger;
  19. use crate::messenger::dists::{ConfigDistributions, Distributions};
  20. use crate::messenger::error::{FatalError, MessengerError};
  21. use crate::messenger::state::{
  22. manage_active_conversation, manage_idle_conversation, StateFromReader, StateMachine,
  23. StateToWriter,
  24. };
  25. use crate::messenger::tcp::{connect, SocksParams};
  26. /// Type for sending messages from the reader thread to the state thread.
  27. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  28. /// Type of messages sent to the writer thread.
  29. type MessageHolder = Box<SerializedMessage>;
  30. /// Type for getting messages from the state thread in the writer thread.
  31. type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
  32. /// Type for sending the updated read half of the socket.
  33. type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
  34. /// Type for getting the updated read half of the socket.
  35. type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
  36. /// Type for sending the updated write half of the socket.
  37. type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
  38. /// Type for getting the updated write half of the socket.
  39. type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
  40. /// Type for sending errors to other threads.
  41. type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
  42. /// Type for getting errors from other threads.
  43. type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
  44. /// Type for sending sizes to the attachment sender thread.
  45. type SizeChannelIn = mpsc::UnboundedSender<usize>;
  46. /// Type for getting sizes from other threads.
  47. type SizeChannelOut = mpsc::UnboundedReceiver<usize>;
  48. // we gain a (very) tiny performance win by not bothering to validate the cert
  49. struct NoCertificateVerification {}
  50. impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
  51. fn verify_server_cert(
  52. &self,
  53. _end_entity: &tokio_rustls::rustls::Certificate,
  54. _intermediates: &[tokio_rustls::rustls::Certificate],
  55. _server_name: &tokio_rustls::rustls::ServerName,
  56. _scts: &mut dyn Iterator<Item = &[u8]>,
  57. _ocsp: &[u8],
  58. _now: std::time::SystemTime,
  59. ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
  60. Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
  61. }
  62. }
  63. /// Create a URL the web server can use to accept or produce traffic.
  64. /// `target` is the IP or host name of the web server,
  65. /// `size` is the number of bytes to download or upload,
  66. /// `user` is to let the server log the user making the request.
  67. /// Panics if the arguments do not produce a valid URI.
  68. fn web_url(target: &str, size: usize, user: &str) -> hyper::Uri {
  69. let formatted = format!("https://{}/?size={}&user={}", target, size, user);
  70. formatted
  71. .parse()
  72. .unwrap_or_else(|_| panic!("Invalid URI: {}", formatted))
  73. }
  74. fn get_plain_https_client(
  75. tls_config: tokio_rustls::rustls::ClientConfig,
  76. ) -> hyper::client::Client<HttpsConnector<HttpConnector>> {
  77. let https = hyper_rustls::HttpsConnectorBuilder::new()
  78. .with_tls_config(tls_config)
  79. .https_or_http()
  80. .enable_http1()
  81. .build();
  82. hyper::Client::builder().build(https)
  83. }
  84. fn get_socks_https_client(
  85. tls_config: tokio_rustls::rustls::ClientConfig,
  86. username: String,
  87. password: String,
  88. proxy: String,
  89. ) -> hyper::client::Client<HttpsConnector<SocksConnector<HttpConnector>>> {
  90. let mut http = hyper::client::HttpConnector::new();
  91. http.enforce_http(false);
  92. let auth = hyper_socks2::Auth { username, password };
  93. let socks = hyper_socks2::SocksConnector {
  94. proxy_addr: format!("socks5://{}", proxy)
  95. .parse()
  96. .expect("Invalid proxy URI"),
  97. auth: Some(auth),
  98. connector: http,
  99. };
  100. let https = hyper_rustls::HttpsConnectorBuilder::new()
  101. .with_tls_config(tls_config)
  102. .https_or_http()
  103. .enable_http1()
  104. .wrap_connector(socks);
  105. hyper::Client::builder().build(https)
  106. }
  107. /// The thread responsible for getting incoming messages,
  108. /// checking for any network errors while doing so,
  109. /// and giving messages to the state thread.
  110. async fn reader(
  111. web_params: SocksParams,
  112. retry: Duration,
  113. tls_config: tokio_rustls::rustls::ClientConfig,
  114. message_channel: ReaderToState,
  115. socket_updater: ReadSocketUpdaterOut,
  116. error_channel: ErrorChannelIn,
  117. ) {
  118. match web_params.socks {
  119. Some(proxy) => {
  120. let client = get_socks_https_client(
  121. tls_config,
  122. web_params.user.clone(),
  123. web_params.target.clone(),
  124. proxy,
  125. );
  126. worker(
  127. web_params.target,
  128. web_params.user,
  129. retry,
  130. client,
  131. message_channel,
  132. socket_updater,
  133. error_channel,
  134. )
  135. .await
  136. }
  137. None => {
  138. let client = get_plain_https_client(tls_config);
  139. worker(
  140. web_params.target,
  141. web_params.user,
  142. retry,
  143. client,
  144. message_channel,
  145. socket_updater,
  146. error_channel,
  147. )
  148. .await
  149. }
  150. };
  151. async fn worker<C>(
  152. target: String,
  153. user: String,
  154. retry: Duration,
  155. client: hyper::Client<C, hyper::Body>,
  156. message_channel: ReaderToState,
  157. mut socket_updater: ReadSocketUpdaterOut,
  158. error_channel: ErrorChannelIn,
  159. ) where
  160. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  161. {
  162. loop {
  163. let mut message_stream = socket_updater.recv().await;
  164. loop {
  165. let msg = match mgen::get_message(&mut message_stream).await {
  166. Ok(msg) => msg,
  167. Err(e) => {
  168. error_channel.send(e.into()).expect("Error channel closed");
  169. break;
  170. }
  171. };
  172. if msg.body.has_attachment() {
  173. let url = web_url(&target, msg.body.total_size(), &user);
  174. let client = client.clone();
  175. spawn(async move {
  176. let mut res = client.get(url.clone()).await;
  177. while res.is_err() {
  178. log!("Error fetching: {}", res.unwrap_err());
  179. sleep(retry).await;
  180. res = client.get(url.clone()).await;
  181. }
  182. });
  183. }
  184. message_channel
  185. .send(msg)
  186. .expect("Reader message channel closed");
  187. }
  188. }
  189. }
  190. }
  191. async fn uploader(
  192. web_params: SocksParams,
  193. retry: Duration,
  194. tls_config: tokio_rustls::rustls::ClientConfig,
  195. size_channel: SizeChannelOut,
  196. ) {
  197. match web_params.socks {
  198. Some(proxy) => {
  199. let client = get_socks_https_client(
  200. tls_config,
  201. web_params.user.clone(),
  202. web_params.target.clone(),
  203. proxy,
  204. );
  205. worker(
  206. web_params.target,
  207. web_params.user,
  208. retry,
  209. client,
  210. size_channel,
  211. )
  212. .await
  213. }
  214. None => {
  215. let client = get_plain_https_client(tls_config);
  216. worker(
  217. web_params.target,
  218. web_params.user,
  219. retry,
  220. client,
  221. size_channel,
  222. )
  223. .await
  224. }
  225. }
  226. async fn worker<C>(
  227. target: String,
  228. user: String,
  229. retry: Duration,
  230. client: hyper::Client<C, hyper::Body>,
  231. mut size_channel: SizeChannelOut,
  232. ) where
  233. C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
  234. {
  235. loop {
  236. let size = size_channel.recv().await.expect("Size channel closed");
  237. let client = client.clone();
  238. let url = web_url(&target, size, &user);
  239. let request = hyper::Request::put(url.clone())
  240. .body(hyper::Body::empty())
  241. .expect("Invalid HTTP request attempted to construct");
  242. let mut res = client.request(request).await;
  243. while res.is_err() {
  244. log!("{},{},Error uploading: {}", user, url, res.unwrap_err());
  245. sleep(retry).await;
  246. res = client.get(url.clone()).await;
  247. }
  248. }
  249. }
  250. }
  251. /// The thread responsible for sending messages from the state thread,
  252. /// and checking for any network errors while doing so.
  253. async fn writer(
  254. mut message_channel: WriterFromState,
  255. attachment_channel: SizeChannelIn,
  256. mut socket_updater: WriteSocketUpdaterOut,
  257. error_channel: ErrorChannelIn,
  258. ) {
  259. loop {
  260. let mut stream = socket_updater.recv().await;
  261. loop {
  262. let msg = message_channel
  263. .recv()
  264. .await
  265. .expect("Writer message channel closed");
  266. if msg.body.has_attachment() {
  267. attachment_channel
  268. .send(msg.body.total_size())
  269. .expect("Attachment channel closed");
  270. }
  271. if let Err(e) = msg.write_all_to(&mut stream).await {
  272. error_channel.send(e.into()).expect("Error channel closed");
  273. break;
  274. }
  275. }
  276. }
  277. }
  278. /// The thread responsible for (re-)establishing connections to the server,
  279. /// and determining how to handle errors this or other threads receive.
  280. async fn socket_updater(
  281. str_params: SocksParams,
  282. retry: Duration,
  283. tls_config: tokio_rustls::rustls::ClientConfig,
  284. mut error_channel: ErrorChannelOut,
  285. reader_channel: ReadSocketUpdaterIn,
  286. writer_channel: WriteSocketUpdaterIn,
  287. ) -> FatalError {
  288. let connector = TlsConnector::from(Arc::new(tls_config));
  289. // unwrap is safe, split always returns at least one element
  290. let tls_server_str = str_params.target.split(':').next().unwrap();
  291. let tls_server_name =
  292. tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
  293. loop {
  294. let stream: TcpStream = match connect(&str_params).await {
  295. Ok(stream) => stream,
  296. Err(MessengerError::Recoverable(e)) => {
  297. log!(
  298. "{},{},error,TCP,{:?}",
  299. str_params.user,
  300. str_params.recipient,
  301. e
  302. );
  303. sleep(retry).await;
  304. continue;
  305. }
  306. Err(MessengerError::Fatal(e)) => return e,
  307. };
  308. let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
  309. Ok(stream) => stream,
  310. Err(e) => {
  311. log!(
  312. "{},{},error,TLS,{:?}",
  313. str_params.user,
  314. str_params.recipient,
  315. e
  316. );
  317. sleep(retry).await;
  318. continue;
  319. }
  320. };
  321. let handshake = HandshakeRef {
  322. sender: &str_params.user,
  323. group: &str_params.recipient,
  324. };
  325. if stream.write_all(&handshake.serialize()).await.is_err() {
  326. continue;
  327. }
  328. log!("{},{},handshake", str_params.user, str_params.recipient);
  329. let (rd, wr) = split(stream);
  330. reader_channel.send(rd);
  331. writer_channel.send(wr);
  332. let res = error_channel.recv().await.expect("Error channel closed");
  333. if let MessengerError::Fatal(e) = res {
  334. return e;
  335. } else {
  336. log!(
  337. "{},{},error,{:?}",
  338. str_params.user,
  339. str_params.recipient,
  340. res
  341. );
  342. }
  343. }
  344. }
  345. /// The thread responsible for handling the conversation state
  346. /// (i.e., whether the user is active or idle, and when to send messages).
  347. async fn manage_conversation(
  348. config: FullConfig,
  349. mut state_from_reader: StateFromReader,
  350. mut state_to_writer: StateToWriter<MessageHolder>,
  351. ) {
  352. sleep(Duration::from_secs_f64(config.bootstrap)).await;
  353. log!("{},{},awake", &config.user, &config.group);
  354. let mut rng = Xoshiro256PlusPlus::from_entropy();
  355. let mut state_machine = StateMachine::start(config.distributions, &mut rng);
  356. loop {
  357. state_machine = match state_machine {
  358. StateMachine::Idle(conversation) => {
  359. manage_idle_conversation::<false, _, _, _>(
  360. conversation,
  361. &mut state_from_reader,
  362. &mut state_to_writer,
  363. &config.user,
  364. &config.group,
  365. &mut rng,
  366. )
  367. .await
  368. }
  369. StateMachine::Active(conversation) => {
  370. manage_active_conversation(
  371. conversation,
  372. &mut state_from_reader,
  373. &mut state_to_writer,
  374. &config.user,
  375. &config.group,
  376. false,
  377. &mut rng,
  378. )
  379. .await
  380. }
  381. };
  382. }
  383. }
  384. /// Spawns all other threads for this conversation.
  385. async fn spawn_threads(config: FullConfig) -> Result<(), MessengerError> {
  386. // without noise during Shadow's bootstrap period, we can overload the SOMAXCONN of the server,
  387. // so we wait a small(ish) pseudorandom amount of time to spread things out
  388. let mut hasher = rustc_hash::FxHasher::default();
  389. config.user.hash(&mut hasher);
  390. config.group.hash(&mut hasher);
  391. let hash = hasher.finish() % 10_000;
  392. log!("{},{},waiting,{}", config.user, config.group, hash);
  393. sleep(Duration::from_millis(hash)).await;
  394. let message_server_params = SocksParams {
  395. socks: config.socks.clone(),
  396. target: config.message_server.clone(),
  397. user: config.user.clone(),
  398. recipient: config.group.clone(),
  399. };
  400. let web_server_params = SocksParams {
  401. socks: config.socks.clone(),
  402. target: config.web_server.clone(),
  403. user: config.user.clone(),
  404. recipient: config.group.clone(),
  405. };
  406. let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
  407. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  408. let read_socket_updater_in = Updater::new();
  409. let read_socket_updater_out = read_socket_updater_in.clone();
  410. let write_socket_updater_in = Updater::new();
  411. let write_socket_updater_out = write_socket_updater_in.clone();
  412. let (errs_in, errs_out) = mpsc::unbounded_channel();
  413. let (writer_to_uploader, uploader_from_writer) = mpsc::unbounded_channel();
  414. let state_to_writer = StateToWriter {
  415. channel: state_to_writer,
  416. };
  417. let retry = Duration::from_secs_f64(config.retry);
  418. let tls_config = tokio_rustls::rustls::ClientConfig::builder()
  419. .with_safe_defaults()
  420. .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
  421. .with_no_client_auth();
  422. spawn(reader(
  423. web_server_params.clone(),
  424. retry,
  425. tls_config.clone(),
  426. reader_to_state,
  427. read_socket_updater_out,
  428. errs_in.clone(),
  429. ));
  430. spawn(writer(
  431. writer_from_state,
  432. writer_to_uploader,
  433. write_socket_updater_out,
  434. errs_in,
  435. ));
  436. spawn(uploader(
  437. web_server_params,
  438. retry,
  439. tls_config.clone(),
  440. uploader_from_writer,
  441. ));
  442. spawn(manage_conversation(
  443. config,
  444. state_from_reader,
  445. state_to_writer,
  446. ));
  447. Err(MessengerError::Fatal(
  448. socket_updater(
  449. message_server_params,
  450. retry,
  451. tls_config,
  452. errs_out,
  453. read_socket_updater_in,
  454. write_socket_updater_in,
  455. )
  456. .await,
  457. ))
  458. }
  459. struct FullConfig {
  460. user: String,
  461. group: String,
  462. socks: Option<String>,
  463. message_server: String,
  464. web_server: String,
  465. bootstrap: f64,
  466. retry: f64,
  467. distributions: Distributions,
  468. }
  469. #[derive(Debug, Deserialize)]
  470. struct ConversationConfig {
  471. group: String,
  472. message_server: Option<String>,
  473. web_server: Option<String>,
  474. bootstrap: Option<f64>,
  475. retry: Option<f64>,
  476. distributions: Option<ConfigDistributions>,
  477. }
  478. #[derive(Debug, Deserialize)]
  479. struct Config {
  480. user: String,
  481. socks: Option<String>,
  482. message_server: String,
  483. web_server: String,
  484. bootstrap: f64,
  485. retry: f64,
  486. distributions: ConfigDistributions,
  487. conversations: Vec<ConversationConfig>,
  488. }
  489. async fn main_worker() -> Result<(), Box<dyn std::error::Error>> {
  490. #[cfg(feature = "tracing")]
  491. console_subscriber::init();
  492. let mut args = std::env::args();
  493. let _ = args.next();
  494. let mut handles = vec![];
  495. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  496. let yaml_s = std::fs::read_to_string(config_file?)?;
  497. let config: Config = serde_yaml::from_str(&yaml_s)?;
  498. let default_dists: Distributions = config.distributions.try_into()?;
  499. for conversation in config.conversations.into_iter() {
  500. let distributions: Distributions = match conversation.distributions {
  501. Some(dists) => dists.try_into()?,
  502. None => default_dists.clone(),
  503. };
  504. let filled_conversation = FullConfig {
  505. user: config.user.clone(),
  506. group: conversation.group,
  507. socks: config.socks.clone(),
  508. message_server: conversation
  509. .message_server
  510. .unwrap_or_else(|| config.message_server.clone()),
  511. web_server: conversation
  512. .web_server
  513. .unwrap_or_else(|| config.web_server.clone()),
  514. bootstrap: conversation.bootstrap.unwrap_or(config.bootstrap),
  515. retry: conversation.retry.unwrap_or(config.retry),
  516. distributions,
  517. };
  518. let handle = spawn_threads(filled_conversation);
  519. handles.push(handle);
  520. }
  521. }
  522. try_join_all(handles).await?;
  523. Ok(())
  524. }
  525. fn main() -> Result<(), Box<dyn std::error::Error>> {
  526. tokio::runtime::Builder::new_multi_thread()
  527. .worker_threads(2)
  528. .enable_all()
  529. .disable_lifo_slot()
  530. .build()
  531. .unwrap()
  532. .block_on(main_worker())
  533. }
  534. #[cfg(test)]
  535. mod tests {
  536. use super::*;
  537. #[test]
  538. fn test_uri_generation() {
  539. // should panic if any of these are invalid
  540. web_url("192.0.2.1", 0, "Alice");
  541. web_url("hostname", 65536, "Bob");
  542. web_url("web0", 4294967295, "Carol");
  543. web_url("web1", 1, "");
  544. web_url("foo.bar.baz", 1, "Dave");
  545. // IPv6 is not a valid in a URI
  546. //web_url("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 1, "1");
  547. // hyper does not automatically convert to punycode
  548. //web_url("web2", 1, "🦀");
  549. //web_url("🦀", 1, "Ferris");
  550. }
  551. }