Explorar el Código

Improve analysis

We seem to get better results when we scale bridge IPs down to multiples of 1 instead of 8 and only mark bridges as blocked if they differ in the 'bad' direction by at least one standard deviation from the mean.
Vecna hace 1 mes
padre
commit
36395181d3
Se han modificado 2 ficheros con 167 adiciones y 33 borrados
  1. 62 32
      src/analysis.rs
  2. 105 1
      src/tests.rs

+ 62 - 32
src/analysis.rs

@@ -7,6 +7,8 @@ use std::{
     collections::{BTreeMap, HashSet},
 };
 
+const SCALE_BRIDGE_IPS: u32 = 8;
+
 /// Provides a function for predicting which countries block this bridge
 pub trait Analyzer {
     /// Evaluate open-entry bridge. Returns true if blocked, false otherwise.
@@ -69,15 +71,15 @@ pub fn blocked_in(
                 None => &new_map_binding,
             };
             let bridge_ips_today = match today_info.get(&BridgeInfoType::BridgeIps) {
-                Some(v) => *v,
+                Some(&v) => v / SCALE_BRIDGE_IPS,
                 None => 0,
             };
             let negative_reports_today = match today_info.get(&BridgeInfoType::NegativeReports) {
-                Some(v) => *v,
+                Some(&v) => v,
                 None => 0,
             };
             let positive_reports_today = match today_info.get(&BridgeInfoType::PositiveReports) {
-                Some(v) => *v,
+                Some(&v) => v,
                 None => 0,
             };
 
@@ -96,7 +98,7 @@ pub fn blocked_in(
                     None => &new_map_binding,
                 };
                 bridge_ips[i as usize] = match day_info.get(&BridgeInfoType::BridgeIps) {
-                    Some(&v) => v,
+                    Some(&v) => v / SCALE_BRIDGE_IPS,
                     None => 0,
                 };
                 negative_reports[i as usize] = match day_info.get(&BridgeInfoType::NegativeReports)
@@ -213,22 +215,33 @@ impl NormalAnalyzer {
         }
     }
 
-    fn mean_vector_and_covariance_matrix(data: &[&[u32]]) -> (Vec<f64>, Vec<f64>) {
+    // Returns the mean vector, vector of individual standard deviations, and
+    // covariance matrix
+    fn stats(data: &[&[u32]]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
         let n = data.len();
 
-        // Compute mean vector
-        let mean_vec = {
+        // Compute mean and standard deviation vectors
+        let (mean_vec, sd_vec) = {
             let mut mean_vec = Vec::<f64>::new();
+            let mut sd_vec = Vec::<f64>::new();
             for var in data {
-                mean_vec.push({
-                    let mut sum = 0.0;
-                    for count in *var {
-                        sum += *count as f64;
-                    }
-                    sum / var.len() as f64
-                });
+                // Compute mean
+                let mut sum = 0.0;
+                for count in *var {
+                    sum += *count as f64;
+                }
+                let mean = sum / var.len() as f64;
+
+                // Compute standard deviation
+                let mut sum = 0.0;
+                for count in *var {
+                    sum += (*count as f64 - mean).powi(2);
+                }
+                let sd = (sum / var.len() as f64).sqrt();
+                mean_vec.push(mean);
+                sd_vec.push(sd);
             }
-            mean_vec
+            (mean_vec, sd_vec)
         };
 
         // Compute covariance matrix
@@ -258,7 +271,7 @@ impl NormalAnalyzer {
             cov_mat
         };
 
-        (mean_vec, cov_mat)
+        (mean_vec, sd_vec, cov_mat)
     }
 }
 
@@ -267,13 +280,14 @@ impl Analyzer for NormalAnalyzer {
     fn stage_one(
         &self,
         _confidence: f64,
-        _bridge_ips: &[u32],
+        bridge_ips: &[u32],
         bridge_ips_today: u32,
         _negative_reports: &[u32],
         negative_reports_today: u32,
     ) -> bool {
         negative_reports_today > self.max_threshold
-            || f64::from(negative_reports_today) > self.scaling_factor * f64::from(bridge_ips_today)
+            || f64::from(negative_reports_today)
+                > self.scaling_factor * f64::from(bridge_ips_today) * SCALE_BRIDGE_IPS as f64
     }
 
     /// Evaluate invite-only bridge based on last 30 days
@@ -288,19 +302,33 @@ impl Analyzer for NormalAnalyzer {
         assert!(bridge_ips.len() >= UNTRUSTED_INTERVAL as usize);
         assert_eq!(bridge_ips.len(), negative_reports.len());
 
-        let (mean_vec, cov_mat) =
-            Self::mean_vector_and_covariance_matrix(&[bridge_ips, negative_reports]);
+        let alpha = 1.0 - confidence;
+
+        let (mean_vec, sd_vec, cov_mat) = Self::stats(&[bridge_ips, negative_reports]);
         let bridge_ips_mean = mean_vec[0];
         let negative_reports_mean = mean_vec[1];
+        let bridge_ips_sd = sd_vec[0];
+        let negative_reports_sd = sd_vec[1];
 
         let mvn = MultivariateNormal::new(mean_vec, cov_mat).unwrap();
+        println!(
+            "evaluate mvn.pdf of [{},{}]",
+            bridge_ips_today as f64, negative_reports_today as f64
+        );
+        println!(
+            "{}",
+            mvn.pdf(&DVector::from_vec(vec![
+                bridge_ips_today as f64,
+                negative_reports_today as f64
+            ]))
+        );
         if mvn.pdf(&DVector::from_vec(vec![
             bridge_ips_today as f64,
             negative_reports_today as f64,
-        ])) < confidence
+        ])) < alpha
         {
-            (negative_reports_today as f64) > negative_reports_mean
-                || (bridge_ips_today as f64) < bridge_ips_mean
+            (negative_reports_today as f64) > negative_reports_mean + negative_reports_sd
+                || (bridge_ips_today as f64) < bridge_ips_mean - bridge_ips_sd
         } else {
             false
         }
@@ -321,25 +349,27 @@ impl Analyzer for NormalAnalyzer {
         assert_eq!(bridge_ips.len(), negative_reports.len());
         assert_eq!(bridge_ips.len(), positive_reports.len());
 
-        let (mean_vec, cov_mat) = Self::mean_vector_and_covariance_matrix(&[
-            bridge_ips,
-            negative_reports,
-            positive_reports,
-        ]);
+        let alpha = 1.0 - confidence;
+
+        let (mean_vec, sd_vec, cov_mat) =
+            Self::stats(&[bridge_ips, negative_reports, positive_reports]);
         let bridge_ips_mean = mean_vec[0];
         let negative_reports_mean = mean_vec[1];
         let positive_reports_mean = mean_vec[2];
+        let bridge_ips_sd = sd_vec[0];
+        let negative_reports_sd = sd_vec[1];
+        let positive_reports_sd = sd_vec[2];
 
         let mvn = MultivariateNormal::new(mean_vec, cov_mat).unwrap();
         if mvn.pdf(&DVector::from_vec(vec![
             bridge_ips_today as f64,
             negative_reports_today as f64,
             positive_reports_today as f64,
-        ])) < confidence
+        ])) < alpha
         {
-            (negative_reports_today as f64) > negative_reports_mean
-                || (bridge_ips_today as f64) < bridge_ips_mean
-                || (positive_reports_today as f64) < positive_reports_mean
+            (negative_reports_today as f64) > negative_reports_mean + negative_reports_sd
+                || (bridge_ips_today as f64) < bridge_ips_mean - bridge_ips_sd
+                || (positive_reports_today as f64) < positive_reports_mean - positive_reports_sd
         } else {
             false
         }

+ 105 - 1
src/tests.rs

@@ -1015,7 +1015,111 @@ fn test_analysis() {
         );
     }
 
-    // TODO: Test stage 2 analysis
+    // Test stage 2 analysis
+
+    {
+        let mut date = get_date();
+
+        // New bridge info
+        let mut bridge_info = BridgeInfo::new([0; 20], &String::default());
+
+        bridge_info
+            .info_by_country
+            .insert("ru".to_string(), BridgeCountryInfo::new());
+        let analyzer = analysis::NormalAnalyzer::new(5, 0.25);
+        let confidence = 0.95;
+
+        let mut blocking_countries = HashSet::<String>::new();
+
+        // No data today
+        assert_eq!(
+            blocked_in(&analyzer, &bridge_info, confidence, date),
+            blocking_countries
+        );
+
+        for i in 1..30 {
+            // 9-32 connections, 0-3 negative reports each day
+            date += 1;
+            bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+                BridgeInfoType::BridgeIps,
+                date,
+                8 * (i % 3 + 2),
+            );
+            bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+                BridgeInfoType::NegativeReports,
+                date,
+                i % 4,
+            );
+            assert_eq!(
+                blocked_in(&analyzer, &bridge_info, confidence, date),
+                blocking_countries
+            );
+        }
+
+        // Data similar to previous days:
+        // 24 connections, 2 negative reports
+        date += 1;
+        bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+            BridgeInfoType::BridgeIps,
+            date,
+            24,
+        );
+        bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+            BridgeInfoType::NegativeReports,
+            date,
+            2,
+        );
+
+        // Should not be blocked because we have similar data.
+        assert_eq!(
+            blocked_in(&analyzer, &bridge_info, confidence, date),
+            blocking_countries
+        );
+
+        // Data different from previous days:
+        // 104 connections, 1 negative report
+        date += 1;
+        bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+            BridgeInfoType::BridgeIps,
+            date,
+            104,
+        );
+        bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+            BridgeInfoType::NegativeReports,
+            date,
+            1,
+        );
+
+        // This should not be blocked even though it's very different because
+        // it's different in the good direction.
+        assert_eq!(
+            blocked_in(&analyzer, &bridge_info, confidence, date),
+            blocking_countries
+        );
+
+        // Data different from previous days:
+        // 40 connections, 12 negative reports
+        date += 1;
+        bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+            BridgeInfoType::BridgeIps,
+            date,
+            40,
+        );
+        bridge_info.info_by_country.get_mut("ru").unwrap().add_info(
+            BridgeInfoType::NegativeReports,
+            date,
+            12,
+        );
+        blocking_countries.insert("ru".to_string());
+
+        // This should be blocked because it's different in the bad direction.
+        assert_eq!(
+            blocked_in(&analyzer, &bridge_info, confidence, date),
+            blocking_countries
+        );
+    }
+
+    // TODO: More tests
 
     // TODO: Test stage 3 analysis
 }