mgen-server.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
  2. use std::collections::HashMap;
  3. use std::error::Error;
  4. use std::io::BufReader;
  5. use std::result::Result;
  6. use std::sync::Arc;
  7. use std::time::Duration;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::{TcpSocket, TcpStream};
  10. use tokio::sync::{mpsc, Notify, RwLock};
  11. use tokio::time::{timeout_at, Instant};
  12. use tokio_rustls::{rustls::PrivateKey, server::TlsStream, TlsAcceptor};
  13. // FIXME: identifiers should be interned
  14. type ID = String;
  15. type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
  16. type WriterDb = HashMap<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>;
  17. type SndDb = HashMap<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>;
  18. #[tokio::main(flavor = "multi_thread", worker_threads = 10)]
  19. async fn main() -> Result<(), Box<dyn Error>> {
  20. let mut args = std::env::args();
  21. let _arg0 = args.next().unwrap();
  22. let cert_filename = args
  23. .next()
  24. .unwrap_or_else(|| panic!("no cert file provided"));
  25. let key_filename = args
  26. .next()
  27. .unwrap_or_else(|| panic!("no key file provided"));
  28. let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
  29. let reg_time = args.next().unwrap_or("5".to_string()).parse()?;
  30. let reg_time = Instant::now() + Duration::from_secs(reg_time);
  31. let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
  32. let mut reader = BufReader::new(certfile);
  33. let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
  34. .unwrap()
  35. .iter()
  36. .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
  37. .collect();
  38. let key = load_private_key(&key_filename);
  39. let config = tokio_rustls::rustls::ServerConfig::builder()
  40. .with_safe_default_cipher_suites()
  41. .with_safe_default_kx_groups()
  42. .with_safe_default_protocol_versions()
  43. .unwrap()
  44. .with_no_client_auth()
  45. .with_single_cert(certs, key)?;
  46. let acceptor = TlsAcceptor::from(Arc::new(config));
  47. let addr = listen_addr.parse().unwrap();
  48. let socket = TcpSocket::new_v4()?;
  49. socket.set_nodelay(true)?;
  50. socket.bind(addr)?;
  51. let listener = socket.listen(4096)?;
  52. log!("listening,{}", listen_addr);
  53. // Maps the (sender, group) pair to the socket updater.
  54. let writer_db = Arc::new(RwLock::new(WriterDb::new()));
  55. // Maps group name to the table of message channels.
  56. let snd_db = Arc::new(RwLock::new(SndDb::new()));
  57. // Notifies listener threads when registration phase is over.
  58. let phase_notify = Arc::new(Notify::new());
  59. // Allow registering or reconnecting during the registration time.
  60. while let Ok(accepted) = timeout_at(reg_time, listener.accept()).await {
  61. let stream = match accepted {
  62. Ok((stream, _)) => stream,
  63. Err(e) => {
  64. log!("failed,accept,{}", e.kind());
  65. continue;
  66. }
  67. };
  68. let acceptor = acceptor.clone();
  69. let writer_db = writer_db.clone();
  70. let snd_db = snd_db.clone();
  71. let phase_notify = phase_notify.clone();
  72. tokio::spawn(async move {
  73. handle_handshake::</*REGISTRATION_PHASE=*/ true>(
  74. stream,
  75. acceptor,
  76. writer_db,
  77. snd_db,
  78. phase_notify,
  79. )
  80. .await
  81. });
  82. }
  83. log!("registration phase complete");
  84. // Notify all the listener threads that registration is over.
  85. phase_notify.notify_waiters();
  86. // Now registration phase is over, only allow reconnecting.
  87. loop {
  88. let stream = match listener.accept().await {
  89. Ok((stream, _)) => stream,
  90. Err(e) => {
  91. log!("failed,accept,{}", e.kind());
  92. continue;
  93. }
  94. };
  95. let acceptor = acceptor.clone();
  96. let writer_db = writer_db.clone();
  97. let snd_db = snd_db.clone();
  98. let phase_notify = phase_notify.clone();
  99. tokio::spawn(async move {
  100. handle_handshake::</*REGISTRATION_PHASE=*/ false>(
  101. stream,
  102. acceptor,
  103. writer_db,
  104. snd_db,
  105. phase_notify,
  106. )
  107. .await
  108. });
  109. }
  110. }
  111. async fn handle_handshake<const REGISTRATION_PHASE: bool>(
  112. stream: TcpStream,
  113. acceptor: TlsAcceptor,
  114. writer_db: Arc<RwLock<WriterDb>>,
  115. snd_db: Arc<RwLock<SndDb>>,
  116. phase_notify: Arc<Notify>,
  117. ) {
  118. log!("accepted {}", stream.peer_addr().unwrap());
  119. let stream = match acceptor.accept(stream).await {
  120. Ok(stream) => stream,
  121. Err(e) => {
  122. log!("failed,tls,{}", e.kind());
  123. return;
  124. }
  125. };
  126. let (mut rd, wr) = split(stream);
  127. let handshake = match mgen::get_handshake(&mut rd).await {
  128. Ok(handshake) => handshake,
  129. Err(mgen::Error::Io(e)) => {
  130. log!("failed,handshake,{}", e.kind());
  131. return;
  132. }
  133. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  134. Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
  135. };
  136. log!("accept,{},{}", handshake.sender, handshake.group);
  137. let read_writer_db = writer_db.read().await;
  138. if let Some(socket_updater) = read_writer_db.get(&handshake) {
  139. // we've seen this client before
  140. // start the new reader thread with a new notify
  141. // (we can't use the existing notify channel, else we get race conditions where
  142. // the reader thread terminates and spawns again before the sender thread
  143. // notices and activates its existing notify channel)
  144. let socket_notify = Arc::new(Notify::new());
  145. let db = snd_db.read().await[&handshake.group].clone();
  146. spawn_message_receiver(
  147. handshake.sender,
  148. handshake.group,
  149. rd,
  150. db,
  151. phase_notify,
  152. socket_notify.clone(),
  153. );
  154. // give the writer thread the new write half of the socket and notify
  155. socket_updater.send((wr, socket_notify));
  156. } else {
  157. drop(read_writer_db);
  158. // newly-registered client
  159. log!("register,{},{}", handshake.sender, handshake.group);
  160. if REGISTRATION_PHASE {
  161. // message channel, for sending messages between threads
  162. let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
  163. let group_snds = {
  164. let read_snd_db = snd_db.read().await;
  165. let group_snds = read_snd_db.get(&handshake.group);
  166. if let Some(group_snds) = group_snds {
  167. let group_snds = group_snds.clone();
  168. drop(read_snd_db);
  169. group_snds
  170. .write()
  171. .await
  172. .insert(handshake.sender.clone(), msg_snd);
  173. group_snds
  174. } else {
  175. drop(read_snd_db);
  176. let mut write_snd_db = snd_db.write().await;
  177. let group_snds = write_snd_db
  178. .entry(handshake.group.clone())
  179. .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
  180. group_snds
  181. .write()
  182. .await
  183. .insert(handshake.sender.clone(), msg_snd);
  184. group_snds.clone()
  185. }
  186. };
  187. // socket notify, for terminating the socket if the sender encounters an error
  188. let socket_notify = Arc::new(Notify::new());
  189. // socket updater, for giving the sender thread a new socket + notify channel
  190. let socket_updater_snd = Updater::new();
  191. let socket_updater_rcv = socket_updater_snd.clone();
  192. socket_updater_snd.send((wr, socket_notify.clone()));
  193. spawn_message_receiver(
  194. handshake.sender.clone(),
  195. handshake.group.clone(),
  196. rd,
  197. group_snds,
  198. phase_notify,
  199. socket_notify,
  200. );
  201. let sender = handshake.sender.clone();
  202. let group = handshake.group.clone();
  203. tokio::spawn(async move {
  204. send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
  205. });
  206. writer_db
  207. .write()
  208. .await
  209. .insert(handshake, socket_updater_snd);
  210. } else {
  211. panic!(
  212. "late registration: {},{}",
  213. handshake.sender, handshake.group
  214. );
  215. };
  216. }
  217. }
  218. fn spawn_message_receiver(
  219. sender: String,
  220. group: String,
  221. rd: ReadHalf<TlsStream<TcpStream>>,
  222. db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  223. phase_notify: Arc<Notify>,
  224. socket_notify: Arc<Notify>,
  225. ) {
  226. tokio::spawn(async move {
  227. tokio::select! {
  228. // n.b.: get_message is not cancellation safe,
  229. // but this is one of the cases where that's expected
  230. // (we only cancel when something is wrong with the stream anyway)
  231. ret = get_messages(&sender, &group, rd, phase_notify, db) => {
  232. match ret {
  233. Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
  234. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  235. Err(mgen::Error::MalformedSerialization(v, b)) => panic!(
  236. "Malformed Serialization: {:?}\n{:?})", v, b),
  237. Ok(()) => panic!("Message receiver returned OK"),
  238. }
  239. }
  240. _ = socket_notify.notified() => {
  241. log!("terminated,{},{}", sender, group);
  242. // should cause get_messages to terminate, dropping the socket
  243. }
  244. }
  245. });
  246. }
  247. /// Loop for receiving messages on the socket, figuring out who to deliver them to,
  248. /// and forwarding them locally to the respective channel.
  249. async fn get_messages<T: tokio::io::AsyncRead>(
  250. sender: &str,
  251. group: &str,
  252. mut socket: ReadHalf<T>,
  253. phase_notify: Arc<Notify>,
  254. global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  255. ) -> Result<(), mgen::Error> {
  256. // Wait for the registration phase to end before updating our local copy of the DB
  257. phase_notify.notified().await;
  258. let db = global_db.read().await.clone();
  259. let message_channels: Vec<_> = db
  260. .iter()
  261. .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
  262. .collect();
  263. loop {
  264. let buf = mgen::get_message_bytes(&mut socket).await?;
  265. let message = MessageHeaderRef::deserialize(&buf[4..])?;
  266. assert!(message.sender == sender);
  267. match message.body {
  268. MessageBody::Size(_) => {
  269. assert!(message.group == group);
  270. log!("received,{},{},{}", sender, group, message.id);
  271. let body = message.body;
  272. let m = Arc::new(SerializedMessage { header: buf, body });
  273. for recipient in message_channels.iter() {
  274. recipient.send(m.clone()).unwrap();
  275. }
  276. }
  277. MessageBody::Receipt => {
  278. log!(
  279. "receipt,{},{},{},{}",
  280. sender,
  281. group,
  282. message.group,
  283. message.id
  284. );
  285. let recipient = &db[message.group];
  286. let body = message.body;
  287. let m = Arc::new(SerializedMessage { header: buf, body });
  288. recipient.send(m).unwrap();
  289. }
  290. }
  291. }
  292. }
  293. /// Loop for receiving messages on the mpsc channel for this recipient,
  294. /// and sending them out on the associated socket.
  295. async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
  296. recipient: ID,
  297. group: ID,
  298. mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
  299. mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
  300. ) {
  301. let (mut current_socket, mut current_watch) = socket_updater.recv().await;
  302. let mut message_cache = None;
  303. loop {
  304. let message = if let Some(message) = message_cache {
  305. message
  306. } else {
  307. msg_rcv.recv().await.expect("message channel closed")
  308. };
  309. if message.write_all_to(&mut current_socket).await.is_err()
  310. || current_socket.flush().await.is_err()
  311. {
  312. message_cache = Some(message);
  313. log!("terminating,{},{}", recipient, group);
  314. // socket is presumably closed, clean up and notify the listening end to close
  315. // (all best-effort, we can ignore errors because it presumably means it's done)
  316. current_watch.notify_one();
  317. let _ = current_socket.shutdown().await;
  318. // wait for the new socket
  319. (current_socket, current_watch) = socket_updater.recv().await;
  320. } else {
  321. log!("sent,{},{}", recipient, group);
  322. message_cache = None;
  323. }
  324. }
  325. }
  326. fn load_private_key(filename: &str) -> PrivateKey {
  327. let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
  328. let mut reader = BufReader::new(keyfile);
  329. loop {
  330. match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
  331. Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
  332. Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
  333. Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
  334. None => break,
  335. _ => {}
  336. }
  337. }
  338. panic!(
  339. "no keys found in {:?} (encrypted keys not supported)",
  340. filename
  341. );
  342. }