Переглянути джерело

make client-server connections use TLS

Justin Tracey 10 місяців тому
батько
коміт
e263a26eac

+ 2 - 0
Cargo.toml

@@ -10,9 +10,11 @@ glob = "0.3.1"
 rand = "0.8.5"
 rand_distr = { version = "0.4.3", features = ["serde1"] }
 rand_xoshiro = "0.6.0"
+rustls-pemfile = "1.0.3"
 serde = { version = "1.0.158", features = ["derive"] }
 serde_yaml = "0.9.21"
 tokio = { version = "1", features = ["full"] }
+tokio-rustls = { version = "0.24.1", features = ["dangerous_configuration"] }
 tokio-socks = "0.5.1"
 
 [profile.release]

+ 21 - 0
shadow/client/shadow.data.template/hosts/server/server.crt

@@ -0,0 +1,21 @@
+-----BEGIN CERTIFICATE-----
+MIIDazCCAlOgAwIBAgIUZ++oU4ax9bOfOI5c+TfysW8P0UowDQYJKoZIhvcNAQEL
+BQAwRTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
+GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzA3MTUyMjM5NTdaFw0yMzA4
+MTQyMjM5NTdaMEUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
+HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
+AQUAA4IBDwAwggEKAoIBAQDGE1VtF1/YTjm/yxeRcWdy/Tt0zybPUJn7VAn78T0u
+Xckr2s8L71yC513dA9pVcTUGcSqaL7JUQe1tpq1vIdK6EHKWGhz8Uzbn+RSd8gfV
+NM5MH28jb9dnKBfY7y0AcNBnWkSrJaHx4OHbgNbMaYqJmVN9YOBSrqYBXL9prtGt
+e1TpPPGAWHu5K8nb6WCd7/8g3ih6LQc2FAjcMdm0AWIAm1gGmxqu1i4swjyB4ECF
+VE7H1QUA4qL1rNLYz9boLqSEXMjIaF9nSMZ1y43yWPk8+6gYvvmvqIFj/t1ArY0C
+l2dePm+2C/lyv/XmyRNNpaZxbR4q1RRT2PdTU2vgV1SxAgMBAAGjUzBRMB0GA1Ud
+DgQWBBQ9ZhuVtY4lwGK+TuROxVVddW19rTAfBgNVHSMEGDAWgBQ9ZhuVtY4lwGK+
+TuROxVVddW19rTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA4
++D344kBAGa2lirPjXKMk8AbIcGOI12/g47juTJ2njj0IOEBjJKj6Sd+DonBF0ZFf
+ENiHjXXLAoQuhEbtgfUcqfOFqHUcA4rNIb9FwVVkElSNcl173JuBguv9oxO4cLsE
+a/8xMKH3pEBM/jONXSj899X7Psf9XEnOX6SOwIzvcP+9zlCHZ8I17EK1AXJnLRap
+uJy2WZONkUcEtCi7mij3Y7JCkFHYMKM6R2IEJnktfczyC/EQ4pTFJwsLPyqyb1q8
+R7I8Ea5fN95tzuB8Et6ke9Zz/UwmwPVGwhXg3ieEz5rSAYVVwrDJeYEGzADErGpf
+ZS5uF8f3OfmcADWdPRuD
+-----END CERTIFICATE-----

+ 28 - 0
shadow/client/shadow.data.template/hosts/server/server.key

@@ -0,0 +1,28 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDGE1VtF1/YTjm/
+yxeRcWdy/Tt0zybPUJn7VAn78T0uXckr2s8L71yC513dA9pVcTUGcSqaL7JUQe1t
+pq1vIdK6EHKWGhz8Uzbn+RSd8gfVNM5MH28jb9dnKBfY7y0AcNBnWkSrJaHx4OHb
+gNbMaYqJmVN9YOBSrqYBXL9prtGte1TpPPGAWHu5K8nb6WCd7/8g3ih6LQc2FAjc
+Mdm0AWIAm1gGmxqu1i4swjyB4ECFVE7H1QUA4qL1rNLYz9boLqSEXMjIaF9nSMZ1
+y43yWPk8+6gYvvmvqIFj/t1ArY0Cl2dePm+2C/lyv/XmyRNNpaZxbR4q1RRT2PdT
+U2vgV1SxAgMBAAECggEAA7+CqLBuKr2LM/UDvoew85D1Zq/TTg26RjJYSIVPeTDC
+4WKv84u9WkhGw0uC/oYogNVUHywLIbNIKwCiDEXtcwIj6vF2TjOEaNYSpOz7JzaL
+N09KdvcTMkNk1SDsfvNDjEsd3Me25WjyGSlaVy6hlZo6RVd3kzT1FPZEdHtfghr4
+cVv+fm18AxF5LF+wKJ9XnKKt1N8j+yoIJGBjn8CfUC9Za5Vi7tMvtixadg0F1UDC
+lG6vSdeW+uXsxiMBh+c42Qhhk8B321FDVdF3jw7Yn8AHpFLnOcm1UAT5fDMcpQ4a
+TnA9sTepEuXNtKZdY3u2c7FNHXzG79rvYyi34ox5QQKBgQDbkWtVZN0a6bKK+DUT
+0Fcc4RgTQRYBdEZ9aiTPm+scLGSOe5hvuvkabB/1OoVB/vOI6/mmC84B3TTzm2zo
+itMteFXXuCnOpxq+5kpjsMBmKXPorXqXOtpcVXCPziQhG/KYNNJE7qbqpoYApBH6
+WGjCD7SPUewbxI3qhd2Hms/goQKBgQDm8PpH+iXZ0dMl5CvrgpXQ/a6BR3FQk7BS
+uMXYs7yQpfc8ifDVfDz064hEJoKtwOhhfmRybwdnHegb6u/c6goNvY/cKSLAlWyF
+T89TEqB1bv+7bH8T/7iT6jDJ1m3lJQMgBccSqS7cd9+FIG7j78W3ruZL2j+O+AgX
+Lp2yrE8qEQKBgCG+f5hoH/L6542kB8Q7yKePkHulDRS8Ifk0TuP5OnDiAbJEHHFP
+cuk0pNSzYbd6z0LDwWJbfhWbQYAO6vXyH/JlBAxbKVGxLNMZ4WTgzTDmPgIMZ0LG
+sLhwCRSQwcy01tu9gnNFmjGF1iJTFNA8thzc/QrptDewRX89g4ZLrJcBAoGBAI2c
+uSyH1MwDoWF7z/7DfZDA7k/x+ic52QZwrUlbtcZRLxEdWOPgIhThlRaNMtbPEvAt
+q/SL5tMxgJIV923UycNxORT82IWVWw1ISk6bfm9kWEaamjYuOgXhtnceGRdJIehy
+AoeL3ONuUk70+2qkLe6bvjZHJ3BI4dUtTaAxjv2xAoGBALkRxCE9Pw3oaSSGmJtc
+/JnXjE8mwDQxk+wKYFBPy08F8vltTAISZeqKx2X+RZeQ+zMa8+vT7QMMeYFw1NEE
+bzyo8KWv4Vx1vD2WL5dDnY6cAPdcjzTkRletvqBno2E0kyLeexkv1hnPhGarZNsM
+eaGu3kM8xsKbjqJ7D6xFVWz0
+-----END PRIVATE KEY-----

+ 1 - 1
shadow/client/shadow.yaml

@@ -13,7 +13,7 @@ hosts:
     ip_addr: 100.0.0.1
     processes:
     - path: mgen-server
-      args: 100.0.0.1:6397
+      args: [server.crt, server.key, 100.0.0.1:6397]
       start_time: 3s
       expected_final_state: running
 

+ 46 - 11
src/bin/mgen-client.rs

@@ -5,14 +5,13 @@ use mgen::{HandshakeRef, MessageHeader, SerializedMessage};
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
 use std::result::Result;
-use tokio::io::AsyncWriteExt;
-use tokio::net::{
-    tcp::{OwnedReadHalf, OwnedWriteHalf},
-    TcpStream,
-};
+use std::sync::Arc;
+use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
+use tokio::net::TcpStream;
 use tokio::sync::mpsc;
 use tokio::task;
 use tokio::time::Duration;
+use tokio_rustls::{client::TlsStream, TlsConnector};
 
 mod messenger;
 
@@ -30,18 +29,35 @@ type MessageHolder = Box<SerializedMessage>;
 /// Type for getting messages from the state thread in the writer thread.
 type WriterFromState = mpsc::UnboundedReceiver<MessageHolder>;
 /// Type for sending the updated read half of the socket.
-type ReadSocketUpdaterIn = Updater<OwnedReadHalf>;
+type ReadSocketUpdaterIn = Updater<ReadHalf<TlsStream<TcpStream>>>;
 /// Type for getting the updated read half of the socket.
-type ReadSocketUpdaterOut = Updater<OwnedReadHalf>;
+type ReadSocketUpdaterOut = Updater<ReadHalf<TlsStream<TcpStream>>>;
 /// Type for sending the updated write half of the socket.
-type WriteSocketUpdaterIn = Updater<OwnedWriteHalf>;
+type WriteSocketUpdaterIn = Updater<WriteHalf<TlsStream<TcpStream>>>;
 /// Type for getting the updated write half of the socket.
-type WriteSocketUpdaterOut = Updater<OwnedWriteHalf>;
+type WriteSocketUpdaterOut = Updater<WriteHalf<TlsStream<TcpStream>>>;
 /// Type for sending errors to other threads.
 type ErrorChannelIn = mpsc::UnboundedSender<MessengerError>;
 /// Type for getting errors from other threads.
 type ErrorChannelOut = mpsc::UnboundedReceiver<MessengerError>;
 
+// we gain a (very) tiny performance win by not bothering to validate the cert
+pub struct NoCertificateVerification {}
+
+impl tokio_rustls::rustls::client::ServerCertVerifier for NoCertificateVerification {
+    fn verify_server_cert(
+        &self,
+        _end_entity: &tokio_rustls::rustls::Certificate,
+        _intermediates: &[tokio_rustls::rustls::Certificate],
+        _server_name: &tokio_rustls::rustls::ServerName,
+        _scts: &mut dyn Iterator<Item = &[u8]>,
+        _ocsp: &[u8],
+        _now: std::time::SystemTime,
+    ) -> Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
+        Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
+    }
+}
+
 /// The thread responsible for getting incoming messages,
 /// checking for any network errors while doing so,
 /// and giving messages to the state thread.
@@ -100,8 +116,19 @@ async fn socket_updater(
     writer_channel: WriteSocketUpdaterIn,
 ) -> FatalError {
     let retry = Duration::from_secs_f64(retry);
+
+    let tls_config = tokio_rustls::rustls::ClientConfig::builder()
+        .with_safe_defaults()
+        .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {}))
+        .with_no_client_auth();
+    let connector = TlsConnector::from(Arc::new(tls_config));
+
+    // unwrap is safe, split always returns at least one element
+    let tls_server_str = str_params.target.split(':').next().unwrap();
+    let tls_server_name =
+        tokio_rustls::rustls::ServerName::try_from(tls_server_str).expect("invalid server name");
     loop {
-        let mut stream: TcpStream = match connect(&str_params).await {
+        let stream: TcpStream = match connect(&str_params).await {
             Ok(stream) => stream,
             Err(MessengerError::Recoverable(_)) => {
                 tokio::time::sleep(retry).await;
@@ -110,6 +137,14 @@ async fn socket_updater(
             Err(MessengerError::Fatal(e)) => return e,
         };
 
+        let mut stream = match connector.connect(tls_server_name.clone(), stream).await {
+            Ok(stream) => stream,
+            Err(_) => {
+                tokio::time::sleep(retry).await;
+                continue;
+            }
+        };
+
         let handshake = HandshakeRef {
             sender: &str_params.user,
             group: &str_params.recipient,
@@ -119,7 +154,7 @@ async fn socket_updater(
             continue;
         }
 
-        let (rd, wr) = stream.into_split();
+        let (rd, wr) = split(stream);
         reader_channel.send(rd);
         writer_channel.send(wr);
 

+ 73 - 21
src/bin/mgen-server.rs

@@ -1,14 +1,13 @@
 use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
 use std::collections::HashMap;
 use std::error::Error;
+use std::io::BufReader;
 use std::result::Result;
 use std::sync::Arc;
-use tokio::io::AsyncWriteExt;
-use tokio::net::{
-    tcp::{OwnedReadHalf, OwnedWriteHalf},
-    TcpListener,
-};
+use tokio::io::{split, AsyncWriteExt, ReadHalf, WriteHalf};
+use tokio::net::{TcpListener, TcpStream};
 use tokio::sync::{mpsc, Notify, RwLock};
+use tokio_rustls::{rustls::PrivateKey, server::TlsStream};
 
 // FIXME: identifiers should be interned
 type ID = String;
@@ -17,30 +16,63 @@ type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
 
 #[tokio::main]
 async fn main() -> Result<(), Box<dyn Error>> {
-    let args: Vec<String> = std::env::args().collect();
-    let listen_addr = if args.len() > 1 {
-        &args[1]
-    } else {
-        "127.0.0.1:6397"
-    };
-    let listener = TcpListener::bind(listen_addr).await?;
+    let mut args = std::env::args();
+    let _arg0 = args.next().unwrap();
 
+    let cert_filename = args
+        .next()
+        .unwrap_or_else(|| panic!("no cert file provided"));
+    let key_filename = args
+        .next()
+        .unwrap_or_else(|| panic!("no key file provided"));
+
+    let listen_addr = args.next().unwrap_or("127.0.0.1:6397".to_string());
+
+    let certfile = std::fs::File::open(cert_filename).expect("cannot open certificate file");
+    let mut reader = BufReader::new(certfile);
+    let certs: Vec<tokio_rustls::rustls::Certificate> = rustls_pemfile::certs(&mut reader)
+        .unwrap()
+        .iter()
+        .map(|v| tokio_rustls::rustls::Certificate(v.clone()))
+        .collect();
+    let key = load_private_key(&key_filename);
+
+    let config = tokio_rustls::rustls::ServerConfig::builder()
+        .with_safe_default_cipher_suites()
+        .with_safe_default_kx_groups()
+        .with_safe_default_protocol_versions()
+        .unwrap()
+        .with_no_client_auth()
+        .with_single_cert(certs, key)?;
+    let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
+
+    let listener = TcpListener::bind(&listen_addr).await?;
     log!("listening,{}", listen_addr);
 
     // Maps group name to the table of message channels.
     let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
     // Maps the (sender, group) pair to the socket updater.
-    let mut writer_db = HashMap::<Handshake, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
+    let mut writer_db =
+        HashMap::<Handshake, Updater<(WriteHalf<TlsStream<TcpStream>>, Arc<Notify>)>>::new();
 
     loop {
-        let socket = match listener.accept().await {
-            Ok((socket, _)) => socket,
+        let stream = match listener.accept().await {
+            Ok((stream, _)) => stream,
             Err(e) => {
                 log!("failed,accept,{}", e.kind());
                 continue;
             }
         };
-        let (mut rd, wr) = socket.into_split();
+        let acceptor = acceptor.clone();
+        let stream = match acceptor.accept(stream).await {
+            Ok(stream) => stream,
+            Err(e) => {
+                log!("failed,tls,{}", e.kind());
+                continue;
+            }
+        };
+
+        let (mut rd, wr) = split(stream);
 
         let handshake = match mgen::get_handshake(&mut rd).await {
             Ok(handshake) => handshake,
@@ -111,7 +143,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 fn spawn_message_receiver(
     sender: String,
     group: String,
-    rd: OwnedReadHalf,
+    rd: ReadHalf<TlsStream<TcpStream>>,
     db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
     notify: Arc<Notify>,
 ) {
@@ -139,10 +171,10 @@ fn spawn_message_receiver(
 
 /// Loop for receiving messages on the socket, figuring out who to deliver them to,
 /// and forwarding them locally to the respective channel.
-async fn get_messages(
+async fn get_messages<T: tokio::io::AsyncRead>(
     sender: &str,
     group: &str,
-    mut socket: OwnedReadHalf,
+    mut socket: ReadHalf<T>,
     global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
 ) -> Result<(), mgen::Error> {
     // Wait for the next message to be received before populating our local copy of the db,
@@ -218,11 +250,11 @@ async fn get_messages(
 
 /// Loop for receiving messages on the mpsc channel for this recipient,
 /// and sending them out on the associated socket.
-async fn send_messages(
+async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
     recipient: ID,
     group: ID,
     mut msg_rcv: mpsc::UnboundedReceiver<Arc<SerializedMessage>>,
-    mut socket_updater: Updater<(OwnedWriteHalf, Arc<Notify>)>,
+    mut socket_updater: Updater<(WriteHalf<T>, Arc<Notify>)>,
 ) {
     let (mut current_socket, mut current_watch) = socket_updater.recv().await;
     let mut message_cache = None;
@@ -250,3 +282,23 @@ async fn send_messages(
         }
     }
 }
+
+fn load_private_key(filename: &str) -> PrivateKey {
+    let keyfile = std::fs::File::open(filename).expect("cannot open private key file");
+    let mut reader = BufReader::new(keyfile);
+
+    loop {
+        match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
+            Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
+            Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
+            Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
+            None => break,
+            _ => {}
+        }
+    }
+
+    panic!(
+        "no keys found in {:?} (encrypted keys not supported)",
+        filename
+    );
+}