Просмотр исходного кода

make client and server more robust against network failures

Justin Tracey 1 год назад
Родитель
Сommit
78e76feb2c
4 измененных файлов с 312 добавлено и 165 удалено
  1. 9 0
      README.md
  2. 290 152
      src/bin/client.rs
  3. 9 1
      src/bin/server.rs
  4. 4 12
      src/lib.rs

+ 9 - 0
README.md

@@ -44,6 +44,15 @@ socks = "127.0.0.1:9050"
 # The <address>:<port> of the message server, where <address> is an IP or onion address.
 server = "insert.ip.or.onion:6397"
 
+# The number of seconds to wait until the client starts sending messages.
+# This should be long enough that all clients have had time to start (sending
+# messages to a client that isn't registered on the server is a fatal error),
+# but short enough all conversations will have started by the experiment start.
+bootstrap = 5.0
+
+# The number of seconds to wait after a network failure before retrying.
+retry = 5.0
+
 
 # Parameters for distributions used by the Markov model.
 [distributions]

+ 290 - 152
src/bin/client.rs

@@ -16,16 +16,34 @@ use tokio::time::{Duration, Instant};
 
 #[derive(Debug)]
 enum ClientError {
-    // errors from the library
-    Mgen(mgen::Error),
-    // errors from parsing the conversation files
-    Parameter(DistParameterError),
-    // errors from the socks connection
+    Recoverable(RecoverableError),
+    Fatal(FatalError),
+}
+
+/// Errors where it is possible reconnecting could resolve the problem.
+#[derive(Debug)]
+enum RecoverableError {
+    /// Recoverable errors from the socks connection.
     Socks(tokio_socks::Error),
-    // general I/O errors in this file
+    /// Network I/O errors.
+    // Note that all I/O handled by ClientError should be recoverable;
+    // if you need fatal I/O errors, use a different error type.
     Io(std::io::Error),
 }
 
+/// Errors where something is wrong enough we should terminate.
+#[derive(Debug)]
+enum FatalError {
+    /// Fatal errors from the socks connection.
+    Socks(tokio_socks::Error),
+    /// Errors from parsing the conversation files.
+    Parameter(DistParameterError),
+    /// Error while trying to interpret bytes as a String.
+    Utf8Error(std::str::Utf8Error),
+    /// A message failed to deserialize.
+    MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
+}
+
 impl std::fmt::Display for ClientError {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         write!(f, "{:?}", self)
@@ -36,25 +54,38 @@ impl std::error::Error for ClientError {}
 
 impl From<mgen::Error> for ClientError {
     fn from(e: mgen::Error) -> Self {
-        Self::Mgen(e)
+        match e {
+            mgen::Error::Io(e) => Self::Recoverable(RecoverableError::Io(e)),
+            mgen::Error::Utf8Error(e) => Self::Fatal(FatalError::Utf8Error(e)),
+            mgen::Error::MalformedSerialization(v, b) => {
+                Self::Fatal(FatalError::MalformedSerialization(v, b))
+            }
+        }
     }
 }
 
 impl From<DistParameterError> for ClientError {
     fn from(e: DistParameterError) -> Self {
-        Self::Parameter(e)
+        Self::Fatal(FatalError::Parameter(e))
     }
 }
 
 impl From<tokio_socks::Error> for ClientError {
     fn from(e: tokio_socks::Error) -> Self {
-        Self::Socks(e)
+        match e {
+            tokio_socks::Error::Io(_)
+            | tokio_socks::Error::ProxyServerUnreachable
+            | tokio_socks::Error::GeneralSocksServerFailure
+            | tokio_socks::Error::HostUnreachable
+            | tokio_socks::Error::TtlExpired => Self::Recoverable(RecoverableError::Socks(e)),
+            _ => Self::Fatal(FatalError::Socks(e)),
+        }
     }
 }
 
 impl From<std::io::Error> for ClientError {
     fn from(e: std::io::Error) -> Self {
-        Self::Io(e)
+        Self::Recoverable(RecoverableError::Io(e))
     }
 }
 
@@ -133,45 +164,42 @@ struct Distributions {
     r: Bernoulli,
 }
 
-trait State {}
+trait State {
+    fn sent(conversation: Conversation<Self>, rng: &mut Xoshiro256PlusPlus) -> StateMachine
+    where
+        Self: Sized;
+    fn received(conversation: Conversation<Self>, rng: &mut Xoshiro256PlusPlus) -> StateMachine
+    where
+        Self: Sized;
+    fn to_machine(conversation: Conversation<Self>) -> StateMachine
+    where
+        Self: Sized;
+}
 
 struct Idle {}
 struct Active {
     wait: Instant,
 }
 
-impl State for Idle {}
-impl State for Active {}
-
-impl Conversation<Idle> {
-    fn start(dists: Distributions, rng: &mut Xoshiro256PlusPlus) -> Self {
-        let delay = Instant::now() + dists.i.sample_secs(rng);
-        log!("[start]");
-        Self {
-            dists,
-            delay,
-            state: Idle {},
-        }
-    }
-
-    fn sent(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
-        if self.dists.s.sample(rng) {
+impl State for Idle {
+    fn sent(conversation: Conversation<Self>, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
+        if conversation.dists.s.sample(rng) {
             log!("Idle: [sent] tranisition to [Active]");
-            let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
-            let wait = Instant::now() + self.dists.w.sample_secs(rng);
+            let delay = Instant::now() + conversation.dists.a_s.sample_secs(rng);
+            let wait = Instant::now() + conversation.dists.w.sample_secs(rng);
             StateMachine::Active({
                 Conversation::<Active> {
-                    dists: self.dists,
+                    dists: conversation.dists,
                     delay,
                     state: Active { wait },
                 }
             })
         } else {
             log!("Idle: [sent] tranisition to [Idle]");
-            let delay = Instant::now() + self.dists.i.sample_secs(rng);
+            let delay = Instant::now() + conversation.dists.i.sample_secs(rng);
             StateMachine::Idle({
                 Conversation::<Idle> {
-                    dists: self.dists,
+                    dists: conversation.dists,
                     delay,
                     state: Idle {},
                 }
@@ -179,46 +207,68 @@ impl Conversation<Idle> {
         }
     }
 
-    fn received(self, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
-        if self.dists.r.sample(rng) {
+    fn received(conversation: Conversation<Self>, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
+        if conversation.dists.r.sample(rng) {
             log!("Idle: [recv'd] tranisition to [Active]");
-            let wait = Instant::now() + self.dists.w.sample_secs(rng);
-            let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
+            let wait = Instant::now() + conversation.dists.w.sample_secs(rng);
+            let delay = Instant::now() + conversation.dists.a_r.sample_secs(rng);
             StateMachine::Active({
                 Conversation::<Active> {
-                    dists: self.dists,
+                    dists: conversation.dists,
                     delay,
                     state: Active { wait },
                 }
             })
         } else {
             log!("Idle: [recv'd] tranisition to [Idle]");
-            StateMachine::Idle(self)
+            StateMachine::Idle(conversation)
         }
     }
+
+    fn to_machine(conversation: Conversation<Self>) -> StateMachine {
+        StateMachine::Idle(conversation)
+    }
 }
 
-impl Conversation<Active> {
-    fn sent(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
+impl State for Active {
+    fn sent(conversation: Conversation<Self>, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
         log!("Active: [sent] transition to [Active]");
-        let delay = Instant::now() + self.dists.a_s.sample_secs(rng);
-        Conversation::<Active> {
-            dists: self.dists,
+        let delay = Instant::now() + conversation.dists.a_s.sample_secs(rng);
+        StateMachine::Active(Conversation::<Active> {
+            dists: conversation.dists,
             delay,
-            state: self.state,
-        }
+            state: conversation.state,
+        })
     }
 
-    fn received(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Active> {
+    fn received(conversation: Conversation<Self>, rng: &mut Xoshiro256PlusPlus) -> StateMachine {
         log!("Active: [recv'd] transition to [Active]");
-        let delay = Instant::now() + self.dists.a_r.sample_secs(rng);
-        Conversation::<Active> {
-            dists: self.dists,
+        let delay = Instant::now() + conversation.dists.a_r.sample_secs(rng);
+        StateMachine::Active(Conversation::<Active> {
+            dists: conversation.dists,
+            delay,
+            state: conversation.state,
+        })
+    }
+
+    fn to_machine(conversation: Conversation<Self>) -> StateMachine {
+        StateMachine::Active(conversation)
+    }
+}
+
+impl Conversation<Idle> {
+    fn start(dists: Distributions, rng: &mut Xoshiro256PlusPlus) -> Self {
+        let delay = Instant::now() + dists.i.sample_secs(rng);
+        log!("[start]");
+        Self {
+            dists,
             delay,
-            state: self.state,
+            state: Idle {},
         }
     }
+}
 
+impl Conversation<Active> {
     fn waited(self, rng: &mut Xoshiro256PlusPlus) -> Conversation<Idle> {
         log!("Active: [waited] tranision to [Idle]");
         let delay = Instant::now() + self.dists.i.sample_secs(rng);
@@ -264,6 +314,77 @@ async fn read_header_size(
     }
 }
 
+async fn send_action<T: State>(
+    conversation: Conversation<T>,
+    stream: &mut TcpStream,
+    our_id: &str,
+    recipients: Vec<&str>,
+    rng: &mut Xoshiro256PlusPlus,
+) -> Result<StateMachine, (StateMachine, ClientError)> {
+    let size = conversation.dists.m.sample(rng);
+    log!(
+        "sending message from {} to {:?} of size {}",
+        our_id,
+        recipients,
+        size
+    );
+    let m = construct_message(
+        our_id.to_string(),
+        recipients.iter().map(|s| s.to_string()).collect(),
+        size,
+    );
+
+    if let Err(e) = m.write_all_to(stream).await {
+        return Err((T::to_machine(conversation), e.into()));
+    }
+    if let Err(e) = stream.flush().await {
+        return Err((T::to_machine(conversation), e.into()));
+    }
+
+    Ok(T::sent(conversation, rng))
+}
+
+async fn receive_action<T: State>(
+    n: usize,
+    mut header_size: [u8; 4],
+    conversation: Conversation<T>,
+    stream: &mut TcpStream,
+    our_id: &str,
+    rng: &mut Xoshiro256PlusPlus,
+) -> Result<StateMachine, (StateMachine, ClientError)> {
+    if n < 4 {
+        // we didn't get the whole size, but we can use read_exact now
+        if let Err(e) = stream.read_exact(&mut header_size[n..]).await {
+            return Err((T::to_machine(conversation), e.into()));
+        }
+    }
+
+    let msg = match mgen::get_message_with_header_size(stream, header_size).await {
+        Ok((msg, _)) => msg,
+        Err(e) => return Err((T::to_machine(conversation), e.into())),
+    };
+
+    match msg.body {
+        mgen::MessageBody::Size(size) => {
+            log!(
+                "{:?} got message from {} of size {}",
+                msg.recipients,
+                msg.sender,
+                size
+            );
+            let m = construct_receipt(our_id.to_string(), msg.sender);
+            if let Err(e) = m.write_all_to(stream).await {
+                return Err((T::to_machine(conversation), e.into()));
+            }
+            if let Err(e) = stream.flush().await {
+                return Err((T::to_machine(conversation), e.into()));
+            }
+            Ok(T::received(conversation, rng))
+        }
+        mgen::MessageBody::Receipt => Ok(T::to_machine(conversation)),
+    }
+}
+
 enum IdleActions {
     Send,
     Receive(usize),
@@ -275,7 +396,7 @@ async fn manage_idle_conversation(
     our_id: &str,
     recipients: Vec<&str>,
     rng: &mut Xoshiro256PlusPlus,
-) -> Result<StateMachine, ClientError> {
+) -> Result<StateMachine, (StateMachine, ClientError)> {
     log!("delaying for {:?}", conversation.delay - Instant::now());
     let mut header_size = [0; 4];
     let action = tokio::select! {
@@ -289,41 +410,16 @@ async fn manage_idle_conversation(
                 Err(e) => Err(e),
             }
         }
-    }?;
+    };
+    let action = match action {
+        Ok(action) => action,
+        Err(e) => return Err((StateMachine::Idle(conversation), e)),
+    };
 
     match action {
-        IdleActions::Send => {
-            let size = conversation.dists.m.sample(rng);
-            log!(
-                "sending message from {} to {:?} of size {}",
-                our_id,
-                recipients,
-                size
-            );
-            let m = construct_message(
-                our_id.to_string(),
-                recipients.iter().map(|s| s.to_string()).collect(),
-                size,
-            );
-            m.write_all_to(stream).await?;
-            stream.flush().await?;
-            Ok(conversation.sent(rng))
-        }
+        IdleActions::Send => send_action(conversation, stream, our_id, recipients, rng).await,
         IdleActions::Receive(n) => {
-            if n < 4 {
-                // we didn't get the whole size, but we can use read_exact now
-                stream.read_exact(&mut header_size[n..]).await?;
-            }
-            let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
-            if msg.body != mgen::MessageBody::Receipt {
-                log!("{:?} got message from {}", msg.recipients, msg.sender);
-                let m = construct_receipt(our_id.to_string(), msg.sender);
-                m.write_all_to(stream).await?;
-                stream.flush().await?;
-                Ok(conversation.received(rng))
-            } else {
-                Ok(StateMachine::Idle(conversation))
-            }
+            receive_action(n, header_size, conversation, stream, our_id, rng).await
         }
     }
 }
@@ -340,7 +436,7 @@ async fn manage_active_conversation(
     our_id: &str,
     recipients: Vec<&str>,
     rng: &mut Xoshiro256PlusPlus,
-) -> Result<StateMachine, ClientError> {
+) -> Result<StateMachine, (StateMachine, ClientError)> {
     let mut header_size = [0; 4];
     let action = tokio::select! {
         action = Conversation::<Active>::sleep(conversation.delay, conversation.state.wait) => {
@@ -353,47 +449,16 @@ async fn manage_active_conversation(
                 Err(e) => Err(e),
             }
         }
-    }?;
+    };
+    let action = match action {
+        Ok(action) => action,
+        Err(e) => return Err((StateMachine::Active(conversation), e)),
+    };
 
     match action {
-        ActiveActions::Send => {
-            let size = conversation.dists.m.sample(rng);
-            log!(
-                "sending message from {} to {:?} of size {}",
-                our_id,
-                recipients,
-                size
-            );
-            let m = construct_message(
-                our_id.to_string(),
-                recipients.iter().map(|s| s.to_string()).collect(),
-                size,
-            );
-            m.write_all_to(stream).await?;
-            stream.flush().await?;
-            Ok(StateMachine::Active(conversation.sent(rng)))
-        }
+        ActiveActions::Send => send_action(conversation, stream, our_id, recipients, rng).await,
         ActiveActions::Receive(n) => {
-            if n < 4 {
-                // we didn't get the whole size, but we can use read_exact now
-                stream.read_exact(&mut header_size[n..]).await?;
-            }
-            let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
-            match msg.body {
-                mgen::MessageBody::Size(size) => {
-                    log!(
-                        "{:?} got message from {} of size {}",
-                        msg.recipients,
-                        msg.sender,
-                        size
-                    );
-                    let m = construct_receipt(our_id.to_string(), msg.sender);
-                    m.write_all_to(stream).await?;
-                    stream.flush().await?;
-                    Ok(StateMachine::Active(conversation.received(rng)))
-                }
-                mgen::MessageBody::Receipt => Ok(StateMachine::Active(conversation)),
-            }
+            receive_action(n, header_size, conversation, stream, our_id, rng).await
         }
         ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
     }
@@ -402,44 +467,115 @@ async fn manage_active_conversation(
 async fn manage_conversation(config: Config) -> Result<(), ClientError> {
     let mut rng = Xoshiro256PlusPlus::from_entropy();
     let distributions: Distributions = config.distributions.try_into()?;
+
+    struct StrParams<'a> {
+        socks: &'a str,
+        server: &'a str,
+        sender: &'a str,
+        group: &'a str,
+    }
+    let str_params = StrParams {
+        socks: &config.socks,
+        server: &config.server,
+        sender: &config.sender,
+        group: &config.group,
+    };
+
     let mut state_machine =
         StateMachine::Idle(Conversation::<Idle>::start(distributions, &mut rng));
     let recipients: Vec<&str> = config.recipients.iter().map(String::as_str).collect();
 
-    let mut stream = tokio_socks::tcp::Socks5Stream::connect_with_password(
-        config.socks.as_str(),
-        config.server.as_str(),
-        &config.sender,
-        &config.group,
+    async fn error_collector<'a>(
+        bootstrap: Option<Duration>,
+        str_params: &'a StrParams<'a>,
+        rng: &mut Xoshiro256PlusPlus,
+        mut state_machine: StateMachine,
+        recipients: Vec<&str>,
+    ) -> Result<(), (StateMachine, ClientError)> {
+        let mut stream = match tokio_socks::tcp::Socks5Stream::connect_with_password(
+            str_params.socks,
+            str_params.server,
+            str_params.sender,
+            str_params.group,
+        )
+        .await
+        {
+            Ok(stream) => stream,
+            Err(e) => return Err((state_machine, e.into())),
+        };
+        if let Err(e) = stream
+            .write_all(&mgen::serialize_str(str_params.sender))
+            .await
+        {
+            return Err((state_machine, e.into()));
+        }
+
+        if let Some(bootstrap) = bootstrap {
+            tokio::time::sleep(bootstrap).await;
+        }
+
+        loop {
+            state_machine = match state_machine {
+                StateMachine::Idle(conversation) => {
+                    manage_idle_conversation(
+                        conversation,
+                        &mut stream,
+                        str_params.sender,
+                        recipients.clone(),
+                        rng,
+                    )
+                    .await?
+                }
+                StateMachine::Active(conversation) => {
+                    manage_active_conversation(
+                        conversation,
+                        &mut stream,
+                        str_params.sender,
+                        recipients.clone(),
+                        rng,
+                    )
+                    .await?
+                }
+            };
+        }
+    }
+
+    let retry = config.retry;
+    let retry = Duration::from_secs_f64(retry);
+
+    match error_collector(
+        Some(Duration::from_secs_f64(config.bootstrap)),
+        &str_params,
+        &mut rng,
+        state_machine,
+        recipients.clone(),
     )
-    .await?;
-    stream
-        .write_all(&mgen::serialize_str(&config.sender))
-        .await?;
+    .await
+    .expect_err("Inner loop returned Ok?")
+    {
+        (sm, ClientError::Recoverable(_)) => {
+            state_machine = sm;
+            tokio::time::sleep(retry).await;
+        }
+        (_, e) => return Err(e),
+    };
 
-    tokio::time::sleep(Duration::from_secs(5)).await;
     loop {
-        state_machine = match state_machine {
-            StateMachine::Idle(conversation) => {
-                manage_idle_conversation(
-                    conversation,
-                    &mut stream,
-                    &config.sender,
-                    recipients.clone(),
-                    &mut rng,
-                )
-                .await?
-            }
-            StateMachine::Active(conversation) => {
-                manage_active_conversation(
-                    conversation,
-                    &mut stream,
-                    &config.sender,
-                    recipients.clone(),
-                    &mut rng,
-                )
-                .await?
+        match error_collector(
+            None,
+            &str_params,
+            &mut rng,
+            state_machine,
+            recipients.clone(),
+        )
+        .await
+        .expect_err("Inner loop returned Ok?")
+        {
+            (sm, ClientError::Recoverable(_)) => {
+                state_machine = sm;
+                tokio::time::sleep(retry).await;
             }
+            (_, e) => return Err(e),
         };
     }
 }
@@ -606,6 +742,8 @@ struct Config {
     recipients: Vec<String>,
     socks: String,
     server: String,
+    bootstrap: f64,
+    retry: f64,
     distributions: ConfigDistributions,
 }
 

+ 9 - 1
src/bin/server.rs

@@ -160,12 +160,18 @@ async fn send_messages(
 ) {
     let (mut current_socket, mut current_watch) =
         socket_updater.recv().await.expect("socket updater closed");
+    let mut message_cache = None;
     loop {
-        let message = msg_rcv.recv().await.expect("message channel closed");
+        let message = if message_cache.is_none() {
+            msg_rcv.recv().await.expect("message channel closed")
+        } else {
+            message_cache.unwrap()
+        };
         log!("sending message");
         if message.write_all_to(&mut current_socket).await.is_err()
             || current_socket.flush().await.is_err()
         {
+            message_cache = Some(message);
             log!("terminating connection");
             // socket is presumably closed, clean up and notify the listening end to close
             // (all best-effort, we can ignore errors because it presumably means it's done)
@@ -176,6 +182,8 @@ async fn send_messages(
             (current_socket, current_watch) =
                 socket_updater.recv().await.expect("socket updater closed");
             log!("socket updated");
+        } else {
+            message_cache = None;
         }
     }
 }

+ 4 - 12
src/lib.rs

@@ -22,7 +22,6 @@ macro_rules! log {
 pub enum Error {
     Io(std::io::Error),
     Utf8Error(std::str::Utf8Error),
-    TryFromSliceError(std::array::TryFromSliceError),
     MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
 }
 
@@ -46,12 +45,6 @@ impl From<std::str::Utf8Error> for Error {
     }
 }
 
-impl From<std::array::TryFromSliceError> for Error {
-    fn from(e: std::array::TryFromSliceError) -> Self {
-        Self::TryFromSliceError(e)
-    }
-}
-
 /// Metadata for the body of the message.
 ///
 /// Message contents are always 0-filled buffers, so never represented.
@@ -200,11 +193,10 @@ pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
 }
 
 fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
-    let bytes = buf.get(0..4).ok_or(Error::MalformedSerialization(
-        buf.to_vec(),
-        std::backtrace::Backtrace::capture(),
-    ))?;
-    Ok((u32::from_be_bytes(bytes.try_into()?), &buf[4..]))
+    let bytes = buf.get(0..4).ok_or_else(|| {
+        Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
+    })?;
+    Ok((u32::from_be_bytes(bytes.try_into().unwrap()), &buf[4..]))
 }
 
 fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {