spdpf.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. //! Trait definitions and implementations of single-point distributed point functions (SP-DPFs).
  2. use bincode;
  3. use core::fmt::Debug;
  4. use core::marker::PhantomData;
  5. use core::ops::{Add, Neg, Sub};
  6. use num::traits::Zero;
  7. use rand::{thread_rng, Rng};
  8. use utils::bit_decompose::bit_decompose;
  9. use utils::fixed_key_aes::FixedKeyAes;
  10. use utils::pseudorandom_conversion::{PRConvertTo, PRConverter};
  11. /// Trait for the keys of a single-point DPF scheme.
  12. pub trait SinglePointDpfKey: Clone + Debug {
  13. /// Return the party ID, 0 or 1, corresponding to this key.
  14. fn get_party_id(&self) -> usize;
  15. /// Return the domain size of the shared function.
  16. fn get_domain_size(&self) -> usize;
  17. }
  18. /// Trait for a single-point DPF scheme.
  19. pub trait SinglePointDpf {
  20. /// The key type of the scheme.
  21. type Key: SinglePointDpfKey;
  22. /// The value type of the scheme.
  23. type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
  24. /// Key generation for a given `domain_size`, an index `alpha` and a value `beta`.
  25. ///
  26. /// The shared point function is `f: {0, ..., domain_size - 1} -> Self::Value` such that
  27. /// `f(alpha) = beta` and `f(x) = 0` for `x != alpha`.
  28. fn generate_keys(domain_size: usize, alpha: u64, beta: Self::Value) -> (Self::Key, Self::Key);
  29. /// Evaluation using a DPF key on a single `index` from `{0, ..., domain_size - 1}`.
  30. fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
  31. /// Evaluation using a DPF key on the whole domain.
  32. ///
  33. /// This might be implemented more efficiently than just repeatedly calling
  34. /// [`Self::evaluate_at`].
  35. fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
  36. (0..key.get_domain_size())
  37. .map(|x| Self::evaluate_at(key, x as u64))
  38. .collect()
  39. }
  40. }
  41. /// Key type for the insecure [DummySpDpf] scheme, which trivially contains the defining parameters
  42. /// `alpha` and `beta`.
  43. #[derive(Clone, Copy, Debug, bincode::Encode, bincode::Decode)]
  44. pub struct DummySpDpfKey<V: Copy + Debug> {
  45. party_id: usize,
  46. domain_size: usize,
  47. alpha: u64,
  48. beta: V,
  49. }
  50. impl<V> SinglePointDpfKey for DummySpDpfKey<V>
  51. where
  52. V: Copy + Debug,
  53. {
  54. fn get_party_id(&self) -> usize {
  55. self.party_id
  56. }
  57. fn get_domain_size(&self) -> usize {
  58. self.domain_size
  59. }
  60. }
  61. /// Insecure SP-DPF scheme for testing purposes.
  62. pub struct DummySpDpf<V>
  63. where
  64. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  65. {
  66. phantom: PhantomData<V>,
  67. }
  68. impl<V> SinglePointDpf for DummySpDpf<V>
  69. where
  70. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  71. {
  72. type Key = DummySpDpfKey<V>;
  73. type Value = V;
  74. fn generate_keys(domain_size: usize, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  75. assert!(alpha < domain_size as u64);
  76. (
  77. DummySpDpfKey {
  78. party_id: 0,
  79. domain_size,
  80. alpha,
  81. beta,
  82. },
  83. DummySpDpfKey {
  84. party_id: 1,
  85. domain_size,
  86. alpha,
  87. beta,
  88. },
  89. )
  90. }
  91. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  92. if key.get_party_id() == 0 && index == key.alpha {
  93. key.beta
  94. } else {
  95. V::zero()
  96. }
  97. }
  98. fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
  99. let mut output = vec![V::zero(); key.domain_size];
  100. if key.get_party_id() == 0 {
  101. output[key.alpha as usize] = key.beta;
  102. }
  103. output
  104. }
  105. }
  106. /// Key type for the [HalfTreeSpDpf] scheme.
  107. #[derive(Clone, Debug, bincode::Encode, bincode::Decode)]
  108. pub struct HalfTreeSpDpfKey<V: Copy + Debug> {
  109. /// party id `b`
  110. party_id: usize,
  111. /// size `n` of the DPF's domain `[n]`
  112. domain_size: usize,
  113. /// `(s_b^0 || t_b^0)` and `t_b^0` is the LSB
  114. party_seed: u128,
  115. /// vector of length `n`: `CW_1, ..., CW_(n-1)`
  116. correction_words: Vec<u128>,
  117. /// high part of `CW_n = (HCW, [LCW[0], LCW[1]])`
  118. hcw: u128,
  119. /// low parts of `CW_n = (HCW, [LCW[0], LCW[1]])`
  120. lcw: [bool; 2],
  121. /// `CW_(n+1)`
  122. correction_word_np1: V,
  123. }
  124. impl<V> SinglePointDpfKey for HalfTreeSpDpfKey<V>
  125. where
  126. V: Copy + Debug,
  127. {
  128. fn get_party_id(&self) -> usize {
  129. self.party_id
  130. }
  131. fn get_domain_size(&self) -> usize {
  132. self.domain_size
  133. }
  134. }
  135. /// Implementation of the Half-Tree DPF scheme from Guo et al. ([ePrint 2022/1431, Figure
  136. /// 8](https://eprint.iacr.org/2022/1431.pdf#page=18)).
  137. pub struct HalfTreeSpDpf<V>
  138. where
  139. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  140. {
  141. phantom: PhantomData<V>,
  142. }
  143. impl<V> HalfTreeSpDpf<V>
  144. where
  145. V: Add<Output = V> + Sub<Output = V> + Copy + Debug + Eq + Zero,
  146. {
  147. const FIXED_KEY_AES_KEY: [u8; 16] =
  148. 0xdead_beef_1337_4247_dead_beef_1337_4247_u128.to_le_bytes();
  149. const HASH_KEY: u128 = 0xc000_ffee_c0ff_ffee_c0ff_eeee_c00f_feee_u128;
  150. }
  151. impl<V> SinglePointDpf for HalfTreeSpDpf<V>
  152. where
  153. V: Add<Output = V> + Sub<Output = V> + Neg<Output = V> + Copy + Debug + Eq + Zero,
  154. PRConverter: PRConvertTo<V>,
  155. {
  156. type Key = HalfTreeSpDpfKey<V>;
  157. type Value = V;
  158. fn generate_keys(domain_size: usize, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  159. assert!(alpha < domain_size as u64);
  160. let mut rng = thread_rng();
  161. if domain_size == 1 {
  162. // simply secret-share beta
  163. let beta_0: V = PRConverter::convert(rng.gen::<u128>());
  164. let beta_1: V = beta - beta_0;
  165. return (
  166. HalfTreeSpDpfKey {
  167. party_id: 0,
  168. domain_size,
  169. party_seed: Default::default(),
  170. correction_words: Default::default(),
  171. hcw: Default::default(),
  172. lcw: Default::default(),
  173. correction_word_np1: beta_0,
  174. },
  175. HalfTreeSpDpfKey {
  176. party_id: 1,
  177. domain_size,
  178. party_seed: Default::default(),
  179. correction_words: Default::default(),
  180. hcw: Default::default(),
  181. lcw: Default::default(),
  182. correction_word_np1: beta_1,
  183. },
  184. );
  185. }
  186. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  187. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  188. let convert = |x: u128| -> V { PRConverter::convert(x) };
  189. let tree_height = (domain_size as f64).log2().ceil() as usize;
  190. let alpha_bits: Vec<bool> = bit_decompose(alpha, tree_height);
  191. let delta = rng.gen::<u128>() | 1u128;
  192. let mut correction_words = Vec::<u128>::with_capacity(tree_height - 1);
  193. let mut st_0 = rng.gen::<u128>();
  194. let mut st_1 = st_0 ^ delta;
  195. let party_seeds = (st_0, st_1);
  196. debug_assert_eq!(alpha_bits.len(), tree_height);
  197. for alpha_i in alpha_bits.iter().copied().take(tree_height - 1) {
  198. let cw_i = hash(st_0) ^ hash(st_1) ^ ((1 - alpha_i as u128) * delta);
  199. st_0 = hash(st_0) ^ (alpha_i as u128 * st_0) ^ ((st_0 & 1) * cw_i);
  200. st_1 = hash(st_1) ^ (alpha_i as u128 * st_1) ^ ((st_1 & 1) * cw_i);
  201. correction_words.push(cw_i);
  202. }
  203. let high_low = [[hash(st_0), hash(st_0 ^ 1)], [hash(st_1), hash(st_1 ^ 1)]];
  204. const HIGH_MASK: u128 = u128::MAX - 1;
  205. const LOW_MASK: u128 = 1u128;
  206. let a_n = alpha_bits[tree_height - 1];
  207. let hcw = (high_low[0][1 - a_n as usize] ^ high_low[1][1 - a_n as usize]) & HIGH_MASK;
  208. let lcw = [
  209. ((high_low[0][0] ^ high_low[1][0] ^ (1 - a_n as u128)) & LOW_MASK) != 0,
  210. ((high_low[0][1] ^ high_low[1][1] ^ a_n as u128) & LOW_MASK) != 0,
  211. ];
  212. st_0 = high_low[0][a_n as usize] ^ ((st_0 & 1) * (hcw | lcw[a_n as usize] as u128));
  213. st_1 = high_low[1][a_n as usize] ^ ((st_1 & 1) * (hcw | lcw[a_n as usize] as u128));
  214. let correction_word_np1: V = match (st_0 & 1).wrapping_sub(st_1 & 1) {
  215. u128::MAX => convert(st_0 >> 1) - convert(st_1 >> 1) - beta,
  216. 0 => V::zero(),
  217. 1 => convert(st_1 >> 1) - convert(st_0 >> 1) + beta,
  218. _ => panic!("should not happend, since matching a difference of two bits"),
  219. };
  220. (
  221. HalfTreeSpDpfKey {
  222. party_id: 0,
  223. domain_size,
  224. party_seed: party_seeds.0,
  225. correction_words: correction_words.clone(),
  226. hcw,
  227. lcw,
  228. correction_word_np1,
  229. },
  230. HalfTreeSpDpfKey {
  231. party_id: 1,
  232. domain_size,
  233. party_seed: party_seeds.1,
  234. correction_words,
  235. hcw,
  236. lcw,
  237. correction_word_np1,
  238. },
  239. )
  240. }
  241. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  242. assert!(key.domain_size > 0);
  243. assert!(index < key.domain_size as u64);
  244. if key.domain_size == 1 {
  245. // beta is simply secret-shared
  246. return key.correction_word_np1;
  247. }
  248. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  249. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  250. let convert = |x: u128| -> V { PRConverter::convert(x) };
  251. let tree_height = (key.domain_size as f64).log2().ceil() as usize;
  252. let index_bits: Vec<bool> = bit_decompose(index, tree_height);
  253. debug_assert_eq!(index_bits.len(), tree_height);
  254. let mut st_b = key.party_seed;
  255. for (index_bit_i, correction_word_i) in index_bits
  256. .iter()
  257. .copied()
  258. .zip(key.correction_words.iter())
  259. .take(tree_height - 1)
  260. {
  261. st_b = hash(st_b) ^ (index_bit_i as u128 * st_b) ^ ((st_b & 1) * correction_word_i);
  262. }
  263. let x_n = index_bits[tree_height - 1];
  264. let high_low_b_xn = hash(st_b ^ x_n as u128);
  265. st_b = high_low_b_xn ^ ((st_b & 1) * (key.hcw | key.lcw[x_n as usize] as u128));
  266. let value = convert(st_b >> 1)
  267. + if st_b & 1 == 0 {
  268. V::zero()
  269. } else {
  270. key.correction_word_np1
  271. };
  272. if key.party_id == 0 {
  273. value
  274. } else {
  275. V::zero() - value
  276. }
  277. }
  278. fn evaluate_domain(key: &Self::Key) -> Vec<V> {
  279. assert!(key.domain_size > 0);
  280. if key.domain_size == 1 {
  281. // beta is simply secret-shared
  282. return vec![key.correction_word_np1];
  283. }
  284. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  285. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  286. let convert = |x: u128| -> V { PRConverter::convert(x) };
  287. let tree_height = (key.domain_size as f64).log2().ceil() as usize;
  288. let last_index = key.domain_size - 1;
  289. let mut seeds = vec![0u128; key.domain_size];
  290. seeds[0] = key.party_seed;
  291. // since the last layer is handled separately, we only need the following block if we have
  292. // more than one layer
  293. if tree_height > 1 {
  294. // iterate over the tree layer by layer
  295. for i in 0..(tree_height - 1) {
  296. // expand each node in this layer;
  297. // we need to iterate from right to left, since we reuse the same buffer
  298. for j in (0..(last_index >> (tree_height - i)) + 1).rev() {
  299. // for j in (0..(1 << i)).rev() {
  300. let st = seeds[j];
  301. let st_0 = hash(st) ^ ((st & 1) * key.correction_words[i]);
  302. let st_1 = hash(st) ^ st ^ ((st & 1) * key.correction_words[i]);
  303. seeds[2 * j] = st_0;
  304. seeds[2 * j + 1] = st_1;
  305. }
  306. }
  307. }
  308. // expand last layer
  309. {
  310. // handle the last expansion separately, since we might not need both outputs
  311. let j = last_index >> 1;
  312. let st = seeds[j];
  313. let st_0 = hash(st) ^ ((st & 1) * (key.hcw | key.lcw[0] as u128));
  314. seeds[2 * j] = st_0;
  315. // check if we need both outputs
  316. if key.domain_size & 1 == 0 {
  317. let st_1 = hash(st ^ 1) ^ ((st & 1) * (key.hcw | key.lcw[1] as u128));
  318. seeds[2 * j + 1] = st_1;
  319. }
  320. // handle the other expansions as usual
  321. for j in (0..(last_index >> 1)).rev() {
  322. let st = seeds[j];
  323. let st_0 = hash(st) ^ ((st & 1) * (key.hcw | key.lcw[0] as u128));
  324. let st_1 = hash(st ^ 1) ^ ((st & 1) * (key.hcw | key.lcw[1] as u128));
  325. seeds[2 * j] = st_0;
  326. seeds[2 * j + 1] = st_1;
  327. }
  328. }
  329. // convert leaves into V elements
  330. if key.party_id == 0 {
  331. seeds
  332. .iter()
  333. .map(|st_b| {
  334. let mut tmp = convert(st_b >> 1);
  335. if st_b & 1 == 1 {
  336. tmp = tmp + key.correction_word_np1;
  337. }
  338. tmp
  339. })
  340. .collect()
  341. } else {
  342. seeds
  343. .iter()
  344. .map(|st_b| {
  345. let mut tmp = convert(st_b >> 1);
  346. if st_b & 1 == 1 {
  347. tmp = tmp + key.correction_word_np1;
  348. }
  349. -tmp
  350. })
  351. .collect()
  352. }
  353. }
  354. }
  355. #[cfg(test)]
  356. mod tests {
  357. use super::*;
  358. use core::num::Wrapping;
  359. use rand::distributions::{Distribution, Standard};
  360. use rand::{thread_rng, Rng};
  361. fn test_spdpf_with_param<SPDPF: SinglePointDpf>(domain_size: usize, alpha: Option<u64>)
  362. where
  363. Standard: Distribution<SPDPF::Value>,
  364. {
  365. let alpha = if alpha.is_some() {
  366. alpha.unwrap()
  367. } else {
  368. thread_rng().gen_range(0..domain_size as u64)
  369. };
  370. let beta = thread_rng().gen();
  371. let (key_0, key_1) = SPDPF::generate_keys(domain_size, alpha, beta);
  372. let out_0 = SPDPF::evaluate_domain(&key_0);
  373. let out_1 = SPDPF::evaluate_domain(&key_1);
  374. assert_eq!(out_0.len(), domain_size);
  375. assert_eq!(out_1.len(), domain_size);
  376. for i in 0..domain_size as u64 {
  377. let value = SPDPF::evaluate_at(&key_0, i) + SPDPF::evaluate_at(&key_1, i);
  378. assert_eq!(
  379. value,
  380. out_0[i as usize] + out_1[i as usize],
  381. "evaluate_at/domain mismatch at position {i}"
  382. );
  383. if i == alpha {
  384. assert_eq!(
  385. value, beta,
  386. "incorrect value != beta at position alpha = {i}"
  387. );
  388. } else {
  389. assert_eq!(
  390. value,
  391. SPDPF::Value::zero(),
  392. "incorrect value != 0 at position {i}"
  393. );
  394. }
  395. }
  396. }
  397. #[test]
  398. fn test_spdpf_dummy() {
  399. for log_domain_size in 0..10 {
  400. test_spdpf_with_param::<DummySpDpf<u64>>(1 << log_domain_size, None);
  401. }
  402. }
  403. #[test]
  404. fn test_spdpf_half_tree_power_of_two_domain() {
  405. for log_domain_size in 0..10 {
  406. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(1 << log_domain_size, None);
  407. }
  408. }
  409. #[test]
  410. fn test_spdpf_half_tree_random_domain() {
  411. for _ in 0..10 {
  412. let domain_size = thread_rng().gen_range(1..(1 << 10));
  413. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size, None);
  414. }
  415. }
  416. #[test]
  417. fn test_spdpf_half_tree_exhaustive_params() {
  418. for domain_size in 1..=32 {
  419. for alpha in 0..domain_size as u64 {
  420. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size, Some(alpha));
  421. }
  422. }
  423. }
  424. }