spdpf.rs 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. use core::fmt::Debug;
  2. use core::marker::PhantomData;
  3. use core::ops::{Add, Sub};
  4. use num::traits::Zero;
  5. use rand::{thread_rng, Rng};
  6. use utils::bit_decompose::bit_decompose;
  7. use utils::fixed_key_aes::FixedKeyAes;
  8. use utils::pseudorandom_conversion::{PRConvertTo, PRConverter};
  9. pub trait SinglePointDpfKey: Clone + Debug {
  10. fn get_party_id(&self) -> usize;
  11. fn get_log_domain_size(&self) -> u64;
  12. }
  13. pub trait SinglePointDpf {
  14. type Key: SinglePointDpfKey;
  15. type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
  16. fn generate_keys(log_domain_size: u64, alpha: u64, beta: Self::Value)
  17. -> (Self::Key, Self::Key);
  18. fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
  19. fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
  20. (0..(1 << key.get_log_domain_size()))
  21. .map(|x| Self::evaluate_at(&key, x))
  22. .collect()
  23. }
  24. }
  25. #[derive(Clone, Copy, Debug)]
  26. pub struct DummySpDpfKey<V: Copy + Debug> {
  27. party_id: usize,
  28. log_domain_size: u64,
  29. alpha: u64,
  30. beta: V,
  31. }
  32. impl<V> SinglePointDpfKey for DummySpDpfKey<V>
  33. where
  34. V: Copy + Debug,
  35. {
  36. fn get_party_id(&self) -> usize {
  37. self.party_id
  38. }
  39. fn get_log_domain_size(&self) -> u64 {
  40. self.log_domain_size
  41. }
  42. }
  43. pub struct DummySpDpf<V>
  44. where
  45. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  46. {
  47. phantom: PhantomData<V>,
  48. }
  49. impl<V> SinglePointDpf for DummySpDpf<V>
  50. where
  51. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  52. {
  53. type Key = DummySpDpfKey<V>;
  54. type Value = V;
  55. fn generate_keys(log_domain_size: u64, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  56. assert!(alpha < (1 << log_domain_size));
  57. (
  58. DummySpDpfKey {
  59. party_id: 0,
  60. log_domain_size,
  61. alpha,
  62. beta,
  63. },
  64. DummySpDpfKey {
  65. party_id: 1,
  66. log_domain_size,
  67. alpha,
  68. beta,
  69. },
  70. )
  71. }
  72. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  73. if key.get_party_id() == 0 && index == key.alpha {
  74. key.beta
  75. } else {
  76. V::zero()
  77. }
  78. }
  79. }
  80. /// Implementation of the Half-Tree DPF scheme from Guo et al. (ePrint 2022/1431, Figure 8)
  81. #[derive(Clone, Debug)]
  82. pub struct HalfTreeSpDpfKey<V: Copy + Debug> {
  83. /// party id b
  84. party_id: usize,
  85. /// n where domain size is N := 2^n
  86. log_domain_size: u64,
  87. /// (s_b^0 || t_b^0) and t_b^0 is the LSB
  88. party_seed: u128,
  89. /// vector of length n: CW_1, ..., CW_(n-1)
  90. correction_words: Vec<u128>,
  91. /// high part of CW_n = (HCW, [LCW[0], LCW[1]])
  92. hcw: u128,
  93. /// low parts of CW_n = (HCW, [LCW[0], LCW[1]])
  94. lcw: [bool; 2],
  95. /// CW_(n+1)
  96. correction_word_np1: V,
  97. }
  98. impl<V> SinglePointDpfKey for HalfTreeSpDpfKey<V>
  99. where
  100. V: Copy + Debug,
  101. {
  102. fn get_party_id(&self) -> usize {
  103. self.party_id
  104. }
  105. fn get_log_domain_size(&self) -> u64 {
  106. self.log_domain_size
  107. }
  108. }
  109. pub struct HalfTreeSpDpf<V>
  110. where
  111. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  112. {
  113. phantom: PhantomData<V>,
  114. }
  115. impl<V> HalfTreeSpDpf<V>
  116. where
  117. V: Add<Output = V> + Sub<Output = V> + Copy + Debug + Eq + Zero,
  118. {
  119. const FIXED_KEY_AES_KEY: [u8; 16] =
  120. 0xdead_beef_1337_4247_dead_beef_1337_4247_u128.to_le_bytes();
  121. const HASH_KEY: u128 = 0xc000ffee_c0ffffee_c0ffeeee_c00ffeee_u128;
  122. }
  123. impl<V> SinglePointDpf for HalfTreeSpDpf<V>
  124. where
  125. V: Add<Output = V> + Sub<Output = V> + Copy + Debug + Eq + Zero,
  126. PRConverter: PRConvertTo<V>,
  127. {
  128. type Key = HalfTreeSpDpfKey<V>;
  129. type Value = V;
  130. fn generate_keys(log_domain_size: u64, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  131. assert!(alpha < (1 << log_domain_size));
  132. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  133. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  134. let convert = |x: u128| -> V { PRConverter::convert(x) };
  135. let mut rng = thread_rng();
  136. let n = log_domain_size as usize;
  137. let alpha_bits: Vec<bool> = bit_decompose(alpha, n);
  138. let delta = rng.gen::<u128>() | 1u128;
  139. let mut correction_words = Vec::<u128>::with_capacity(n - 1);
  140. let mut st_0 = rng.gen::<u128>();
  141. let mut st_1 = st_0 ^ delta;
  142. let party_seeds = (st_0, st_1);
  143. for i in 0..(n - 1) as usize {
  144. let cw_i = hash(st_0) ^ hash(st_1) ^ (1 - alpha_bits[i] as u128) * delta;
  145. st_0 = hash(st_0) ^ alpha_bits[i] as u128 * (st_0) ^ (st_0 & 1) * cw_i;
  146. st_1 = hash(st_1) ^ alpha_bits[i] as u128 * (st_1) ^ (st_1 & 1) * cw_i;
  147. correction_words.push(cw_i);
  148. }
  149. let high_low = [[hash(st_0), hash(st_0 ^ 1)], [hash(st_1), hash(st_1 ^ 1)]];
  150. const HIGH_MASK: u128 = u128::MAX - 1;
  151. const LOW_MASK: u128 = 1u128;
  152. let a_n = alpha_bits[n - 1];
  153. let hcw = (high_low[0][1 - a_n as usize] ^ high_low[1][1 - a_n as usize]) & HIGH_MASK;
  154. let lcw = [
  155. ((high_low[0][0] ^ high_low[1][0] ^ (1 - a_n as u128)) & LOW_MASK) != 0,
  156. ((high_low[0][1] ^ high_low[1][1] ^ a_n as u128) & LOW_MASK) != 0,
  157. ];
  158. st_0 = high_low[0][a_n as usize] ^ (st_0 & 1) * (hcw | lcw[a_n as usize] as u128);
  159. st_1 = high_low[1][a_n as usize] ^ (st_1 & 1) * (hcw | lcw[a_n as usize] as u128);
  160. let correction_word_np1: V = match (st_0 & 1).wrapping_sub(st_1 & 1) {
  161. u128::MAX => convert(st_0 >> 1) - convert(st_1 >> 1) - beta,
  162. 0 => V::zero(),
  163. 1 => convert(st_1 >> 1) - convert(st_0 >> 1) + beta,
  164. _ => panic!("should not happend, since matching a difference of two bits"),
  165. };
  166. (
  167. HalfTreeSpDpfKey {
  168. party_id: 0,
  169. log_domain_size,
  170. party_seed: party_seeds.0,
  171. correction_words: correction_words.clone(),
  172. hcw,
  173. lcw,
  174. correction_word_np1,
  175. },
  176. HalfTreeSpDpfKey {
  177. party_id: 1,
  178. log_domain_size,
  179. party_seed: party_seeds.1,
  180. correction_words,
  181. hcw,
  182. lcw,
  183. correction_word_np1,
  184. },
  185. )
  186. }
  187. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  188. assert!(index < (1 << key.log_domain_size));
  189. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  190. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  191. let convert = |x: u128| -> V { PRConverter::convert(x) };
  192. let n = key.log_domain_size as usize;
  193. let index_bits: Vec<bool> = bit_decompose(index, n);
  194. let mut st_b = key.party_seed;
  195. for i in 0..n - 1 {
  196. st_b = hash(st_b) ^ index_bits[i] as u128 * st_b ^ (st_b & 1) * key.correction_words[i];
  197. }
  198. let x_n = index_bits[n - 1];
  199. let high_low_b_xn = hash(st_b ^ x_n as u128);
  200. st_b = high_low_b_xn ^ (st_b & 1) * (key.hcw | key.lcw[x_n as usize] as u128);
  201. let value = convert(st_b >> 1)
  202. + if st_b & 1 == 0 {
  203. V::zero()
  204. } else {
  205. key.correction_word_np1
  206. };
  207. if key.party_id == 0 {
  208. value
  209. } else {
  210. V::zero() - value
  211. }
  212. }
  213. }
  214. #[cfg(test)]
  215. mod tests {
  216. use super::*;
  217. use core::num::Wrapping;
  218. use rand::distributions::{Distribution, Standard};
  219. use rand::{thread_rng, Rng};
  220. fn test_spdpf_with_param<SPDPF: SinglePointDpf>(log_domain_size: u64)
  221. where
  222. Standard: Distribution<SPDPF::Value>,
  223. {
  224. let domain_size = 1 << log_domain_size;
  225. let alpha = thread_rng().gen_range(0..domain_size);
  226. let beta = thread_rng().gen();
  227. let (key_0, key_1) = SPDPF::generate_keys(log_domain_size, alpha, beta);
  228. let out_0 = SPDPF::evaluate_domain(&key_0);
  229. let out_1 = SPDPF::evaluate_domain(&key_1);
  230. for i in 0..domain_size {
  231. let value = SPDPF::evaluate_at(&key_0, i) + SPDPF::evaluate_at(&key_1, i);
  232. assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
  233. if i == alpha {
  234. assert_eq!(value, beta);
  235. } else {
  236. assert_eq!(value, SPDPF::Value::zero());
  237. }
  238. }
  239. }
  240. #[test]
  241. fn test_spdpf() {
  242. for log_domain_size in 5..10 {
  243. test_spdpf_with_param::<DummySpDpf<u64>>(log_domain_size);
  244. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(log_domain_size);
  245. }
  246. }
  247. }