mgen-peer.rs 16 KB


  1. // Code specific to the peer in the p2p mode.
  2. use mgen::{log, updater::Updater, MessageHeader, SerializedMessage};
  3. use futures::future::try_join_all;
  4. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  5. use serde::Deserialize;
  6. use std::collections::HashMap;
  7. use std::result::Result;
  8. use std::sync::Arc;
  9. use tokio::io::AsyncWriteExt;
  10. use tokio::net::{
  11. tcp::{OwnedReadHalf, OwnedWriteHalf},
  12. TcpListener,
  13. };
  14. use tokio::sync::mpsc;
  15. use tokio::task::JoinHandle;
  16. use tokio::time::Duration;
  17. mod messenger;
  18. use crate::messenger::dists::{ConfigDistributions, Distributions};
  19. use crate::messenger::error::{FatalError, MessengerError};
  20. use crate::messenger::state::{
  21. manage_active_conversation, manage_idle_conversation, StateFromReader, StateMachine,
  22. StateToWriter,
  23. };
  24. use crate::messenger::tcp::{connect, SocksParams};
  25. /// Type for sending messages from the reader thread to the state thread.
  26. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  27. /// Type for getting messages from the state thread in the writer thread.
  28. type WriterFromState = mpsc::UnboundedReceiver<Arc<SerializedMessage>>;
  29. /// Type for sending messages from the state thread to the writer thread.
  30. type MessageHolder = Arc<SerializedMessage>;
  31. /// Type for sending the updated read half of the socket.
  32. type ReadSocketUpdaterIn = Updater<OwnedReadHalf>;
  33. /// Type for getting the updated read half of the socket.
  34. type ReadSocketUpdaterOut = Updater<OwnedReadHalf>;
  35. /// Type for sending the updated write half of the socket.
  36. type WriteSocketUpdaterIn = Updater<OwnedWriteHalf>;
  37. /// Type for getting the updated write half of the socket.
  38. type WriteSocketUpdaterOut = Updater<OwnedWriteHalf>;
  39. /// The conversation (state) thread tracks the conversation state
  40. /// (i.e., whether the user is active or idle, and when to send messages).
  41. /// One state thread per conversation.
  42. async fn manage_conversation(
  43. user: String,
  44. group: String,
  45. distributions: Distributions,
  46. bootstrap: f64,
  47. mut state_from_reader: StateFromReader,
  48. mut state_to_writers: HashMap<String, StateToWriter<MessageHolder>>,
  49. ) {
  50. let mut rng = Xoshiro256PlusPlus::from_entropy();
  51. let user = &user;
  52. let group = &group;
  53. let mut state_machine = StateMachine::start(distributions, &mut rng);
  54. tokio::time::sleep(Duration::from_secs_f64(bootstrap)).await;
  55. loop {
  56. state_machine = match state_machine {
  57. StateMachine::Idle(conversation) => {
  58. manage_idle_conversation::<true, _, _, _>(
  59. conversation,
  60. &mut state_from_reader,
  61. &mut state_to_writers,
  62. user,
  63. group,
  64. &mut rng,
  65. )
  66. .await
  67. }
  68. StateMachine::Active(conversation) => {
  69. manage_active_conversation(
  70. conversation,
  71. &mut state_from_reader,
  72. &mut state_to_writers,
  73. user,
  74. group,
  75. true,
  76. &mut rng,
  77. )
  78. .await
  79. }
  80. };
  81. }
  82. }
  83. /// The listener thread listens for inbound connections on the given address.
  84. /// It breaks those connections into reader and writer halves,
  85. /// and gives them to the correct reader and writer threads.
  86. /// One listener thread per user.
  87. async fn listener(
  88. address: String,
  89. name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)>,
  90. ) -> Result<(), FatalError> {
  91. let listener = TcpListener::bind(&address).await?;
  92. log!("listening on {}", &address);
  93. async fn error_collector(
  94. address: &str,
  95. listener: &TcpListener,
  96. name_to_io_threads: &HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)>,
  97. ) -> Result<(), MessengerError> {
  98. let (stream, _) = listener.accept().await?;
  99. let (mut rd, wr) = stream.into_split();
  100. let from = mgen::parse_identifier(&mut rd).await?;
  101. let (channel_to_reader, channel_to_writer) = name_to_io_threads
  102. .get(&from)
  103. .unwrap_or_else(|| panic!("{} got connection from unknown contact: {}", address, from));
  104. channel_to_reader.send(rd);
  105. channel_to_writer.send(wr);
  106. Ok(())
  107. }
  108. loop {
  109. if let Err(MessengerError::Fatal(e)) =
  110. error_collector(&address, &listener, &name_to_io_threads).await
  111. {
  112. return Err(e);
  113. }
  114. }
  115. }
  116. /// The reader thread reads messages from the socket it has been given,
  117. /// and sends them to the correct state thread.
  118. /// One reader thread per (user, recipient) pair.
  119. async fn reader(
  120. mut connection_channel: ReadSocketUpdaterOut,
  121. group_to_conversation_thread: HashMap<String, ReaderToState>,
  122. ) {
  123. loop {
  124. // wait for listener or writer thread to give us a stream to read from
  125. let mut stream = connection_channel.recv().await;
  126. loop {
  127. let Ok(msg) = mgen::get_message(&mut stream).await else {
  128. // Unlike the client-server case, we can assume that if there
  129. // were a message someone was trying to send us, they'd make
  130. // sure to re-establish the connection; so when the socket
  131. // breaks, don't bother trying to reform it until we need to
  132. // send a message or the peer reaches out to us.
  133. break;
  134. };
  135. let channel_to_conversation = group_to_conversation_thread
  136. .get(&msg.group)
  137. .unwrap_or_else(|| panic!("Unknown group: {}", msg.group));
  138. channel_to_conversation
  139. .send(msg)
  140. .expect("reader: Channel to group closed");
  141. }
  142. }
  143. }
  144. /// The writer thread takes in messages from state threads,
  145. /// and sends it to the recipient associated with this thread.
  146. /// If it doesn't have a socket from the listener thread,
  147. /// it'll create its own and give the read half to the reader thread.
  148. /// One writer thread per (user, recipient) pair.
  149. async fn writer<'a>(
  150. mut messages_to_send: WriterFromState,
  151. mut write_socket_updater: WriteSocketUpdaterOut,
  152. read_socket_updater: ReadSocketUpdaterIn,
  153. socks_params: SocksParams,
  154. retry: Duration,
  155. ) -> Result<(), FatalError> {
  156. // make sure this is the first step to avoid connections until there's
  157. // something to send
  158. let mut msg = messages_to_send
  159. .recv()
  160. .await
  161. .expect("writer: Channel from conversations closed");
  162. let mut stream = establish_connection(
  163. &mut write_socket_updater,
  164. &read_socket_updater,
  165. &socks_params,
  166. retry,
  167. )
  168. .await
  169. .expect("Fatal error establishing connection");
  170. loop {
  171. while msg.write_all_to(&mut stream).await.is_err() {
  172. stream = establish_connection(
  173. &mut write_socket_updater,
  174. &read_socket_updater,
  175. &socks_params,
  176. retry,
  177. )
  178. .await
  179. .expect("Fatal error establishing connection");
  180. }
  181. msg = messages_to_send
  182. .recv()
  183. .await
  184. .expect("writer: Channel from conversations closed");
  185. }
  186. // helper functions
  187. /// Attempt to get a connection to the peer,
  188. /// whether by getting an existing connection from the listener,
  189. /// or by establishing a new connection.
  190. async fn establish_connection<'a>(
  191. write_socket_updater: &mut WriteSocketUpdaterOut,
  192. read_socket_updater: &ReadSocketUpdaterIn,
  193. socks_params: &SocksParams,
  194. retry: Duration,
  195. ) -> Result<OwnedWriteHalf, FatalError> {
  196. // first check if the listener thread already has a socket
  197. if let Some(wr) = write_socket_updater.maybe_recv() {
  198. return Ok(wr);
  199. }
  200. // immediately try to connect to the peer
  201. tokio::select! {
  202. connection_attempt = connect(socks_params) => {
  203. if let Ok(mut stream) = connection_attempt {
  204. log!(
  205. "connection attempt success from {} to {} on {}",
  206. &socks_params.user,
  207. &socks_params.recipient,
  208. &socks_params.target
  209. );
  210. stream
  211. .write_all(&mgen::serialize_str(&socks_params.user))
  212. .await?;
  213. let (rd, wr) = stream.into_split();
  214. read_socket_updater.send(rd);
  215. return Ok(wr);
  216. } else if let Err(MessengerError::Fatal(e)) = connection_attempt {
  217. return Err(e);
  218. }
  219. }
  220. stream = write_socket_updater.recv() => {return Ok(stream);},
  221. }
  222. // Usually we'll have returned by now, but sometimes we'll fail to
  223. // connect for whatever reason. Initiate a loop of waiting Duration,
  224. // then trying to connect again, allowing it to be inerrupted by
  225. // the listener thread.
  226. loop {
  227. match error_collector(
  228. write_socket_updater,
  229. read_socket_updater,
  230. socks_params,
  231. retry,
  232. )
  233. .await
  234. {
  235. Ok(wr) => return Ok(wr),
  236. Err(MessengerError::Recoverable(_)) => continue,
  237. Err(MessengerError::Fatal(e)) => return Err(e),
  238. }
  239. }
  240. async fn error_collector<'a>(
  241. write_socket_updater: &mut WriteSocketUpdaterOut,
  242. read_socket_updater: &ReadSocketUpdaterIn,
  243. socks_params: &SocksParams,
  244. retry: Duration,
  245. ) -> Result<OwnedWriteHalf, MessengerError> {
  246. tokio::select! {
  247. () = tokio::time::sleep(retry) => {
  248. let mut stream = connect(socks_params)
  249. .await?;
  250. stream.write_all(&mgen::serialize_str(&socks_params.user)).await?;
  251. let (rd, wr) = stream.into_split();
  252. read_socket_updater.send(rd);
  253. Ok(wr)
  254. },
  255. stream = write_socket_updater.recv() => Ok(stream),
  256. }
  257. }
  258. }
  259. }
  260. fn parse_hosts_file(file_contents: &str) -> HashMap<&str, &str> {
  261. let mut ret = HashMap::new();
  262. for line in file_contents.lines() {
  263. let mut words = line.split_ascii_whitespace();
  264. if let Some(addr) = words.next() {
  265. for name in words {
  266. ret.insert(name, addr);
  267. }
  268. }
  269. }
  270. ret
  271. }
  272. #[derive(Debug, Deserialize)]
  273. struct ConversationConfig {
  274. group: String,
  275. recipients: Vec<String>,
  276. bootstrap: Option<f64>,
  277. retry: Option<f64>,
  278. distributions: Option<ConfigDistributions>,
  279. }
  280. #[derive(Debug, Deserialize)]
  281. struct Config {
  282. user: String,
  283. socks: Option<String>,
  284. listen: Option<String>,
  285. bootstrap: f64,
  286. retry: f64,
  287. distributions: ConfigDistributions,
  288. conversations: Vec<ConversationConfig>,
  289. }
  290. fn process_config(
  291. config: Config,
  292. hosts_map: &HashMap<&str, &str>,
  293. handles: &mut Vec<JoinHandle<Result<(), FatalError>>>,
  294. ) -> Result<(), Box<dyn std::error::Error>> {
  295. struct ForIoThreads {
  296. state_to_writer: mpsc::UnboundedSender<MessageHolder>,
  297. writer_from_state: WriterFromState,
  298. reader_to_states: HashMap<String, ReaderToState>,
  299. str_params: SocksParams,
  300. retry: f64,
  301. }
  302. let default_dists: Distributions = config.distributions.try_into()?;
  303. // map from `recipient` to things the (user, recipient) reader/writer threads will need
  304. let mut recipient_map = HashMap::<String, ForIoThreads>::new();
  305. for conversation in config.conversations.into_iter() {
  306. let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
  307. let mut conversation_recipient_map =
  308. HashMap::<String, StateToWriter<MessageHolder>>::with_capacity(
  309. conversation.recipients.len(),
  310. );
  311. for recipient in conversation.recipients.iter() {
  312. let for_io = recipient_map
  313. .entry(recipient.to_string())
  314. .and_modify(|e| {
  315. e.reader_to_states
  316. .entry(conversation.group.clone())
  317. .or_insert_with(|| reader_to_state.clone());
  318. })
  319. .or_insert_with(|| {
  320. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  321. let mut reader_to_states = HashMap::new();
  322. reader_to_states.insert(conversation.group.clone(), reader_to_state.clone());
  323. let address = hosts_map
  324. .get(recipient.as_str())
  325. .unwrap_or_else(|| panic!("recipient not in hosts file: {}", recipient));
  326. let str_params = SocksParams {
  327. socks: config.socks.clone(),
  328. target: address.to_string(),
  329. user: config.user.clone(),
  330. recipient: recipient.clone(),
  331. };
  332. let retry = conversation.retry.unwrap_or(config.retry);
  333. ForIoThreads {
  334. state_to_writer,
  335. writer_from_state,
  336. reader_to_states,
  337. str_params,
  338. retry,
  339. }
  340. });
  341. let state_to_writer = for_io.state_to_writer.clone();
  342. conversation_recipient_map.insert(
  343. recipient.clone(),
  344. StateToWriter {
  345. channel: state_to_writer,
  346. },
  347. );
  348. }
  349. let distributions: Distributions = match conversation.distributions {
  350. Some(dists) => dists.try_into()?,
  351. None => default_dists.clone(),
  352. };
  353. let bootstrap = conversation.bootstrap.unwrap_or(config.bootstrap);
  354. tokio::spawn(manage_conversation(
  355. config.user.clone(),
  356. conversation.group,
  357. distributions,
  358. bootstrap,
  359. state_from_reader,
  360. conversation_recipient_map,
  361. ));
  362. }
  363. let mut name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)> =
  364. HashMap::new();
  365. for (recipient, for_io) in recipient_map.drain() {
  366. let listener_writer_to_reader = Updater::new();
  367. let reader_from_listener_writer = listener_writer_to_reader.clone();
  368. let listener_to_writer = Updater::new();
  369. let writer_from_listener = listener_to_writer.clone();
  370. name_to_io_threads.insert(
  371. recipient.to_string(),
  372. (listener_writer_to_reader.clone(), listener_to_writer),
  373. );
  374. tokio::spawn(reader(reader_from_listener_writer, for_io.reader_to_states));
  375. let retry = Duration::from_secs_f64(for_io.retry);
  376. let handle: JoinHandle<Result<(), FatalError>> = tokio::spawn(writer(
  377. for_io.writer_from_state,
  378. writer_from_listener,
  379. listener_writer_to_reader,
  380. for_io.str_params,
  381. retry,
  382. ));
  383. handles.push(handle);
  384. }
  385. let address = if let Some(address) = config.listen {
  386. address
  387. } else {
  388. hosts_map
  389. .get(config.user.as_str())
  390. .unwrap_or_else(|| panic!("user not found in hosts file: {}", config.user))
  391. .to_string()
  392. };
  393. let handle: JoinHandle<Result<(), FatalError>> =
  394. tokio::spawn(listener(address, name_to_io_threads));
  395. handles.push(handle);
  396. Ok(())
  397. }
  398. async fn main_worker() -> Result<(), Box<dyn std::error::Error>> {
  399. #[cfg(feature = "tracing")]
  400. console_subscriber::init();
  401. let mut args = std::env::args();
  402. let _ = args.next();
  403. let hosts_file = args.next().expect("missing hosts file arg");
  404. let hosts_file = std::fs::read_to_string(hosts_file).expect("could not find hosts file");
  405. let hosts_map = parse_hosts_file(&hosts_file);
  406. let mut handles = vec![];
  407. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  408. let yaml_s = std::fs::read_to_string(config_file?)?;
  409. let config: Config = serde_yaml::from_str(&yaml_s)?;
  410. process_config(config, &hosts_map, &mut handles)?;
  411. }
  412. try_join_all(handles).await?;
  413. Ok(())
  414. }
  415. fn main() -> Result<(), Box<dyn std::error::Error>> {
  416. tokio::runtime::Builder::new_multi_thread()
  417. .worker_threads(2)
  418. .enable_all()
  419. .disable_lifo_slot()
  420. .build()
  421. .unwrap()
  422. .block_on(main_worker())
  423. }