Browse Source

add support for weighted dists on message timings

Justin Tracey 1 year ago
parent
commit
edf8a376f2
1 changed files with 25 additions and 10 deletions
  1. 25 10
      src/bin/messenger/dists.rs

+ 25 - 10
src/bin/messenger/dists.rs

@@ -7,7 +7,7 @@ use rand_distr::{
 };
 use rand_xoshiro::Xoshiro256PlusPlus;
 use serde::Deserialize;
-use std::{num::ParseIntError, str::FromStr};
+use std::str::FromStr;
 use tokio::time::Duration;
 
 /// The set of Distributions we currently support for message sizes (in padding blocks).
@@ -47,6 +47,7 @@ pub enum TimingDistribution {
     Uniform(Uniform<f64>),
     Exp(Exp<f64>),
     Pareto(Pareto<f64>),
+    Weighted(WeightedAliasIndex<u32>, Vec<f64>),
 }
 
 impl Distribution<f64> for TimingDistribution {
@@ -57,6 +58,7 @@ impl Distribution<f64> for TimingDistribution {
             Self::Uniform(d) => d.sample(rng),
             Self::Exp(d) => d.sample(rng),
             Self::Pareto(d) => d.sample(rng),
+            Self::Weighted(d, v) => v[d.sample(rng)],
         };
         ret.max(0.0)
     }
@@ -125,6 +127,7 @@ enum ConfigTimingDistribution {
     Uniform { low: f64, high: f64 },
     Exp { lambda: f64 },
     Pareto { scale: f64, shape: f64 },
+    Weighted { weights_file: String },
 }
 
 #[derive(Debug)]
@@ -154,7 +157,7 @@ impl std::error::Error for DistParameterError {}
 pub enum WeightedParseError {
     WeightedError(WeightedError),
     Io(std::io::Error),
-    ParseIntError(ParseIntError),
+    ParseNumError,
 }
 
 impl std::fmt::Display for WeightedParseError {
@@ -165,22 +168,29 @@ impl std::fmt::Display for WeightedParseError {
 
 impl std::error::Error for WeightedParseError {}
 
-fn parse_weights_file(
+fn parse_weights_file<T: FromStr>(
     path: String,
-) -> Result<(WeightedAliasIndex<u32>, Vec<u32>), WeightedParseError> {
+) -> Result<(WeightedAliasIndex<u32>, Vec<T>), WeightedParseError> {
     let weights_file = std::fs::read_to_string(path).map_err(WeightedParseError::Io)?;
     let mut weights_lines = weights_file.lines();
-    let weights = weights_lines.next().unwrap()
+    let weights = weights_lines
+        .next()
+        .unwrap()
         .split(',')
         .map(u32::from_str)
         .collect::<Result<Vec<_>, _>>()
-        .map_err(WeightedParseError::ParseIntError)?;
-    let vals = weights_lines.next().expect("Weights file only has one line")
+        .or(Err(WeightedParseError::ParseNumError))?;
+    let vals = weights_lines
+        .next()
+        .expect("Weights file only has one line")
         .split(',')
-        .map(u32::from_str)
+        .map(T::from_str)
         .collect::<Result<Vec<_>, _>>()
-        .map_err(WeightedParseError::ParseIntError)?;
-    assert!(weights.len() == vals.len(), "Weights file doesn't have the same number of weights and values.");
+        .or(Err(WeightedParseError::ParseNumError))?;
+    assert!(
+        weights.len() == vals.len(),
+        "Weights file doesn't have the same number of weights and values."
+    );
     let dist =
         WeightedAliasIndex::<u32>::new(weights).map_err(WeightedParseError::WeightedError)?;
     Ok((dist, vals))
@@ -241,6 +251,11 @@ impl TryFrom<ConfigTimingDistribution> for TimingDistribution {
             ConfigTimingDistribution::Pareto { scale, shape } => TimingDistribution::Pareto(
                 Pareto::new(scale, shape).map_err(DistParameterError::Pareto)?,
             ),
+            ConfigTimingDistribution::Weighted { weights_file } => {
+                let (dist, vals) = parse_weights_file(weights_file)
+                    .map_err(DistParameterError::WeightedParseError)?;
+                TimingDistribution::Weighted(dist, vals)
+            }
         };
         Ok(dist)
     }