mgen-server.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
  2. use std::collections::HashMap;
  3. use std::error::Error;
  4. use std::io::BufReader;
  5. use std::result::Result;
  6. use std::sync::Arc;
  7. use std::time::Duration;
  8. use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
  9. use tokio::net::{TcpSocket, TcpStream};
  10. use tokio::sync::{mpsc, Notify, RwLock};
  11. use tokio::time::{timeout_at, Instant};
  12. use tokio_rustls::{
  13. rustls::{KeyLogFile, PrivateKey},
  14. server::TlsStream,
  15. TlsAcceptor,
  16. };
  17. // FIXME: identifiers should be interned
  18. type ID = String;
  19. type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
  20. type WriterDb = HashMap<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>;
  21. type SndDb = HashMap<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>;
  22. #[cfg(feature = "tracing")]
  23. async fn tracing(metrics_monitor: tokio_metrics::TaskMonitor) {
  24. console_subscriber::init();
  25. let handle = tokio::runtime::Handle::current();
  26. let runtime_monitor = tokio_metrics::RuntimeMonitor::new(&handle);
  27. for intervals in std::iter::zip(metrics_monitor.intervals(), runtime_monitor.intervals()) {
  28. log!("{:?}", intervals.0);
  29. log!("{:?}", intervals.1);
  30. tokio::time::sleep(Duration::from_secs(5)).await;
  31. }
  32. }
  33. fn main() -> Result<(), Box<dyn std::error::Error>> {
  34. tokio::runtime::Builder::new_multi_thread()
  35. .worker_threads(10)
  36. .enable_all()
  37. .disable_lifo_slot()
  38. .build()
  39. .unwrap()
  40. .block_on(main_worker())
  41. }
  42. async fn main_worker() -> Result<(), Box<dyn Error>> {
  43. #[cfg(feature = "tracing")]
  44. let metrics_monitor = {
  45. let metrics_monitor = tokio_metrics::TaskMonitor::new();
  46. tokio::spawn(tracing(metrics_monitor.clone()));
  47. metrics_monitor
  48. };
  49. let mut args = std::env::args();
  50. let _arg0 = args.next().unwrap();
  51. let cert_filename = args
  52. .next()
  53. .unwrap_or_else(|| panic!("no cert file provided"));
  54. let key_filename = args
  55. .next()
  56. .unwrap_or_else(|| panic!("no key file provided"));
  57. let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
  58. let reg_time = args.next().unwrap_or("30".to_string()).parse()?;
  59. let reg_time = Instant::now() + Duration::from_secs(reg_time);
  60. let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
  61. let mut reader = BufReader::new(certfile);
  62. let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
  63. .unwrap()
  64. .iter()
  65. .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
  66. .collect();
  67. let key = load_private_key(&key_filename);
  68. let key_log = Arc::new(KeyLogFile::new());
  69. let mut config = tokio_rustls::rustls::ServerConfig::builder()
  70. .with_safe_default_cipher_suites()
  71. .with_safe_default_kx_groups()
  72. .with_safe_default_protocol_versions()
  73. .unwrap()
  74. .with_no_client_auth()
  75. .with_single_cert(certs, key)?;
  76. config.key_log = key_log;
  77. let acceptor = TlsAcceptor::from(Arc::new(config));
  78. let addr = listen_addr.parse().unwrap();
  79. let socket = TcpSocket::new_v4()?;
  80. socket.set_nodelay(true)?;
  81. socket.bind(addr)?;
  82. let listener = socket.listen(4096)?;
  83. log!("listening,{}", listen_addr);
  84. // Maps the (sender, group) pair to the socket updater.
  85. let writer_db = Arc::new(RwLock::new(WriterDb::new()));
  86. // Maps group name to the table of message channels.
  87. let snd_db = Arc::new(RwLock::new(SndDb::new()));
  88. // Notifies listener threads when registration phase is over.
  89. let phase_notify = Arc::new(Notify::new());
  90. // Allow registering or reconnecting during the registration time.
  91. while let Ok(accepted) = timeout_at(reg_time, listener.accept()).await {
  92. let stream = match accepted {
  93. Ok((stream, _)) => stream,
  94. Err(e) => {
  95. log!("failed,accept,{}", e.kind());
  96. continue;
  97. }
  98. };
  99. let acceptor = acceptor.clone();
  100. let writer_db = writer_db.clone();
  101. let snd_db = snd_db.clone();
  102. let phase_notify = phase_notify.clone();
  103. #[cfg(feature = "tracing")]
  104. tokio::spawn(metrics_monitor.instrument(async move {
  105. handle_handshake::</*REGISTRATION_PHASE=*/ true>(
  106. stream,
  107. acceptor,
  108. writer_db,
  109. snd_db,
  110. phase_notify,
  111. )
  112. .await
  113. }));
  114. #[cfg(not(feature = "tracing"))]
  115. tokio::spawn(async move {
  116. handle_handshake::</*REGISTRATION_PHASE=*/ true>(
  117. stream,
  118. acceptor,
  119. writer_db,
  120. snd_db,
  121. phase_notify,
  122. )
  123. .await
  124. });
  125. }
  126. log!("registration phase complete");
  127. // Notify all the listener threads that registration is over.
  128. phase_notify.notify_waiters();
  129. // Now registration phase is over, only allow reconnecting.
  130. loop {
  131. let stream = match listener.accept().await {
  132. Ok((stream, _)) => stream,
  133. Err(e) => {
  134. log!("failed,accept,{}", e.kind());
  135. continue;
  136. }
  137. };
  138. let acceptor = acceptor.clone();
  139. let writer_db = writer_db.clone();
  140. let snd_db = snd_db.clone();
  141. let phase_notify = phase_notify.clone();
  142. tokio::spawn(async move {
  143. handle_handshake::</*REGISTRATION_PHASE=*/ false>(
  144. stream,
  145. acceptor,
  146. writer_db,
  147. snd_db,
  148. phase_notify,
  149. )
  150. .await
  151. });
  152. }
  153. }
  154. /*
  155. An informal proof that the main thread + handshake threads will not deadlock.
  156. (The rest of the code mainly uses channels so is a lot simpler.)
  157. locks:
  158. - writer_db (WDB)
  159. - snd_db (SDB)
  160. - group_snds (GSS)
  161. == CFG ==
  162. WDB.R() |-> Some(socket_updater) -> SDB.R(); drop(WDB.R, SDB.R)
  163. |-> None -> drop(WDB.R); SDB.R() |-> Some(group_snds) -> drop(SDB.R); GSS.W(); drop(GSS.W)
  164. |-> None -> drop(SDB.R); SDB.W(); GSS.W(); drop(GSS.W, SDB.W)
  165. =========
  166. The program deadlocks iff lock A can't drop until it gets lock B, while lock B can't drop until it
  167. gets lock A, or a transitive equivalent.
  168. We have three potential locks that can deadlock: WDB, SDB, and GSS.
  169. Can WDB ever deadlock?
  170. It only ever locks in one place: at the start, when the thread holds no other locks.
  171. None case: Drops immediately, never takes any other locks, no opportunity to deadlock.
  172. Some case: Get SDB.R. Can locked SDB ever be waiting for WDB? No, SDB only
  173. locks either after it already has the WDB.R (in another copy of this branch), or the WDB isn't
  174. locked (in the other branch).
  175. This covers all branches, therefore, WDB can never deadlock.
  176. Can GSS ever deadlock?
  177. GSS locks in three places (one of which is not shown in the CFG, it's in get_messages() as
  178. global_db, and is extra irrelevant because it doesn't even read lock until all write lock threads
  179. have terminated). In all three places, it immediately drops the lock without doing any other locking
  180. operations. Therefore, GSS can never deadlock.
  181. Can SDB ever deadlock?
  182. SDB locks in three places: a read lock in the top None (1), a write lock in the bottom None (2), and
  183. a read lock in the top Some (3).
  184. The read lock in (1) drops before doing any locking operations in either option of the next branch,
  185. and therefore has no chance to deadlock.
  186. The read lock in (3) also does no locking operations before dropping, so has no chance to deadlock.
  187. The write lock in (2) can't deadlock with the GSS write lock, since we already proved GSS never
  188. deadlocks. The only remaining operation is then dropping (2).
  189. Therefore, SDB can never deadlock.
  190. */
  191. async fn handle_handshake<const REGISTRATION_PHASE: bool>(
  192. stream: TcpStream,
  193. acceptor: TlsAcceptor,
  194. writer_db: Arc<RwLock<WriterDb>>,
  195. snd_db: Arc<RwLock<SndDb>>,
  196. phase_notify: Arc<Notify>,
  197. ) {
  198. log!("accepted {}", stream.peer_addr().unwrap());
  199. let stream = match acceptor.accept(stream).await {
  200. Ok(stream) => stream,
  201. Err(e) => {
  202. log!("failed,tls,{}", e.kind());
  203. return;
  204. }
  205. };
  206. let (mut rd, wr) = split(stream);
  207. let handshake = match mgen::get_handshake(&mut rd).await {
  208. Ok(handshake) => handshake,
  209. Err(mgen::Error::Io(e)) => {
  210. log!("failed,handshake,{}", e.kind());
  211. return;
  212. }
  213. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  214. Err(mgen::Error::MalformedSerialization(_, _)) => panic!(),
  215. };
  216. log!("accept,{},{}", handshake.sender, handshake.group);
  217. let read_writer_db = writer_db.read().await;
  218. if let Some(socket_updater) = read_writer_db.get(&handshake) {
  219. // we've seen this client before
  220. // start the new reader thread with a new notify
  221. // (we can't use the existing notify channel, else we get race conditions where
  222. // the reader thread terminates and spawns again before the sender thread
  223. // notices and activates its existing notify channel)
  224. let socket_notify = Arc::new(Notify::new());
  225. let db = snd_db.read().await[&handshake.group].clone();
  226. spawn_message_receiver(
  227. handshake.sender,
  228. handshake.group,
  229. rd,
  230. db,
  231. phase_notify,
  232. socket_notify.clone(),
  233. );
  234. // give the writer thread the new write half of the socket and notify
  235. socket_updater.send((wr, socket_notify));
  236. } else {
  237. drop(read_writer_db);
  238. // newly-registered client
  239. log!("register,{},{}", handshake.sender, handshake.group);
  240. if REGISTRATION_PHASE {
  241. // message channel, for sending messages between threads
  242. let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
  243. let group_snds = {
  244. let read_snd_db = snd_db.read().await;
  245. let group_snds = read_snd_db.get(&handshake.group);
  246. if let Some(group_snds) = group_snds {
  247. let group_snds = group_snds.clone();
  248. drop(read_snd_db);
  249. group_snds
  250. .write()
  251. .await
  252. .insert(handshake.sender.clone(), msg_snd);
  253. group_snds
  254. } else {
  255. drop(read_snd_db);
  256. let mut write_snd_db = snd_db.write().await;
  257. let group_snds = write_snd_db
  258. .entry(handshake.group.clone())
  259. .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
  260. group_snds
  261. .write()
  262. .await
  263. .insert(handshake.sender.clone(), msg_snd);
  264. group_snds.clone()
  265. }
  266. };
  267. // socket notify, for terminating the socket if the sender encounters an error
  268. let socket_notify = Arc::new(Notify::new());
  269. // socket updater, for giving the sender thread a new socket + notify channel
  270. let socket_updater_snd = Updater::new();
  271. let socket_updater_rcv = socket_updater_snd.clone();
  272. socket_updater_snd.send((wr, socket_notify.clone()));
  273. spawn_message_receiver(
  274. handshake.sender.clone(),
  275. handshake.group.clone(),
  276. rd,
  277. group_snds,
  278. phase_notify,
  279. socket_notify,
  280. );
  281. let sender = handshake.sender.clone();
  282. let group = handshake.group.clone();
  283. tokio::spawn(async move {
  284. send_messages(sender, group, msg_rcv, socket_updater_rcv).await;
  285. });
  286. writer_db
  287. .write()
  288. .await
  289. .insert(handshake, socket_updater_snd);
  290. } else {
  291. panic!(
  292. "late registration: {},{}",
  293. handshake.sender, handshake.group
  294. );
  295. };
  296. }
  297. }
  298. fn spawn_message_receiver(
  299. sender: String,
  300. group: String,
  301. rd: ReadHalf<TlsStream<TcpStream>>,
  302. db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  303. phase_notify: Arc<Notify>,
  304. socket_notify: Arc<Notify>,
  305. ) {
  306. tokio::spawn(async move {
  307. tokio::select! {
  308. // n.b.: get_message is not cancellation safe,
  309. // but this is one of the cases where that's expected
  310. // (we only cancel when something is wrong with the stream anyway)
  311. ret = get_messages(&sender, &group, rd, phase_notify, db) => {
  312. match ret {
  313. Err(mgen::Error::Io(e)) => log!("failed,receive,{}", e.kind()),
  314. Err(mgen::Error::Utf8Error(e)) => panic!("{:?}", e),
  315. Err(mgen::Error::MalformedSerialization(v, b)) => panic!(
  316. "Malformed Serialization: {:?}\n{:?})", v, b),
  317. Ok(()) => panic!("Message receiver returned OK"),
  318. }
  319. }
  320. _ = socket_notify.notified() => {
  321. log!("terminated,{},{}", sender, group);
  322. // should cause get_messages to terminate, dropping the socket
  323. }
  324. }
  325. });
  326. }
  327. /// Loop for receiving messages on the socket, figuring out who to deliver them to,
  328. /// and forwarding them locally to the respective channel.
  329. async fn get_messages<T: tokio::io::AsyncRead>(
  330. sender: &str,
  331. group: &str,
  332. mut socket: ReadHalf<T>,
  333. phase_notify: Arc<Notify>,
  334. global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
  335. ) -> Result<(), mgen::Error> {
  336. // Wait for the registration phase to end before updating our local copy of the DB
  337. phase_notify.notified().await;
  338. let db = global_db.read().await.clone();
  339. let message_channels: Vec<_> = db
  340. .iter()
  341. .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
  342. .collect();
  343. loop {
  344. let buf = mgen::get_message_bytes::<false, _>(&mut socket).await?;
  345. let message = MessageHeaderRef::deserialize(&buf[4..])?;
  346. assert!(message.sender == sender);
  347. match message.body {
  348. MessageBody::Size(_) => {
  349. assert!(message.group == group);
  350. log!("received,{},{},{}", sender, group, message.id);
  351. let body = message.body;
  352. let m = Arc::new(SerializedMessage { header: buf, body });
  353. for recipient in message_channels.iter() {
  354. recipient.send(m.clone()).unwrap();
  355. }
  356. }
  357. MessageBody::Receipt => {
  358. log!(
  359. "receipt,{},{},{},{}",
  360. sender,
  361. group,
  362. message.group,
  363. message.id
  364. );
  365. let recipient = &db[message.group];
  366. let body = message.body;
  367. let m = Arc::new(SerializedMessage { header: buf, body });
  368. recipient.send(m).unwrap();
  369. }
  370. }
  371. }
  372. }
  373. /// Loop for receiving messages on the mpsc channel for this recipient,
  374. /// and sending them out on the associated socket.
  375. async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
  376. recipient: ID,
  377. group: ID,
  378. mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
  379. mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
  380. ) {
  381. let (mut current_socket, mut current_watch) = socket_updater.recv().await;
  382. let mut message_cache = None;
  383. loop {
  384. let message = if let Some(message) = message_cache {
  385. message
  386. } else {
  387. msg_rcv.recv().await.expect("message channel closed")
  388. };
  389. if message
  390. .write_all_to::<false, _>(&mut current_socket)
  391. .await
  392. .is_err()
  393. || current_socket.flush().await.is_err()
  394. {
  395. message_cache = Some(message);
  396. log!("terminating,{},{}", recipient, group);
  397. // socket is presumably closed, clean up and notify the listening end to close
  398. // (all best-effort, we can ignore errors because it presumably means it's done)
  399. current_watch.notify_one();
  400. let _ = current_socket.shutdown().await;
  401. // wait for the new socket
  402. (current_socket, current_watch) = socket_updater.recv().await;
  403. } else {
  404. log!("sent,{},{}", recipient, group);
  405. message_cache = None;
  406. }
  407. }
  408. }
  409. fn load_private_key(filename: &str) -> PrivateKey {
  410. let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
  411. let mut reader = BufReader::new(keyfile);
  412. loop {
  413. match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
  414. Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
  415. Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
  416. Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
  417. None => break,
  418. _ => {}
  419. }
  420. }
  421. panic!(
  422. "no keys found in {:?} (encrypted keys not supported)",
  423. filename
  424. );
  425. }