dists.rs 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. // Representations of Distributions for sampling timing and message sizes.
  2. use rand_distr::{
  3. Bernoulli, BernoulliError, Binomial, BinomialError, Distribution, Exp, ExpError, GeoError,
  4. Geometric, HyperGeoError, Hypergeometric, LogNormal, Normal, NormalError, Pareto, ParetoError,
  5. Poisson, PoissonError, Uniform, WeightedAliasIndex, WeightedError,
  6. };
  7. use rand_xoshiro::Xoshiro256PlusPlus;
  8. use serde::Deserialize;
  9. use std::str::FromStr;
  10. use tokio::time::Duration;
  11. /// The set of Distributions we currently support for message sizes (in padding blocks).
  12. /// To modify the code to add support for more, one approach is to first add them here,
  13. /// then fix all the compiler errors and warnings that arise as a result.
  14. #[derive(Debug)]
  15. pub enum MessageDistribution {
  16. // Poisson is only defined for floats for technical reasons.
  17. // https://rust-random.github.io/book/guide-dist.html#integers
  18. Poisson(Poisson<f64>),
  19. Binomial(Binomial),
  20. Geometric(Geometric),
  21. Hypergeometric(Hypergeometric),
  22. Weighted(WeightedAliasIndex<u32>, Vec<u32>),
  23. }
  24. impl Distribution<u32> for MessageDistribution {
  25. fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u32 {
  26. let ret = match self {
  27. Self::Poisson(d) => d.sample(rng) as u64,
  28. Self::Binomial(d) => d.sample(rng),
  29. Self::Geometric(d) => d.sample(rng),
  30. Self::Hypergeometric(d) => d.sample(rng),
  31. Self::Weighted(d, v) => v[d.sample(rng)].into(),
  32. };
  33. std::cmp::min(ret, mgen::MAX_BLOCKS_IN_BODY.into()) as u32
  34. }
  35. }
  36. /// The set of Distributions we currently support for timings.
  37. /// To modify the code to add support for more, one approach is to first add them here,
  38. /// then fix all the compiler errors and warnings that arise as a result.
  39. #[derive(Debug)]
  40. pub enum TimingDistribution {
  41. Normal(Normal<f64>),
  42. LogNormal(LogNormal<f64>),
  43. Uniform(Uniform<f64>),
  44. Exp(Exp<f64>),
  45. Pareto(Pareto<f64>),
  46. Weighted(WeightedAliasIndex<u32>, Vec<f64>),
  47. }
  48. impl Distribution<f64> for TimingDistribution {
  49. fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
  50. let ret = match self {
  51. Self::Normal(d) => d.sample(rng),
  52. Self::LogNormal(d) => d.sample(rng),
  53. Self::Uniform(d) => d.sample(rng),
  54. Self::Exp(d) => d.sample(rng),
  55. Self::Pareto(d) => d.sample(rng),
  56. Self::Weighted(d, v) => v[d.sample(rng)],
  57. };
  58. ret.max(0.0)
  59. }
  60. }
  61. /// The set of distributions necessary to represent the actions of the state machine.
  62. #[derive(Debug)]
  63. pub struct Distributions {
  64. pub m: MessageDistribution,
  65. pub i: TimingDistribution,
  66. pub w: TimingDistribution,
  67. pub a_s: TimingDistribution,
  68. pub a_r: TimingDistribution,
  69. pub s: Bernoulli,
  70. pub r: Bernoulli,
  71. }
  72. impl TimingDistribution {
  73. pub fn sample_secs(&self, rng: &mut Xoshiro256PlusPlus) -> Duration {
  74. Duration::from_secs_f64(self.sample(rng))
  75. }
  76. }
  77. /// The same as Distributions, but designed for easier deserialization.
  78. #[derive(Clone, Debug, Deserialize)]
  79. pub struct ConfigDistributions {
  80. m: ConfigMessageDistribution,
  81. i: ConfigTimingDistribution,
  82. w: ConfigTimingDistribution,
  83. a_s: ConfigTimingDistribution,
  84. a_r: ConfigTimingDistribution,
  85. s: f64,
  86. r: f64,
  87. }
  88. /// The same as MessageDistribution, but designed for easier deserialization.
  89. #[derive(Clone, Debug, Deserialize)]
  90. #[serde(tag = "distribution")]
  91. enum ConfigMessageDistribution {
  92. Poisson {
  93. lambda: f64,
  94. },
  95. Binomial {
  96. n: u64,
  97. p: f64,
  98. },
  99. Geometric {
  100. p: f64,
  101. },
  102. Hypergeometric {
  103. total_population_size: u64,
  104. population_with_feature: u64,
  105. sample_size: u64,
  106. },
  107. Weighted {
  108. weights_file: String,
  109. },
  110. }
  111. /// The same as TimingDistribution, but designed for easier deserialization.
  112. #[derive(Clone, Debug, Deserialize)]
  113. #[serde(tag = "distribution")]
  114. enum ConfigTimingDistribution {
  115. Normal { mean: f64, std_dev: f64 },
  116. LogNormal { mean: f64, std_dev: f64 },
  117. Uniform { low: f64, high: f64 },
  118. Exp { lambda: f64 },
  119. Pareto { scale: f64, shape: f64 },
  120. Weighted { weights_file: String },
  121. }
  122. #[derive(Debug)]
  123. pub enum DistParameterError {
  124. Poisson(PoissonError),
  125. Binomial(BinomialError),
  126. Geometric(GeoError),
  127. Hypergeometric(HyperGeoError),
  128. Bernoulli(BernoulliError),
  129. Normal(NormalError),
  130. LogNormal(NormalError),
  131. Uniform, // Uniform::new doesn't return an error, it just panics
  132. Exp(ExpError),
  133. Pareto(ParetoError),
  134. WeightedParseError(WeightedParseError),
  135. }
  136. impl std::fmt::Display for DistParameterError {
  137. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  138. write!(f, "{:?}", self)
  139. }
  140. }
  141. impl std::error::Error for DistParameterError {}
  142. #[derive(Debug)]
  143. pub enum WeightedParseError {
  144. WeightedError(WeightedError),
  145. Io(std::io::Error),
  146. ParseNumError,
  147. }
  148. impl std::fmt::Display for WeightedParseError {
  149. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  150. write!(f, "{:?}", self)
  151. }
  152. }
  153. impl std::error::Error for WeightedParseError {}
  154. fn parse_weights_file<T: FromStr>(
  155. path: String,
  156. ) -> Result<(WeightedAliasIndex<u32>, Vec<T>), WeightedParseError> {
  157. let weights_file = std::fs::read_to_string(path).map_err(WeightedParseError::Io)?;
  158. let mut weights_lines = weights_file.lines();
  159. let weights = weights_lines
  160. .next()
  161. .unwrap()
  162. .split(',')
  163. .map(u32::from_str)
  164. .collect::<Result<Vec<_>, _>>()
  165. .or(Err(WeightedParseError::ParseNumError))?;
  166. let vals = weights_lines
  167. .next()
  168. .expect("Weights file only has one line")
  169. .split(',')
  170. .map(T::from_str)
  171. .collect::<Result<Vec<_>, _>>()
  172. .or(Err(WeightedParseError::ParseNumError))?;
  173. assert!(
  174. weights.len() == vals.len(),
  175. "Weights file doesn't have the same number of weights and values."
  176. );
  177. let dist =
  178. WeightedAliasIndex::<u32>::new(weights).map_err(WeightedParseError::WeightedError)?;
  179. Ok((dist, vals))
  180. }
  181. impl TryFrom<ConfigMessageDistribution> for MessageDistribution {
  182. type Error = DistParameterError;
  183. fn try_from(dist: ConfigMessageDistribution) -> Result<Self, DistParameterError> {
  184. let dist = match dist {
  185. ConfigMessageDistribution::Poisson { lambda } => MessageDistribution::Poisson(
  186. Poisson::new(lambda).map_err(DistParameterError::Poisson)?,
  187. ),
  188. ConfigMessageDistribution::Binomial { n, p } => MessageDistribution::Binomial(
  189. Binomial::new(n, p).map_err(DistParameterError::Binomial)?,
  190. ),
  191. ConfigMessageDistribution::Geometric { p } => MessageDistribution::Geometric(
  192. Geometric::new(p).map_err(DistParameterError::Geometric)?,
  193. ),
  194. ConfigMessageDistribution::Hypergeometric {
  195. total_population_size,
  196. population_with_feature,
  197. sample_size,
  198. } => MessageDistribution::Hypergeometric(
  199. Hypergeometric::new(total_population_size, population_with_feature, sample_size)
  200. .map_err(DistParameterError::Hypergeometric)?,
  201. ),
  202. ConfigMessageDistribution::Weighted { weights_file } => {
  203. let (dist, vals) = parse_weights_file(weights_file)
  204. .map_err(DistParameterError::WeightedParseError)?;
  205. MessageDistribution::Weighted(dist, vals)
  206. }
  207. };
  208. Ok(dist)
  209. }
  210. }
  211. impl TryFrom<ConfigTimingDistribution> for TimingDistribution {
  212. type Error = DistParameterError;
  213. fn try_from(dist: ConfigTimingDistribution) -> Result<Self, DistParameterError> {
  214. let dist = match dist {
  215. ConfigTimingDistribution::Normal { mean, std_dev } => TimingDistribution::Normal(
  216. Normal::new(mean, std_dev).map_err(DistParameterError::Normal)?,
  217. ),
  218. ConfigTimingDistribution::LogNormal { mean, std_dev } => TimingDistribution::LogNormal(
  219. LogNormal::new(mean, std_dev).map_err(DistParameterError::LogNormal)?,
  220. ),
  221. ConfigTimingDistribution::Uniform { low, high } => {
  222. if low >= high {
  223. return Err(DistParameterError::Uniform);
  224. }
  225. TimingDistribution::Uniform(Uniform::new(low, high))
  226. }
  227. ConfigTimingDistribution::Exp { lambda } => {
  228. TimingDistribution::Exp(Exp::new(lambda).map_err(DistParameterError::Exp)?)
  229. }
  230. ConfigTimingDistribution::Pareto { scale, shape } => TimingDistribution::Pareto(
  231. Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
  232. ),
  233. ConfigTimingDistribution::Weighted { weights_file } => {
  234. let (dist, vals) = parse_weights_file(weights_file)
  235. .map_err(DistParameterError::WeightedParseError)?;
  236. TimingDistribution::Weighted(dist, vals)
  237. }
  238. };
  239. Ok(dist)
  240. }
  241. }
  242. impl TryFrom<ConfigDistributions> for Distributions {
  243. type Error = DistParameterError;
  244. fn try_from(config: ConfigDistributions) -> Result<Self, DistParameterError> {
  245. Ok(Distributions {
  246. m: config.m.try_into()?,
  247. i: config.i.try_into()?,
  248. w: config.w.try_into()?,
  249. a_s: config.a_s.try_into()?,
  250. a_r: config.a_r.try_into()?,
  251. s: Bernoulli::new(config.s).map_err(DistParameterError::Bernoulli)?,
  252. r: Bernoulli::new(config.r).map_err(DistParameterError::Bernoulli)?,
  253. })
  254. }
  255. }