mgen-peer.rs 16 KB


  1. // Code specific to the peer in the p2p mode.
  2. use mgen::{log, updater::Updater, MessageHeader, SerializedMessage};
  3. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  4. use serde::Deserialize;
  5. use std::collections::HashMap;
  6. use std::result::Result;
  7. use std::sync::Arc;
  8. use tokio::io::AsyncWriteExt;
  9. use tokio::net::{
  10. tcp::{OwnedReadHalf, OwnedWriteHalf},
  11. TcpListener,
  12. };
  13. use tokio::sync::mpsc;
  14. use tokio::task::JoinHandle;
  15. use tokio::time::Duration;
  16. mod messenger;
  17. use crate::messenger::dists::{ConfigDistributions, Distributions};
  18. use crate::messenger::error::{FatalError, MessengerError};
  19. use crate::messenger::state::{
  20. manage_active_conversation, manage_idle_conversation, StateFromReader, StateMachine,
  21. StateToWriter,
  22. };
  23. use crate::messenger::tcp::{connect, SocksParams};
  24. /// Type for sending messages from the reader thread to the state thread.
  25. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  26. /// Type for getting messages from the state thread in the writer thread.
  27. type WriterFromState = mpsc::UnboundedReceiver<Arc<SerializedMessage>>;
  28. /// Type for sending messages from the state thread to the writer thread.
  29. type MessageHolder = Arc<SerializedMessage>;
  30. /// Type for sending the updated read half of the socket.
  31. type ReadSocketUpdaterIn = Updater<OwnedReadHalf>;
  32. /// Type for getting the updated read half of the socket.
  33. type ReadSocketUpdaterOut = Updater<OwnedReadHalf>;
  34. /// Type for sending the updated write half of the socket.
  35. type WriteSocketUpdaterIn = Updater<OwnedWriteHalf>;
  36. /// Type for getting the updated write half of the socket.
  37. type WriteSocketUpdaterOut = Updater<OwnedWriteHalf>;
  38. /// The conversation (state) thread tracks the conversation state
  39. /// (i.e., whether the user is active or idle, and when to send messages).
  40. /// One state thread per conversation.
  41. async fn manage_conversation(
  42. user: String,
  43. group: String,
  44. distributions: Distributions,
  45. bootstrap: f64,
  46. mut state_from_reader: StateFromReader,
  47. mut state_to_writers: HashMap<String, StateToWriter<MessageHolder>>,
  48. ) {
  49. let mut rng = Xoshiro256PlusPlus::from_entropy();
  50. let user = &user;
  51. let group = &group;
  52. let mut state_machine = StateMachine::start(distributions, &mut rng);
  53. tokio::time::sleep(Duration::from_secs_f64(bootstrap)).await;
  54. loop {
  55. state_machine = match state_machine {
  56. StateMachine::Idle(conversation) => {
  57. manage_idle_conversation::<true, _, _, _>(
  58. conversation,
  59. &mut state_from_reader,
  60. &mut state_to_writers,
  61. user,
  62. group,
  63. &mut rng,
  64. )
  65. .await
  66. }
  67. StateMachine::Active(conversation) => {
  68. manage_active_conversation(
  69. conversation,
  70. &mut state_from_reader,
  71. &mut state_to_writers,
  72. user,
  73. group,
  74. true,
  75. &mut rng,
  76. )
  77. .await
  78. }
  79. };
  80. }
  81. }
  82. /// The listener thread listens for inbound connections on the given address.
  83. /// It breaks those connections into reader and writer halves,
  84. /// and gives them to the correct reader and writer threads.
  85. /// One listener thread per user.
  86. async fn listener(
  87. address: String,
  88. name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)>,
  89. ) -> Result<(), FatalError> {
  90. let listener = TcpListener::bind(&address).await?;
  91. log!("listening on {}", &address);
  92. async fn error_collector(
  93. address: &str,
  94. listener: &TcpListener,
  95. name_to_io_threads: &HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)>,
  96. ) -> Result<(), MessengerError> {
  97. let (stream, _) = listener.accept().await?;
  98. let (mut rd, wr) = stream.into_split();
  99. let from = mgen::parse_identifier(&mut rd).await?;
  100. let (channel_to_reader, channel_to_writer) = name_to_io_threads
  101. .get(&from)
  102. .unwrap_or_else(|| panic!("{} got connection from unknown contact: {}", address, from));
  103. channel_to_reader.send(rd);
  104. channel_to_writer.send(wr);
  105. Ok(())
  106. }
  107. loop {
  108. if let Err(MessengerError::Fatal(e)) =
  109. error_collector(&address, &listener, &name_to_io_threads).await
  110. {
  111. return Err(e);
  112. }
  113. }
  114. }
  115. /// The reader thread reads messages from the socket it has been given,
  116. /// and sends them to the correct state thread.
  117. /// One reader thread per (user, recipient) pair.
  118. async fn reader(
  119. mut connection_channel: ReadSocketUpdaterOut,
  120. group_to_conversation_thread: HashMap<String, ReaderToState>,
  121. ) {
  122. loop {
  123. // wait for listener or writer thread to give us a stream to read from
  124. let mut stream = connection_channel.recv().await;
  125. loop {
  126. let Ok(msg) = mgen::get_message(&mut stream).await else {
  127. // Unlike the client-server case, we can assume that if there
  128. // were a message someone was trying to send us, they'd make
  129. // sure to re-establish the connection; so when the socket
  130. // breaks, don't bother trying to reform it until we need to
  131. // send a message or the peer reaches out to us.
  132. break;
  133. };
  134. let channel_to_conversation = group_to_conversation_thread
  135. .get(&msg.group)
  136. .unwrap_or_else(|| panic!("Unknown group: {}", msg.group));
  137. channel_to_conversation
  138. .send(msg)
  139. .expect("reader: Channel to group closed");
  140. }
  141. }
  142. }
  143. /// The writer thread takes in messages from state threads,
  144. /// and sends it to the recipient associated with this thread.
  145. /// If it doesn't have a socket from the listener thread,
  146. /// it'll create its own and give the read half to the reader thread.
  147. /// One writer thread per (user, recipient) pair.
  148. async fn writer<'a>(
  149. mut messages_to_send: WriterFromState,
  150. mut write_socket_updater: WriteSocketUpdaterOut,
  151. read_socket_updater: ReadSocketUpdaterIn,
  152. socks_params: SocksParams,
  153. retry: Duration,
  154. ) -> Result<(), FatalError> {
  155. // make sure this is the first step to avoid connections until there's
  156. // something to send
  157. let mut msg = messages_to_send
  158. .recv()
  159. .await
  160. .expect("writer: Channel from conversations closed");
  161. let mut stream = establish_connection(
  162. &mut write_socket_updater,
  163. &read_socket_updater,
  164. &socks_params,
  165. retry,
  166. )
  167. .await
  168. .expect("Fatal error establishing connection");
  169. loop {
  170. while msg.write_all_to(&mut stream).await.is_err() {
  171. stream = establish_connection(
  172. &mut write_socket_updater,
  173. &read_socket_updater,
  174. &socks_params,
  175. retry,
  176. )
  177. .await
  178. .expect("Fatal error establishing connection");
  179. }
  180. msg = messages_to_send
  181. .recv()
  182. .await
  183. .expect("writer: Channel from conversations closed");
  184. }
  185. // helper functions
  186. /// Attempt to get a connection to the peer,
  187. /// whether by getting an existing connection from the listener,
  188. /// or by establishing a new connection.
  189. async fn establish_connection<'a>(
  190. write_socket_updater: &mut WriteSocketUpdaterOut,
  191. read_socket_updater: &ReadSocketUpdaterIn,
  192. socks_params: &SocksParams,
  193. retry: Duration,
  194. ) -> Result<OwnedWriteHalf, FatalError> {
  195. // first check if the listener thread already has a socket
  196. if let Some(wr) = write_socket_updater.maybe_recv() {
  197. return Ok(wr);
  198. }
  199. // immediately try to connect to the peer
  200. tokio::select! {
  201. connection_attempt = connect(socks_params) => {
  202. if let Ok(mut stream) = connection_attempt {
  203. log!(
  204. "connection attempt success from {} to {} on {}",
  205. &socks_params.user,
  206. &socks_params.recipient,
  207. &socks_params.target
  208. );
  209. stream
  210. .write_all(&mgen::serialize_str(&socks_params.user))
  211. .await?;
  212. let (rd, wr) = stream.into_split();
  213. read_socket_updater.send(rd);
  214. return Ok(wr);
  215. } else if let Err(MessengerError::Fatal(e)) = connection_attempt {
  216. return Err(e);
  217. }
  218. }
  219. stream = write_socket_updater.recv() => {return Ok(stream);},
  220. }
  221. // Usually we'll have returned by now, but sometimes we'll fail to
  222. // connect for whatever reason. Initiate a loop of waiting Duration,
  223. // then trying to connect again, allowing it to be inerrupted by
  224. // the listener thread.
  225. loop {
  226. match error_collector(
  227. write_socket_updater,
  228. read_socket_updater,
  229. socks_params,
  230. retry,
  231. )
  232. .await
  233. {
  234. Ok(wr) => return Ok(wr),
  235. Err(MessengerError::Recoverable(_)) => continue,
  236. Err(MessengerError::Fatal(e)) => return Err(e),
  237. }
  238. }
  239. async fn error_collector<'a>(
  240. write_socket_updater: &mut WriteSocketUpdaterOut,
  241. read_socket_updater: &ReadSocketUpdaterIn,
  242. socks_params: &SocksParams,
  243. retry: Duration,
  244. ) -> Result<OwnedWriteHalf, MessengerError> {
  245. tokio::select! {
  246. () = tokio::time::sleep(retry) => {
  247. let mut stream = connect(socks_params)
  248. .await?;
  249. stream.write_all(&mgen::serialize_str(&socks_params.user)).await?;
  250. let (rd, wr) = stream.into_split();
  251. read_socket_updater.send(rd);
  252. Ok(wr)
  253. },
  254. stream = write_socket_updater.recv() => Ok(stream),
  255. }
  256. }
  257. }
  258. }
  259. fn parse_hosts_file(file_contents: &str) -> HashMap<&str, &str> {
  260. let mut ret = HashMap::new();
  261. for line in file_contents.lines() {
  262. let mut words = line.split_ascii_whitespace();
  263. if let Some(addr) = words.next() {
  264. for name in words {
  265. ret.insert(name, addr);
  266. }
  267. }
  268. }
  269. ret
  270. }
  271. #[derive(Debug, Deserialize)]
  272. struct ConversationConfig {
  273. group: String,
  274. recipients: Vec<String>,
  275. bootstrap: Option<f64>,
  276. retry: Option<f64>,
  277. distributions: Option<ConfigDistributions>,
  278. }
  279. #[derive(Debug, Deserialize)]
  280. struct Config {
  281. user: String,
  282. socks: Option<String>,
  283. listen: Option<String>,
  284. bootstrap: f64,
  285. retry: f64,
  286. distributions: ConfigDistributions,
  287. conversations: Vec<ConversationConfig>,
  288. }
  289. fn process_config(
  290. config: Config,
  291. hosts_map: &HashMap<&str, &str>,
  292. handles: &mut Vec<JoinHandle<Result<(), FatalError>>>,
  293. ) -> Result<(), Box<dyn std::error::Error>> {
  294. struct ForIoThreads {
  295. state_to_writer: mpsc::UnboundedSender<MessageHolder>,
  296. writer_from_state: WriterFromState,
  297. reader_to_states: HashMap<String, ReaderToState>,
  298. str_params: SocksParams,
  299. retry: f64,
  300. }
  301. // map from `recipient` to things the (user, recipient) reader/writer threads will need
  302. let mut recipient_map = HashMap::<String, ForIoThreads>::new();
  303. for conversation in config.conversations.into_iter() {
  304. let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
  305. let mut conversation_recipient_map =
  306. HashMap::<String, StateToWriter<MessageHolder>>::with_capacity(
  307. conversation.recipients.len(),
  308. );
  309. for recipient in conversation.recipients.iter() {
  310. let for_io = recipient_map
  311. .entry(recipient.to_string())
  312. .and_modify(|e| {
  313. e.reader_to_states
  314. .entry(conversation.group.clone())
  315. .or_insert_with(|| reader_to_state.clone());
  316. })
  317. .or_insert_with(|| {
  318. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  319. let mut reader_to_states = HashMap::new();
  320. reader_to_states.insert(conversation.group.clone(), reader_to_state.clone());
  321. let address = hosts_map
  322. .get(recipient.as_str())
  323. .unwrap_or_else(|| panic!("recipient not in hosts file: {}", recipient));
  324. let str_params = SocksParams {
  325. socks: config.socks.clone(),
  326. target: address.to_string(),
  327. user: config.user.clone(),
  328. recipient: recipient.clone(),
  329. };
  330. let retry = conversation.retry.unwrap_or(config.retry);
  331. ForIoThreads {
  332. state_to_writer,
  333. writer_from_state,
  334. reader_to_states,
  335. str_params,
  336. retry,
  337. }
  338. });
  339. let state_to_writer = for_io.state_to_writer.clone();
  340. conversation_recipient_map.insert(
  341. recipient.clone(),
  342. StateToWriter {
  343. channel: state_to_writer,
  344. },
  345. );
  346. }
  347. let distributions: Distributions = conversation
  348. .distributions
  349. .unwrap_or_else(|| config.distributions.clone())
  350. .try_into()?;
  351. let bootstrap = conversation.bootstrap.unwrap_or(config.bootstrap);
  352. tokio::spawn(manage_conversation(
  353. config.user.clone(),
  354. conversation.group,
  355. distributions,
  356. bootstrap,
  357. state_from_reader,
  358. conversation_recipient_map,
  359. ));
  360. }
  361. let mut name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)> =
  362. HashMap::new();
  363. for (recipient, for_io) in recipient_map.drain() {
  364. let listener_writer_to_reader = Updater::new();
  365. let reader_from_listener_writer = listener_writer_to_reader.clone();
  366. let listener_to_writer = Updater::new();
  367. let writer_from_listener = listener_to_writer.clone();
  368. name_to_io_threads.insert(
  369. recipient.to_string(),
  370. (listener_writer_to_reader.clone(), listener_to_writer),
  371. );
  372. tokio::spawn(reader(reader_from_listener_writer, for_io.reader_to_states));
  373. let retry = Duration::from_secs_f64(for_io.retry);
  374. let handle: JoinHandle<Result<(), FatalError>> = tokio::spawn(writer(
  375. for_io.writer_from_state,
  376. writer_from_listener,
  377. listener_writer_to_reader,
  378. for_io.str_params,
  379. retry,
  380. ));
  381. handles.push(handle);
  382. }
  383. let address = if let Some(address) = config.listen {
  384. address
  385. } else {
  386. hosts_map
  387. .get(config.user.as_str())
  388. .unwrap_or_else(|| panic!("user not found in hosts file: {}", config.user))
  389. .to_string()
  390. };
  391. let handle: JoinHandle<Result<(), FatalError>> =
  392. tokio::spawn(listener(address, name_to_io_threads));
  393. handles.push(handle);
  394. Ok(())
  395. }
  396. #[tokio::main]
  397. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  398. let mut args = std::env::args();
  399. let _ = args.next();
  400. let hosts_file = std::fs::read_to_string(args.next().unwrap())?;
  401. let hosts_map = parse_hosts_file(&hosts_file);
  402. let mut handles = vec![];
  403. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  404. let yaml_s = std::fs::read_to_string(config_file?)?;
  405. let config: Config = serde_yaml::from_str(&yaml_s)?;
  406. process_config(config, &hosts_map, &mut handles)?;
  407. }
  408. let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
  409. for handle in handles {
  410. handle.await??;
  411. }
  412. Ok(())
  413. }