Explorar o código

dpf: add precomputation for mpdpf

Lennart Braun %!s(int64=2) %!d(string=hai) anos
pai
achega
43a833559b
Modificáronse 1 ficheiros con 117 adicións e 34 borrados
  1. 117 34
      dpf/src/mpdpf.rs

+ 117 - 34
dpf/src/mpdpf.rs

@@ -23,6 +23,7 @@ pub trait MultiPointDpf {
     fn new(domain_size: usize, number_points: usize) -> Self;
     fn get_domain_size(&self) -> usize;
     fn get_number_points(&self) -> usize;
+    fn precompute(&mut self) {}
     fn generate_keys(&self, alphas: &[u64], betas: &[Self::Value]) -> (Self::Key, Self::Key);
     fn evaluate_at(&self, key: &Self::Key, index: u64) -> Self::Value;
     fn evaluate_domain(&self, key: &Self::Key) -> Vec<Self::Value> {
@@ -227,6 +228,15 @@ where
     }
 }
 
+struct SmartMpDpfPrecomputationData<H: HashFunction<u32>> {
+    pub cuckoo_parameters: CuckooParameters<H, u32>,
+    pub hasher: CuckooHasher<H, u32>,
+    pub hashes: [Vec<u32>; CUCKOO_NUMBER_HASH_FUNCTIONS],
+    pub simple_htable: Vec<Vec<u64>>,
+    pub bucket_sizes: Vec<usize>,
+    pub position_map_lookup_table: Vec<[(usize, usize); 3]>,
+}
+
 pub struct SmartMpDpf<V, SPDPF, H>
 where
     V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
@@ -235,11 +245,52 @@ where
 {
     domain_size: usize,
     number_points: usize,
+    precomputation_data: Option<SmartMpDpfPrecomputationData<H>>,
     phantom_v: PhantomData<V>,
     phantom_s: PhantomData<SPDPF>,
     phantom_h: PhantomData<H>,
 }
 
+impl<V, SPDPF, H> SmartMpDpf<V, SPDPF, H>
+where
+    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
+    SPDPF: SinglePointDpf<Value = V>,
+    H: HashFunction<u32>,
+{
+    fn precompute_hashes(
+        domain_size: usize,
+        number_points: usize,
+    ) -> SmartMpDpfPrecomputationData<H> {
+        let seed: [u8; 32] = [
+            42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42,
+            42, 42, 42, 42, 42, 42, 42, 42, 42, 42,
+        ];
+        let cuckoo_parameters = CuckooParameters::from_seed(number_points, seed);
+        let hasher = CuckooHasher::<H, u32>::new(cuckoo_parameters);
+        let hashes = hasher.hash_domain(domain_size as u64);
+        let simple_htable =
+            hasher.hash_domain_into_buckets_given_hashes(domain_size as u64, &hashes);
+        let bucket_sizes = CuckooHasher::<H, u32>::compute_bucket_sizes(&simple_htable);
+        let position_map_lookup_table =
+            CuckooHasher::<H, u32>::compute_pos_lookup_table(domain_size as u64, &simple_htable);
+        SmartMpDpfPrecomputationData {
+            cuckoo_parameters,
+            hasher,
+            hashes,
+            simple_htable,
+            bucket_sizes,
+            position_map_lookup_table,
+        }
+    }
+
+    pub fn precompute(&mut self) {
+        self.precomputation_data = Some(Self::precompute_hashes(
+            self.domain_size,
+            self.number_points,
+        ));
+    }
+}
+
 impl<V, SPDPF, H> MultiPointDpf for SmartMpDpf<V, SPDPF, H>
 where
     V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
@@ -254,6 +305,7 @@ where
         Self {
             domain_size,
             number_points,
+            precomputation_data: None,
             phantom_v: PhantomData,
             phantom_s: PhantomData,
             phantom_h: PhantomData,
@@ -274,40 +326,43 @@ where
         assert!(alphas.iter().all(|&alpha| alpha < self.domain_size as u64));
         let number_points = alphas.len();
 
-        let cuckoo_parameters = CuckooParameters::<H, u32>::sample(number_points);
-        let hasher = CuckooHasher::<H, u32>::new(cuckoo_parameters);
+        // if not data is precomputed, do it now
+        let mut precomputation_data_fresh: Option<SmartMpDpfPrecomputationData<H>> = None;
+        if self.precomputation_data.is_none() {
+            precomputation_data_fresh = Some(Self::precompute_hashes(
+                self.domain_size,
+                self.number_points,
+            ));
+        }
+        // select either the precomputed or the freshly computed data
+        let precomputation_data = self
+            .precomputation_data
+            .as_ref()
+            .unwrap_or_else(|| precomputation_data_fresh.as_ref().unwrap());
+        let cuckoo_parameters = &precomputation_data.cuckoo_parameters;
+        let hasher = &precomputation_data.hasher;
         let (cuckoo_table_items, cuckoo_table_indices) = hasher.cuckoo_hash_items(alphas);
-        let simple_htable = hasher.hash_domain_into_buckets(self.domain_size as u64);
-
+        let position_map_lookup_table = &precomputation_data.position_map_lookup_table;
         let pos = |bucket_i: usize, item: u64| -> u64 {
-            let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
-            assert!(idx != simple_htable[bucket_i].len());
-            assert_eq!(item, simple_htable[bucket_i][idx]);
-            assert!(idx == 0 || simple_htable[bucket_i][idx - 1] != item);
-            idx as u64
+            CuckooHasher::<H, u32>::pos_lookup(position_map_lookup_table, bucket_i, item)
         };
 
         let number_buckets = hasher.get_parameters().get_number_buckets();
+        let bucket_sizes = &precomputation_data.bucket_sizes;
 
         let mut keys_0 = Vec::<Option<SPDPF::Key>>::with_capacity(number_buckets);
         let mut keys_1 = Vec::<Option<SPDPF::Key>>::with_capacity(number_buckets);
-        let mut bucket_sizes = vec![0u64; number_buckets];
 
         for bucket_i in 0..number_buckets {
-            let bucket_size = simple_htable[bucket_i].len() as u64;
-
-            // remember the bucket size
-            bucket_sizes[bucket_i] = bucket_size;
-
             // if bucket is empty, add invalid dummy keys to the arrays to make the
             // indices work
-            if bucket_size == 0 {
+            if bucket_sizes[bucket_i] == 0 {
                 keys_0.push(None);
                 keys_1.push(None);
                 continue;
             }
 
-            let sp_log_domain_size = (bucket_size as f64).log2().ceil() as u64;
+            let sp_log_domain_size = (bucket_sizes[bucket_i] as f64).log2().ceil() as u64;
 
             let (alpha, beta) =
                 if cuckoo_table_items[bucket_i] != CuckooHasher::<H, u32>::UNOCCUPIED {
@@ -328,14 +383,14 @@ where
                 domain_size: self.domain_size,
                 number_points,
                 spdpf_keys: keys_0,
-                cuckoo_parameters,
+                cuckoo_parameters: cuckoo_parameters.clone(),
             },
             SmartMpDpfKey::<SPDPF, H> {
                 party_id: 1,
                 domain_size: self.domain_size,
                 number_points,
                 spdpf_keys: keys_1,
-                cuckoo_parameters,
+                cuckoo_parameters: cuckoo_parameters.clone(),
             },
         )
     }
@@ -396,16 +451,24 @@ where
         assert_eq!(self.number_points, key.number_points);
         let domain_size = self.domain_size as u64;
 
-        let hasher = CuckooHasher::<H, u32>::new(key.cuckoo_parameters);
-        let hashes = hasher.hash_domain(domain_size);
-        let simple_htable = hasher.hash_domain_into_buckets(domain_size);
-
+        // if not data is precomputed, do it now
+        let mut precomputation_data_fresh: Option<SmartMpDpfPrecomputationData<H>> = None;
+        if self.precomputation_data.is_none() {
+            precomputation_data_fresh = Some(Self::precompute_hashes(
+                self.domain_size,
+                self.number_points,
+            ));
+        }
+        // select either the precomputed or the freshly computed data
+        let precomputation_data = self
+            .precomputation_data
+            .as_ref()
+            .unwrap_or_else(|| precomputation_data_fresh.as_ref().unwrap());
+        let hashes = &precomputation_data.hashes;
+        let simple_htable = &precomputation_data.simple_htable;
+        let position_map_lookup_table = &precomputation_data.position_map_lookup_table;
         let pos = |bucket_i: usize, item: u64| -> u64 {
-            let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
-            assert!(idx != simple_htable[bucket_i].len());
-            assert_eq!(item, simple_htable[bucket_i][idx]);
-            assert!(idx == 0 || simple_htable[bucket_i][idx - 1] != item);
-            idx as u64
+            CuckooHasher::<H, u32>::pos_lookup(position_map_lookup_table, bucket_i, item)
         };
 
         let mut outputs = Vec::<Self::Value>::with_capacity(domain_size as usize);
@@ -455,8 +518,11 @@ mod tests {
     use rand::{thread_rng, Rng};
     use std::num::Wrapping;
 
-    fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(log_domain_size: u32, number_points: usize)
-    where
+    fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(
+        log_domain_size: u32,
+        number_points: usize,
+        precomputation: bool,
+    ) where
         Standard: Distribution<MPDPF::Value>,
     {
         let domain_size = (1 << log_domain_size) as u64;
@@ -473,7 +539,10 @@ mod tests {
             alphas
         };
         let betas: Vec<MPDPF::Value> = (0..number_points).map(|_| thread_rng().gen()).collect();
-        let mpdpf = MPDPF::new(domain_size as usize, number_points);
+        let mut mpdpf = MPDPF::new(domain_size as usize, number_points);
+        if precomputation {
+            mpdpf.precompute();
+        }
         let (key_0, key_1) = mpdpf.generate_keys(&alphas, &betas);
 
         let out_0 = mpdpf.evaluate_domain(&key_0);
@@ -490,16 +559,30 @@ mod tests {
     }
 
     #[test]
-    fn test_mpdpf() {
+    fn test_dummy_mpdpf() {
         type Value = Wrapping<u64>;
         for log_domain_size in 5..10 {
             for log_number_points in 0..5 {
-                test_mpdpf_with_param::<DummyMpDpf<Value>>(log_domain_size, 1 << log_number_points);
-                test_mpdpf_with_param::<SmartMpDpf<Value, DummySpDpf<Value>, AesHashFunction<u32>>>(
+                test_mpdpf_with_param::<DummyMpDpf<Value>>(
                     log_domain_size,
                     1 << log_number_points,
+                    false,
                 );
             }
         }
     }
+
+    #[test]
+    fn test_smart_mpdpf() {
+        type Value = Wrapping<u64>;
+        for log_domain_size in 5..7 {
+            for log_number_points in 0..5 {
+                for precomputation in [false, true] {
+                    test_mpdpf_with_param::<
+                        SmartMpDpf<Value, DummySpDpf<Value>, AesHashFunction<u32>>,
+                    >(log_domain_size, 1 << log_number_points, precomputation);
+                }
+            }
+        }
+    }
 }