Browse Source

sample message sizes from a distribution

Justin Tracey 1 year ago
parent
commit
a4272b771e
4 changed files with 168 additions and 95 deletions
  1. 0 1
      Cargo.toml
  2. 4 1
      README.md
  3. 156 89
      src/bin/client.rs
  4. 8 4
      src/lib.rs

+ 0 - 1
Cargo.toml

@@ -5,7 +5,6 @@ edition = "2021"
 
 [dependencies]
 chrono = "0.4.24"
-enum_dispatch = "0.3.11"
 rand = "0.8.5"
 rand_distr = { version = "0.4.3", features = ["serde1"] }
 rand_xoshiro = "0.6.0"

+ 4 - 1
README.md

@@ -52,6 +52,9 @@ server = "insert.ip.or.onion:6397"
 s = 0.5
 r = 0.1
 
+# The distribution of message sizes, as measured in padding blocks.
+m = {distribution = "Poisson", lambda = 1.0}
+
 # Distribution I, the amount of time Idle before sending a message.
 i = {distribution = "Normal", mean = 30.0, std_dev = 100.0}
 
@@ -66,7 +69,7 @@ a_r = {distribution = "Pareto", scale = 1.0, shape = 3.0}
 
 ```
 
-The client currently supports five probability distributions: Normal and LogNormal, Uniform, Exp(onential), and Pareto.
+The client currently supports five probability distributions for message timings: Normal and LogNormal, Uniform, Exp(onential), and Pareto.
 The parameter names can be found in the example above.
 The distributions are sampled to return a double-precision floating point number of seconds.
 The particular distributions and parameters used in the example are for demonstration purposes only, they have no relationship to empirical conversation behaviors.

+ 156 - 89
src/bin/client.rs

@@ -1,8 +1,8 @@
-use enum_dispatch::enum_dispatch;
 use mgen::{log, SerializedMessage};
 use rand_distr::{
-    Bernoulli, BernoulliError, Distribution, Exp, ExpError, LogNormal, Normal, NormalError, Pareto,
-    ParetoError, Uniform,
+    Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
+    Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
+    Poisson, PoissonError, Uniform,
 };
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
@@ -71,12 +71,36 @@ struct Conversation<S: State> {
     state: S,
 }
 
+/// The set of Distributions we currently support for message sizes (in padding blocks).
+/// To modify the code to add support for more, one approach is to first add them here,
+/// then fix all the compiler errors and warnings that arise as a result.
 #[derive(Debug)]
-#[enum_dispatch(Distribution)]
-/// The set of Distributions we currently support.
+enum MessageDistribution {
+    // Poisson is only defined for floats for technical reasons.
+    // https://rust-random.github.io/book/guide-dist.html#integers
+    Poisson(Poisson<f64>),
+    Binomial(Binomial),
+    Geometric(Geometric),
+    Hypergeometric(Hypergeometric),
+}
+
+impl Distribution<u32> for MessageDistribution {
+    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u32 {
+        let ret = match self {
+            Self::Poisson(d) => d.sample(rng) as u64,
+            Self::Binomial(d) => d.sample(rng),
+            Self::Geometric(d) => d.sample(rng),
+            Self::Hypergeometric(d) => d.sample(rng),
+        };
+        std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
+    }
+}
+
+/// The set of Distributions we currently support for timings.
 /// To modify the code to add support for more, one approach is to first add them here,
-/// then fix all the compiler errors that arise as a result.
-enum SupportedDistribution {
+/// then fix all the compiler errors and warnings that arise as a result.
+#[derive(Debug)]
+enum TimingDistribution {
     Normal(Normal<f64>),
     LogNormal(LogNormal<f64>),
     Uniform(Uniform<f64>),
@@ -84,13 +108,27 @@ enum SupportedDistribution {
     Pareto(Pareto<f64>),
 }
 
+impl Distribution<f64> for TimingDistribution {
+    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
+        let ret = match self {
+            Self::Normal(d) => d.sample(rng),
+            Self::LogNormal(d) => d.sample(rng),
+            Self::Uniform(d) => d.sample(rng),
+            Self::Exp(d) => d.sample(rng),
+            Self::Pareto(d) => d.sample(rng),
+        };
+        ret.max(0.0)
+    }
+}
+
 /// The set of distributions necessary to represent the actions of the state machine.
 #[derive(Debug)]
 struct Distributions {
-    i: SupportedDistribution,
-    w: SupportedDistribution,
-    a_s: SupportedDistribution,
-    a_r: SupportedDistribution,
+    m: MessageDistribution,
+    i: TimingDistribution,
+    w: TimingDistribution,
+    a_s: TimingDistribution,
+    a_r: TimingDistribution,
     s: Bernoulli,
     r: Bernoulli,
 }
@@ -255,10 +293,17 @@ async fn manage_idle_conversation(
 
     match action {
         IdleActions::Send => {
-            log!("sending message from {} to {:?}", our_id, recipients);
+            let size = conversation.dists.m.sample(rng);
+            log!(
+                "sending message from {} to {:?} of size {}",
+                our_id,
+                recipients,
+                size
+            );
             let m = construct_message(
                 our_id.to_string(),
                 recipients.iter().map(|s| s.to_string()).collect(),
+                size,
             );
             m.write_all_to(stream).await?;
             stream.flush().await?;
@@ -312,10 +357,17 @@ async fn manage_active_conversation(
 
     match action {
         ActiveActions::Send => {
-            log!("sending message from {} to {:?}", our_id, recipients);
+            let size = conversation.dists.m.sample(rng);
+            log!(
+                "sending message from {} to {:?} of size {}",
+                our_id,
+                recipients,
+                size
+            );
             let m = construct_message(
                 our_id.to_string(),
                 recipients.iter().map(|s| s.to_string()).collect(),
+                size,
             );
             m.write_all_to(stream).await?;
             stream.flush().await?;
@@ -327,14 +379,20 @@ async fn manage_active_conversation(
                 stream.read_exact(&mut header_size[n..]).await?;
             }
             let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
-            if msg.body != mgen::MessageBody::Receipt {
-                log!("{:?} got message from {}", msg.recipients, msg.sender);
-                let m = construct_receipt(our_id.to_string(), msg.sender);
-                m.write_all_to(stream).await?;
-                stream.flush().await?;
-                Ok(StateMachine::Active(conversation.received(rng)))
-            } else {
-                Ok(StateMachine::Active(conversation))
+            match msg.body {
+                mgen::MessageBody::Size(size) => {
+                    log!(
+                        "{:?} got message from {} of size {}",
+                        msg.recipients,
+                        msg.sender,
+                        size
+                    );
+                    let m = construct_receipt(our_id.to_string(), msg.sender);
+                    m.write_all_to(stream).await?;
+                    stream.flush().await?;
+                    Ok(StateMachine::Active(conversation.received(rng)))
+                }
+                mgen::MessageBody::Receipt => Ok(StateMachine::Active(conversation)),
             }
         }
         ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
@@ -386,62 +444,19 @@ async fn manage_conversation(config: Config) -> Result<(), ClientError> {
     }
 }
 
-/// A wrapper for the Distribution trait that specifies the RNG to allow (fake) dynamic dispatch.
-#[enum_dispatch(SupportedDistribution)]
-trait Dist {
-    fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64;
-}
-
-/*
-// This would be easier, but we run into https://github.com/rust-lang/rust/issues/48869
-impl<T, D> Dist<T> for D
-where
-    D: Distribution<T> + Send + Sync,
-{
-    fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> T {
-        self.sample(rng)
-    }
-}
- */
-
-macro_rules! dist_impl {
-    ($dist:ident) => {
-        impl Dist for $dist<f64> {
-            fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
-                Distribution::sample(self, rng)
-            }
-        }
-    };
-}
-
-dist_impl!(Exp);
-dist_impl!(Normal);
-dist_impl!(LogNormal);
-dist_impl!(Pareto);
-dist_impl!(Uniform);
-
-impl SupportedDistribution {
-    // FIXME: there's probably a better way to do this integrated with the crate
-    fn clamped_sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
-        let sample = self.sample(rng);
-        if sample >= 0.0 {
-            sample
-        } else {
-            0.0
-        }
-    }
-
+impl TimingDistribution {
     fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
-        Duration::from_secs_f64(self.clamped_sample(rng))
+        Duration::from_secs_f64(self.sample(rng))
     }
 }
 
-fn construct_message(sender: String, recipients: Vec<String>) -> SerializedMessage {
-    // FIXME: sample size from distribution
+/// Construct and serialize a message from the sender to the recipients with the given number of blocks.
+fn construct_message(sender: String, recipients: Vec<String>, blocks: u32) -> SerializedMessage {
+    let size = std::cmp::max(blocks, 1) * mgen::PADDING_BLOCK_SIZE;
     let m = mgen::MessageHeader {
         sender,
         recipients,
-        body: mgen::MessageBody::Size(NonZeroU32::new(1024).unwrap()),
+        body: mgen::MessageBody::Size(NonZeroU32::new(size).unwrap()),
     };
     m.serialize()
 }
@@ -458,18 +473,40 @@ fn construct_receipt(sender: String, recipient: String) -> SerializedMessage {
 /// The same as Distributions, but designed for easier deserialization.
 #[derive(Debug, Deserialize)]
 struct ConfigDistributions {
-    i: ConfigSupportedDistribution,
-    w: ConfigSupportedDistribution,
-    a_s: ConfigSupportedDistribution,
-    a_r: ConfigSupportedDistribution,
+    m: ConfigMessageDistribution,
+    i: ConfigTimingDistribution,
+    w: ConfigTimingDistribution,
+    a_s: ConfigTimingDistribution,
+    a_r: ConfigTimingDistribution,
     s: f64,
     r: f64,
 }
 
-/// The same as SupportedDistributions, but designed for easier deserialization.
+/// The same as MessageDistribution, but designed for easier deserialization.
+#[derive(Debug, Deserialize)]
+#[serde(tag = "distribution")]
+enum ConfigMessageDistribution {
+    Poisson {
+        lambda: f64,
+    },
+    Binomial {
+        n: u64,
+        p: f64,
+    },
+    Geometric {
+        p: f64,
+    },
+    Hypergeometric {
+        total_population_size: u64,
+        population_with_feature: u64,
+        sample_size: u64,
+    },
+}
+
+/// The same as TimingDistribution, but designed for easier deserialization.
 #[derive(Debug, Deserialize)]
 #[serde(tag = "distribution")]
-enum ConfigSupportedDistribution {
+enum ConfigTimingDistribution {
     Normal { mean: f64, std_dev: f64 },
     LogNormal { mean: f64, std_dev: f64 },
     Uniform { low: f64, high: f64 },
@@ -479,6 +516,10 @@ enum ConfigSupportedDistribution {
 
 #[derive(Debug)]
 enum DistParameterError {
+    Poisson(PoissonError),
+    Binomial(BinomialError),
+    Geometric(GeoError),
+    Hypergeometric(HyperGeoError),
     Bernoulli(BernoulliError),
     Normal(NormalError),
     LogNormal(NormalError),
@@ -487,29 +528,54 @@ enum DistParameterError {
     Pareto(ParetoError),
 }
 
-impl TryFrom<ConfigSupportedDistribution> for SupportedDistribution {
+impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
     type Error = DistParameterError;
 
-    fn try_from(dist: ConfigSupportedDistribution) -> Result<Self, DistParameterError> {
+    fn try_from(dist: ConfigMessageDistribution) -> Result<Self, DistParameterError> {
         let dist = match dist {
-            ConfigSupportedDistribution::Normal { mean, std_dev } => SupportedDistribution::Normal(
+            ConfigMessageDistribution::Poisson { lambda } => MessageDistribution::Poisson(
+                Poisson::new(lambda).map_err(DistParameterError::Poisson)?,
+            ),
+            ConfigMessageDistribution::Binomial { n, p } => MessageDistribution::Binomial(
+                Binomial::new(n, p).map_err(DistParameterError::Binomial)?,
+            ),
+            ConfigMessageDistribution::Geometric { p } => MessageDistribution::Geometric(
+                Geometric::new(p).map_err(DistParameterError::Geometric)?,
+            ),
+            ConfigMessageDistribution::Hypergeometric {
+                total_population_size,
+                population_with_feature,
+                sample_size,
+            } => MessageDistribution::Hypergeometric(
+                Hypergeometric::new(total_population_size, population_with_feature, sample_size)
+                    .map_err(DistParameterError::Hypergeometric)?,
+            ),
+        };
+        Ok(dist)
+    }
+}
+
+impl TryFrom<ConfigTimingDistribution> for TimingDistribution {
+    type Error = DistParameterError;
+
+    fn try_from(dist: ConfigTimingDistribution) -> Result<Self, DistParameterError> {
+        let dist = match dist {
+            ConfigTimingDistribution::Normal { mean, std_dev } => TimingDistribution::Normal(
                 Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
             ),
-            ConfigSupportedDistribution::LogNormal { mean, std_dev } => {
-                SupportedDistribution::LogNormal(
-                    LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
-                )
-            }
-            ConfigSupportedDistribution::Uniform { low, high } => {
+            ConfigTimingDistribution::LogNormal { mean, std_dev } => TimingDistribution::LogNormal(
+                LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
+            ),
+            ConfigTimingDistribution::Uniform { low, high } => {
                 if low >= high {
                     return Err(DistParameterError::Uniform);
                 }
-                SupportedDistribution::Uniform(Uniform::new(low, high))
+                TimingDistribution::Uniform(Uniform::new(low, high))
             }
-            ConfigSupportedDistribution::Exp { lambda } => {
-                SupportedDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
+            ConfigTimingDistribution::Exp { lambda } => {
+                TimingDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
             }
-            ConfigSupportedDistribution::Pareto { scale, shape } => SupportedDistribution::Pareto(
+            ConfigTimingDistribution::Pareto { scale, shape } => TimingDistribution::Pareto(
                 Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
             ),
         };
@@ -522,6 +588,7 @@ impl TryFrom<ConfigDistributions> for Distributions {
 
     fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
         Ok(Distributions {
+            m: config.m.try_into()?,
             i: config.i.try_into()?,
             w: config.w.try_into()?,
             a_s: config.a_s.try_into()?,

+ 8 - 4
src/lib.rs

@@ -2,9 +2,13 @@ use std::mem::size_of;
 use std::num::NonZeroU32;
 use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
 
-/// The minimum message size.
-/// All messages bodies less than this size (notably, receipts) will be padded to this length.
-const MIN_MESSAGE_SIZE: u32 = 256; // FIXME: double check what this should be
+/// The padding interval. All message bodies are a size of some multiple of this.
+/// All messages bodies are a minimum  of this size.
+// FIXME: double check what this should be
+pub const PADDING_BLOCK_SIZE: u32 = 256;
+/// The most blocks a message body can contain.
+// from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
+pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
 
 #[macro_export]
 macro_rules! log {
@@ -60,7 +64,7 @@ pub enum MessageBody {
 impl MessageBody {
     fn size(&self) -> u32 {
         match self {
-            MessageBody::Receipt => MIN_MESSAGE_SIZE,
+            MessageBody::Receipt => PADDING_BLOCK_SIZE,
             MessageBody::Size(size) => size.get(),
         }
     }