|
@@ -1,14 +1,13 @@
|
|
|
use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
|
|
|
use std::collections::HashMap;
|
|
|
use std::error::Error;
|
|
|
+use std::io::BufReader;
|
|
|
use std::result::Result;
|
|
|
use std::sync::Arc;
|
|
|
-use tokio::io::AsyncWriteExt;
|
|
|
-use tokio::net::{
|
|
|
- tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
|
- TcpListener,
|
|
|
-};
|
|
|
+use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
|
|
|
+use tokio::net::{TcpListener, TcpStream};
|
|
|
use tokio::sync::{mpsc, Notify, RwLock};
|
|
|
+use tokio_rustls::{rustls::PrivateKey, server::TlsStream};
|
|
|
|
|
|
// FIXME: identifiers should be interned
|
|
|
type ID = String;
|
|
@@ -17,30 +16,63 @@ type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
|
|
|
|
|
|
#[tokio::main]
|
|
|
async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
- let args: Vec<String> = std::env::args().collect();
|
|
|
- let listen_addr = if args.len() > 1 {
|
|
|
- &args[1]
|
|
|
- } else {
|
|
|
- "127.0.0.1:6397"
|
|
|
- };
|
|
|
- let listener = TcpListener::bind(listen_addr).await?;
|
|
|
+ let mut args = std::env::args();
|
|
|
+ let _arg0 = args.next().unwrap();
|
|
|
|
|
|
+ let cert_filename = args
|
|
|
+ .next()
|
|
|
+ .unwrap_or_else(|| panic!("no cert file provided"));
|
|
|
+ let key_filename = args
|
|
|
+ .next()
|
|
|
+ .unwrap_or_else(|| panic!("no key file provided"));
|
|
|
+
|
|
|
+ let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
|
|
|
+
|
|
|
+ let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
|
|
|
+ let mut reader = BufReader::new(certfile);
|
|
|
+ let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
|
|
|
+ .unwrap()
|
|
|
+ .iter()
|
|
|
+ .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
|
|
|
+ .collect();
|
|
|
+ let key = load_private_key(&key_filename);
|
|
|
+
|
|
|
+ let config = tokio_rustls::rustls::ServerConfig::builder()
|
|
|
+ .with_safe_default_cipher_suites()
|
|
|
+ .with_safe_default_kx_groups()
|
|
|
+ .with_safe_default_protocol_versions()
|
|
|
+ .unwrap()
|
|
|
+ .with_no_client_auth()
|
|
|
+ .with_single_cert(certs, key)?;
|
|
|
+ let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
|
|
|
+
|
|
|
+ let listener = TcpListener::bind(&listen_addr).await?;
|
|
|
log!("listening,{}", listen_addr);
|
|
|
|
|
|
// Maps group name to the table of message channels.
|
|
|
let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
|
|
|
// Maps the (sender, group) pair to the socket updater.
|
|
|
- let mut writer_db = HashMap::<Handshake, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
|
|
|
+ let mut writer_db =
|
|
|
+ HashMap::<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>::new();
|
|
|
|
|
|
loop {
|
|
|
- let socket = match listener.accept().await {
|
|
|
- Ok((socket, _)) => socket,
|
|
|
+ let stream = match listener.accept().await {
|
|
|
+ Ok((stream, _)) => stream,
|
|
|
Err(e) => {
|
|
|
log!("failed,accept,{}", e.kind());
|
|
|
continue;
|
|
|
}
|
|
|
};
|
|
|
- let (mut rd, wr) = socket.into_split();
|
|
|
+ let acceptor = acceptor.clone();
|
|
|
+ let stream = match acceptor.accept(stream).await {
|
|
|
+ Ok(stream) => stream,
|
|
|
+ Err(e) => {
|
|
|
+ log!("failed,tls,{}", e.kind());
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ let (mut rd, wr) = split(stream);
|
|
|
|
|
|
let handshake = match mgen::get_handshake(&mut rd).await {
|
|
|
Ok(handshake) => handshake,
|
|
@@ -111,7 +143,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
fn spawn_message_receiver(
|
|
|
sender: String,
|
|
|
group: String,
|
|
|
- rd: OwnedReadHalf,
|
|
|
+ rd: ReadHalf<TlsStream<TcpStream>>,
|
|
|
db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
|
|
|
notify: Arc<Notify>,
|
|
|
) {
|
|
@@ -139,10 +171,10 @@ fn spawn_message_receiver(
|
|
|
|
|
|
/// Loop for receiving messages on the socket, figuring out who to deliver them to,
|
|
|
/// and forwarding them locally to the respective channel.
|
|
|
-async fn get_messages(
|
|
|
+async fn get_messages<T: tokio::io::AsyncRead>(
|
|
|
sender: &str,
|
|
|
group: &str,
|
|
|
- mut socket: OwnedReadHalf,
|
|
|
+ mut socket: ReadHalf<T>,
|
|
|
global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
|
|
|
) -> Result<(), mgen::Error> {
|
|
|
// Wait for the next message to be received before populating our local copy of the db,
|
|
@@ -218,11 +250,11 @@ async fn get_messages(
|
|
|
|
|
|
/// Loop for receiving messages on the mpsc channel for this recipient,
|
|
|
/// and sending them out on the associated socket.
|
|
|
-async fn send_messages(
|
|
|
+async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
|
|
|
recipient: ID,
|
|
|
group: ID,
|
|
|
mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
|
|
|
- mut socket_updater: Updater<(OwnedWriteHalf, Arc<Notify>)>,
|
|
|
+ mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
|
|
|
) {
|
|
|
let (mut current_socket, mut current_watch) = socket_updater.recv().await;
|
|
|
let mut message_cache = None;
|
|
@@ -250,3 +282,23 @@ async fn send_messages(
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+fn load_private_key(filename: &str) -> PrivateKey {
|
|
|
+ let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
|
|
|
+ let mut reader = BufReader::new(keyfile);
|
|
|
+
|
|
|
+ loop {
|
|
|
+ match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
|
|
|
+ Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
|
|
|
+ Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
|
|
|
+ Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
|
|
|
+ None => break,
|
|
|
+ _ => {}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ panic!(
|
|
|
+ "no keys found in {:?} (encrypted keys not supported)",
|
|
|
+ filename
|
|
|
+ );
|
|
|
+}
|