mgen-client.rs 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. // Code specific to the client in the client-server mode.
  2. use mgen::updater::Updater;
  3. use mgen::{HandshakeRef, MessageHeader, SerializedMessage};
  4. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  5. use serde::Deserialize;
  6. use std::result::Result;
  7. use std::sync::Arc;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::TcpStream;
  10. use tokio::sync::mpsc;
  11. use tokio::task;
  12. use tokio::time::Duration;
  13. use tokio_rustls::{client::TlsStream, TlsConnector};
  14. mod messenger;
  15. use crate::messenger::dists::{ConfigDistributions, Distributions};
  16. use crate::messenger::error::{FatalError, MessengerError};
  17. use crate::messenger::state::{
  18. manage_active_conversation, manage_idle_conversation, StateMachine, StateToWriter,
  19. };
  20. use crate::messenger::tcp::{connect, SocksParams};
  21. /// Type for sending messages from the reader thread to the state thread.
  22. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  23. /// Type of messages sent to the writer thread.
  24. type MessageHolder = Box<SerializedMessage>;
  25. /// Type for getting messages from the state thread in the writer thread.
  26. type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
  27. /// Type for sending the updated read half of the socket.
  28. type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
  29. /// Type for getting the updated read half of the socket.
  30. type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
  31. /// Type for sending the updated write half of the socket.
  32. type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
  33. /// Type for getting the updated write half of the socket.
  34. type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
  35. /// Type for sending errors to other threads.
  36. type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
  37. /// Type for getting errors from other threads.
  38. type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
  39. // we gain a (very) tiny performance win by not bothering to validate the cert
  40. pub struct NoCertificateVerification {}
  41. impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
  42. fn verify_server_cert(
  43. &self,
  44. _end_entity: &tokio_rustls::rustls::Certificate,
  45. _intermediates: &[tokio_rustls::rustls::Certificate],
  46. _server_name: &tokio_rustls::rustls::ServerName,
  47. _scts: &mut dyn Iterator<Item = &[u8]>,
  48. _ocsp: &[u8],
  49. _now: std::time::SystemTime,
  50. ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
  51. Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
  52. }
  53. }
  54. /// The thread responsible for getting incoming messages,
  55. /// checking for any network errors while doing so,
  56. /// and giving messages to the state thread.
  57. async fn reader(
  58. message_channel: ReaderToState,
  59. mut socket_updater: ReadSocketUpdaterOut,
  60. error_channel: ErrorChannelIn,
  61. ) {
  62. loop {
  63. let mut stream = socket_updater.recv().await;
  64. loop {
  65. let msg = match mgen::get_message(&mut stream).await {
  66. Ok(msg) => msg,
  67. Err(e) => {
  68. error_channel.send(e.into()).expect("Error channel closed");
  69. break;
  70. }
  71. };
  72. message_channel
  73. .send(msg)
  74. .expect("Reader message channel closed");
  75. }
  76. }
  77. }
  78. /// The thread responsible for sending messages from the state thread,
  79. /// and checking for any network errors while doing so.
  80. async fn writer(
  81. mut message_channel: WriterFromState,
  82. mut socket_updater: WriteSocketUpdaterOut,
  83. error_channel: ErrorChannelIn,
  84. ) {
  85. loop {
  86. let mut stream = socket_updater.recv().await;
  87. loop {
  88. let msg = message_channel
  89. .recv()
  90. .await
  91. .expect("Writer message channel closed");
  92. if let Err(e) = msg.write_all_to(&mut stream).await {
  93. error_channel.send(e.into()).expect("Error channel closed");
  94. break;
  95. }
  96. }
  97. }
  98. }
  99. /// The thread responsible for (re-)establishing connections to the server,
  100. /// and determining how to handle errors this or other threads receive.
  101. async fn socket_updater(
  102. str_params: SocksParams,
  103. retry: f64,
  104. mut error_channel: ErrorChannelOut,
  105. reader_channel: ReadSocketUpdaterIn,
  106. writer_channel: WriteSocketUpdaterIn,
  107. ) -> FatalError {
  108. let retry = Duration::from_secs_f64(retry);
  109. let tls_config = tokio_rustls::rustls::ClientConfig::builder()
  110. .with_safe_defaults()
  111. .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
  112. .with_no_client_auth();
  113. let connector = TlsConnector::from(Arc::new(tls_config));
  114. // unwrap is safe, split always returns at least one element
  115. let tls_server_str = str_params.target.split(':').next().unwrap();
  116. let tls_server_name =
  117. tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
  118. loop {
  119. let stream: TcpStream = match connect(&str_params).await {
  120. Ok(stream) => stream,
  121. Err(MessengerError::Recoverable(_)) => {
  122. tokio::time::sleep(retry).await;
  123. continue;
  124. }
  125. Err(MessengerError::Fatal(e)) => return e,
  126. };
  127. let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
  128. Ok(stream) => stream,
  129. Err(_) => {
  130. tokio::time::sleep(retry).await;
  131. continue;
  132. }
  133. };
  134. let handshake = HandshakeRef {
  135. sender: &str_params.user,
  136. group: &str_params.recipient,
  137. };
  138. if stream.write_all(&handshake.serialize()).await.is_err() {
  139. continue;
  140. }
  141. let (rd, wr) = split(stream);
  142. reader_channel.send(rd);
  143. writer_channel.send(wr);
  144. let res = error_channel.recv().await.expect("Error channel closed");
  145. if let MessengerError::Fatal(e) = res {
  146. return e;
  147. }
  148. }
  149. }
  150. /// The thread responsible for handling the conversation state
  151. /// (i.e., whether the user is active or idle, and when to send messages).
  152. /// Spawns all other threads for this conversation.
  153. async fn manage_conversation(
  154. user: String,
  155. socks: Option<String>,
  156. config: ConversationConfig,
  157. ) -> Result<(), MessengerError> {
  158. let mut rng = Xoshiro256PlusPlus::from_entropy();
  159. let distributions: Distributions = config.distributions.try_into()?;
  160. let str_params = SocksParams {
  161. socks,
  162. target: config.server,
  163. user: user.clone(),
  164. recipient: config.group.clone(),
  165. };
  166. let mut state_machine = StateMachine::start(distributions, &mut rng);
  167. let (reader_to_state, mut state_from_reader) = mpsc::unbounded_channel();
  168. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  169. let read_socket_updater_in = Updater::new();
  170. let read_socket_updater_out = read_socket_updater_in.clone();
  171. let write_socket_updater_in = Updater::new();
  172. let write_socket_updater_out = write_socket_updater_in.clone();
  173. let (errs_in, errs_out) = mpsc::unbounded_channel();
  174. tokio::spawn(reader(
  175. reader_to_state,
  176. read_socket_updater_out,
  177. errs_in.clone(),
  178. ));
  179. tokio::spawn(writer(writer_from_state, write_socket_updater_out, errs_in));
  180. tokio::spawn(socket_updater(
  181. str_params,
  182. config.retry,
  183. errs_out,
  184. read_socket_updater_in,
  185. write_socket_updater_in,
  186. ));
  187. tokio::time::sleep(Duration::from_secs_f64(config.bootstrap)).await;
  188. let mut state_to_writer = StateToWriter {
  189. channel: state_to_writer,
  190. };
  191. loop {
  192. state_machine = match state_machine {
  193. StateMachine::Idle(conversation) => {
  194. manage_idle_conversation::<false, _, _, _>(
  195. conversation,
  196. &mut state_from_reader,
  197. &mut state_to_writer,
  198. &user,
  199. &config.group,
  200. &mut rng,
  201. )
  202. .await
  203. }
  204. StateMachine::Active(conversation) => {
  205. manage_active_conversation(
  206. conversation,
  207. &mut state_from_reader,
  208. &mut state_to_writer,
  209. &user,
  210. &config.group,
  211. false,
  212. &mut rng,
  213. )
  214. .await
  215. }
  216. };
  217. }
  218. }
  219. #[derive(Debug, Deserialize)]
  220. struct ConversationConfig {
  221. group: String,
  222. server: String,
  223. bootstrap: f64,
  224. retry: f64,
  225. distributions: ConfigDistributions,
  226. }
  227. #[derive(Debug, Deserialize)]
  228. struct Config {
  229. user: String,
  230. socks: Option<String>,
  231. conversations: Vec<ConversationConfig>,
  232. }
  233. #[tokio::main]
  234. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  235. let mut args = std::env::args();
  236. let _ = args.next();
  237. let mut handles = vec![];
  238. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  239. let yaml_s = std::fs::read_to_string(config_file?)?;
  240. let config: Config = serde_yaml::from_str(&yaml_s)?;
  241. for conversation in config.conversations.into_iter() {
  242. let handle: task::JoinHandle<Result<(), MessengerError>> = tokio::spawn(
  243. manage_conversation(config.user.clone(), config.socks.clone(), conversation),
  244. );
  245. handles.push(handle);
  246. }
  247. }
  248. let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
  249. for handle in handles {
  250. handle.await??;
  251. }
  252. Ok(())
  253. }