mpdpf.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. //! Trait definitions and implementations of multi-point distributed point functions (MP-DPFs).
  2. use crate::spdpf::SinglePointDpf;
  3. use bincode;
  4. use core::fmt;
  5. use core::fmt::Debug;
  6. use core::marker::PhantomData;
  7. use core::ops::{Add, AddAssign};
  8. use num::traits::Zero;
  9. use rayon::prelude::*;
  10. use utils::cuckoo::{
  11. Hasher as CuckooHasher, Parameters as CuckooParameters,
  12. NUMBER_HASH_FUNCTIONS as CUCKOO_NUMBER_HASH_FUNCTIONS,
  13. };
  14. use utils::hash::HashFunction;
  15. /// Trait for the keys of a multi-point DPF scheme.
  16. pub trait MultiPointDpfKey: Clone + Debug {
  17. /// Return the party ID, 0 or 1, corresponding to this key.
  18. fn get_party_id(&self) -> usize;
  19. /// Return the domain size of the shared function.
  20. fn get_domain_size(&self) -> usize;
  21. /// Return the number of (possibly) non-zero points of the shared function.
  22. fn get_number_points(&self) -> usize;
  23. }
  24. /// Trait for a single-point DPF scheme.
  25. pub trait MultiPointDpf {
  26. /// The key type of the scheme.
  27. type Key: MultiPointDpfKey;
  28. /// The value type of the scheme.
  29. type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
  30. /// Constructor for the MP-DPF scheme with a given domain size and number of points.
  31. ///
  32. /// Having a stateful scheme, allows for reusable precomputation.
  33. fn new(domain_size: usize, number_points: usize) -> Self;
  34. /// Return the domain size.
  35. fn get_domain_size(&self) -> usize;
  36. /// Return the number of (possibly) non-zero points.
  37. fn get_number_points(&self) -> usize;
  38. /// Run a possible precomputation phase.
  39. fn precompute(&mut self) {}
  40. /// Key generation for a given `domain_size`, an index `alpha` and a value `beta`.
  41. ///
  42. /// The shared point function is `f: {0, ..., domain_size - 1} -> Self::Value` such that
  43. /// `f(alpha_i) = beta_i` and `f(x) = 0` for `x` is not one of the `alpha_i`.
  44. fn generate_keys(&self, alphas: &[u64], betas: &[Self::Value]) -> (Self::Key, Self::Key);
  45. /// Evaluation using a DPF key on a single `index` from `{0, ..., domain_size - 1}`.
  46. fn evaluate_at(&self, key: &Self::Key, index: u64) -> Self::Value;
  47. /// Evaluation using a DPF key on the whole domain.
  48. ///
  49. /// This might be implemented more efficiently than just repeatedly calling
  50. /// [`Self::evaluate_at`].
  51. fn evaluate_domain(&self, key: &Self::Key) -> Vec<Self::Value> {
  52. (0..key.get_domain_size())
  53. .map(|x| self.evaluate_at(key, x as u64))
  54. .collect()
  55. }
  56. }
  57. /// Key type for the insecure [DummyMpDpf] scheme, which trivially contains the defining parameters
  58. /// `alpha_i` and `beta_i`.
  59. #[derive(Clone, Debug, bincode::Encode, bincode::Decode)]
  60. pub struct DummyMpDpfKey<V: Copy + Debug> {
  61. party_id: usize,
  62. domain_size: usize,
  63. number_points: usize,
  64. alphas: Vec<u64>,
  65. betas: Vec<V>,
  66. }
  67. impl<V> MultiPointDpfKey for DummyMpDpfKey<V>
  68. where
  69. V: Copy + Debug,
  70. {
  71. fn get_party_id(&self) -> usize {
  72. self.party_id
  73. }
  74. fn get_domain_size(&self) -> usize {
  75. self.domain_size
  76. }
  77. fn get_number_points(&self) -> usize {
  78. self.number_points
  79. }
  80. }
  81. /// Insecure MP-DPF scheme for testing purposes.
  82. pub struct DummyMpDpf<V>
  83. where
  84. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  85. {
  86. domain_size: usize,
  87. number_points: usize,
  88. phantom: PhantomData<V>,
  89. }
  90. impl<V> MultiPointDpf for DummyMpDpf<V>
  91. where
  92. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  93. {
  94. type Key = DummyMpDpfKey<V>;
  95. type Value = V;
  96. fn new(domain_size: usize, number_points: usize) -> Self {
  97. Self {
  98. domain_size,
  99. number_points,
  100. phantom: PhantomData,
  101. }
  102. }
  103. fn get_domain_size(&self) -> usize {
  104. self.domain_size
  105. }
  106. fn get_number_points(&self) -> usize {
  107. self.number_points
  108. }
  109. fn generate_keys(&self, alphas: &[u64], betas: &[V]) -> (Self::Key, Self::Key) {
  110. assert_eq!(
  111. alphas.len(),
  112. self.number_points,
  113. "number of points does not match constructor argument"
  114. );
  115. assert_eq!(
  116. alphas.len(),
  117. betas.len(),
  118. "alphas and betas must be the same size"
  119. );
  120. assert!(
  121. alphas
  122. .iter()
  123. .all(|&alpha| alpha < (self.domain_size as u64)),
  124. "all alphas must be in the domain"
  125. );
  126. assert!(
  127. alphas.windows(2).all(|w| w[0] < w[1]),
  128. "alphas must be sorted"
  129. );
  130. (
  131. DummyMpDpfKey {
  132. party_id: 0,
  133. domain_size: self.domain_size,
  134. number_points: self.number_points,
  135. alphas: alphas.to_vec(),
  136. betas: betas.to_vec(),
  137. },
  138. DummyMpDpfKey {
  139. party_id: 1,
  140. domain_size: self.domain_size,
  141. number_points: self.number_points,
  142. alphas: alphas.to_vec(),
  143. betas: betas.to_vec(),
  144. },
  145. )
  146. }
  147. fn evaluate_at(&self, key: &Self::Key, index: u64) -> V {
  148. assert_eq!(self.domain_size, key.domain_size);
  149. assert_eq!(self.number_points, key.number_points);
  150. if key.get_party_id() == 0 {
  151. match key.alphas.binary_search(&index) {
  152. Ok(i) => key.betas[i],
  153. Err(_) => V::zero(),
  154. }
  155. } else {
  156. V::zero()
  157. }
  158. }
  159. }
  160. /// Key type for the [SmartMpDpf] scheme.
  161. pub struct SmartMpDpfKey<SPDPF, H>
  162. where
  163. SPDPF: SinglePointDpf,
  164. H: HashFunction<u16>,
  165. {
  166. party_id: usize,
  167. domain_size: usize,
  168. number_points: usize,
  169. spdpf_keys: Vec<Option<SPDPF::Key>>,
  170. cuckoo_parameters: CuckooParameters<H, u16>,
  171. }
  172. impl<SPDPF, H> Debug for SmartMpDpfKey<SPDPF, H>
  173. where
  174. SPDPF: SinglePointDpf,
  175. H: HashFunction<u16>,
  176. {
  177. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
  178. let (newline, indentation) = if f.alternate() {
  179. ("\n", " ")
  180. } else {
  181. (" ", "")
  182. };
  183. write!(f, "SmartMpDpfKey<SPDPF, H>{{{newline}")?;
  184. write!(
  185. f,
  186. "{}party_id: {:?},{}",
  187. indentation, self.party_id, newline
  188. )?;
  189. write!(
  190. f,
  191. "{}domain_size: {:?},{}",
  192. indentation, self.domain_size, newline
  193. )?;
  194. write!(
  195. f,
  196. "{}number_points: {:?},{}",
  197. indentation, self.number_points, newline
  198. )?;
  199. if f.alternate() {
  200. writeln!(f, " spdpf_keys:")?;
  201. for (i, k) in self.spdpf_keys.iter().enumerate() {
  202. writeln!(f, " spdpf_keys[{i}]: {k:?}")?;
  203. }
  204. } else {
  205. write!(f, " spdpf_keys: {:?},", self.spdpf_keys)?;
  206. }
  207. write!(
  208. f,
  209. "{}cuckoo_parameters: {:?}{}",
  210. indentation, self.cuckoo_parameters, newline
  211. )?;
  212. write!(f, "}}")?;
  213. Ok(())
  214. }
  215. }
  216. impl<SPDPF, H> Clone for SmartMpDpfKey<SPDPF, H>
  217. where
  218. SPDPF: SinglePointDpf,
  219. H: HashFunction<u16>,
  220. {
  221. fn clone(&self) -> Self {
  222. Self {
  223. party_id: self.party_id,
  224. domain_size: self.domain_size,
  225. number_points: self.number_points,
  226. spdpf_keys: self.spdpf_keys.clone(),
  227. cuckoo_parameters: self.cuckoo_parameters,
  228. }
  229. }
  230. }
  231. impl<SPDPF, H> bincode::Encode for SmartMpDpfKey<SPDPF, H>
  232. where
  233. SPDPF: SinglePointDpf,
  234. SPDPF::Key: bincode::Encode,
  235. H: HashFunction<u16>,
  236. CuckooParameters<H, u16>: bincode::Encode,
  237. {
  238. fn encode<E: bincode::enc::Encoder>(
  239. &self,
  240. encoder: &mut E,
  241. ) -> core::result::Result<(), bincode::error::EncodeError> {
  242. bincode::Encode::encode(&self.party_id, encoder)?;
  243. bincode::Encode::encode(&self.domain_size, encoder)?;
  244. bincode::Encode::encode(&self.number_points, encoder)?;
  245. bincode::Encode::encode(&self.spdpf_keys, encoder)?;
  246. bincode::Encode::encode(&self.cuckoo_parameters, encoder)?;
  247. Ok(())
  248. }
  249. }
  250. impl<SPDPF, H> bincode::Decode for SmartMpDpfKey<SPDPF, H>
  251. where
  252. SPDPF: SinglePointDpf,
  253. SPDPF::Key: bincode::Decode,
  254. H: HashFunction<u16>,
  255. CuckooParameters<H, u16>: bincode::Decode,
  256. {
  257. fn decode<D: bincode::de::Decoder>(
  258. decoder: &mut D,
  259. ) -> core::result::Result<Self, bincode::error::DecodeError> {
  260. Ok(Self {
  261. party_id: bincode::Decode::decode(decoder)?,
  262. domain_size: bincode::Decode::decode(decoder)?,
  263. number_points: bincode::Decode::decode(decoder)?,
  264. spdpf_keys: bincode::Decode::decode(decoder)?,
  265. cuckoo_parameters: bincode::Decode::decode(decoder)?,
  266. })
  267. }
  268. }
  269. impl<SPDPF, H> MultiPointDpfKey for SmartMpDpfKey<SPDPF, H>
  270. where
  271. SPDPF: SinglePointDpf,
  272. H: HashFunction<u16>,
  273. {
  274. fn get_party_id(&self) -> usize {
  275. self.party_id
  276. }
  277. fn get_domain_size(&self) -> usize {
  278. self.domain_size
  279. }
  280. fn get_number_points(&self) -> usize {
  281. self.number_points
  282. }
  283. }
  284. /// Precomputed state for [SmartMpDpf].
  285. struct SmartMpDpfPrecomputationData<H: HashFunction<u16>> {
  286. pub cuckoo_parameters: CuckooParameters<H, u16>,
  287. pub hasher: CuckooHasher<H, u16>,
  288. pub hashes: [Vec<u16>; CUCKOO_NUMBER_HASH_FUNCTIONS],
  289. pub simple_htable: Vec<Vec<u64>>,
  290. pub bucket_sizes: Vec<usize>,
  291. pub position_map_lookup_table: Vec<[(usize, usize); 3]>,
  292. }
  293. /// MP-DPF construction using SP-DPFs and Cuckoo hashing from [Schoppmann et al. (CCS'19), Section
  294. /// 5](https://eprint.iacr.org/2019/1084.pdf#page=7).
  295. pub struct SmartMpDpf<V, SPDPF, H>
  296. where
  297. V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero,
  298. SPDPF: SinglePointDpf<Value = V>,
  299. H: HashFunction<u16>,
  300. {
  301. domain_size: usize,
  302. number_points: usize,
  303. precomputation_data: Option<SmartMpDpfPrecomputationData<H>>,
  304. phantom_v: PhantomData<V>,
  305. phantom_s: PhantomData<SPDPF>,
  306. phantom_h: PhantomData<H>,
  307. }
  308. impl<V, SPDPF, H> SmartMpDpf<V, SPDPF, H>
  309. where
  310. V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero + Send + Sync,
  311. SPDPF: SinglePointDpf<Value = V>,
  312. H: HashFunction<u16>,
  313. H::Description: Sync,
  314. {
  315. fn precompute_hashes(
  316. domain_size: usize,
  317. number_points: usize,
  318. ) -> SmartMpDpfPrecomputationData<H> {
  319. let seed: [u8; 32] = [
  320. 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42,
  321. 42, 42, 42, 42, 42, 42, 42, 42, 42, 42,
  322. ];
  323. let cuckoo_parameters = CuckooParameters::from_seed(number_points, seed);
  324. assert!(
  325. cuckoo_parameters.get_number_buckets() < (1 << u16::BITS),
  326. "too many buckets, use larger type for hash values"
  327. );
  328. let hasher = CuckooHasher::<H, u16>::new(cuckoo_parameters);
  329. let hashes = hasher.hash_domain(domain_size as u64);
  330. let simple_htable =
  331. hasher.hash_domain_into_buckets_given_hashes(domain_size as u64, &hashes);
  332. let bucket_sizes = CuckooHasher::<H, u16>::compute_bucket_sizes(&simple_htable);
  333. let position_map_lookup_table =
  334. CuckooHasher::<H, u16>::compute_pos_lookup_table(domain_size as u64, &simple_htable);
  335. SmartMpDpfPrecomputationData {
  336. cuckoo_parameters,
  337. hasher,
  338. hashes,
  339. simple_htable,
  340. bucket_sizes,
  341. position_map_lookup_table,
  342. }
  343. }
  344. }
  345. impl<V, SPDPF, H> MultiPointDpf for SmartMpDpf<V, SPDPF, H>
  346. where
  347. V: Add<Output = V> + AddAssign + Copy + Debug + Eq + Zero + Send + Sync,
  348. SPDPF: SinglePointDpf<Value = V>,
  349. SPDPF::Key: Sync,
  350. H: HashFunction<u16>,
  351. H::Description: Sync,
  352. {
  353. type Key = SmartMpDpfKey<SPDPF, H>;
  354. type Value = V;
  355. fn new(domain_size: usize, number_points: usize) -> Self {
  356. assert!(domain_size < (1 << u32::BITS));
  357. Self {
  358. domain_size,
  359. number_points,
  360. precomputation_data: None,
  361. phantom_v: PhantomData,
  362. phantom_s: PhantomData,
  363. phantom_h: PhantomData,
  364. }
  365. }
  366. fn get_domain_size(&self) -> usize {
  367. self.domain_size
  368. }
  369. fn get_number_points(&self) -> usize {
  370. self.domain_size
  371. }
  372. fn precompute(&mut self) {
  373. if self.precomputation_data.is_none() {
  374. self.precomputation_data = Some(Self::precompute_hashes(
  375. self.domain_size,
  376. self.number_points,
  377. ));
  378. }
  379. }
  380. fn generate_keys(&self, alphas: &[u64], betas: &[Self::Value]) -> (Self::Key, Self::Key) {
  381. assert_eq!(alphas.len(), betas.len());
  382. debug_assert!(alphas.windows(2).all(|w| w[0] < w[1]));
  383. debug_assert!(alphas.iter().all(|&alpha| alpha < self.domain_size as u64));
  384. let number_points = alphas.len();
  385. // if not data is precomputed, do it now
  386. // (&self is not mut, so we cannot store the new data here nor call precompute() ...)
  387. let mut precomputation_data_fresh: Option<SmartMpDpfPrecomputationData<H>> = None;
  388. if self.precomputation_data.is_none() {
  389. precomputation_data_fresh = Some(Self::precompute_hashes(
  390. self.domain_size,
  391. self.number_points,
  392. ));
  393. }
  394. // select either the precomputed or the freshly computed data
  395. let precomputation_data = self
  396. .precomputation_data
  397. .as_ref()
  398. .unwrap_or_else(|| precomputation_data_fresh.as_ref().unwrap());
  399. let cuckoo_parameters = &precomputation_data.cuckoo_parameters;
  400. let hasher = &precomputation_data.hasher;
  401. let (cuckoo_table_items, cuckoo_table_indices) = hasher.cuckoo_hash_items(alphas);
  402. let position_map_lookup_table = &precomputation_data.position_map_lookup_table;
  403. let pos = |bucket_i: usize, item: u64| -> u64 {
  404. CuckooHasher::<H, u16>::pos_lookup(position_map_lookup_table, bucket_i, item)
  405. };
  406. let number_buckets = hasher.get_parameters().get_number_buckets();
  407. let bucket_sizes = &precomputation_data.bucket_sizes;
  408. let mut keys_0 = Vec::<Option<SPDPF::Key>>::with_capacity(number_buckets);
  409. let mut keys_1 = Vec::<Option<SPDPF::Key>>::with_capacity(number_buckets);
  410. for bucket_i in 0..number_buckets {
  411. // if bucket is empty, add invalid dummy keys to the arrays to make the
  412. // indices work
  413. if bucket_sizes[bucket_i] == 0 {
  414. keys_0.push(None);
  415. keys_1.push(None);
  416. continue;
  417. }
  418. let (alpha, beta) =
  419. if cuckoo_table_items[bucket_i] != CuckooHasher::<H, u16>::UNOCCUPIED {
  420. let alpha = pos(bucket_i, cuckoo_table_items[bucket_i]);
  421. let beta = betas[cuckoo_table_indices[bucket_i]];
  422. (alpha, beta)
  423. } else {
  424. (0, V::zero())
  425. };
  426. let (key_0, key_1) = SPDPF::generate_keys(bucket_sizes[bucket_i], alpha, beta);
  427. keys_0.push(Some(key_0));
  428. keys_1.push(Some(key_1));
  429. }
  430. (
  431. SmartMpDpfKey::<SPDPF, H> {
  432. party_id: 0,
  433. domain_size: self.domain_size,
  434. number_points,
  435. spdpf_keys: keys_0,
  436. cuckoo_parameters: *cuckoo_parameters,
  437. },
  438. SmartMpDpfKey::<SPDPF, H> {
  439. party_id: 1,
  440. domain_size: self.domain_size,
  441. number_points,
  442. spdpf_keys: keys_1,
  443. cuckoo_parameters: *cuckoo_parameters,
  444. },
  445. )
  446. }
  447. fn evaluate_at(&self, key: &Self::Key, index: u64) -> Self::Value {
  448. assert_eq!(self.domain_size, key.domain_size);
  449. assert_eq!(self.number_points, key.number_points);
  450. assert_eq!(key.domain_size, self.domain_size);
  451. assert!(index < self.domain_size as u64);
  452. let hasher = CuckooHasher::<H, u16>::new(key.cuckoo_parameters);
  453. let hashes = hasher.hash_items(&[index]);
  454. let simple_htable = hasher.hash_domain_into_buckets(self.domain_size as u64);
  455. let pos = |bucket_i: usize, item: u64| -> u64 {
  456. let idx = simple_htable[bucket_i].partition_point(|x| x < &item);
  457. debug_assert!(idx != simple_htable[bucket_i].len());
  458. debug_assert_eq!(item, simple_htable[bucket_i][idx]);
  459. debug_assert!(idx == 0 || simple_htable[bucket_i][idx - 1] != item);
  460. idx as u64
  461. };
  462. let mut output = {
  463. let hash = H::hash_value_as_usize(hashes[0][0]);
  464. debug_assert!(key.spdpf_keys[hash].is_some());
  465. let sp_key = key.spdpf_keys[hash].as_ref().unwrap();
  466. debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
  467. SPDPF::evaluate_at(sp_key, pos(hash, index))
  468. };
  469. // prevent adding the same term multiple times when we have collisions
  470. let mut hash_bit_map = [0u8; 2];
  471. if hashes[0][0] != hashes[1][0] {
  472. // hash_bit_map[i] |= 1;
  473. hash_bit_map[0] = 1;
  474. }
  475. if hashes[0][0] != hashes[2][0] && hashes[1][0] != hashes[2][0] {
  476. // hash_bit_map[i] |= 2;
  477. hash_bit_map[1] = 1;
  478. }
  479. for j in 1..CUCKOO_NUMBER_HASH_FUNCTIONS {
  480. if hash_bit_map[j - 1] == 0 {
  481. continue;
  482. }
  483. let hash = H::hash_value_as_usize(hashes[j][0]);
  484. debug_assert!(key.spdpf_keys[hash].is_some());
  485. let sp_key = key.spdpf_keys[hash].as_ref().unwrap();
  486. debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
  487. output += SPDPF::evaluate_at(sp_key, pos(hash, index));
  488. }
  489. output
  490. }
  491. fn evaluate_domain(&self, key: &Self::Key) -> Vec<Self::Value> {
  492. assert_eq!(self.domain_size, key.domain_size);
  493. assert_eq!(self.number_points, key.number_points);
  494. let domain_size = self.domain_size as u64;
  495. // if not data is precomputed, do it now
  496. // (&self is not mut, so we cannot store the new data here nor call precompute() ...)
  497. let mut precomputation_data_fresh: Option<SmartMpDpfPrecomputationData<H>> = None;
  498. if self.precomputation_data.is_none() {
  499. precomputation_data_fresh = Some(Self::precompute_hashes(
  500. self.domain_size,
  501. self.number_points,
  502. ));
  503. }
  504. // select either the precomputed or the freshly computed data
  505. let precomputation_data = self
  506. .precomputation_data
  507. .as_ref()
  508. .unwrap_or_else(|| precomputation_data_fresh.as_ref().unwrap());
  509. let hashes = &precomputation_data.hashes;
  510. let simple_htable = &precomputation_data.simple_htable;
  511. let position_map_lookup_table = &precomputation_data.position_map_lookup_table;
  512. let pos = |bucket_i: usize, item: u64| -> u64 {
  513. CuckooHasher::<H, u16>::pos_lookup(position_map_lookup_table, bucket_i, item)
  514. };
  515. let sp_dpf_full_domain_evaluations: Vec<Vec<V>> = key
  516. .spdpf_keys
  517. .par_iter()
  518. .map(|sp_key_opt| {
  519. sp_key_opt
  520. .as_ref()
  521. .map_or(vec![], |sp_key| SPDPF::evaluate_domain(sp_key))
  522. })
  523. .collect();
  524. let spdpf_evaluate_at =
  525. |hash: usize, index| sp_dpf_full_domain_evaluations[hash][pos(hash, index) as usize];
  526. let outputs: Vec<_> = (0..domain_size)
  527. .into_par_iter()
  528. .map(|index| {
  529. let mut output = {
  530. let hash = H::hash_value_as_usize(hashes[0][index as usize]);
  531. debug_assert!(key.spdpf_keys[hash].is_some());
  532. debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
  533. spdpf_evaluate_at(hash, index)
  534. };
  535. // prevent adding the same term multiple times when we have collisions
  536. let mut hash_bit_map = [0u8; 2];
  537. if hashes[0][index as usize] != hashes[1][index as usize] {
  538. hash_bit_map[0] = 1;
  539. }
  540. if hashes[0][index as usize] != hashes[2][index as usize]
  541. && hashes[1][index as usize] != hashes[2][index as usize]
  542. {
  543. hash_bit_map[1] = 1;
  544. }
  545. for j in 1..CUCKOO_NUMBER_HASH_FUNCTIONS {
  546. if hash_bit_map[j - 1] == 0 {
  547. continue;
  548. }
  549. let hash = H::hash_value_as_usize(hashes[j][index as usize]);
  550. debug_assert!(key.spdpf_keys[hash].is_some());
  551. debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
  552. output += spdpf_evaluate_at(hash, index);
  553. }
  554. output
  555. })
  556. .collect();
  557. outputs
  558. }
  559. }
  560. #[cfg(test)]
  561. mod tests {
  562. use super::*;
  563. use crate::spdpf::DummySpDpf;
  564. use rand::distributions::{Distribution, Standard};
  565. use rand::{thread_rng, Rng};
  566. use std::num::Wrapping;
  567. use utils::hash::AesHashFunction;
  568. fn test_mpdpf_with_param<MPDPF: MultiPointDpf>(
  569. log_domain_size: u32,
  570. number_points: usize,
  571. precomputation: bool,
  572. ) where
  573. Standard: Distribution<MPDPF::Value>,
  574. {
  575. let domain_size = (1 << log_domain_size) as u64;
  576. assert!(number_points <= domain_size as usize);
  577. let alphas = {
  578. let mut alphas = Vec::<u64>::with_capacity(number_points);
  579. while alphas.len() < number_points {
  580. let x = thread_rng().gen_range(0..domain_size);
  581. match alphas.as_slice().binary_search(&x) {
  582. Ok(_) => continue,
  583. Err(i) => alphas.insert(i, x),
  584. }
  585. }
  586. alphas
  587. };
  588. let betas: Vec<MPDPF::Value> = (0..number_points).map(|_| thread_rng().gen()).collect();
  589. let mut mpdpf = MPDPF::new(domain_size as usize, number_points);
  590. if precomputation {
  591. mpdpf.precompute();
  592. }
  593. let (key_0, key_1) = mpdpf.generate_keys(&alphas, &betas);
  594. let out_0 = mpdpf.evaluate_domain(&key_0);
  595. let out_1 = mpdpf.evaluate_domain(&key_1);
  596. for i in 0..domain_size {
  597. let value = mpdpf.evaluate_at(&key_0, i) + mpdpf.evaluate_at(&key_1, i);
  598. assert_eq!(value, out_0[i as usize] + out_1[i as usize]);
  599. let expected_result = match alphas.binary_search(&i) {
  600. Ok(i) => betas[i],
  601. Err(_) => MPDPF::Value::zero(),
  602. };
  603. assert_eq!(value, expected_result, "wrong value at index {}", i);
  604. }
  605. }
  606. #[test]
  607. fn test_dummy_mpdpf() {
  608. type Value = Wrapping<u64>;
  609. for log_domain_size in 5..10 {
  610. for log_number_points in 0..5 {
  611. test_mpdpf_with_param::<DummyMpDpf<Value>>(
  612. log_domain_size,
  613. 1 << log_number_points,
  614. false,
  615. );
  616. }
  617. }
  618. }
  619. #[test]
  620. fn test_smart_mpdpf() {
  621. type Value = Wrapping<u64>;
  622. for log_domain_size in 5..7 {
  623. for log_number_points in 0..5 {
  624. for precomputation in [false, true] {
  625. test_mpdpf_with_param::<
  626. SmartMpDpf<Value, DummySpDpf<Value>, AesHashFunction<u16>>,
  627. >(log_domain_size, 1 << log_number_points, precomputation);
  628. }
  629. }
  630. }
  631. }
  632. }