Переглянути джерело

dpf: change mpdpf interface

arbitrary domain sizes + state
Lennart Braun 2 роки тому
батько
коміт
df33239024
1 змінених файлів з 93 додано та 47 видалено
  1. 93 47
      dpf/src/mpdpf.rs

+ 93 - 47
dpf/src/mpdpf.rs

@@ -12,7 +12,7 @@ use cuckoo::{
 
 pub trait MultiPointDpfKey: Clone + Debug {
     fn get_party_id(&self) -> usize;
-    fn get_log_domain_size(&self) -> u32;
+    fn get_domain_size(&self) -> usize;
     fn get_number_points(&self) -> usize;
 }
 
@@ -20,15 +20,14 @@ pub trait MultiPointDpf {
     type Key: MultiPointDpfKey;
     type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
 
-    fn generate_keys(
-        log_domain_size: u32,
-        alphas: &[u64],
-        betas: &[Self::Value],
-    ) -> (Self::Key, Self::Key);
-    fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
-    fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
-        (0..(1 << key.get_log_domain_size()))
-            .map(|x| Self::evaluate_at(&key, x))
+    fn new(domain_size: usize, number_points: usize) -> Self;
+    fn get_domain_size(&self) -> usize;
+    fn get_number_points(&self) -> usize;
+    fn generate_keys(&self, alphas: &[u64], betas: &[Self::Value]) -> (Self::Key, Self::Key);
+    fn evaluate_at(&self, key: &Self::Key, index: u64) -> Self::Value;
+    fn evaluate_domain(&self, key: &Self::Key) -> Vec<Self::Value> {
+        (0..key.get_domain_size())
+            .map(|x| self.evaluate_at(&key, x as u64))
             .collect()
     }
 }
@@ -36,7 +35,7 @@ pub trait MultiPointDpf {
 #[derive(Clone, Debug)]
 pub struct DummyMpDpfKey<V: Copy + Debug> {
     party_id: usize,
-    log_domain_size: u32,
+    domain_size: usize,
     number_points: usize,
     alphas: Vec<u64>,
     betas: Vec<V>,
@@ -49,8 +48,8 @@ where
     fn get_party_id(&self) -> usize {
         self.party_id
     }
-    fn get_log_domain_size(&self) -> u32 {
-        self.log_domain_size
+    fn get_domain_size(&self) -> usize {
+        self.domain_size
     }
     fn get_number_points(&self) -> usize {
         self.number_points
@@ -61,6 +60,8 @@ pub struct DummyMpDpf<V>
 where
     V: Add<Output = V> + Copy + Debug + Eq + Zero,
 {
+    domain_size: usize,
+    number_points: usize,
     phantom: PhantomData<V>,
 }
 
@@ -71,37 +72,61 @@ where
     type Key = DummyMpDpfKey<V>;
     type Value = V;
 
-    fn generate_keys(log_domain_size: u32, alphas: &[u64], betas: &[V]) -> (Self::Key, Self::Key) {
+    fn new(domain_size: usize, number_points: usize) -> Self {
+        Self {
+            domain_size,
+            number_points,
+            phantom: PhantomData,
+        }
+    }
+
+    fn get_domain_size(&self) -> usize {
+        self.domain_size
+    }
+
+    fn get_number_points(&self) -> usize {
+        self.number_points
+    }
+
+    fn generate_keys(&self, alphas: &[u64], betas: &[V]) -> (Self::Key, Self::Key) {
+        assert_eq!(
+            alphas.len(),
+            self.number_points,
+            "number of points does not match constructor argument"
+        );
         assert_eq!(
             alphas.len(),
             betas.len(),
             "alphas and betas must be the same size"
         );
         assert!(
-            alphas.iter().all(|alpha| alpha < &(1 << log_domain_size)),
+            alphas
+                .iter()
+                .all(|&alpha| alpha < (self.domain_size as u64)),
             "all alphas must be in the domain"
         );
         assert!(alphas.windows(2).all(|w| w[0] <= w[1]));
-        let number_points = alphas.len();
         (
             DummyMpDpfKey {
                 party_id: 0,
-                log_domain_size,
-                number_points,
+                domain_size: self.domain_size,
+                number_points: self.number_points,
                 alphas: alphas.iter().copied().collect(),
                 betas: betas.iter().copied().collect(),
             },
             DummyMpDpfKey {
                 party_id: 1,
-                log_domain_size,
-                number_points,
+                domain_size: self.domain_size,
+                number_points: self.number_points,
                 alphas: alphas.iter().copied().collect(),
                 betas: betas.iter().copied().collect(),
             },
         )
     }
 
-    fn evaluate_at(key: &Self::Key, index: u64) -> V {
+    fn evaluate_at(&self, key: &Self::Key, index: u64) -> V {
+        assert_eq!(self.domain_size, key.domain_size);
+        assert_eq!(self.number_points, key.number_points);
         if key.get_party_id() == 0 {
             match key.alphas.binary_search(&index) {
                 Ok(i) => key.betas[i],
@@ -119,7 +144,7 @@ where
     H: HashFunction<u32>,
 {
     party_id: usize,
-    log_domain_size: u32,
+    domain_size: usize,
     number_points: usize,
     spdpf_keys: Vec<Option<SPDPF::Key>>,
     cuckoo_parameters: CuckooParameters<H, u32>,
@@ -144,8 +169,8 @@ where
         )?;
         write!(
             f,
-            "{}log_domain_size: {:?},{}",
-            indentation, self.log_domain_size, newline
+            "{}domain_size: {:?},{}",
+            indentation, self.domain_size, newline
         )?;
         write!(
             f,
@@ -178,7 +203,7 @@ where
     fn clone(&self) -> Self {
         Self {
             party_id: self.party_id,
-            log_domain_size: self.log_domain_size,
+            domain_size: self.domain_size,
             number_points: self.number_points,
             spdpf_keys: self.spdpf_keys.clone(),
             cuckoo_parameters: self.cuckoo_parameters.clone(),
@@ -194,8 +219,8 @@ where
     fn get_party_id(&self) -> usize {
         self.party_id
     }
-    fn get_log_domain_size(&self) -> u32 {
-        self.log_domain_size
+    fn get_domain_size(&self) -> usize {
+        self.domain_size
     }
     fn get_number_points(&self) -> usize {
         self.number_points
@@ -208,6 +233,8 @@ where
     SPDPF: SinglePointDpf<Value = V>,
     H: HashFunction<u32>,
 {
+    domain_size: usize,
+    number_points: usize,
     phantom_v: PhantomData<V>,
     phantom_s: PhantomData<SPDPF>,
     phantom_h: PhantomData<H>,
@@ -222,21 +249,35 @@ where
     type Key = SmartMpDpfKey<SPDPF, H>;
     type Value = V;
 
-    fn generate_keys(
-        log_domain_size: u32,
-        alphas: &[u64],
-        betas: &[Self::Value],
-    ) -> (Self::Key, Self::Key) {
-        assert!(log_domain_size < u32::BITS);
+    fn new(domain_size: usize, number_points: usize) -> Self {
+        assert!(domain_size < (1 << u32::BITS));
+        Self {
+            domain_size,
+            number_points,
+            phantom_v: PhantomData,
+            phantom_s: PhantomData,
+            phantom_h: PhantomData,
+        }
+    }
+
+    fn get_domain_size(&self) -> usize {
+        self.domain_size
+    }
+
+    fn get_number_points(&self) -> usize {
+        self.domain_size
+    }
+
+    fn generate_keys(&self, alphas: &[u64], betas: &[Self::Value]) -> (Self::Key, Self::Key) {
         assert_eq!(alphas.len(), betas.len());
         assert!(alphas.windows(2).all(|w| w[0] < w[1]));
-        assert!(alphas.iter().all(|&alpha| alpha < (1 << log_domain_size)));
+        assert!(alphas.iter().all(|&alpha| alpha < self.domain_size as u64));
         let number_points = alphas.len();
 
         let cuckoo_parameters = CuckooParameters::<H, u32>::sample(number_points);
         let hasher = CuckooHasher::<H, u32>::new(cuckoo_parameters);
         let (cuckoo_table_items, cuckoo_table_indices) = hasher.cuckoo_hash_items(alphas);
-        let simple_htable = hasher.hash_domain_into_buckets(1 << log_domain_size);
+        let simple_htable = hasher.hash_domain_into_buckets(self.domain_size as u64);
 
         let pos = |bucket_i: usize, item: u64| -> u64 {
             let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
@@ -284,14 +325,14 @@ where
         (
             SmartMpDpfKey::<SPDPF, H> {
                 party_id: 0,
-                log_domain_size,
+                domain_size: self.domain_size,
                 number_points,
                 spdpf_keys: keys_0,
                 cuckoo_parameters,
             },
             SmartMpDpfKey::<SPDPF, H> {
                 party_id: 1,
-                log_domain_size,
+                domain_size: self.domain_size,
                 number_points,
                 spdpf_keys: keys_1,
                 cuckoo_parameters,
@@ -299,14 +340,16 @@ where
         )
     }
 
-    fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value {
-        let domain_size = 1 << key.log_domain_size;
-        assert!(index < domain_size);
+    fn evaluate_at(&self, key: &Self::Key, index: u64) -> Self::Value {
+        assert_eq!(self.domain_size, key.domain_size);
+        assert_eq!(self.number_points, key.number_points);
+        assert_eq!(key.domain_size, self.domain_size);
+        assert!(index < self.domain_size as u64);
 
         let hasher = CuckooHasher::<H, u32>::new(key.cuckoo_parameters);
 
         let hashes = hasher.hash_items(&[index]);
-        let simple_htable = hasher.hash_domain_into_buckets(domain_size);
+        let simple_htable = hasher.hash_domain_into_buckets(self.domain_size as u64);
 
         let pos = |bucket_i: usize, item: u64| -> u64 {
             let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
@@ -348,8 +391,10 @@ where
         output
     }
 
-    fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
-        let domain_size = 1 << key.log_domain_size;
+    fn evaluate_domain(&self, key: &Self::Key) -> Vec<Self::Value> {
+        assert_eq!(self.domain_size, key.domain_size);
+        assert_eq!(self.number_points, key.number_points);
+        let domain_size = self.domain_size as u64;
 
         let hasher = CuckooHasher::<H, u32>::new(key.cuckoo_parameters);
         let hashes = hasher.hash_domain(domain_size);
@@ -428,12 +473,13 @@ mod tests {
             alphas
         };
         let betas: Vec<MPDPF::Value> = (0..number_points).map(|_| thread_rng().gen()).collect();
-        let (key_0, key_1) = MPDPF::generate_keys(log_domain_size, &alphas, &betas);
+        let mpdpf = MPDPF::new(domain_size as usize, number_points);
+        let (key_0, key_1) = mpdpf.generate_keys(&alphas, &betas);
 
-        let out_0 = MPDPF::evaluate_domain(&key_0);
-        let out_1 = MPDPF::evaluate_domain(&key_1);
+        let out_0 = mpdpf.evaluate_domain(&key_0);
+        let out_1 = mpdpf.evaluate_domain(&key_1);
         for i in 0..domain_size {
-            let value = MPDPF::evaluate_at(&key_0, i) + MPDPF::evaluate_at(&key_1, i);
+            let value = mpdpf.evaluate_at(&key_0, i) + mpdpf.evaluate_at(&key_1, i);
             assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
             let expected_result = match alphas.binary_search(&i) {
                 Ok(i) => betas[i],