Browse Source

sample message sizes from a distribution

Justin Tracey 2 years ago
parent
commit
a4272b771e
4 changed files with 168 additions and 95 deletions
  1. 0 1
      Cargo.toml
  2. 4 1
      README.md
  3. 156 89
      src/bin/client.rs
  4. 8 4
      src/lib.rs

+ 0 - 1
Cargo.toml

@@ -5,7 +5,6 @@ edition = "2021"
 
 
 [dependencies]
 [dependencies]
 chrono = "0.4.24"
 chrono = "0.4.24"
-enum_dispatch = "0.3.11"
 rand = "0.8.5"
 rand = "0.8.5"
 rand_distr = { version = "0.4.3", features = ["serde1"] }
 rand_distr = { version = "0.4.3", features = ["serde1"] }
 rand_xoshiro = "0.6.0"
 rand_xoshiro = "0.6.0"

+ 4 - 1
README.md

@@ -52,6 +52,9 @@ server = "insert.ip.or.onion:6397"
 s = 0.5
 s = 0.5
 r = 0.1
 r = 0.1
 
 
+# The distribution of message sizes, as measured in padding blocks.
+m = {distribution = "Poisson", lambda = 1.0}
+
 # Distribution I, the amount of time Idle before sending a message.
 # Distribution I, the amount of time Idle before sending a message.
 i = {distribution = "Normal", mean = 30.0, std_dev = 100.0}
 i = {distribution = "Normal", mean = 30.0, std_dev = 100.0}
 
 
@@ -66,7 +69,7 @@ a_r = {distribution = "Pareto", scale = 1.0, shape = 3.0}
 
 
 ```
 ```
 
 
-The client currently supports five probability distributions: Normal and LogNormal, Uniform, Exp(onential), and Pareto.
+The client currently supports five probability distributions for message timings: Normal and LogNormal, Uniform, Exp(onential), and Pareto.
 The parameter names can be found in the example above.
 The parameter names can be found in the example above.
 The distributions are sampled to return a double-precision floating point number of seconds.
 The distributions are sampled to return a double-precision floating point number of seconds.
 The particular distributions and parameters used in the example are for demonstration purposes only, they have no relationship to empirical conversation behaviors.
 The particular distributions and parameters used in the example are for demonstration purposes only, they have no relationship to empirical conversation behaviors.

+ 156 - 89
src/bin/client.rs

@@ -1,8 +1,8 @@
-use enum_dispatch::enum_dispatch;
 use mgen::{log, SerializedMessage};
 use mgen::{log, SerializedMessage};
 use rand_distr::{
 use rand_distr::{
-    Bernoulli, BernoulliError, Distribution, Exp, ExpError, LogNormal, Normal, NormalError, Pareto,
-    ParetoError, Uniform,
+    Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
+    Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
+    Poisson, PoissonError, Uniform,
 };
 };
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256PlusPlus};
 use serde::Deserialize;
 use serde::Deserialize;
@@ -71,12 +71,36 @@ struct Conversation<S: State> {
     state: S,
     state: S,
 }
 }
 
 
+/// The set of Distributions we currently support for message sizes (in padding blocks).
+/// To modify the code to add support for more, one approach is to first add them here,
+/// then fix all the compiler errors and warnings that arise as a result.
 #[derive(Debug)]
 #[derive(Debug)]
-#[enum_dispatch(Distribution)]
-/// The set of Distributions we currently support.
+enum MessageDistribution {
+    // Poisson is only defined for floats for technical reasons.
+    // https://rust-random.github.io/book/guide-dist.html#integers
+    Poisson(Poisson<f64>),
+    Binomial(Binomial),
+    Geometric(Geometric),
+    Hypergeometric(Hypergeometric),
+}
+
+impl Distribution<u32> for MessageDistribution {
+    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u32 {
+        let ret = match self {
+            Self::Poisson(d) => d.sample(rng) as u64,
+            Self::Binomial(d) => d.sample(rng),
+            Self::Geometric(d) => d.sample(rng),
+            Self::Hypergeometric(d) => d.sample(rng),
+        };
+        std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
+    }
+}
+
+/// The set of Distributions we currently support for timings.
 /// To modify the code to add support for more, one approach is to first add them here,
 /// To modify the code to add support for more, one approach is to first add them here,
-/// then fix all the compiler errors that arise as a result.
-enum SupportedDistribution {
+/// then fix all the compiler errors and warnings that arise as a result.
+#[derive(Debug)]
+enum TimingDistribution {
     Normal(Normal<f64>),
     Normal(Normal<f64>),
     LogNormal(LogNormal<f64>),
     LogNormal(LogNormal<f64>),
     Uniform(Uniform<f64>),
     Uniform(Uniform<f64>),
@@ -84,13 +108,27 @@ enum SupportedDistribution {
     Pareto(Pareto<f64>),
     Pareto(Pareto<f64>),
 }
 }
 
 
+impl Distribution<f64> for TimingDistribution {
+    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
+        let ret = match self {
+            Self::Normal(d) => d.sample(rng),
+            Self::LogNormal(d) => d.sample(rng),
+            Self::Uniform(d) => d.sample(rng),
+            Self::Exp(d) => d.sample(rng),
+            Self::Pareto(d) => d.sample(rng),
+        };
+        ret.max(0.0)
+    }
+}
+
 /// The set of distributions necessary to represent the actions of the state machine.
 /// The set of distributions necessary to represent the actions of the state machine.
 #[derive(Debug)]
 #[derive(Debug)]
 struct Distributions {
 struct Distributions {
-    i: SupportedDistribution,
-    w: SupportedDistribution,
-    a_s: SupportedDistribution,
-    a_r: SupportedDistribution,
+    m: MessageDistribution,
+    i: TimingDistribution,
+    w: TimingDistribution,
+    a_s: TimingDistribution,
+    a_r: TimingDistribution,
     s: Bernoulli,
     s: Bernoulli,
     r: Bernoulli,
     r: Bernoulli,
 }
 }
@@ -255,10 +293,17 @@ async fn manage_idle_conversation(
 
 
     match action {
     match action {
         IdleActions::Send => {
         IdleActions::Send => {
-            log!("sending message from {} to {:?}", our_id, recipients);
+            let size = conversation.dists.m.sample(rng);
+            log!(
+                "sending message from {} to {:?} of size {}",
+                our_id,
+                recipients,
+                size
+            );
             let m = construct_message(
             let m = construct_message(
                 our_id.to_string(),
                 our_id.to_string(),
                 recipients.iter().map(|s| s.to_string()).collect(),
                 recipients.iter().map(|s| s.to_string()).collect(),
+                size,
             );
             );
             m.write_all_to(stream).await?;
             m.write_all_to(stream).await?;
             stream.flush().await?;
             stream.flush().await?;
@@ -312,10 +357,17 @@ async fn manage_active_conversation(
 
 
     match action {
     match action {
         ActiveActions::Send => {
         ActiveActions::Send => {
-            log!("sending message from {} to {:?}", our_id, recipients);
+            let size = conversation.dists.m.sample(rng);
+            log!(
+                "sending message from {} to {:?} of size {}",
+                our_id,
+                recipients,
+                size
+            );
             let m = construct_message(
             let m = construct_message(
                 our_id.to_string(),
                 our_id.to_string(),
                 recipients.iter().map(|s| s.to_string()).collect(),
                 recipients.iter().map(|s| s.to_string()).collect(),
+                size,
             );
             );
             m.write_all_to(stream).await?;
             m.write_all_to(stream).await?;
             stream.flush().await?;
             stream.flush().await?;
@@ -327,14 +379,20 @@ async fn manage_active_conversation(
                 stream.read_exact(&mut header_size[n..]).await?;
                 stream.read_exact(&mut header_size[n..]).await?;
             }
             }
             let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
             let (msg, _) = mgen::get_message_with_header_size(stream, header_size).await?;
-            if msg.body != mgen::MessageBody::Receipt {
-                log!("{:?} got message from {}", msg.recipients, msg.sender);
-                let m = construct_receipt(our_id.to_string(), msg.sender);
-                m.write_all_to(stream).await?;
-                stream.flush().await?;
-                Ok(StateMachine::Active(conversation.received(rng)))
-            } else {
-                Ok(StateMachine::Active(conversation))
+            match msg.body {
+                mgen::MessageBody::Size(size) => {
+                    log!(
+                        "{:?} got message from {} of size {}",
+                        msg.recipients,
+                        msg.sender,
+                        size
+                    );
+                    let m = construct_receipt(our_id.to_string(), msg.sender);
+                    m.write_all_to(stream).await?;
+                    stream.flush().await?;
+                    Ok(StateMachine::Active(conversation.received(rng)))
+                }
+                mgen::MessageBody::Receipt => Ok(StateMachine::Active(conversation)),
             }
             }
         }
         }
         ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
         ActiveActions::Idle => Ok(StateMachine::Idle(conversation.waited(rng))),
@@ -386,62 +444,19 @@ async fn manage_conversation(config: Config) -> Result<(), ClientError> {
     }
     }
 }
 }
 
 
-/// A wrapper for the Distribution trait that specifies the RNG to allow (fake) dynamic dispatch.
-#[enum_dispatch(SupportedDistribution)]
-trait Dist {
-    fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64;
-}
-
-/*
-// This would be easier, but we run into https://github.com/rust-lang/rust/issues/48869
-impl<T, D> Dist<T> for D
-where
-    D: Distribution<T> + Send + Sync,
-{
-    fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> T {
-        self.sample(rng)
-    }
-}
- */
-
-macro_rules! dist_impl {
-    ($dist:ident) => {
-        impl Dist for $dist<f64> {
-            fn sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
-                Distribution::sample(self, rng)
-            }
-        }
-    };
-}
-
-dist_impl!(Exp);
-dist_impl!(Normal);
-dist_impl!(LogNormal);
-dist_impl!(Pareto);
-dist_impl!(Uniform);
-
-impl SupportedDistribution {
-    // FIXME: there's probably a better way to do this integrated with the crate
-    fn clamped_sample(&self, rng: &mut Xoshiro256PlusPlus) -> f64 {
-        let sample = self.sample(rng);
-        if sample >= 0.0 {
-            sample
-        } else {
-            0.0
-        }
-    }
-
+impl TimingDistribution {
     fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
     fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
-        Duration::from_secs_f64(self.clamped_sample(rng))
+        Duration::from_secs_f64(self.sample(rng))
     }
     }
 }
 }
 
 
-fn construct_message(sender: String, recipients: Vec<String>) -> SerializedMessage {
-    // FIXME: sample size from distribution
+/// Construct and serialize a message from the sender to the recipients with the given number of blocks.
+fn construct_message(sender: String, recipients: Vec<String>, blocks: u32) -> SerializedMessage {
+    let size = std::cmp::max(blocks, 1) * mgen::PADDING_BLOCK_SIZE;
     let m = mgen::MessageHeader {
     let m = mgen::MessageHeader {
         sender,
         sender,
         recipients,
         recipients,
-        body: mgen::MessageBody::Size(NonZeroU32::new(1024).unwrap()),
+        body: mgen::MessageBody::Size(NonZeroU32::new(size).unwrap()),
     };
     };
     m.serialize()
     m.serialize()
 }
 }
@@ -458,18 +473,40 @@ fn construct_receipt(sender: String, recipient: String) -> SerializedMessage {
 /// The same as Distributions, but designed for easier deserialization.
 /// The same as Distributions, but designed for easier deserialization.
 #[derive(Debug, Deserialize)]
 #[derive(Debug, Deserialize)]
 struct ConfigDistributions {
 struct ConfigDistributions {
-    i: ConfigSupportedDistribution,
-    w: ConfigSupportedDistribution,
-    a_s: ConfigSupportedDistribution,
-    a_r: ConfigSupportedDistribution,
+    m: ConfigMessageDistribution,
+    i: ConfigTimingDistribution,
+    w: ConfigTimingDistribution,
+    a_s: ConfigTimingDistribution,
+    a_r: ConfigTimingDistribution,
     s: f64,
     s: f64,
     r: f64,
     r: f64,
 }
 }
 
 
-/// The same as SupportedDistributions, but designed for easier deserialization.
+/// The same as MessageDistribution, but designed for easier deserialization.
+#[derive(Debug, Deserialize)]
+#[serde(tag = "distribution")]
+enum ConfigMessageDistribution {
+    Poisson {
+        lambda: f64,
+    },
+    Binomial {
+        n: u64,
+        p: f64,
+    },
+    Geometric {
+        p: f64,
+    },
+    Hypergeometric {
+        total_population_size: u64,
+        population_with_feature: u64,
+        sample_size: u64,
+    },
+}
+
+/// The same as TimingDistribution, but designed for easier deserialization.
 #[derive(Debug, Deserialize)]
 #[derive(Debug, Deserialize)]
 #[serde(tag = "distribution")]
 #[serde(tag = "distribution")]
-enum ConfigSupportedDistribution {
+enum ConfigTimingDistribution {
     Normal { mean: f64, std_dev: f64 },
     Normal { mean: f64, std_dev: f64 },
     LogNormal { mean: f64, std_dev: f64 },
     LogNormal { mean: f64, std_dev: f64 },
     Uniform { low: f64, high: f64 },
     Uniform { low: f64, high: f64 },
@@ -479,6 +516,10 @@ enum ConfigSupportedDistribution {
 
 
 #[derive(Debug)]
 #[derive(Debug)]
 enum DistParameterError {
 enum DistParameterError {
+    Poisson(PoissonError),
+    Binomial(BinomialError),
+    Geometric(GeoError),
+    Hypergeometric(HyperGeoError),
     Bernoulli(BernoulliError),
     Bernoulli(BernoulliError),
     Normal(NormalError),
     Normal(NormalError),
     LogNormal(NormalError),
     LogNormal(NormalError),
@@ -487,29 +528,54 @@ enum DistParameterError {
     Pareto(ParetoError),
     Pareto(ParetoError),
 }
 }
 
 
-impl TryFrom<ConfigSupportedDistribution> for SupportedDistribution {
+impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
     type Error = DistParameterError;
     type Error = DistParameterError;
 
 
-    fn try_from(dist: ConfigSupportedDistribution) -> Result<Self, DistParameterError> {
+    fn try_from(dist: ConfigMessageDistribution) -> Result<Self, DistParameterError> {
         let dist = match dist {
         let dist = match dist {
-            ConfigSupportedDistribution::Normal { mean, std_dev } => SupportedDistribution::Normal(
+            ConfigMessageDistribution::Poisson { lambda } => MessageDistribution::Poisson(
+                Poisson::new(lambda).map_err(DistParameterError::Poisson)?,
+            ),
+            ConfigMessageDistribution::Binomial { n, p } => MessageDistribution::Binomial(
+                Binomial::new(n, p).map_err(DistParameterError::Binomial)?,
+            ),
+            ConfigMessageDistribution::Geometric { p } => MessageDistribution::Geometric(
+                Geometric::new(p).map_err(DistParameterError::Geometric)?,
+            ),
+            ConfigMessageDistribution::Hypergeometric {
+                total_population_size,
+                population_with_feature,
+                sample_size,
+            } => MessageDistribution::Hypergeometric(
+                Hypergeometric::new(total_population_size, population_with_feature, sample_size)
+                    .map_err(DistParameterError::Hypergeometric)?,
+            ),
+        };
+        Ok(dist)
+    }
+}
+
+impl TryFrom<ConfigTimingDistribution> for TimingDistribution {
+    type Error = DistParameterError;
+
+    fn try_from(dist: ConfigTimingDistribution) -> Result<Self, DistParameterError> {
+        let dist = match dist {
+            ConfigTimingDistribution::Normal { mean, std_dev } => TimingDistribution::Normal(
                 Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
                 Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
             ),
             ),
-            ConfigSupportedDistribution::LogNormal { mean, std_dev } => {
-                SupportedDistribution::LogNormal(
-                    LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
-                )
-            }
-            ConfigSupportedDistribution::Uniform { low, high } => {
+            ConfigTimingDistribution::LogNormal { mean, std_dev } => TimingDistribution::LogNormal(
+                LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
+            ),
+            ConfigTimingDistribution::Uniform { low, high } => {
                 if low >= high {
                 if low >= high {
                     return Err(DistParameterError::Uniform);
                     return Err(DistParameterError::Uniform);
                 }
                 }
-                SupportedDistribution::Uniform(Uniform::new(low, high))
+                TimingDistribution::Uniform(Uniform::new(low, high))
             }
             }
-            ConfigSupportedDistribution::Exp { lambda } => {
-                SupportedDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
+            ConfigTimingDistribution::Exp { lambda } => {
+                TimingDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
             }
             }
-            ConfigSupportedDistribution::Pareto { scale, shape } => SupportedDistribution::Pareto(
+            ConfigTimingDistribution::Pareto { scale, shape } => TimingDistribution::Pareto(
                 Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
                 Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
             ),
             ),
         };
         };
@@ -522,6 +588,7 @@ impl TryFrom<ConfigDistributions> for Distributions {
 
 
     fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
     fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
         Ok(Distributions {
         Ok(Distributions {
+            m: config.m.try_into()?,
             i: config.i.try_into()?,
             i: config.i.try_into()?,
             w: config.w.try_into()?,
             w: config.w.try_into()?,
             a_s: config.a_s.try_into()?,
             a_s: config.a_s.try_into()?,

+ 8 - 4
src/lib.rs

@@ -2,9 +2,13 @@ use std::mem::size_of;
 use std::num::NonZeroU32;
 use std::num::NonZeroU32;
 use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
 use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
 
 
-/// The minimum message size.
-/// All messages bodies less than this size (notably, receipts) will be padded to this length.
-const MIN_MESSAGE_SIZE: u32 = 256; // FIXME: double check what this should be
+/// The padding interval. All message bodies are a size of some multiple of this.
+/// All messages bodies are a minimum  of this size.
+// FIXME: double check what this should be
+pub const PADDING_BLOCK_SIZE: u32 = 256;
+/// The most blocks a message body can contain.
+// from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
+pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
 
 
 #[macro_export]
 #[macro_export]
 macro_rules! log {
 macro_rules! log {
@@ -60,7 +64,7 @@ pub enum MessageBody {
 impl MessageBody {
 impl MessageBody {
     fn size(&self) -> u32 {
     fn size(&self) -> u32 {
         match self {
         match self {
-            MessageBody::Receipt => MIN_MESSAGE_SIZE,
+            MessageBody::Receipt => PADDING_BLOCK_SIZE,
             MessageBody::Size(size) => size.get(),
             MessageBody::Size(size) => size.get(),
         }
         }
     }
     }