Browse Source

add mpdpf from spdpf/cuckoo implementation

Lennart Braun 2 years ago
parent
commit
0a5acaa2ec
2 changed files with 274 additions and 13 deletions
  1. 1 0
      dpf/Cargo.toml
  2. 273 13
      dpf/src/mpdpf.rs

+ 1 - 0
dpf/Cargo.toml

@@ -6,5 +6,6 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
+cuckoo = { path = "../cuckoo" }
 rand = "0.8.5"
 num = "0.4.0"

+ 273 - 13
dpf/src/mpdpf.rs

@@ -1,16 +1,23 @@
+use core::fmt;
 use core::fmt::Debug;
 use core::marker::PhantomData;
-use core::ops::Add;
+use core::ops::{Add, AddAssign};
 use num::traits::Zero;
 
-pub trait MultiPointDpfKey {
+use crate::spdpf::SinglePointDpf;
+use cuckoo::{
+    cuckoo::Hasher as CuckooHasher, cuckoo::Parameters as CuckooParameters,
+    cuckoo::NUMBER_HASH_FUNCTIONS as CUCKOO_NUMBER_HASH_FUNCTIONS, hash::HashFunction,
+};
+
+pub trait MultiPointDpfKey: Clone + Debug {
     fn get_party_id(&self) -> usize;
     fn get_log_domain_size(&self) -> u64;
     fn get_number_points(&self) -> usize;
 }
 
 pub trait MultiPointDpf {
-    type Key: Clone + MultiPointDpfKey;
+    type Key: MultiPointDpfKey;
     type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
 
     fn generate_keys(
@@ -27,7 +34,7 @@ pub trait MultiPointDpf {
 }
 
 #[derive(Clone, Debug)]
-pub struct DummyMpDpfKey<V> {
+pub struct DummyMpDpfKey<V: Copy + Debug> {
     party_id: usize,
     log_domain_size: u64,
     number_points: usize,
@@ -35,7 +42,10 @@ pub struct DummyMpDpfKey<V> {
     betas: Vec<V>,
 }
 
-impl<V> MultiPointDpfKey for DummyMpDpfKey<V> {
+impl<V> MultiPointDpfKey for DummyMpDpfKey<V>
+where
+    V: Copy + Debug,
+{
     fn get_party_id(&self) -> usize {
         self.party_id
     }
@@ -103,23 +113,268 @@ where
     }
 }
 
+pub struct SmartMpDpfKey<SPDPF, H>
+where
+    SPDPF: SinglePointDpf,
+    H: HashFunction,
+{
+    party_id: usize,
+    log_domain_size: u64,
+    number_points: usize,
+    spdpf_keys: Vec<Option<SPDPF::Key>>,
+    cuckoo_parameters: CuckooParameters<H>,
+}
+
+impl<SPDPF, H> Debug for SmartMpDpfKey<SPDPF, H>
+where
+    SPDPF: SinglePointDpf,
+    H: HashFunction,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+        let (newline, indentation) = if f.alternate() {
+            ("\n", "    ")
+        } else {
+            (" ", "")
+        };
+        write!(f, "SmartMpDpfKey<SPDPF, H>{{{}", newline)?;
+        write!(
+            f,
+            "{}party_id: {:?},{}",
+            indentation, self.party_id, newline
+        )?;
+        write!(
+            f,
+            "{}log_domain_size: {:?},{}",
+            indentation, self.log_domain_size, newline
+        )?;
+        write!(
+            f,
+            "{}number_points: {:?},{}",
+            indentation, self.number_points, newline
+        )?;
+        if f.alternate() {
+            write!(f, "    spdpf_keys:\n")?;
+            for (i, k) in self.spdpf_keys.iter().enumerate() {
+                write!(f, "        spdpf_keys[{}]: {:?}\n", i, k)?;
+            }
+        } else {
+            write!(f, " spdpf_keys: {:?},", self.spdpf_keys)?;
+        }
+        write!(
+            f,
+            "{}cuckoo_parameters: {:?}{}",
+            indentation, self.cuckoo_parameters, newline
+        )?;
+        write!(f, "}}")?;
+        Ok(())
+    }
+}
+
+impl<SPDPF, H> Clone for SmartMpDpfKey<SPDPF, H>
+where
+    SPDPF: SinglePointDpf,
+    H: HashFunction,
+{
+    fn clone(&self) -> Self {
+        Self {
+            party_id: self.party_id,
+            log_domain_size: self.log_domain_size,
+            number_points: self.number_points,
+            spdpf_keys: self.spdpf_keys.clone(),
+            cuckoo_parameters: self.cuckoo_parameters.clone(),
+        }
+    }
+}
+
+impl<SPDPF, H> MultiPointDpfKey for SmartMpDpfKey<SPDPF, H>
+where
+    SPDPF: SinglePointDpf,
+    H: HashFunction,
+{
+    fn get_party_id(&self) -> usize {
+        self.party_id
+    }
+    fn get_log_domain_size(&self) -> u64 {
+        self.log_domain_size
+    }
+    fn get_number_points(&self) -> usize {
+        self.number_points
+    }
+}
+
+pub struct SmartMpDpf<V, SPDPF, H>
+where
+    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
+    SPDPF: SinglePointDpf<Value = V>,
+    H: HashFunction,
+{
+    phantom_v: PhantomData<V>,
+    phantom_s: PhantomData<SPDPF>,
+    phantom_h: PhantomData<H>,
+}
+
+impl<V, SPDPF, H> MultiPointDpf for SmartMpDpf<V, SPDPF, H>
+where
+    V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
+    SPDPF: SinglePointDpf<Value = V>,
+    H: HashFunction,
+{
+    type Key = SmartMpDpfKey<SPDPF, H>;
+    type Value = V;
+
+    fn generate_keys(
+        log_domain_size: u64,
+        alphas: &[u64],
+        betas: &[Self::Value],
+    ) -> (Self::Key, Self::Key) {
+        assert_eq!(alphas.len(), betas.len());
+        assert!(alphas.windows(2).all(|w| w[0] < w[1]));
+        assert!(alphas.iter().all(|&alpha| alpha < (1 << log_domain_size)));
+        let number_points = alphas.len();
+
+        let cuckoo_parameters = CuckooParameters::<H>::sample(number_points);
+        let hasher = CuckooHasher::<H>::new(cuckoo_parameters);
+        let (cuckoo_table_items, cuckoo_table_indices) = hasher.cuckoo_hash_items(alphas);
+        let simple_htable = hasher.hash_domain_into_buckets(1 << log_domain_size);
+
+        let pos = |bucket_i: usize, item: u64| -> u64 {
+            let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
+            assert!(idx != simple_htable[bucket_i].len());
+            assert_eq!(item, simple_htable[bucket_i][idx]);
+            assert!(idx == 0 || simple_htable[bucket_i][idx - 1] != item);
+            idx as u64
+        };
+
+        let number_buckets = hasher.get_parameters().get_number_buckets();
+
+        let mut keys_0 = Vec::<Option<SPDPF::Key>>::with_capacity(number_buckets);
+        let mut keys_1 = Vec::<Option<SPDPF::Key>>::with_capacity(number_buckets);
+        let mut bucket_sizes = vec![0u64; number_buckets];
+
+        for bucket_i in 0..number_buckets {
+            let bucket_size = simple_htable[bucket_i].len() as u64;
+
+            // remember the bucket size
+            bucket_sizes[bucket_i] = bucket_size;
+
+            // if bucket is empty, add invalid dummy keys to the arrays to make the
+            // indices work
+            if bucket_size == 0 {
+                keys_0.push(None);
+                keys_1.push(None);
+                continue;
+            }
+
+            let sp_log_domain_size = (bucket_size as f64).log2().ceil() as u64;
+
+            let (alpha, beta) = if cuckoo_table_items[bucket_i] != CuckooHasher::<H>::UNOCCUPIED {
+                let alpha = pos(bucket_i, cuckoo_table_items[bucket_i]);
+                let beta = betas[cuckoo_table_indices[bucket_i]];
+                (alpha, beta)
+            } else {
+                (0, V::zero())
+            };
+            let (key_0, key_1) = SPDPF::generate_keys(sp_log_domain_size, alpha, beta);
+            keys_0.push(Some(key_0));
+            keys_1.push(Some(key_1));
+        }
+
+        (
+            SmartMpDpfKey::<SPDPF, H> {
+                party_id: 0,
+                log_domain_size,
+                number_points,
+                spdpf_keys: keys_0,
+                cuckoo_parameters,
+            },
+            SmartMpDpfKey::<SPDPF, H> {
+                party_id: 1,
+                log_domain_size,
+                number_points,
+                spdpf_keys: keys_1,
+                cuckoo_parameters,
+            },
+        )
+    }
+
+    fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value {
+        let domain_size = 1 << key.log_domain_size;
+        assert!(index < domain_size);
+
+        let hasher = CuckooHasher::<H>::new(key.cuckoo_parameters);
+
+        let hashes = hasher.hash_items(&[index]);
+        let simple_htable = hasher.hash_domain_into_buckets(domain_size);
+
+        let pos = |bucket_i: usize, item: u64| -> u64 {
+            let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
+            assert!(idx != simple_htable[bucket_i].len());
+            assert_eq!(item, simple_htable[bucket_i][idx]);
+            assert!(idx == 0 || simple_htable[bucket_i][idx - 1] != item);
+            idx as u64
+        };
+        eprintln!(
+            "hashes({}) = ({:#?}, {:#?}, {:#?})",
+            index, hashes[0][0], hashes[1][0], hashes[2][0]
+        );
+        let mut output = {
+            let hash = hashes[0][0] as usize;
+            assert!(key.spdpf_keys[hash].is_some());
+            let sp_key = &key.spdpf_keys[hash].unwrap();
+            assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
+            SPDPF::evaluate_at(sp_key, pos(hash, index))
+        };
+
+        // prevent adding the same term multiple times when we have collisions
+        let mut hash_bit_map = [0u8; 2];
+        if hashes[0][0] != hashes[1][0] {
+            // hash_bit_map[i] |= 1;
+            hash_bit_map[0] = 1;
+        }
+        if hashes[0][0] != hashes[2][0] && hashes[1][0] != hashes[2][0] {
+            // hash_bit_map[i] |= 2;
+            hash_bit_map[1] = 1;
+        }
+
+        for j in 1..CUCKOO_NUMBER_HASH_FUNCTIONS {
+            if hash_bit_map[j - 1] == 0 {
+                continue;
+            }
+            let hash = hashes[j][0] as usize;
+            assert!(key.spdpf_keys[hash].is_some());
+            let sp_key = &key.spdpf_keys[hash].unwrap();
+            assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
+            output += SPDPF::evaluate_at(sp_key, pos(hash, index));
+        }
+
+        output
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::spdpf::DummySpDpf;
+    use cuckoo::hash::AesHashFunction;
     use rand::distributions::{Distribution, Standard};
     use rand::{thread_rng, Rng};
+    use std::num::Wrapping;
 
     fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(log_domain_size: u64, number_points: usize)
     where
         Standard: Distribution<MPDPF::Value>,
     {
-        assert!(number_points <= (1 << log_domain_size));
-        let domain_size = 1 << log_domain_size;
+        let domain_size = (1 << log_domain_size) as u64;
+        assert!(number_points <= domain_size as usize);
         let alphas = {
-            let mut alphas: Vec<u64> = (0..number_points)
-                .map(|_| thread_rng().gen_range(0..domain_size))
-                .collect();
-            alphas.sort();
+            let mut alphas = Vec::<u64>::with_capacity(number_points);
+            while alphas.len() < number_points {
+                let x = thread_rng().gen_range(0..domain_size);
+                match alphas.as_slice().binary_search(&x) {
+                    Ok(_) => continue,
+                    Err(i) => alphas.insert(i, x),
+                }
+            }
             alphas
         };
         let betas: Vec<MPDPF::Value> = (0..number_points).map(|_| thread_rng().gen()).collect();
@@ -134,15 +389,20 @@ mod tests {
                 Ok(i) => betas[i],
                 Err(_) => MPDPF::Value::zero(),
             };
-            assert_eq!(value, expected_result);
+            assert_eq!(value, expected_result, "wrong value at index {}", i);
         }
     }
 
     #[test]
     fn test_mpdpf() {
+        type Value = Wrapping<u64>;
         for log_domain_size in 5..10 {
             for log_number_points in 0..5 {
-                test_mpdpf_with_param::<DummyMpDpf<u64>>(log_domain_size, 1 << log_number_points);
+                test_mpdpf_with_param::<DummyMpDpf<Value>>(log_domain_size, 1 << log_number_points);
+                test_mpdpf_with_param::<SmartMpDpf<Value, DummySpDpf<Value>, AesHashFunction>>(
+                    log_domain_size,
+                    1 << log_number_points,
+                );
             }
         }
     }