mgen-peer.rs 16 KB


  1. // Code specific to the peer in the p2p mode.
  2. use mgen::{log, updater::Updater, MessageHeader, SerializedMessage};
  3. use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
  4. use serde::Deserialize;
  5. use std::collections::HashMap;
  6. use std::result::Result;
  7. use std::sync::Arc;
  8. use tokio::io::AsyncWriteExt;
  9. use tokio::net::{
  10. tcp::{OwnedReadHalf, OwnedWriteHalf},
  11. TcpListener,
  12. };
  13. use tokio::sync::mpsc;
  14. use tokio::task::JoinHandle;
  15. use tokio::time::Duration;
  16. mod messenger;
  17. use crate::messenger::dists::{ConfigDistributions, Distributions};
  18. use crate::messenger::error::{FatalError, MessengerError};
  19. use crate::messenger::state::{
  20. manage_active_conversation, manage_idle_conversation, StateFromReader, StateMachine,
  21. StateToWriter,
  22. };
  23. use crate::messenger::tcp::{connect, SocksParams};
  24. /// Type for sending messages from the reader thread to the state thread.
  25. type ReaderToState = mpsc::UnboundedSender<MessageHeader>;
  26. /// Type for getting messages from the state thread in the writer thread.
  27. type WriterFromState = mpsc::UnboundedReceiver<Arc<SerializedMessage>>;
  28. /// Type for sending messages from the state thread to the writer thread.
  29. type MessageHolder = Arc<SerializedMessage>;
  30. /// Type for sending the updated read half of the socket.
  31. type ReadSocketUpdaterIn = Updater<OwnedReadHalf>;
  32. /// Type for getting the updated read half of the socket.
  33. type ReadSocketUpdaterOut = Updater<OwnedReadHalf>;
  34. /// Type for sending the updated write half of the socket.
  35. type WriteSocketUpdaterIn = Updater<OwnedWriteHalf>;
  36. /// Type for getting the updated write half of the socket.
  37. type WriteSocketUpdaterOut = Updater<OwnedWriteHalf>;
  38. /// The conversation (state) thread tracks the conversation state
  39. /// (i.e., whether the user is active or idle, and when to send messages).
  40. /// One state thread per conversation.
  41. async fn manage_conversation(
  42. user: String,
  43. group: String,
  44. distributions: Distributions,
  45. bootstrap: f64,
  46. mut state_from_reader: StateFromReader,
  47. mut state_to_writers: HashMap<String, StateToWriter<MessageHolder>>,
  48. ) {
  49. let mut rng = Xoshiro256PlusPlus::from_entropy();
  50. let user = &user;
  51. let group = &group;
  52. let mut state_machine = StateMachine::start(distributions, &mut rng);
  53. tokio::time::sleep(Duration::from_secs_f64(bootstrap)).await;
  54. loop {
  55. state_machine = match state_machine {
  56. StateMachine::Idle(conversation) => {
  57. manage_idle_conversation(
  58. conversation,
  59. &mut state_from_reader,
  60. &mut state_to_writers,
  61. user,
  62. group,
  63. true,
  64. &mut rng,
  65. )
  66. .await
  67. }
  68. StateMachine::Active(conversation) => {
  69. manage_active_conversation(
  70. conversation,
  71. &mut state_from_reader,
  72. &mut state_to_writers,
  73. user,
  74. group,
  75. true,
  76. &mut rng,
  77. )
  78. .await
  79. }
  80. };
  81. }
  82. }
  83. /// The listener thread listens for inbound connections on the given address.
  84. /// It breaks those connections into reader and writer halves,
  85. /// and gives them to the correct reader and writer threads.
  86. /// One listener thread per user.
  87. async fn listener(
  88. address: String,
  89. name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)>,
  90. ) -> Result<(), FatalError> {
  91. let listener = TcpListener::bind(&address).await?;
  92. log!("listening on {}", &address);
  93. async fn error_collector(
  94. address: &str,
  95. listener: &TcpListener,
  96. name_to_io_threads: &HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)>,
  97. ) -> Result<(), MessengerError> {
  98. let (stream, _) = listener.accept().await?;
  99. let (mut rd, wr) = stream.into_split();
  100. let from = mgen::parse_identifier(&mut rd).await?;
  101. let (channel_to_reader, channel_to_writer) = name_to_io_threads
  102. .get(&from)
  103. .unwrap_or_else(|| panic!("{} got connection from unknown contact: {}", address, from));
  104. channel_to_reader.send(rd);
  105. channel_to_writer.send(wr);
  106. Ok(())
  107. }
  108. loop {
  109. if let Err(MessengerError::Fatal(e)) =
  110. error_collector(&address, &listener, &name_to_io_threads).await
  111. {
  112. return Err(e);
  113. }
  114. }
  115. }
  116. /// The reader thread reads messages from the socket it has been given,
  117. /// and sends them to the correct state thread.
  118. /// One reader thread per (user, recipient) pair.
  119. async fn reader(
  120. mut connection_channel: ReadSocketUpdaterOut,
  121. group_to_conversation_thread: HashMap<String, ReaderToState>,
  122. ) {
  123. loop {
  124. // wait for listener or writer thread to give us a stream to read from
  125. let mut stream = connection_channel.recv().await;
  126. loop {
  127. let msg = if let Ok(msg) = mgen::get_message(&mut stream).await {
  128. msg
  129. } else {
  130. // Unlike the client-server case, we can assume that if there
  131. // were a message someone was trying to send us, they'd make
  132. // sure to re-establish the connection; so when the socket
  133. // breaks, don't bother trying to reform it until we need to
  134. // send a message or the peer reaches out to us.
  135. break;
  136. };
  137. let channel_to_conversation = group_to_conversation_thread
  138. .get(&msg.group)
  139. .unwrap_or_else(|| panic!("Unknown group: {}", msg.group));
  140. channel_to_conversation
  141. .send(msg)
  142. .expect("reader: Channel to group closed");
  143. }
  144. }
  145. }
  146. /// The writer thread takes in messages from state threads,
  147. /// and sends it to the recipient associated with this thread.
  148. /// If it doesn't have a socket from the listener thread,
  149. /// it'll create its own and give the read half to the reader thread.
  150. /// One writer thread per (user, recipient) pair.
  151. async fn writer<'a>(
  152. mut messages_to_send: WriterFromState,
  153. mut write_socket_updater: WriteSocketUpdaterOut,
  154. read_socket_updater: ReadSocketUpdaterIn,
  155. socks_params: SocksParams,
  156. retry: Duration,
  157. ) -> Result<(), FatalError> {
  158. // make sure this is the first step to avoid connections until there's
  159. // something to send
  160. let mut msg = messages_to_send
  161. .recv()
  162. .await
  163. .expect("writer: Channel from conversations closed");
  164. let mut stream = establish_connection(
  165. &mut write_socket_updater,
  166. &read_socket_updater,
  167. &socks_params,
  168. retry,
  169. )
  170. .await
  171. .expect("Fatal error establishing connection");
  172. loop {
  173. while msg.write_all_to(&mut stream).await.is_err() {
  174. stream = establish_connection(
  175. &mut write_socket_updater,
  176. &read_socket_updater,
  177. &socks_params,
  178. retry,
  179. )
  180. .await
  181. .expect("Fatal error establishing connection");
  182. }
  183. msg = messages_to_send
  184. .recv()
  185. .await
  186. .expect("writer: Channel from conversations closed");
  187. }
  188. // helper functions
  189. /// Attempt to get a connection to the peer,
  190. /// whether by getting an existing connection from the listener,
  191. /// or by establishing a new connection.
  192. async fn establish_connection<'a>(
  193. write_socket_updater: &mut WriteSocketUpdaterOut,
  194. read_socket_updater: &ReadSocketUpdaterIn,
  195. socks_params: &SocksParams,
  196. retry: Duration,
  197. ) -> Result<OwnedWriteHalf, FatalError> {
  198. // first check if the listener thread already has a socket
  199. if let Some(wr) = write_socket_updater.maybe_recv() {
  200. return Ok(wr);
  201. }
  202. // immediately try to connect to the peer
  203. tokio::select! {
  204. connection_attempt = connect(socks_params) => {
  205. if let Ok(mut stream) = connection_attempt {
  206. log!(
  207. "connection attempt success from {} to {} on {}",
  208. &socks_params.user,
  209. &socks_params.recipient,
  210. &socks_params.target
  211. );
  212. stream
  213. .write_all(&mgen::serialize_str(&socks_params.user))
  214. .await?;
  215. let (rd, wr) = stream.into_split();
  216. read_socket_updater.send(rd);
  217. return Ok(wr);
  218. } else if let Err(MessengerError::Fatal(e)) = connection_attempt {
  219. return Err(e);
  220. }
  221. }
  222. stream = write_socket_updater.recv() => {return Ok(stream);},
  223. }
  224. // Usually we'll have returned by now, but sometimes we'll fail to
  225. // connect for whatever reason. Initiate a loop of waiting Duration,
  226. // then trying to connect again, allowing it to be inerrupted by
  227. // the listener thread.
  228. loop {
  229. match error_collector(
  230. write_socket_updater,
  231. read_socket_updater,
  232. socks_params,
  233. retry,
  234. )
  235. .await
  236. {
  237. Ok(wr) => return Ok(wr),
  238. Err(MessengerError::Recoverable(_)) => continue,
  239. Err(MessengerError::Fatal(e)) => return Err(e),
  240. }
  241. }
  242. async fn error_collector<'a>(
  243. write_socket_updater: &mut WriteSocketUpdaterOut,
  244. read_socket_updater: &ReadSocketUpdaterIn,
  245. socks_params: &SocksParams,
  246. retry: Duration,
  247. ) -> Result<OwnedWriteHalf, MessengerError> {
  248. tokio::select! {
  249. () = tokio::time::sleep(retry) => {
  250. let mut stream = connect(socks_params)
  251. .await?;
  252. stream.write_all(&mgen::serialize_str(&socks_params.user)).await?;
  253. let (rd, wr) = stream.into_split();
  254. read_socket_updater.send(rd);
  255. Ok(wr)
  256. },
  257. stream = write_socket_updater.recv() => Ok(stream),
  258. }
  259. }
  260. }
  261. }
  262. fn parse_hosts_file(file_contents: &str) -> HashMap<&str, &str> {
  263. let mut ret = HashMap::new();
  264. for line in file_contents.lines() {
  265. let mut words = line.split_ascii_whitespace();
  266. if let Some(addr) = words.next() {
  267. for name in words {
  268. ret.insert(name, addr);
  269. }
  270. }
  271. }
  272. ret
  273. }
  274. #[derive(Debug, Deserialize)]
  275. struct ConversationConfig {
  276. group: String,
  277. recipients: Vec<String>,
  278. bootstrap: f64,
  279. retry: f64,
  280. distributions: ConfigDistributions,
  281. }
  282. #[derive(Debug, Deserialize)]
  283. struct Config {
  284. user: String,
  285. socks: Option<String>,
  286. listen: Option<String>,
  287. conversations: Vec<ConversationConfig>,
  288. }
  289. fn process_config(
  290. config: Config,
  291. hosts_map: &HashMap<&str, &str>,
  292. handles: &mut Vec<JoinHandle<Result<(), FatalError>>>,
  293. ) -> Result<(), Box<dyn std::error::Error>> {
  294. struct ForIoThreads {
  295. state_to_writer: mpsc::UnboundedSender<MessageHolder>,
  296. writer_from_state: WriterFromState,
  297. reader_to_states: HashMap<String, ReaderToState>,
  298. str_params: SocksParams,
  299. retry: f64,
  300. }
  301. // map from `recipient` to things the (user, recipient) reader/writer threads will need
  302. let mut recipient_map = HashMap::<String, ForIoThreads>::new();
  303. for conversation in config.conversations.into_iter() {
  304. let (reader_to_state, state_from_reader) = mpsc::unbounded_channel();
  305. let mut conversation_recipient_map =
  306. HashMap::<String, StateToWriter<MessageHolder>>::with_capacity(
  307. conversation.recipients.len(),
  308. );
  309. for recipient in conversation.recipients.iter() {
  310. let state_to_writer = if !recipient_map.contains_key(recipient) {
  311. let (state_to_writer, writer_from_state) = mpsc::unbounded_channel();
  312. let mut reader_to_states = HashMap::new();
  313. reader_to_states.insert(conversation.group.clone(), reader_to_state.clone());
  314. let address = hosts_map
  315. .get(recipient.as_str())
  316. .unwrap_or_else(|| panic!("recipient not in hosts file: {}", recipient));
  317. let str_params = SocksParams {
  318. socks: config.socks.clone(),
  319. target: address.to_string(),
  320. user: config.user.clone(),
  321. recipient: recipient.clone(),
  322. };
  323. let for_io = ForIoThreads {
  324. state_to_writer: state_to_writer.clone(),
  325. writer_from_state,
  326. reader_to_states,
  327. str_params,
  328. retry: conversation.retry,
  329. };
  330. recipient_map.insert(recipient.clone(), for_io);
  331. state_to_writer
  332. } else {
  333. let for_io = recipient_map.get_mut(recipient).unwrap();
  334. if !for_io.reader_to_states.contains_key(&conversation.group) {
  335. for_io
  336. .reader_to_states
  337. .insert(conversation.group.clone(), reader_to_state.clone());
  338. }
  339. for_io.state_to_writer.clone()
  340. };
  341. conversation_recipient_map.insert(
  342. recipient.clone(),
  343. StateToWriter {
  344. channel: state_to_writer,
  345. },
  346. );
  347. }
  348. let distributions: Distributions = conversation.distributions.try_into()?;
  349. tokio::spawn(manage_conversation(
  350. config.user.clone(),
  351. conversation.group,
  352. distributions,
  353. conversation.bootstrap,
  354. state_from_reader,
  355. conversation_recipient_map,
  356. ));
  357. }
  358. let mut name_to_io_threads: HashMap<String, (ReadSocketUpdaterIn, WriteSocketUpdaterIn)> =
  359. HashMap::new();
  360. for (recipient, for_io) in recipient_map.drain() {
  361. let listener_writer_to_reader = Updater::new();
  362. let reader_from_listener_writer = listener_writer_to_reader.clone();
  363. let listener_to_writer = Updater::new();
  364. let writer_from_listener = listener_to_writer.clone();
  365. name_to_io_threads.insert(
  366. recipient.to_string(),
  367. (listener_writer_to_reader.clone(), listener_to_writer),
  368. );
  369. tokio::spawn(reader(reader_from_listener_writer, for_io.reader_to_states));
  370. let retry = Duration::from_secs_f64(for_io.retry);
  371. let handle: JoinHandle<Result<(), FatalError>> = tokio::spawn(writer(
  372. for_io.writer_from_state,
  373. writer_from_listener,
  374. listener_writer_to_reader,
  375. for_io.str_params,
  376. retry,
  377. ));
  378. handles.push(handle);
  379. }
  380. let address = if let Some(address) = config.listen {
  381. address
  382. } else {
  383. hosts_map
  384. .get(config.user.as_str())
  385. .unwrap_or_else(|| panic!("user not found in hosts file: {}", config.user))
  386. .to_string()
  387. };
  388. let handle: JoinHandle<Result<(), FatalError>> =
  389. tokio::spawn(listener(address, name_to_io_threads));
  390. handles.push(handle);
  391. Ok(())
  392. }
  393. #[tokio::main]
  394. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  395. let mut args = std::env::args();
  396. let _ = args.next();
  397. let hosts_file = std::fs::read_to_string(args.next().unwrap())?;
  398. let hosts_map = parse_hosts_file(&hosts_file);
  399. println!("{:?}", hosts_map);
  400. let mut handles = vec![];
  401. for config_file in args.flat_map(|a| glob::glob(a.as_str()).unwrap()) {
  402. let yaml_s = std::fs::read_to_string(config_file?)?;
  403. let config: Config = serde_yaml::from_str(&yaml_s)?;
  404. process_config(config, &hosts_map, &mut handles)?;
  405. }
  406. let handles: futures::stream::FuturesUnordered<_> = handles.into_iter().collect();
  407. for handle in handles {
  408. handle.await??;
  409. }
  410. Ok(())
  411. }