Browse Source

make dpf traits and dummy generic wrt. output type

Lennart Braun 2 years ago
parent
commit
5b086ad023
3 changed files with 79 additions and 37 deletions
  1. 1 0
      dpf/Cargo.toml
  2. 40 22
      dpf/src/mpdpf.rs
  3. 38 15
      dpf/src/spdpf.rs

+ 1 - 0
dpf/Cargo.toml

@@ -7,3 +7,4 @@ edition = "2021"
 
 [dependencies]
 rand = "0.8.5"
+num = "0.4.0"

+ 40 - 22
dpf/src/mpdpf.rs

@@ -1,3 +1,8 @@
+use core::fmt::Debug;
+use core::marker::PhantomData;
+use core::ops::Add;
+use num::traits::Zero;
+
 pub trait MultiPointDpfKey {
     fn get_party_id(&self) -> usize;
     fn get_log_domain_size(&self) -> u64;
@@ -6,11 +11,15 @@ pub trait MultiPointDpfKey {
 
 pub trait MultiPointDpf {
     type Key: Clone + MultiPointDpfKey;
+    type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
 
-    fn generate_keys(log_domain_size: u64, alphas: &[u64], betas: &[u64])
-        -> (Self::Key, Self::Key);
-    fn evaluate_at(key: &Self::Key, index: u64) -> u64;
-    fn evaluate_domain(key: &Self::Key) -> Vec<u64> {
+    fn generate_keys(
+        log_domain_size: u64,
+        alphas: &[u64],
+        betas: &[Self::Value],
+    ) -> (Self::Key, Self::Key);
+    fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
+    fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
         (0..(1 << key.get_log_domain_size()))
             .map(|x| Self::evaluate_at(&key, x))
             .collect()
@@ -18,15 +27,15 @@ pub trait MultiPointDpf {
 }
 
 #[derive(Clone, Debug)]
-pub struct DummyMpDpfKey {
+pub struct DummyMpDpfKey<V> {
     party_id: usize,
     log_domain_size: u64,
     number_points: usize,
     alphas: Vec<u64>,
-    betas: Vec<u64>,
+    betas: Vec<V>,
 }
 
-impl MultiPointDpfKey for DummyMpDpfKey {
+impl<V> MultiPointDpfKey for DummyMpDpfKey<V> {
     fn get_party_id(&self) -> usize {
         self.party_id
     }
@@ -38,16 +47,21 @@ impl MultiPointDpfKey for DummyMpDpfKey {
     }
 }
 
-pub struct DummyMpDpf {}
+pub struct DummyMpDpf<V>
+where
+    V: Add<Output = V> + Copy + Debug + Eq + Zero,
+{
+    phantom: PhantomData<V>,
+}
 
-impl MultiPointDpf for DummyMpDpf {
-    type Key = DummyMpDpfKey;
+impl<V> MultiPointDpf for DummyMpDpf<V>
+where
+    V: Add<Output = V> + Copy + Debug + Eq + Zero,
+{
+    type Key = DummyMpDpfKey<V>;
+    type Value = V;
 
-    fn generate_keys(
-        log_domain_size: u64,
-        alphas: &[u64],
-        betas: &[u64],
-    ) -> (Self::Key, Self::Key) {
+    fn generate_keys(log_domain_size: u64, alphas: &[u64], betas: &[V]) -> (Self::Key, Self::Key) {
         assert_eq!(
             alphas.len(),
             betas.len(),
@@ -77,14 +91,14 @@ impl MultiPointDpf for DummyMpDpf {
         )
     }
 
-    fn evaluate_at(key: &Self::Key, index: u64) -> u64 {
+    fn evaluate_at(key: &Self::Key, index: u64) -> V {
         if key.get_party_id() == 0 {
             match key.alphas.binary_search(&index) {
                 Ok(i) => key.betas[i],
-                Err(_) => 0,
+                Err(_) => V::zero(),
             }
         } else {
-            0
+            V::zero()
         }
     }
 }
@@ -92,9 +106,13 @@ impl MultiPointDpf for DummyMpDpf {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use rand::distributions::{Distribution, Standard};
     use rand::{thread_rng, Rng};
 
-    fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(log_domain_size: u64, number_points: usize) {
+    fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(log_domain_size: u64, number_points: usize)
+    where
+        Standard: Distribution<MPDPF::Value>,
+    {
         assert!(number_points <= (1 << log_domain_size));
         let domain_size = 1 << log_domain_size;
         let alphas = {
@@ -104,7 +122,7 @@ mod tests {
             alphas.sort();
             alphas
         };
-        let betas: Vec<u64> = (0..number_points).map(|_| thread_rng().gen()).collect();
+        let betas: Vec<MPDPF::Value> = (0..number_points).map(|_| thread_rng().gen()).collect();
         let (key_0, key_1) = MPDPF::generate_keys(log_domain_size, &alphas, &betas);
 
         let out_0 = MPDPF::evaluate_domain(&key_0);
@@ -114,7 +132,7 @@ mod tests {
             assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
             let expected_result = match alphas.binary_search(&i) {
                 Ok(i) => betas[i],
-                Err(_) => 0,
+                Err(_) => MPDPF::Value::zero(),
             };
             assert_eq!(value, expected_result);
         }
@@ -124,7 +142,7 @@ mod tests {
     fn test_mpdpf() {
         for log_domain_size in 5..10 {
             for log_number_points in 0..5 {
-                test_mpdpf_with_param::<DummyMpDpf>(log_domain_size, 1 << log_number_points);
+                test_mpdpf_with_param::<DummyMpDpf<u64>>(log_domain_size, 1 << log_number_points);
             }
         }
     }

+ 38 - 15
dpf/src/spdpf.rs

@@ -1,3 +1,8 @@
+use core::fmt::Debug;
+use core::marker::PhantomData;
+use core::ops::Add;
+use num::traits::Zero;
+
 pub trait SinglePointDpfKey {
     fn get_party_id(&self) -> usize;
     fn get_log_domain_size(&self) -> u64;
@@ -5,10 +10,12 @@ pub trait SinglePointDpfKey {
 
 pub trait SinglePointDpf {
     type Key: Copy + SinglePointDpfKey;
+    type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
 
-    fn generate_keys(log_domain_size: u64, alpha: u64, beta: u64) -> (Self::Key, Self::Key);
-    fn evaluate_at(key: &Self::Key, index: u64) -> u64;
-    fn evaluate_domain(key: &Self::Key) -> Vec<u64> {
+    fn generate_keys(log_domain_size: u64, alpha: u64, beta: Self::Value)
+        -> (Self::Key, Self::Key);
+    fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
+    fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
         (0..(1 << key.get_log_domain_size()))
             .map(|x| Self::evaluate_at(&key, x))
             .collect()
@@ -16,14 +23,17 @@ pub trait SinglePointDpf {
 }
 
 #[derive(Clone, Copy, Debug)]
-pub struct DummySpDpfKey {
+pub struct DummySpDpfKey<V: Copy> {
     party_id: usize,
     log_domain_size: u64,
     alpha: u64,
-    beta: u64,
+    beta: V,
 }
 
-impl SinglePointDpfKey for DummySpDpfKey {
+impl<V> SinglePointDpfKey for DummySpDpfKey<V>
+where
+    V: Copy,
+{
     fn get_party_id(&self) -> usize {
         self.party_id
     }
@@ -32,12 +42,21 @@ impl SinglePointDpfKey for DummySpDpfKey {
     }
 }
 
-pub struct DummySpDpf {}
+pub struct DummySpDpf<V>
+where
+    V: Add<Output = V> + Copy + Debug + Eq + Zero,
+{
+    phantom: PhantomData<V>,
+}
 
-impl SinglePointDpf for DummySpDpf {
-    type Key = DummySpDpfKey;
+impl<V> SinglePointDpf for DummySpDpf<V>
+where
+    V: Add<Output = V> + Copy + Debug + Eq + Zero,
+{
+    type Key = DummySpDpfKey<V>;
+    type Value = V;
 
-    fn generate_keys(log_domain_size: u64, alpha: u64, beta: u64) -> (Self::Key, Self::Key) {
+    fn generate_keys(log_domain_size: u64, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
         assert!(alpha < (1 << log_domain_size));
         (
             DummySpDpfKey {
@@ -55,11 +74,11 @@ impl SinglePointDpf for DummySpDpf {
         )
     }
 
-    fn evaluate_at(key: &Self::Key, index: u64) -> u64 {
+    fn evaluate_at(key: &Self::Key, index: u64) -> V {
         if key.get_party_id() == 0 && index == key.alpha {
             key.beta
         } else {
-            0
+            V::zero()
         }
     }
 }
@@ -67,9 +86,13 @@ impl SinglePointDpf for DummySpDpf {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use rand::distributions::{Distribution, Standard};
     use rand::{thread_rng, Rng};
 
-    fn test_spdpf_with_param<SPDPF: SinglePointDpf>(log_domain_size: u64) {
+    fn test_spdpf_with_param<SPDPF: SinglePointDpf>(log_domain_size: u64)
+    where
+        Standard: Distribution<SPDPF::Value>,
+    {
         let domain_size = 1 << log_domain_size;
         let alpha = thread_rng().gen_range(0..domain_size);
         let beta = thread_rng().gen();
@@ -83,7 +106,7 @@ mod tests {
             if i == alpha {
                 assert_eq!(value, beta);
             } else {
-                assert_eq!(value, 0);
+                assert_eq!(value, SPDPF::Value::zero());
             }
         }
     }
@@ -91,7 +114,7 @@ mod tests {
     #[test]
     fn test_spdpf() {
         for log_domain_size in 5..10 {
-            test_spdpf_with_param::<DummySpDpf>(log_domain_size);
+            test_spdpf_with_param::<DummySpDpf<u64>>(log_domain_size);
         }
     }
 }