spdpf.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. use bincode;
  2. use core::fmt::Debug;
  3. use core::marker::PhantomData;
  4. use core::ops::{Add, Neg, Sub};
  5. use num::traits::Zero;
  6. use rand::{thread_rng, Rng};
  7. use utils::bit_decompose::bit_decompose;
  8. use utils::fixed_key_aes::FixedKeyAes;
  9. use utils::pseudorandom_conversion::{PRConvertTo, PRConverter};
  10. pub trait SinglePointDpfKey: Clone + Debug {
  11. fn get_party_id(&self) -> usize;
  12. fn get_domain_size(&self) -> usize;
  13. }
  14. pub trait SinglePointDpf {
  15. type Key: SinglePointDpfKey;
  16. type Value: Add<Output = Self::Value> + Copy + Debug + Eq + Zero;
  17. fn generate_keys(domain_size: usize, alpha: u64, beta: Self::Value) -> (Self::Key, Self::Key);
  18. fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
  19. fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
  20. (0..key.get_domain_size())
  21. .map(|x| Self::evaluate_at(&key, x as u64))
  22. .collect()
  23. }
  24. }
  25. #[derive(Clone, Copy, Debug, bincode::Encode, bincode::Decode)]
  26. pub struct DummySpDpfKey<V: Copy + Debug> {
  27. party_id: usize,
  28. domain_size: usize,
  29. alpha: u64,
  30. beta: V,
  31. }
  32. impl<V> SinglePointDpfKey for DummySpDpfKey<V>
  33. where
  34. V: Copy + Debug,
  35. {
  36. fn get_party_id(&self) -> usize {
  37. self.party_id
  38. }
  39. fn get_domain_size(&self) -> usize {
  40. self.domain_size
  41. }
  42. }
  43. pub struct DummySpDpf<V>
  44. where
  45. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  46. {
  47. phantom: PhantomData<V>,
  48. }
  49. impl<V> SinglePointDpf for DummySpDpf<V>
  50. where
  51. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  52. {
  53. type Key = DummySpDpfKey<V>;
  54. type Value = V;
  55. fn generate_keys(domain_size: usize, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  56. assert!(alpha < domain_size as u64);
  57. (
  58. DummySpDpfKey {
  59. party_id: 0,
  60. domain_size,
  61. alpha,
  62. beta,
  63. },
  64. DummySpDpfKey {
  65. party_id: 1,
  66. domain_size,
  67. alpha,
  68. beta,
  69. },
  70. )
  71. }
  72. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  73. if key.get_party_id() == 0 && index == key.alpha {
  74. key.beta
  75. } else {
  76. V::zero()
  77. }
  78. }
  79. fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
  80. let mut output = vec![V::zero(); key.domain_size];
  81. if key.get_party_id() == 0 {
  82. output[key.alpha as usize] = key.beta;
  83. }
  84. output
  85. }
  86. }
  87. /// Implementation of the Half-Tree DPF scheme from Guo et al. (ePrint 2022/1431, Figure 8)
  88. #[derive(Clone, Debug, bincode::Encode, bincode::Decode)]
  89. pub struct HalfTreeSpDpfKey<V: Copy + Debug> {
  90. /// party id b
  91. party_id: usize,
  92. /// size n of the DPF's domain [n]
  93. domain_size: usize,
  94. /// (s_b^0 || t_b^0) and t_b^0 is the LSB
  95. party_seed: u128,
  96. /// vector of length n: CW_1, ..., CW_(n-1)
  97. correction_words: Vec<u128>,
  98. /// high part of CW_n = (HCW, [LCW[0], LCW[1]])
  99. hcw: u128,
  100. /// low parts of CW_n = (HCW, [LCW[0], LCW[1]])
  101. lcw: [bool; 2],
  102. /// CW_(n+1)
  103. correction_word_np1: V,
  104. }
  105. impl<V> SinglePointDpfKey for HalfTreeSpDpfKey<V>
  106. where
  107. V: Copy + Debug,
  108. {
  109. fn get_party_id(&self) -> usize {
  110. self.party_id
  111. }
  112. fn get_domain_size(&self) -> usize {
  113. self.domain_size
  114. }
  115. }
  116. pub struct HalfTreeSpDpf<V>
  117. where
  118. V: Add<Output = V> + Copy + Debug + Eq + Zero,
  119. {
  120. phantom: PhantomData<V>,
  121. }
  122. impl<V> HalfTreeSpDpf<V>
  123. where
  124. V: Add<Output = V> + Sub<Output = V> + Copy + Debug + Eq + Zero,
  125. {
  126. const FIXED_KEY_AES_KEY: [u8; 16] =
  127. 0xdead_beef_1337_4247_dead_beef_1337_4247_u128.to_le_bytes();
  128. const HASH_KEY: u128 = 0xc000ffee_c0ffffee_c0ffeeee_c00ffeee_u128;
  129. }
  130. impl<V> SinglePointDpf for HalfTreeSpDpf<V>
  131. where
  132. V: Add<Output = V> + Sub<Output = V> + Neg<Output = V> + Copy + Debug + Eq + Zero,
  133. PRConverter: PRConvertTo<V>,
  134. {
  135. type Key = HalfTreeSpDpfKey<V>;
  136. type Value = V;
  137. fn generate_keys(domain_size: usize, alpha: u64, beta: V) -> (Self::Key, Self::Key) {
  138. assert!(alpha < domain_size as u64);
  139. let mut rng = thread_rng();
  140. if domain_size == 1 {
  141. // simply secret-share beta
  142. let beta_0: V = PRConverter::convert(rng.gen::<u128>());
  143. let beta_1: V = beta - beta_0;
  144. return (
  145. HalfTreeSpDpfKey {
  146. party_id: 0,
  147. domain_size,
  148. party_seed: Default::default(),
  149. correction_words: Default::default(),
  150. hcw: Default::default(),
  151. lcw: Default::default(),
  152. correction_word_np1: beta_0,
  153. },
  154. HalfTreeSpDpfKey {
  155. party_id: 1,
  156. domain_size,
  157. party_seed: Default::default(),
  158. correction_words: Default::default(),
  159. hcw: Default::default(),
  160. lcw: Default::default(),
  161. correction_word_np1: beta_1,
  162. },
  163. );
  164. }
  165. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  166. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  167. let convert = |x: u128| -> V { PRConverter::convert(x) };
  168. let tree_height = (domain_size as f64).log2().ceil() as usize;
  169. let alpha_bits: Vec<bool> = bit_decompose(alpha, tree_height);
  170. let delta = rng.gen::<u128>() | 1u128;
  171. let mut correction_words = Vec::<u128>::with_capacity(tree_height - 1);
  172. let mut st_0 = rng.gen::<u128>();
  173. let mut st_1 = st_0 ^ delta;
  174. let party_seeds = (st_0, st_1);
  175. for i in 0..(tree_height - 1) as usize {
  176. let cw_i = hash(st_0) ^ hash(st_1) ^ (1 - alpha_bits[i] as u128) * delta;
  177. st_0 = hash(st_0) ^ alpha_bits[i] as u128 * (st_0) ^ (st_0 & 1) * cw_i;
  178. st_1 = hash(st_1) ^ alpha_bits[i] as u128 * (st_1) ^ (st_1 & 1) * cw_i;
  179. correction_words.push(cw_i);
  180. }
  181. let high_low = [[hash(st_0), hash(st_0 ^ 1)], [hash(st_1), hash(st_1 ^ 1)]];
  182. const HIGH_MASK: u128 = u128::MAX - 1;
  183. const LOW_MASK: u128 = 1u128;
  184. let a_n = alpha_bits[tree_height - 1];
  185. let hcw = (high_low[0][1 - a_n as usize] ^ high_low[1][1 - a_n as usize]) & HIGH_MASK;
  186. let lcw = [
  187. ((high_low[0][0] ^ high_low[1][0] ^ (1 - a_n as u128)) & LOW_MASK) != 0,
  188. ((high_low[0][1] ^ high_low[1][1] ^ a_n as u128) & LOW_MASK) != 0,
  189. ];
  190. st_0 = high_low[0][a_n as usize] ^ (st_0 & 1) * (hcw | lcw[a_n as usize] as u128);
  191. st_1 = high_low[1][a_n as usize] ^ (st_1 & 1) * (hcw | lcw[a_n as usize] as u128);
  192. let correction_word_np1: V = match (st_0 & 1).wrapping_sub(st_1 & 1) {
  193. u128::MAX => convert(st_0 >> 1) - convert(st_1 >> 1) - beta,
  194. 0 => V::zero(),
  195. 1 => convert(st_1 >> 1) - convert(st_0 >> 1) + beta,
  196. _ => panic!("should not happend, since matching a difference of two bits"),
  197. };
  198. (
  199. HalfTreeSpDpfKey {
  200. party_id: 0,
  201. domain_size,
  202. party_seed: party_seeds.0,
  203. correction_words: correction_words.clone(),
  204. hcw,
  205. lcw,
  206. correction_word_np1,
  207. },
  208. HalfTreeSpDpfKey {
  209. party_id: 1,
  210. domain_size,
  211. party_seed: party_seeds.1,
  212. correction_words,
  213. hcw,
  214. lcw,
  215. correction_word_np1,
  216. },
  217. )
  218. }
  219. fn evaluate_at(key: &Self::Key, index: u64) -> V {
  220. assert!(key.domain_size > 0);
  221. assert!(index < key.domain_size as u64);
  222. if key.domain_size == 1 {
  223. // beta is simply secret-shared
  224. return key.correction_word_np1;
  225. }
  226. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  227. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  228. let convert = |x: u128| -> V { PRConverter::convert(x) };
  229. let tree_height = (key.domain_size as f64).log2().ceil() as usize;
  230. let index_bits: Vec<bool> = bit_decompose(index, tree_height);
  231. let mut st_b = key.party_seed;
  232. for i in 0..tree_height - 1 {
  233. st_b =
  234. hash(st_b) ^ (index_bits[i] as u128 * st_b) ^ (st_b & 1) * key.correction_words[i];
  235. }
  236. let x_n = index_bits[tree_height - 1];
  237. let high_low_b_xn = hash(st_b ^ x_n as u128);
  238. st_b = high_low_b_xn ^ (st_b & 1) * (key.hcw | key.lcw[x_n as usize] as u128);
  239. let value = convert(st_b >> 1)
  240. + if st_b & 1 == 0 {
  241. V::zero()
  242. } else {
  243. key.correction_word_np1
  244. };
  245. if key.party_id == 0 {
  246. value
  247. } else {
  248. V::zero() - value
  249. }
  250. }
  251. fn evaluate_domain(key: &Self::Key) -> Vec<V> {
  252. assert!(key.domain_size > 0);
  253. if key.domain_size == 1 {
  254. // beta is simply secret-shared
  255. return vec![key.correction_word_np1];
  256. }
  257. let fkaes = FixedKeyAes::new(Self::FIXED_KEY_AES_KEY);
  258. let hash = |x: u128| fkaes.hash_ccr(Self::HASH_KEY ^ x);
  259. let convert = |x: u128| -> V { PRConverter::convert(x) };
  260. let tree_height = (key.domain_size as f64).log2().ceil() as usize;
  261. let last_index = key.domain_size - 1;
  262. let mut seeds = vec![0u128; key.domain_size];
  263. seeds[0] = key.party_seed;
  264. // since the last layer is handled separately, we only need the following block if we have
  265. // more than one layer
  266. if tree_height > 1 {
  267. // iterate over the tree layer by layer
  268. for i in 0..(tree_height - 1) {
  269. // expand each node in this layer;
  270. // we need to iterate from right to left, since we reuse the same buffer
  271. for j in (0..(last_index >> (tree_height - i)) + 1).rev() {
  272. // for j in (0..(1 << i)).rev() {
  273. let st = seeds[j];
  274. let st_0 = hash(st) ^ (st & 1) * key.correction_words[i];
  275. let st_1 = hash(st) ^ st ^ (st & 1) * key.correction_words[i];
  276. seeds[2 * j] = st_0;
  277. seeds[2 * j + 1] = st_1;
  278. }
  279. }
  280. }
  281. // expand last layer
  282. {
  283. // handle the last expansion separately, since we might not need both outputs
  284. let j = last_index >> 1;
  285. let st = seeds[j];
  286. let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
  287. seeds[2 * j] = st_0;
  288. // check if we need both outputs
  289. if key.domain_size & 1 == 0 {
  290. let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
  291. seeds[2 * j + 1] = st_1;
  292. }
  293. // handle the other expansions as usual
  294. for j in (0..(last_index >> 1)).rev() {
  295. let st = seeds[j];
  296. let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
  297. let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
  298. seeds[2 * j] = st_0;
  299. seeds[2 * j + 1] = st_1;
  300. }
  301. }
  302. // convert leaves into V elements
  303. if key.party_id == 0 {
  304. seeds
  305. .iter()
  306. .map(|st_b| {
  307. let mut tmp = convert(st_b >> 1);
  308. if st_b & 1 == 1 {
  309. tmp = tmp + key.correction_word_np1;
  310. }
  311. tmp
  312. })
  313. .collect()
  314. } else {
  315. seeds
  316. .iter()
  317. .map(|st_b| {
  318. let mut tmp = convert(st_b >> 1);
  319. if st_b & 1 == 1 {
  320. tmp = tmp + key.correction_word_np1;
  321. }
  322. -tmp
  323. })
  324. .collect()
  325. }
  326. }
  327. }
  328. #[cfg(test)]
  329. mod tests {
  330. use super::*;
  331. use core::num::Wrapping;
  332. use rand::distributions::{Distribution, Standard};
  333. use rand::{thread_rng, Rng};
  334. fn test_spdpf_with_param<SPDPF: SinglePointDpf>(domain_size: usize, alpha: Option<u64>)
  335. where
  336. Standard: Distribution<SPDPF::Value>,
  337. {
  338. let alpha = if alpha.is_some() {
  339. alpha.unwrap()
  340. } else {
  341. thread_rng().gen_range(0..domain_size as u64)
  342. };
  343. let beta = thread_rng().gen();
  344. let (key_0, key_1) = SPDPF::generate_keys(domain_size, alpha, beta);
  345. let out_0 = SPDPF::evaluate_domain(&key_0);
  346. let out_1 = SPDPF::evaluate_domain(&key_1);
  347. assert_eq!(out_0.len(), domain_size);
  348. assert_eq!(out_1.len(), domain_size);
  349. for i in 0..domain_size as u64 {
  350. let value = SPDPF::evaluate_at(&key_0, i) + SPDPF::evaluate_at(&key_1, i);
  351. assert_eq!(
  352. value,
  353. out_0[i as usize] + out_1[i as usize],
  354. "evaluate_at/domain mismatch at position {i}"
  355. );
  356. if i == alpha {
  357. assert_eq!(
  358. value, beta,
  359. "incorrect value != beta at position alpha = {i}"
  360. );
  361. } else {
  362. assert_eq!(
  363. value,
  364. SPDPF::Value::zero(),
  365. "incorrect value != 0 at position {i}"
  366. );
  367. }
  368. }
  369. }
  370. #[test]
  371. fn test_spdpf_dummy() {
  372. for log_domain_size in 0..10 {
  373. test_spdpf_with_param::<DummySpDpf<u64>>(1 << log_domain_size, None);
  374. }
  375. }
  376. #[test]
  377. fn test_spdpf_half_tree_power_of_two_domain() {
  378. for log_domain_size in 0..10 {
  379. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(1 << log_domain_size, None);
  380. }
  381. }
  382. #[test]
  383. fn test_spdpf_half_tree_random_domain() {
  384. for _ in 0..10 {
  385. let domain_size = thread_rng().gen_range(1..(1 << 10));
  386. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size, None);
  387. }
  388. }
  389. #[test]
  390. fn test_spdpf_half_tree_exhaustive_params() {
  391. for domain_size in 1..=32 {
  392. for alpha in 0..domain_size as u64 {
  393. test_spdpf_with_param::<HalfTreeSpDpf<Wrapping<u64>>>(domain_size, Some(alpha));
  394. }
  395. }
  396. }
  397. }