Selaa lähdekoodia

dpf: fix half tree with weird parameters

Lennart Braun 2 vuotta sitten
vanhempi
commit
d5e83d3ea6
1 muutettua tiedostoa jossa 42 lisäystä ja 14 poistoa
  1. 42 14
      dpf/src/spdpf.rs

+ 42 - 14
dpf/src/spdpf.rs

@@ -261,7 +261,8 @@ where
 
         let mut st_b = key.party_seed;
         for i in 0..tree_height - 1 {
-            st_b = hash(st_b) ^ index_bits[i] as u128 * st_b ^ (st_b & 1) * key.correction_words[i];
+            st_b =
+                hash(st_b) ^ (index_bits[i] as u128 * st_b) ^ (st_b & 1) * key.correction_words[i];
         }
         let x_n = index_bits[tree_height - 1];
         let high_low_b_xn = hash(st_b ^ x_n as u128);
@@ -282,15 +283,16 @@ where
 
     fn evaluate_domain(key: &Self::Key) -> Vec<V> {
         assert!(key.domain_size > 0);
-        let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
-        let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
-        let convert = |x: u128| -> V { PRConverter::convert(x) };
 
         if key.domain_size == 1 {
             // beta is simply secret-shared
             return vec![key.correction_word_np1];
         }
 
+        let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
+        let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
+        let convert = |x: u128| -> V { PRConverter::convert(x) };
+
         let tree_height = (key.domain_size as f64).log2().ceil() as usize;
         let last_index = key.domain_size - 1;
 
@@ -318,7 +320,7 @@ where
         // expand last layer
         {
             // handle the last expansion separately, since we might not need both outputs
-            let j = (key.domain_size >> 1) - 1;
+            let j = last_index >> 1;
             let st = seeds[j];
             let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
             seeds[2 * j] = st_0;
@@ -329,7 +331,7 @@ where
             }
 
             // handle the other expansions as usual
-            for j in (0..(key.domain_size >> 1) - 1).rev() {
+            for j in (0..(last_index >> 1)).rev() {
                 let st = seeds[j];
                 let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
                 let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
@@ -372,23 +374,40 @@ mod tests {
     use rand::distributions::{Distribution, Standard};
     use rand::{thread_rng, Rng};
 
-    fn test_spdpf_with_param<SPDPF: SinglePointDpf>(domain_size: usize)
+    fn test_spdpf_with_param<SPDPF: SinglePointDpf>(domain_size: usize, alpha: Option<u64>)
     where
         Standard: Distribution<SPDPF::Value>,
     {
-        let alpha = thread_rng().gen_range(0..domain_size as u64);
+        let alpha = if alpha.is_some() {
+            alpha.unwrap()
+        } else {
+            thread_rng().gen_range(0..domain_size as u64)
+        };
         let beta = thread_rng().gen();
         let (key_0, key_1) = SPDPF::generate_keys(domain_size, alpha, beta);
 
         let out_0 = SPDPF::evaluate_domain(&key_0);
         let out_1 = SPDPF::evaluate_domain(&key_1);
+        assert_eq!(out_0.len(), domain_size);
+        assert_eq!(out_1.len(), domain_size);
         for i in 0..domain_size as u64 {
             let value = SPDPF::evaluate_at(&key_0, i) + SPDPF::evaluate_at(&key_1, i);
-            assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
+            assert_eq!(
+                value,
+                out_0[i as usize] + out_1[i as usize],
+                "evaluate_at/domain mismatch at position {i}"
+            );
             if i == alpha {
-                assert_eq!(value, beta);
+                assert_eq!(
+                    value, beta,
+                    "incorrect value != beta at position alpha = {i}"
+                );
             } else {
-                assert_eq!(value, SPDPF::Value::zero());
+                assert_eq!(
+                    value,
+                    SPDPF::Value::zero(),
+                    "incorrect value != 0 at position {i}"
+                );
             }
         }
     }
@@ -396,14 +415,14 @@ mod tests {
     #[test]
     fn test_spdpf_dummy() {
         for log_domain_size in 0..10 {
-            test_spdpf_with_param::<DummySpDpf<u64>>(1 << log_domain_size);
+            test_spdpf_with_param::<DummySpDpf<u64>>(1 << log_domain_size, None);
         }
     }
 
     #[test]
     fn test_spdpf_half_tree_power_of_two_domain() {
         for log_domain_size in 0..10 {
-            test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(1 << log_domain_size);
+            test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(1 << log_domain_size, None);
         }
     }
 
@@ -411,7 +430,16 @@ mod tests {
     fn test_spdpf_half_tree_random_domain() {
         for _ in 0..10 {
             let domain_size = thread_rng().gen_range(1..(1 << 10));
-            test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size);
+            test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size, None);
+        }
+    }
+
+    #[test]
+    fn test_spdpf_half_tree_exhaustive_params() {
+        for domain_size in 1..=32 {
+            for alpha in 0..domain_size as u64 {
+                test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size, Some(alpha));
+            }
         }
     }
 }