spdpf.rs 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. use core::fmt::Debug;
  2. use core::marker::PhantomData;
  3. use core::ops::Add;
  4. use num::traits::Zero;
  5. pub trait SinglePointDpfKey: Copy + Debug {
  6. fn get_party_id(&self) -> usize;
  7. fn get_log_domain_size(&self) -> u64;
  8. }
  9. pub trait SinglePointDpf {
  10. type Key: SinglePointDpfKey;
  11. type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
  12. fn generate_keys(log_domain_size: u64, alpha: u64, beta: Self::Value)
  13. -> (Self::Key, Self::Key);
  14. fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
  15. fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
  16. (0..(1 << key.get_log_domain_size()))
  17. .map(|x| Self::evaluate_at(&key, x))
  18. .collect()
  19. }
  20. }
  21. #[derive(Clone, Copy, Debug)]
  22. pub struct DummySpDpfKey<V: Copy + Debug> {
  23. party_id: usize,
  24. log_domain_size: u64,
  25. alpha: u64,
  26. beta: V,
  27. }
  28. impl<V> SinglePointDpfKey for DummySpDpfKey<V>
  29. where
  30. V: Copy + Debug,
  31. {
  32. fn get_party_id(&self) -> usize {
  33. self.party_id
  34. }
  35. fn get_log_domain_size(&self) -> u64 {
  36. self.log_domain_size
  37. }
  38. }
  39. pub struct DummySpDpf<V>
  40. where
  41. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  42. {
  43. phantom: PhantomData<V>,
  44. }
  45. impl<V> SinglePointDpf for DummySpDpf<V>
  46. where
  47. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  48. {
  49. type Key = DummySpDpfKey<V>;
  50. type Value = V;
  51. fn generate_keys(log_domain_size: u64, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  52. assert!(alpha < (1 << log_domain_size));
  53. (
  54. DummySpDpfKey {
  55. party_id: 0,
  56. log_domain_size,
  57. alpha,
  58. beta,
  59. },
  60. DummySpDpfKey {
  61. party_id: 1,
  62. log_domain_size,
  63. alpha,
  64. beta,
  65. },
  66. )
  67. }
  68. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  69. if key.get_party_id() == 0 && index == key.alpha {
  70. key.beta
  71. } else {
  72. V::zero()
  73. }
  74. }
  75. }
  76. #[cfg(test)]
  77. mod tests {
  78. use super::*;
  79. use rand::distributions::{Distribution, Standard};
  80. use rand::{thread_rng, Rng};
  81. fn test_spdpf_with_param<SPDPF: SinglePointDpf>(log_domain_size: u64)
  82. where
  83. Standard: Distribution<SPDPF::Value>,
  84. {
  85. let domain_size = 1 << log_domain_size;
  86. let alpha = thread_rng().gen_range(0..domain_size);
  87. let beta = thread_rng().gen();
  88. let (key_0, key_1) = SPDPF::generate_keys(log_domain_size, alpha, beta);
  89. let out_0 = SPDPF::evaluate_domain(&key_0);
  90. let out_1 = SPDPF::evaluate_domain(&key_1);
  91. for i in 0..domain_size {
  92. let value = SPDPF::evaluate_at(&key_0, i) + SPDPF::evaluate_at(&key_1, i);
  93. assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
  94. if i == alpha {
  95. assert_eq!(value, beta);
  96. } else {
  97. assert_eq!(value, SPDPF::Value::zero());
  98. }
  99. }
  100. }
  101. #[test]
  102. fn test_spdpf() {
  103. for log_domain_size in 5..10 {
  104. test_spdpf_with_param::<DummySpDpf<u64>>(log_domain_size);
  105. }
  106. }
  107. }