|
@@ -1,3 +1,8 @@
|
|
|
+use core::fmt::Debug;
|
|
|
+use core::marker::PhantomData;
|
|
|
+use core::ops::Add;
|
|
|
+use num::traits::Zero;
|
|
|
+
|
|
|
pub trait MultiPointDpfKey {
|
|
|
fn get_party_id(&self) -> usize;
|
|
|
fn get_log_domain_size(&self) -> u64;
|
|
@@ -6,11 +11,15 @@ pub trait MultiPointDpfKey {
|
|
|
|
|
|
pub trait MultiPointDpf {
|
|
|
type Key: Clone + MultiPointDpfKey;
|
|
|
+ type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
|
|
|
|
|
|
- fn generate_keys(log_domain_size: u64, alphas: &[u64], betas: &[u64])
|
|
|
- -> (Self::Key, Self::Key);
|
|
|
- fn evaluate_at(key: &Self::Key, index: u64) -> u64;
|
|
|
- fn evaluate_domain(key: &Self::Key) -> Vec<u64> {
|
|
|
+ fn generate_keys(
|
|
|
+ log_domain_size: u64,
|
|
|
+ alphas: &[u64],
|
|
|
+ betas: &[Self::Value],
|
|
|
+ ) -> (Self::Key, Self::Key);
|
|
|
+ fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
|
|
|
+ fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
|
|
|
(0..(1 << key.get_log_domain_size()))
|
|
|
.map(|x| Self::evaluate_at(&key, x))
|
|
|
.collect()
|
|
@@ -18,15 +27,15 @@ pub trait MultiPointDpf {
|
|
|
}
|
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
-pub struct DummyMpDpfKey {
|
|
|
+pub struct DummyMpDpfKey<V> {
|
|
|
party_id: usize,
|
|
|
log_domain_size: u64,
|
|
|
number_points: usize,
|
|
|
alphas: Vec<u64>,
|
|
|
- betas: Vec<u64>,
|
|
|
+ betas: Vec<V>,
|
|
|
}
|
|
|
|
|
|
-impl MultiPointDpfKey for DummyMpDpfKey {
|
|
|
+impl<V> MultiPointDpfKey for DummyMpDpfKey<V> {
|
|
|
fn get_party_id(&self) -> usize {
|
|
|
self.party_id
|
|
|
}
|
|
@@ -38,16 +47,21 @@ impl MultiPointDpfKey for DummyMpDpfKey {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub struct DummyMpDpf {}
|
|
|
+pub struct DummyMpDpf<V>
|
|
|
+where
|
|
|
+ V: Add<Output = V> + Copy + Debug + Eq + Zero,
|
|
|
+{
|
|
|
+ phantom: PhantomData<V>,
|
|
|
+}
|
|
|
|
|
|
-impl MultiPointDpf for DummyMpDpf {
|
|
|
- type Key = DummyMpDpfKey;
|
|
|
+impl<V> MultiPointDpf for DummyMpDpf<V>
|
|
|
+where
|
|
|
+ V: Add<Output = V> + Copy + Debug + Eq + Zero,
|
|
|
+{
|
|
|
+ type Key = DummyMpDpfKey<V>;
|
|
|
+ type Value = V;
|
|
|
|
|
|
- fn generate_keys(
|
|
|
- log_domain_size: u64,
|
|
|
- alphas: &[u64],
|
|
|
- betas: &[u64],
|
|
|
- ) -> (Self::Key, Self::Key) {
|
|
|
+ fn generate_keys(log_domain_size: u64, alphas: &[u64], betas: &[V]) -> (Self::Key, Self::Key) {
|
|
|
assert_eq!(
|
|
|
alphas.len(),
|
|
|
betas.len(),
|
|
@@ -77,14 +91,14 @@ impl MultiPointDpf for DummyMpDpf {
|
|
|
)
|
|
|
}
|
|
|
|
|
|
- fn evaluate_at(key: &Self::Key, index: u64) -> u64 {
|
|
|
+ fn evaluate_at(key: &Self::Key, index: u64) -> V {
|
|
|
if key.get_party_id() == 0 {
|
|
|
match key.alphas.binary_search(&index) {
|
|
|
Ok(i) => key.betas[i],
|
|
|
- Err(_) => 0,
|
|
|
+ Err(_) => V::zero(),
|
|
|
}
|
|
|
} else {
|
|
|
- 0
|
|
|
+ V::zero()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -92,9 +106,13 @@ impl MultiPointDpf for DummyMpDpf {
|
|
|
#[cfg(test)]
|
|
|
mod tests {
|
|
|
use super::*;
|
|
|
+ use rand::distributions::{Distribution, Standard};
|
|
|
use rand::{thread_rng, Rng};
|
|
|
|
|
|
- fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(log_domain_size: u64, number_points: usize) {
|
|
|
+ fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(log_domain_size: u64, number_points: usize)
|
|
|
+ where
|
|
|
+ Standard: Distribution<MPDPF::Value>,
|
|
|
+ {
|
|
|
assert!(number_points <= (1 << log_domain_size));
|
|
|
let domain_size = 1 << log_domain_size;
|
|
|
let alphas = {
|
|
@@ -104,7 +122,7 @@ mod tests {
|
|
|
alphas.sort();
|
|
|
alphas
|
|
|
};
|
|
|
- let betas: Vec<u64> = (0..number_points).map(|_| thread_rng().gen()).collect();
|
|
|
+ let betas: Vec<MPDPF::Value> = (0..number_points).map(|_| thread_rng().gen()).collect();
|
|
|
let (key_0, key_1) = MPDPF::generate_keys(log_domain_size, &alphas, &betas);
|
|
|
|
|
|
let out_0 = MPDPF::evaluate_domain(&key_0);
|
|
@@ -114,7 +132,7 @@ mod tests {
|
|
|
assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
|
|
|
let expected_result = match alphas.binary_search(&i) {
|
|
|
Ok(i) => betas[i],
|
|
|
- Err(_) => 0,
|
|
|
+ Err(_) => MPDPF::Value::zero(),
|
|
|
};
|
|
|
assert_eq!(value, expected_result);
|
|
|
}
|
|
@@ -124,7 +142,7 @@ mod tests {
|
|
|
fn test_mpdpf() {
|
|
|
for log_domain_size in 5..10 {
|
|
|
for log_number_points in 0..5 {
|
|
|
- test_mpdpf_with_param::<DummyMpDpf>(log_domain_size, 1 << log_number_points);
|
|
|
+ test_mpdpf_with_param::<DummyMpDpf<u64>>(log_domain_size, 1 << log_number_points);
|
|
|
}
|
|
|
}
|
|
|
}
|