Browse Source

dpf: parallelize mpdpf evaluate_domain

Lennart Braun 2 years ago
parent
commit
7643524406
2 changed files with 47 additions and 36 deletions
  1. 1 0
      dpf/Cargo.toml
  2. 46 36
      dpf/src/mpdpf.rs

+ 1 - 0
dpf/Cargo.toml

@@ -11,6 +11,7 @@ utils = { path = "../utils" }
 bincode = "2.0.0-rc.2"
 num = "0.4.0"
 rand = "0.8.5"
+rayon = "1.6.1"
 
 [dev-dependencies]
 criterion = "0.4.0"

+ 46 - 36
dpf/src/mpdpf.rs

@@ -4,6 +4,7 @@ use core::fmt::Debug;
 use core::marker::PhantomData;
 use core::ops::{Add, AddAssign};
 use num::traits::Zero;
+use rayon::prelude::*;
 
 use crate::spdpf::SinglePointDpf;
 use cuckoo::{
@@ -297,9 +298,10 @@ where
 
 impl<V, SPDPF, H> SmartMpDpf<V, SPDPF, H>
 where
-    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
+    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero + Send + Sync,
     SPDPF: SinglePointDpf<Value = V>,
     H: HashFunction<u16>,
+    H::Description: Sync,
 {
     fn precompute_hashes(
         domain_size: usize,
@@ -334,9 +336,11 @@ where
 
 impl<V, SPDPF, H> MultiPointDpf for SmartMpDpf<V, SPDPF, H>
 where
-    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
+    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero + Send + Sync,
     SPDPF: SinglePointDpf<Value = V>,
+    SPDPF::Key: Sync,
     H: HashFunction<u16>,
+    H::Description: Sync,
 {
     type Key = SmartMpDpfKey<SPDPF, H>;
     type Value = V;
@@ -362,10 +366,12 @@ where
     }
 
     fn precompute(&mut self) {
-        self.precomputation_data = Some(Self::precompute_hashes(
-            self.domain_size,
-            self.number_points,
-        ));
+        if self.precomputation_data.is_none() {
+            self.precomputation_data = Some(Self::precompute_hashes(
+                self.domain_size,
+                self.number_points,
+            ));
+        }
     }
 
     fn generate_keys(&self, alphas: &[u64], betas: &[Self::Value]) -> (Self::Key, Self::Key) {
@@ -375,6 +381,7 @@ where
         let number_points = alphas.len();
 
         // if not data is precomputed, do it now
+        // (&self is not mut, so we cannot store the new data here nor call precompute() ...)
         let mut precomputation_data_fresh: Option<SmartMpDpfPrecomputationData<H>> = None;
         if self.precomputation_data.is_none() {
             precomputation_data_fresh = Some(Self::precompute_hashes(
@@ -498,6 +505,7 @@ where
         let domain_size = self.domain_size as u64;
 
         // if not data is precomputed, do it now
+        // (&self is not mut, so we cannot store the new data here nor call precompute() ...)
         let mut precomputation_data_fresh: Option<SmartMpDpfPrecomputationData<H>> = None;
         if self.precomputation_data.is_none() {
             precomputation_data_fresh = Some(Self::precompute_hashes(
@@ -517,11 +525,9 @@ where
             CuckooHasher::<H, u16>::pos_lookup(position_map_lookup_table, bucket_i, item)
         };
 
-        let mut outputs = Vec::<Self::Value>::with_capacity(domain_size as usize);
-
         let sp_dpf_full_domain_evaluations: Vec<Vec<V>> = key
             .spdpf_keys
-            .iter()
+            .par_iter()
             .map(|sp_key_opt| {
                 sp_key_opt
                     .as_ref()
@@ -532,35 +538,39 @@ where
         let spdpf_evaluate_at =
             |hash: usize, index| sp_dpf_full_domain_evaluations[hash][pos(hash, index) as usize];
 
-        for index in 0..domain_size {
-            outputs.push({
-                let hash = H::hash_value_as_usize(hashes[0][index as usize]);
-                debug_assert!(key.spdpf_keys[hash].is_some());
-                debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
-                spdpf_evaluate_at(hash, index)
-            });
-
-            // prevent adding the same term multiple times when we have collisions
-            let mut hash_bit_map = [0u8; 2];
-            if hashes[0][index as usize] != hashes[1][index as usize] {
-                hash_bit_map[0] = 1;
-            }
-            if hashes[0][index as usize] != hashes[2][index as usize]
-                && hashes[1][index as usize] != hashes[2][index as usize]
-            {
-                hash_bit_map[1] = 1;
-            }
+        let outputs: Vec<_> = (0..domain_size)
+            .into_par_iter()
+            .map(|index| {
+                let mut output = {
+                    let hash = H::hash_value_as_usize(hashes[0][index as usize]);
+                    debug_assert!(key.spdpf_keys[hash].is_some());
+                    debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
+                    spdpf_evaluate_at(hash, index)
+                };
 
-            for j in 1..CUCKOO_NUMBER_HASH_FUNCTIONS {
-                if hash_bit_map[j - 1] == 0 {
-                    continue;
+                // prevent adding the same term multiple times when we have collisions
+                let mut hash_bit_map = [0u8; 2];
+                if hashes[0][index as usize] != hashes[1][index as usize] {
+                    hash_bit_map[0] = 1;
                 }
-                let hash = H::hash_value_as_usize(hashes[j][index as usize]);
-                debug_assert!(key.spdpf_keys[hash].is_some());
-                debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
-                outputs[index as usize] += spdpf_evaluate_at(hash, index)
-            }
-        }
+                if hashes[0][index as usize] != hashes[2][index as usize]
+                    && hashes[1][index as usize] != hashes[2][index as usize]
+                {
+                    hash_bit_map[1] = 1;
+                }
+
+                for j in 1..CUCKOO_NUMBER_HASH_FUNCTIONS {
+                    if hash_bit_map[j - 1] == 0 {
+                        continue;
+                    }
+                    let hash = H::hash_value_as_usize(hashes[j][index as usize]);
+                    debug_assert!(key.spdpf_keys[hash].is_some());
+                    debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
+                    output += spdpf_evaluate_at(hash, index);
+                }
+                output
+            })
+            .collect();
 
         outputs
     }