mgen-server.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
  2. use std::collections::HashMap;
  3. use std::error::Error;
  4. use std::io::BufReader;
  5. use std::result::Result;
  6. use std::sync::Arc;
  7. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  8. use tokio::net::{TcpListener, TcpStream};
  9. use tokio::sync::{mpsc, Notify, RwLock};
  10. use tokio_rustls::{rustls::PrivateKey, server::TlsStream};
  11. // FIXME: identifiers should be interned
  12. type ID = String;
  13. type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
  14. #[tokio::main]
  15. async fn main() -> Result<(), Box<dyn Error>> {
  16. let mut args = std::env::args();
  17. let _arg0 = args.next().unwrap();
  18. let cert_filename = args
  19. .next()
  20. .unwrap_or_else(|| panic!("no cert file provided"));
  21. let key_filename = args
  22. .next()
  23. .unwrap_or_else(|| panic!("no key file provided"));
  24. let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
  25. let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
  26. let mut reader = BufReader::new(certfile);
  27. let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
  28. .unwrap()
  29. .iter()
  30. .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
  31. .collect();
  32. let key = load_private_key(&key_filename);
  33. let config = tokio_rustls::rustls::ServerConfig::builder()
  34. .with_safe_default_cipher_suites()
  35. .with_safe_default_kx_groups()
  36. .with_safe_default_protocol_versions()
  37. .unwrap()
  38. .with_no_client_auth()
  39. .with_single_cert(certs, key)?;
  40. let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
  41. let listener = TcpListener::bind(&listen_addr).await?;
  42. log!("listening,{}", listen_addr);
  43. // Maps group name to the table of message channels.
  44. let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
  45. // Maps the (sender, group) pair to the socket updater.
  46. let mut writer_db =
  47. HashMap::<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>::new();
  48. loop {
  49. let stream = match listener.accept().await {
  50. Ok((stream, _)) => stream,
  51. Err(e) => {
  52. log!("failed,accept,{}", e.kind());
  53. continue;
  54. }
  55. };
  56. let acceptor = acceptor.clone();
  57. let stream = match acceptor.accept(stream).await {
  58. Ok(stream) => stream,
  59. Err(e) => {
  60. log!("failed,tls,{}", e.kind());
  61. continue;
  62. }
  63. };
  64. let (mut rd, wr) = split(stream);
  65. let handshake = match mgen::get_handshake(&mut rd).await {
  66. Ok(handshake) => handshake,
  67. Err(mgen::Error::Io(e)) => {
  68. log!("failed,handshake,{}", e.kind());
  69. continue;
  70. }
  71. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  72. Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
  73. };
  74. log!("accept,{},{}", handshake.sender, handshake.group);
  75. if let Some(socket_updater) = writer_db.get(&handshake) {
  76. // we've seen this client before
  77. // start the new reader thread with a new notify
  78. // (we can't use the existing notify channel, else we get race conditions where
  79. // the reader thread terminates and spawns again before the sender thread
  80. // notices and activates its existing notify channel)
  81. let notify = Arc::new(Notify::new());
  82. let db = snd_db[&handshake.group].clone();
  83. spawn_message_receiver(handshake.sender, handshake.group, rd, db, notify.clone());
  84. // give the writer thread the new write half of the socket and notify
  85. socket_updater.send((wr, notify));
  86. } else {
  87. // newly-registered client
  88. log!("register,{},{}", handshake.sender, handshake.group);
  89. // message channel, for sending messages between threads
  90. let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
  91. let group_snds = snd_db
  92. .entry(handshake.group.clone())
  93. .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
  94. group_snds
  95. .write()
  96. .await
  97. .insert(handshake.sender.clone(), msg_snd);
  98. // socket notify, for terminating the socket if the sender encounters an error
  99. let notify = Arc::new(Notify::new());
  100. // socket updater, for giving the sender thread a new socket + notify channel
  101. let socket_updater_snd = Updater::new();
  102. let socket_updater_rcv = socket_updater_snd.clone();
  103. socket_updater_snd.send((wr, notify.clone()));
  104. spawn_message_receiver(
  105. handshake.sender.clone(),
  106. handshake.group.clone(),
  107. rd,
  108. group_snds.clone(),
  109. notify,
  110. );
  111. let sender = handshake.sender.clone();
  112. let group = handshake.group.clone();
  113. tokio::spawn(async move {
  114. send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
  115. });
  116. writer_db.insert(handshake, socket_updater_snd);
  117. }
  118. }
  119. }
  120. fn spawn_message_receiver(
  121. sender: String,
  122. group: String,
  123. rd: ReadHalf<TlsStream<TcpStream>>,
  124. db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  125. notify: Arc<Notify>,
  126. ) {
  127. tokio::spawn(async move {
  128. tokio::select! {
  129. // n.b.: get_message is not cancellation safe,
  130. // but this is one of the cases where that's expected
  131. // (we only cancel when something is wrong with the stream anyway)
  132. ret = get_messages(&sender, &group, rd, db) => {
  133. match ret {
  134. Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
  135. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  136. Err(mgen::Error::MalformedSerialization(v, b)) => panic!(
  137. "Malformed Serialization: {:?}\n{:?})", v, b),
  138. Ok(()) => panic!("Message receiver returned OK"),
  139. }
  140. }
  141. _ = notify.notified() => {
  142. log!("terminated,{},{}", sender, group);
  143. // should cause get_messages to terminate, dropping the socket
  144. }
  145. }
  146. });
  147. }
  148. /// Loop for receiving messages on the socket, figuring out who to deliver them to,
  149. /// and forwarding them locally to the respective channel.
  150. async fn get_messages<T: tokio::io::AsyncRead>(
  151. sender: &str,
  152. group: &str,
  153. mut socket: ReadHalf<T>,
  154. global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  155. ) -> Result<(), mgen::Error> {
  156. // Wait for the next message to be received before populating our local copy of the db,
  157. // that way other clients have time to register.
  158. let buf = mgen::get_message_bytes(&mut socket).await?;
  159. let db = global_db.read().await.clone();
  160. let message_channels: Vec<_> = db
  161. .iter()
  162. .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
  163. .collect();
  164. let message = MessageHeaderRef::deserialize(&buf[4..])?;
  165. assert!(message.sender == sender);
  166. match message.body {
  167. MessageBody::Size(_) => {
  168. assert!(message.group == group);
  169. log!("received,{},{},{}", sender, group, message.id);
  170. let body = message.body;
  171. let m = Arc::new(SerializedMessage { header: buf, body });
  172. for recipient in message_channels.iter() {
  173. recipient.send(m.clone()).unwrap();
  174. }
  175. }
  176. MessageBody::Receipt => {
  177. log!(
  178. "receipt,{},{},{},{}",
  179. sender,
  180. group,
  181. message.group,
  182. message.id
  183. );
  184. let recipient = &db[message.group];
  185. let body = message.body;
  186. let m = Arc::new(SerializedMessage { header: buf, body });
  187. recipient.send(m).unwrap();
  188. }
  189. }
  190. // we never have to update the DB again, so repeat the above, skipping that step
  191. loop {
  192. let buf = mgen::get_message_bytes(&mut socket).await?;
  193. let message = MessageHeaderRef::deserialize(&buf[4..])?;
  194. assert!(message.sender == sender);
  195. match message.body {
  196. MessageBody::Size(_) => {
  197. assert!(message.group == group);
  198. log!("received,{},{},{}", sender, group, message.id);
  199. let body = message.body;
  200. let m = Arc::new(SerializedMessage { header: buf, body });
  201. for recipient in message_channels.iter() {
  202. recipient.send(m.clone()).unwrap();
  203. }
  204. }
  205. MessageBody::Receipt => {
  206. log!(
  207. "receipt,{},{},{},{}",
  208. sender,
  209. group,
  210. message.group,
  211. message.id
  212. );
  213. let recipient = &db[message.group];
  214. let body = message.body;
  215. let m = Arc::new(SerializedMessage { header: buf, body });
  216. recipient.send(m).unwrap();
  217. }
  218. }
  219. }
  220. }
  221. /// Loop for receiving messages on the mpsc channel for this recipient,
  222. /// and sending them out on the associated socket.
  223. async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
  224. recipient: ID,
  225. group: ID,
  226. mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
  227. mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
  228. ) {
  229. let (mut current_socket, mut current_watch) = socket_updater.recv().await;
  230. let mut message_cache = None;
  231. loop {
  232. let message = if message_cache.is_none() {
  233. msg_rcv.recv().await.expect("message channel closed")
  234. } else {
  235. message_cache.unwrap()
  236. };
  237. if message.write_all_to(&mut current_socket).await.is_err()
  238. || current_socket.flush().await.is_err()
  239. {
  240. message_cache = Some(message);
  241. log!("terminating,{},{}", recipient, group);
  242. // socket is presumably closed, clean up and notify the listening end to close
  243. // (all best-effort, we can ignore errors because it presumably means it's done)
  244. current_watch.notify_one();
  245. let _ = current_socket.shutdown().await;
  246. // wait for the new socket
  247. (current_socket, current_watch) = socket_updater.recv().await;
  248. } else {
  249. log!("sent,{},{}", recipient, group);
  250. message_cache = None;
  251. }
  252. }
  253. }
  254. fn load_private_key(filename: &str) -> PrivateKey {
  255. let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
  256. let mut reader = BufReader::new(keyfile);
  257. loop {
  258. match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
  259. Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
  260. Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
  261. Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
  262. None => break,
  263. _ => {}
  264. }
  265. }
  266. panic!(
  267. "no keys found in {:?} (encrypted keys not supported)",
  268. filename
  269. );
  270. }