|
@@ -3,10 +3,11 @@
|
|
|
use rand_distr::{
|
|
use rand_distr::{
|
|
|
Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
|
|
Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
|
|
|
Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
|
|
Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
|
|
|
- Poisson, PoissonError, Uniform,
|
|
|
|
|
|
|
+ Poisson, PoissonError, Uniform, WeightedAliasIndex, WeightedError,
|
|
|
};
|
|
};
|
|
|
use rand_xoshiro::Xoshiro256PlusPlus;
|
|
use rand_xoshiro::Xoshiro256PlusPlus;
|
|
|
use serde::Deserialize;
|
|
use serde::Deserialize;
|
|
|
|
|
+use std::{num::ParseIntError, str::FromStr};
|
|
|
use tokio::time::Duration;
|
|
use tokio::time::Duration;
|
|
|
|
|
|
|
|
/// The set of Distributions we currently support for message sizes (in padding blocks).
|
|
/// The set of Distributions we currently support for message sizes (in padding blocks).
|
|
@@ -20,6 +21,7 @@ pub enum MessageDistribution {
|
|
|
Binomial(Binomial),
|
|
Binomial(Binomial),
|
|
|
Geometric(Geometric),
|
|
Geometric(Geometric),
|
|
|
Hypergeometric(Hypergeometric),
|
|
Hypergeometric(Hypergeometric),
|
|
|
|
|
+ Weighted(WeightedAliasIndex<u32>, Vec<u32>),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
impl Distribution<u32> for MessageDistribution {
|
|
impl Distribution<u32> for MessageDistribution {
|
|
@@ -29,6 +31,7 @@ impl Distribution<u32> for MessageDistribution {
|
|
|
Self::Binomial(d) => d.sample(rng),
|
|
Self::Binomial(d) => d.sample(rng),
|
|
|
Self::Geometric(d) => d.sample(rng),
|
|
Self::Geometric(d) => d.sample(rng),
|
|
|
Self::Hypergeometric(d) => d.sample(rng),
|
|
Self::Hypergeometric(d) => d.sample(rng),
|
|
|
|
|
+ Self::Weighted(d, v) => v[d.sample(rng)].into(),
|
|
|
};
|
|
};
|
|
|
std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
|
|
std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
|
|
|
}
|
|
}
|
|
@@ -108,6 +111,9 @@ enum ConfigMessageDistribution {
|
|
|
population_with_feature: u64,
|
|
population_with_feature: u64,
|
|
|
sample_size: u64,
|
|
sample_size: u64,
|
|
|
},
|
|
},
|
|
|
|
|
+ Weighted {
|
|
|
|
|
+ weights_file: String,
|
|
|
|
|
+ },
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/// The same as TimingDistribution, but designed for easier deserialization.
|
|
/// The same as TimingDistribution, but designed for easier deserialization.
|
|
@@ -133,7 +139,9 @@ pub enum DistParameterError {
|
|
|
Uniform, // Uniform::new doesn't return an error, it just panics
|
|
Uniform, // Uniform::new doesn't return an error, it just panics
|
|
|
Exp(ExpError),
|
|
Exp(ExpError),
|
|
|
Pareto(ParetoError),
|
|
Pareto(ParetoError),
|
|
|
|
|
+ WeightedParseError(WeightedParseError),
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
impl std::fmt::Display for DistParameterError {
|
|
impl std::fmt::Display for DistParameterError {
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
write!(f, "{:?}", self)
|
|
write!(f, "{:?}", self)
|
|
@@ -142,6 +150,42 @@ impl std::fmt::Display for DistParameterError {
|
|
|
|
|
|
|
|
impl std::error::Error for DistParameterError {}
|
|
impl std::error::Error for DistParameterError {}
|
|
|
|
|
|
|
|
|
|
+#[derive(Debug)]
|
|
|
|
|
+pub enum WeightedParseError {
|
|
|
|
|
+ WeightedError(WeightedError),
|
|
|
|
|
+ Io(std::io::Error),
|
|
|
|
|
+ ParseIntError(ParseIntError),
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl std::fmt::Display for WeightedParseError {
|
|
|
|
|
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
+ write!(f, "{:?}", self)
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl std::error::Error for WeightedParseError {}
|
|
|
|
|
+
|
|
|
|
|
+fn parse_weights_file(
|
|
|
|
|
+ path: String,
|
|
|
|
|
+) -> Result<(WeightedAliasIndex<u32>, Vec<u32>), WeightedParseError> {
|
|
|
|
|
+ let weights_file = std::fs::read_to_string(path).map_err(WeightedParseError::Io)?;
|
|
|
|
|
+ let mut weights_lines = weights_file.lines();
|
|
|
|
|
+ let weights = weights_lines.next().unwrap()
|
|
|
|
|
+ .split(',')
|
|
|
|
|
+ .map(u32::from_str)
|
|
|
|
|
+ .collect::<Result<Vec<_>, _>>()
|
|
|
|
|
+ .map_err(WeightedParseError::ParseIntError)?;
|
|
|
|
|
+ let vals = weights_lines.next().expect("Weights file only has one line")
|
|
|
|
|
+ .split(',')
|
|
|
|
|
+ .map(u32::from_str)
|
|
|
|
|
+ .collect::<Result<Vec<_>, _>>()
|
|
|
|
|
+ .map_err(WeightedParseError::ParseIntError)?;
|
|
|
|
|
+ assert!(weights.len() == vals.len(), "Weights file doesn't have the same number of weights and values.");
|
|
|
|
|
+ let dist =
|
|
|
|
|
+ WeightedAliasIndex::<u32>::new(weights).map_err(WeightedParseError::WeightedError)?;
|
|
|
|
|
+ Ok((dist, vals))
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
|
|
impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
|
|
|
type Error = DistParameterError;
|
|
type Error = DistParameterError;
|
|
|
|
|
|
|
@@ -164,6 +208,11 @@ impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
|
|
|
Hypergeometric::new(total_population_size, population_with_feature, sample_size)
|
|
Hypergeometric::new(total_population_size, population_with_feature, sample_size)
|
|
|
.map_err(DistParameterError::Hypergeometric)?,
|
|
.map_err(DistParameterError::Hypergeometric)?,
|
|
|
),
|
|
),
|
|
|
|
|
+ ConfigMessageDistribution::Weighted { weights_file } => {
|
|
|
|
|
+ let (dist, vals) = parse_weights_file(weights_file)
|
|
|
|
|
+ .map_err(DistParameterError::WeightedParseError)?;
|
|
|
|
|
+ MessageDistribution::Weighted(dist, vals)
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
Ok(dist)
|
|
Ok(dist)
|
|
|
}
|
|
}
|