Przeglądaj źródła

add support for weighted dists on message sizes

Justin Tracey 2 lat temu
rodzic
commit
ec9d2eb921
1 zmienionych plików z 50 dodań i 1 usunięć
  1. 50 1
      src/bin/messenger/dists.rs

+ 50 - 1
src/bin/messenger/dists.rs

@@ -3,10 +3,11 @@
 use rand_distr::{
 use rand_distr::{
     Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
     Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
     Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
     Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
-    Poisson, PoissonError, Uniform,
+    Poisson, PoissonError, Uniform, WeightedAliasIndex, WeightedError,
 };
 };
 use rand_xoshiro::Xoshiro256PlusPlus;
 use rand_xoshiro::Xoshiro256PlusPlus;
 use serde::Deserialize;
 use serde::Deserialize;
+use std::{num::ParseIntError, str::FromStr};
 use tokio::time::Duration;
 use tokio::time::Duration;
 
 
 /// The set of Distributions we currently support for message sizes (in padding blocks).
 /// The set of Distributions we currently support for message sizes (in padding blocks).
@@ -20,6 +21,7 @@ pub enum MessageDistribution {
     Binomial(Binomial),
     Binomial(Binomial),
     Geometric(Geometric),
     Geometric(Geometric),
     Hypergeometric(Hypergeometric),
     Hypergeometric(Hypergeometric),
+    Weighted(WeightedAliasIndex<u32>, Vec<u32>),
 }
 }
 
 
 impl Distribution<u32> for MessageDistribution {
 impl Distribution<u32> for MessageDistribution {
@@ -29,6 +31,7 @@ impl Distribution<u32> for MessageDistribution {
             Self::Binomial(d) => d.sample(rng),
             Self::Binomial(d) => d.sample(rng),
             Self::Geometric(d) => d.sample(rng),
             Self::Geometric(d) => d.sample(rng),
             Self::Hypergeometric(d) => d.sample(rng),
             Self::Hypergeometric(d) => d.sample(rng),
+            Self::Weighted(d, v) => v[d.sample(rng)].into(),
         };
         };
         std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
         std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
     }
     }
@@ -108,6 +111,9 @@ enum ConfigMessageDistribution {
         population_with_feature: u64,
         population_with_feature: u64,
         sample_size: u64,
         sample_size: u64,
     },
     },
+    Weighted {
+        weights_file: String,
+    },
 }
 }
 
 
 /// The same as TimingDistribution, but designed for easier deserialization.
 /// The same as TimingDistribution, but designed for easier deserialization.
@@ -133,7 +139,9 @@ pub enum DistParameterError {
     Uniform, // Uniform::new doesn't return an error, it just panics
     Uniform, // Uniform::new doesn't return an error, it just panics
     Exp(ExpError),
     Exp(ExpError),
     Pareto(ParetoError),
     Pareto(ParetoError),
+    WeightedParseError(WeightedParseError),
 }
 }
+
 impl std::fmt::Display for DistParameterError {
 impl std::fmt::Display for DistParameterError {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         write!(f, "{:?}", self)
         write!(f, "{:?}", self)
@@ -142,6 +150,42 @@ impl std::fmt::Display for DistParameterError {
 
 
 impl std::error::Error for DistParameterError {}
 impl std::error::Error for DistParameterError {}
 
 
+#[derive(Debug)]
+pub enum WeightedParseError {
+    WeightedError(WeightedError),
+    Io(std::io::Error),
+    ParseIntError(ParseIntError),
+}
+
+impl std::fmt::Display for WeightedParseError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
+impl std::error::Error for WeightedParseError {}
+
+fn parse_weights_file(
+    path: String,
+) -> Result<(WeightedAliasIndex<u32>, Vec<u32>), WeightedParseError> {
+    let weights_file = std::fs::read_to_string(path).map_err(WeightedParseError::Io)?;
+    let mut weights_lines = weights_file.lines();
+    let weights = weights_lines.next().unwrap()
+        .split(',')
+        .map(u32::from_str)
+        .collect::<Result<Vec<_>, _>>()
+        .map_err(WeightedParseError::ParseIntError)?;
+    let vals = weights_lines.next().expect("Weights file only has one line")
+        .split(',')
+        .map(u32::from_str)
+        .collect::<Result<Vec<_>, _>>()
+        .map_err(WeightedParseError::ParseIntError)?;
+    assert!(weights.len() == vals.len(), "Weights file doesn't have the same number of weights and values.");
+    let dist =
+        WeightedAliasIndex::<u32>::new(weights).map_err(WeightedParseError::WeightedError)?;
+    Ok((dist, vals))
+}
+
 impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
 impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
     type Error = DistParameterError;
     type Error = DistParameterError;
 
 
@@ -164,6 +208,11 @@ impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
                 Hypergeometric::new(total_population_size, population_with_feature, sample_size)
                 Hypergeometric::new(total_population_size, population_with_feature, sample_size)
                     .map_err(DistParameterError::Hypergeometric)?,
                     .map_err(DistParameterError::Hypergeometric)?,
             ),
             ),
+            ConfigMessageDistribution::Weighted { weights_file } => {
+                let (dist, vals) = parse_weights_file(weights_file)
+                    .map_err(DistParameterError::WeightedParseError)?;
+                MessageDistribution::Weighted(dist, vals)
+            }
         };
         };
         Ok(dist)
         Ok(dist)
     }
     }