Explorar o código

dpf: more efficient half-tree full domain evaluation

Lennart Braun %!s(int64=2) %!d(string=hai) anos
pai
achega
09b20251bd
Modificáronse 1 ficheiros con 93 adicións e 3 borrados
  1. 93 3
      dpf/src/spdpf.rs

+ 93 - 3
dpf/src/spdpf.rs

@@ -1,6 +1,6 @@
 use core::fmt::Debug;
 use core::marker::PhantomData;
-use core::ops::{Add, Sub};
+use core::ops::{Add, Neg, Sub};
 use num::traits::Zero;
 use rand::{thread_rng, Rng};
 
@@ -137,7 +137,7 @@ where
 
 impl<V> SinglePointDpf for HalfTreeSpDpf<V>
 where
-    V: Add<Output = V> + Sub<Output = V> + Copy + Debug + Eq + Zero,
+    V: Add<Output = V> + Sub<Output = V> + Neg<Output = V> + Copy + Debug + Eq + Zero,
     PRConverter: PRConvertTo<V>,
 {
     type Key = HalfTreeSpDpfKey<V>;
@@ -271,6 +271,90 @@ where
             V::zero() - value
         }
     }
+
+    fn evaluate_domain(key: &Self::Key) -> Vec<V> {
+        let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
+        let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
+        let convert = |x: u128| -> V { PRConverter::convert(x) };
+
+        if key.log_domain_size == 0 {
+            // beta is simply secret-shared
+            return vec![key.correction_word_np1];
+        }
+
+        let output_size = 1 << key.log_domain_size;
+        let tree_height = key.log_domain_size as usize;
+        let last_index = output_size - 1;
+
+        let mut seeds = vec![0u128; output_size];
+        seeds[0] = key.party_seed;
+
+        // since the last layer is handled separately, we only need the following block if we have
+        // more than one layer
+        if tree_height > 1 {
+            // iterate over the tree layer by layer
+            for i in 0..(tree_height - 1) {
+                // expand each node in this layer;
+                // we need to iterate from right to left, since we reuse the same buffer
+                for j in (0..(last_index >> (tree_height - i)) + 1).rev() {
+                    // for j in (0..(1 << i)).rev() {
+                    let st = seeds[j];
+                    let st_0 = hash(st) ^ (st & 1) * key.correction_words[i];
+                    let st_1 = hash(st) ^ st ^ (st & 1) * key.correction_words[i];
+                    seeds[2 * j] = st_0;
+                    seeds[2 * j + 1] = st_1;
+                }
+            }
+        }
+
+        // expand last layer
+        {
+            // handle the last expansion separately, since we might not need both outputs
+            let j = (output_size >> 1) - 1;
+            let st = seeds[j];
+            let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
+            seeds[2 * j] = st_0;
+            // check if we need both outputs
+            if output_size & 1 == 0 {
+                let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
+                seeds[2 * j + 1] = st_1;
+            }
+
+            // handle the other expansions as usual
+            for j in (0..(output_size >> 1) - 1).rev() {
+                let st = seeds[j];
+                let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
+                let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
+                seeds[2 * j] = st_0;
+                seeds[2 * j + 1] = st_1;
+            }
+        }
+
+        // convert leaves into V elements
+        if key.party_id == 0 {
+            seeds
+                .iter()
+                .map(|st_b| {
+                    let mut tmp = convert(st_b >> 1);
+                    if st_b & 1 == 1 {
+                        tmp = tmp + key.correction_word_np1;
+                    }
+                    tmp
+                })
+                .collect()
+        } else {
+            seeds
+                .iter()
+                .map(|st_b| {
+                    let mut tmp = convert(st_b >> 1);
+                    if st_b & 1 == 1 {
+                        tmp = tmp + key.correction_word_np1;
+                    }
+                    -tmp
+                })
+                .collect()
+        }
+    }
 }
 
 #[cfg(test)]
@@ -303,9 +387,15 @@ mod tests {
     }
 
     #[test]
-    fn test_spdpf() {
+    fn test_spdpf_dummy() {
         for log_domain_size in 0..10 {
             test_spdpf_with_param::<DummySpDpf<u64>>(log_domain_size);
+        }
+    }
+
+    #[test]
+    fn test_spdpf_half_tree() {
+        for log_domain_size in 0..10 {
             test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(log_domain_size);
         }
     }