mgen-server.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
  2. use std::collections::HashMap;
  3. use std::error::Error;
  4. use std::io::BufReader;
  5. use std::result::Result;
  6. use std::sync::Arc;
  7. use std::time::Duration;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::{TcpListener, TcpStream};
  10. use tokio::sync::{mpsc, Notify, RwLock};
  11. use tokio::time::{timeout_at, Instant};
  12. use tokio_rustls::{rustls::PrivateKey, server::TlsStream, TlsAcceptor};
  13. // FIXME: identifiers should be interned
  14. type ID = String;
  15. type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
  16. type WriterDb = HashMap<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>;
  17. type SndDb = HashMap<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>;
  18. #[tokio::main(flavor = "multi_thread", worker_threads = 10)]
  19. async fn main() -> Result<(), Box<dyn Error>> {
  20. let mut args = std::env::args();
  21. let _arg0 = args.next().unwrap();
  22. let cert_filename = args
  23. .next()
  24. .unwrap_or_else(|| panic!("no cert file provided"));
  25. let key_filename = args
  26. .next()
  27. .unwrap_or_else(|| panic!("no key file provided"));
  28. let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
  29. let reg_time = args.next().unwrap_or("5".to_string()).parse()?;
  30. let reg_time = Instant::now() + Duration::from_secs(reg_time);
  31. let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
  32. let mut reader = BufReader::new(certfile);
  33. let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
  34. .unwrap()
  35. .iter()
  36. .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
  37. .collect();
  38. let key = load_private_key(&key_filename);
  39. let config = tokio_rustls::rustls::ServerConfig::builder()
  40. .with_safe_default_cipher_suites()
  41. .with_safe_default_kx_groups()
  42. .with_safe_default_protocol_versions()
  43. .unwrap()
  44. .with_no_client_auth()
  45. .with_single_cert(certs, key)?;
  46. let acceptor = TlsAcceptor::from(Arc::new(config));
  47. let listener = TcpListener::bind(&listen_addr).await?;
  48. log!("listening,{}", listen_addr);
  49. // Maps group name to the table of message channels.
  50. let mut snd_db = SndDb::new();
  51. // Maps the (sender, group) pair to the socket updater.
  52. let mut writer_db = WriterDb::new();
  53. // Notifies listener threads when registration phase is over.
  54. let phase_notify = Arc::new(Notify::new());
  55. // Allow registering or reconnecting during the registration time.
  56. while let Ok(accepted) = timeout_at(reg_time, listener.accept()).await {
  57. let stream = match accepted {
  58. Ok((stream, _)) => stream,
  59. Err(e) => {
  60. log!("failed,accept,{}", e.kind());
  61. continue;
  62. }
  63. };
  64. log!("accepted {}", stream.peer_addr()?);
  65. handle_handshake::</*REGISTRATION_PHASE=*/ true>(
  66. stream,
  67. acceptor.clone(),
  68. &mut writer_db,
  69. &mut snd_db,
  70. phase_notify.clone(),
  71. )
  72. .await;
  73. }
  74. log!("registration phase complete");
  75. // Notify all the listener threads that registration is over.
  76. phase_notify.notify_waiters();
  77. // Now registration phase is over, only allow reconnecting.
  78. loop {
  79. let stream = match listener.accept().await {
  80. Ok((stream, _)) => stream,
  81. Err(e) => {
  82. log!("failed,accept,{}", e.kind());
  83. continue;
  84. }
  85. };
  86. handle_handshake::</*REGISTRATION_PHASE=*/ false>(
  87. stream,
  88. acceptor.clone(),
  89. &mut writer_db,
  90. &mut snd_db,
  91. phase_notify.clone(),
  92. )
  93. .await;
  94. }
  95. }
  96. async fn handle_handshake<const REGISTRATION_PHASE: bool>(
  97. stream: TcpStream,
  98. acceptor: TlsAcceptor,
  99. writer_db: &mut WriterDb,
  100. snd_db: &mut SndDb,
  101. phase_notify: Arc<Notify>,
  102. ) {
  103. let stream = match acceptor.accept(stream).await {
  104. Ok(stream) => stream,
  105. Err(e) => {
  106. log!("failed,tls,{}", e.kind());
  107. return;
  108. }
  109. };
  110. let (mut rd, wr) = split(stream);
  111. let handshake = match mgen::get_handshake(&mut rd).await {
  112. Ok(handshake) => handshake,
  113. Err(mgen::Error::Io(e)) => {
  114. log!("failed,handshake,{}", e.kind());
  115. return;
  116. }
  117. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  118. Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
  119. };
  120. log!("accept,{},{}", handshake.sender, handshake.group);
  121. if let Some(socket_updater) = writer_db.get(&handshake) {
  122. // we've seen this client before
  123. // start the new reader thread with a new notify
  124. // (we can't use the existing notify channel, else we get race conditions where
  125. // the reader thread terminates and spawns again before the sender thread
  126. // notices and activates its existing notify channel)
  127. let socket_notify = Arc::new(Notify::new());
  128. let db = snd_db[&handshake.group].clone();
  129. spawn_message_receiver(
  130. handshake.sender,
  131. handshake.group,
  132. rd,
  133. db,
  134. phase_notify,
  135. socket_notify.clone(),
  136. );
  137. // give the writer thread the new write half of the socket and notify
  138. socket_updater.send((wr, socket_notify));
  139. } else {
  140. // newly-registered client
  141. log!("register,{},{}", handshake.sender, handshake.group);
  142. if REGISTRATION_PHASE {
  143. // message channel, for sending messages between threads
  144. let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
  145. let group_snds = snd_db
  146. .entry(handshake.group.clone())
  147. .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
  148. group_snds
  149. .write()
  150. .await
  151. .insert(handshake.sender.clone(), msg_snd);
  152. // socket notify, for terminating the socket if the sender encounters an error
  153. let socket_notify = Arc::new(Notify::new());
  154. // socket updater, for giving the sender thread a new socket + notify channel
  155. let socket_updater_snd = Updater::new();
  156. let socket_updater_rcv = socket_updater_snd.clone();
  157. socket_updater_snd.send((wr, socket_notify.clone()));
  158. spawn_message_receiver(
  159. handshake.sender.clone(),
  160. handshake.group.clone(),
  161. rd,
  162. group_snds.clone(),
  163. phase_notify,
  164. socket_notify,
  165. );
  166. let sender = handshake.sender.clone();
  167. let group = handshake.group.clone();
  168. tokio::spawn(async move {
  169. send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
  170. });
  171. writer_db.insert(handshake, socket_updater_snd);
  172. } else {
  173. panic!(
  174. "late registration: {},{}",
  175. handshake.sender, handshake.group
  176. );
  177. }
  178. }
  179. }
  180. fn spawn_message_receiver(
  181. sender: String,
  182. group: String,
  183. rd: ReadHalf<TlsStream<TcpStream>>,
  184. db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  185. phase_notify: Arc<Notify>,
  186. socket_notify: Arc<Notify>,
  187. ) {
  188. tokio::spawn(async move {
  189. tokio::select! {
  190. // n.b.: get_message is not cancellation safe,
  191. // but this is one of the cases where that's expected
  192. // (we only cancel when something is wrong with the stream anyway)
  193. ret = get_messages(&sender, &group, rd, phase_notify, db) => {
  194. match ret {
  195. Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
  196. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  197. Err(mgen::Error::MalformedSerialization(v, b)) => panic!(
  198. "Malformed Serialization: {:?}\n{:?})", v, b),
  199. Ok(()) => panic!("Message receiver returned OK"),
  200. }
  201. }
  202. _ = socket_notify.notified() => {
  203. log!("terminated,{},{}", sender, group);
  204. // should cause get_messages to terminate, dropping the socket
  205. }
  206. }
  207. });
  208. }
  209. /// Loop for receiving messages on the socket, figuring out who to deliver them to,
  210. /// and forwarding them locally to the respective channel.
  211. async fn get_messages<T: tokio::io::AsyncRead>(
  212. sender: &str,
  213. group: &str,
  214. mut socket: ReadHalf<T>,
  215. phase_notify: Arc<Notify>,
  216. global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  217. ) -> Result<(), mgen::Error> {
  218. // Wait for the registration phase to end before updating our local copy of the DB
  219. phase_notify.notified().await;
  220. let db = global_db.read().await.clone();
  221. let message_channels: Vec<_> = db
  222. .iter()
  223. .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
  224. .collect();
  225. loop {
  226. let buf = mgen::get_message_bytes(&mut socket).await?;
  227. let message = MessageHeaderRef::deserialize(&buf[4..])?;
  228. assert!(message.sender == sender);
  229. match message.body {
  230. MessageBody::Size(_) => {
  231. assert!(message.group == group);
  232. log!("received,{},{},{}", sender, group, message.id);
  233. let body = message.body;
  234. let m = Arc::new(SerializedMessage { header: buf, body });
  235. for recipient in message_channels.iter() {
  236. recipient.send(m.clone()).unwrap();
  237. }
  238. }
  239. MessageBody::Receipt => {
  240. log!(
  241. "receipt,{},{},{},{}",
  242. sender,
  243. group,
  244. message.group,
  245. message.id
  246. );
  247. let recipient = &db[message.group];
  248. let body = message.body;
  249. let m = Arc::new(SerializedMessage { header: buf, body });
  250. recipient.send(m).unwrap();
  251. }
  252. }
  253. }
  254. }
  255. /// Loop for receiving messages on the mpsc channel for this recipient,
  256. /// and sending them out on the associated socket.
  257. async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
  258. recipient: ID,
  259. group: ID,
  260. mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
  261. mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
  262. ) {
  263. let (mut current_socket, mut current_watch) = socket_updater.recv().await;
  264. let mut message_cache = None;
  265. loop {
  266. let message = if let Some(message) = message_cache {
  267. message
  268. } else {
  269. msg_rcv.recv().await.expect("message channel closed")
  270. };
  271. if message.write_all_to(&mut current_socket).await.is_err()
  272. || current_socket.flush().await.is_err()
  273. {
  274. message_cache = Some(message);
  275. log!("terminating,{},{}", recipient, group);
  276. // socket is presumably closed, clean up and notify the listening end to close
  277. // (all best-effort, we can ignore errors because it presumably means it's done)
  278. current_watch.notify_one();
  279. let _ = current_socket.shutdown().await;
  280. // wait for the new socket
  281. (current_socket, current_watch) = socket_updater.recv().await;
  282. } else {
  283. log!("sent,{},{}", recipient, group);
  284. message_cache = None;
  285. }
  286. }
  287. }
  288. fn load_private_key(filename: &str) -> PrivateKey {
  289. let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
  290. let mut reader = BufReader::new(keyfile);
  291. loop {
  292. match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
  293. Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
  294. Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
  295. Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
  296. None => break,
  297. _ => {}
  298. }
  299. }
  300. panic!(
  301. "no keys found in {:?} (encrypted keys not supported)",
  302. filename
  303. );
  304. }