Explorar o código

Add noise when necessary to build distribution

Vecna hai 2 semanas
pai
achega
b50f40fe8a
Modificáronse 1 ficheiros con 32 adicións e 3 borrados
  1. 32 3
      src/analysis.rs

+ 32 - 3
src/analysis.rs

@@ -1,6 +1,7 @@
 use crate::{BridgeInfo, BridgeInfoType};
 use lox_library::proto::trust_promotion::UNTRUSTED_INTERVAL;
-use nalgebra::DVector;
+use nalgebra::{Cholesky, DMatrix, DVector};
+use rand::Rng;
 use statrs::distribution::{Continuous, MultivariateNormal, Normal};
 use std::{
     cmp::min,
@@ -212,7 +213,9 @@ impl NormalAnalyzer {
     }
 
     // Returns the mean vector, vector of individual standard deviations, and
-    // covariance matrix
+    // covariance matrix. If the standard deviation for a variable is 0 and/or
+    // the covariance matrix is not positive definite, add some noise to the
+    // data and recompute.
     fn stats(data: &[&[u32]]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
         let n = data.len();
 
@@ -267,7 +270,33 @@ impl NormalAnalyzer {
             cov_mat
         };
 
-        (mean_vec, sd_vec, cov_mat)
+        // If any standard deviation is 0 or the covariance matrix is not
+        // positive definite, add some noise and recompute.
+        let mut recompute = false;
+        for sd in &sd_vec {
+            if *sd <= 0.0 {
+                recompute = true;
+            }
+        }
+        if Cholesky::new(DMatrix::from_vec(n, n, cov_mat.clone())).is_none() {
+            recompute = true;
+        }
+
+        if !recompute {
+            (mean_vec, sd_vec, cov_mat)
+        } else {
+            // Add random noise and recompute
+            let mut new_data = vec![vec![0; data[0].len()]; n];
+            let mut rng = rand::thread_rng();
+            for i in 0..n {
+                for j in 0..data[i].len() {
+                    // Add 1 to some randomly selected values
+                    new_data[i][j] = data[i][j] + rng.gen_range(0..=1);
+                }
+            }
+            // Compute stats on modified data
+            Self::stats(&new_data.iter().map(Vec::as_slice).collect::<Vec<&[u32]>>())
+        }
     }
 }