doprf.rs 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908
  1. //! Implementation of the Legendre PRF and protocols for the distributed oblivious PRF evaluation.
  2. //!
  3. //! Contains the unmasked and the masked variants.
  4. use crate::common::Error;
  5. use bincode;
  6. use bitvec;
  7. use communicator::{AbstractCommunicator, Fut, Serializable};
  8. use core::marker::PhantomData;
  9. use funty::Unsigned;
  10. use itertools::izip;
  11. use rand::{thread_rng, Rng, RngCore, SeedableRng};
  12. use rand_chacha::ChaChaRng;
  13. use std::iter::repeat;
  14. use utils::field::LegendreSymbol;
  15. /// Bit vector.
  16. pub type BitVec = bitvec::vec::BitVec<u8>;
  17. type BitSlice = bitvec::slice::BitSlice<u8>;
  18. /// Key for a [`LegendrePrf`].
  19. #[derive(Clone, Debug, Eq, PartialEq, bincode::Encode, bincode::Decode)]
  20. pub struct LegendrePrfKey<F: LegendreSymbol> {
  21. /// Keys for each output bit.
  22. pub keys: Vec<F>,
  23. }
  24. impl<F: LegendreSymbol> LegendrePrfKey<F> {
  25. /// Return the number of bits output by the PRF.
  26. pub fn get_output_bitsize(&self) -> usize {
  27. self.keys.len()
  28. }
  29. }
  30. /// Multi-bit Legendre PRF: `F x F -> {0,1}^k`.
  31. pub struct LegendrePrf<F> {
  32. _phantom: PhantomData<F>,
  33. }
  34. impl<F: LegendreSymbol> LegendrePrf<F> {
  35. /// Generate a Legendre PRF key for the given number of output bits.
  36. pub fn key_gen(output_bitsize: usize) -> LegendrePrfKey<F> {
  37. LegendrePrfKey {
  38. keys: (0..output_bitsize)
  39. .map(|_| F::random(thread_rng()))
  40. .collect(),
  41. }
  42. }
  43. /// Evaluate the PRF to obtain an iterator of bits.
  44. pub fn eval(key: &LegendrePrfKey<F>, input: F) -> impl Iterator<Item = bool> + '_ {
  45. key.keys.iter().map(move |&k| {
  46. let ls = F::legendre_symbol(k + input);
  47. debug_assert!(ls != 0, "unlikely");
  48. ls == 1
  49. })
  50. }
  51. /// Evaluate the PRF to obtain a bit vector.
  52. pub fn eval_bits(key: &LegendrePrfKey<F>, input: F) -> BitVec {
  53. let mut output = BitVec::with_capacity(key.keys.len());
  54. output.extend(Self::eval(key, input));
  55. output
  56. }
  57. /// Evaluate the PRF to obtain an integer.
  58. pub fn eval_to_uint<T: Unsigned>(key: &LegendrePrfKey<F>, input: F) -> T {
  59. assert!(key.keys.len() <= T::BITS as usize);
  60. let mut output = T::ZERO;
  61. for (i, b) in Self::eval(key, input).enumerate() {
  62. if b {
  63. output |= T::ONE << i;
  64. }
  65. }
  66. output
  67. }
  68. }
  69. fn to_uint<T: Unsigned>(vs: impl IntoIterator<Item = impl IntoIterator<Item = bool>>) -> Vec<T> {
  70. vs.into_iter()
  71. .map(|v| {
  72. let mut output = T::ZERO;
  73. for (i, b) in v.into_iter().enumerate() {
  74. if b {
  75. output |= T::ONE << i;
  76. }
  77. }
  78. output
  79. })
  80. .collect()
  81. }
  82. type SharedSeed = [u8; 32];
  83. /// Party 1 of the *unmasked* DOPRF protocol.
  84. pub struct DOPrfParty1<F: LegendreSymbol> {
  85. _phantom: PhantomData<F>,
  86. output_bitsize: usize,
  87. shared_prg_1_2: Option<ChaChaRng>,
  88. shared_prg_1_3: Option<ChaChaRng>,
  89. legendre_prf_key: Option<LegendrePrfKey<F>>,
  90. is_initialized: bool,
  91. num_preprocessed_invocations: usize,
  92. preprocessed_squares: Vec<F>,
  93. preprocessed_mt_c1: Vec<F>,
  94. }
  95. impl<F> DOPrfParty1<F>
  96. where
  97. F: LegendreSymbol,
  98. {
  99. /// Create a new instance with the given `output_bitsize`.
  100. pub fn new(output_bitsize: usize) -> Self {
  101. assert!(output_bitsize > 0);
  102. Self {
  103. _phantom: PhantomData,
  104. output_bitsize,
  105. shared_prg_1_2: None,
  106. shared_prg_1_3: None,
  107. legendre_prf_key: None,
  108. is_initialized: false,
  109. num_preprocessed_invocations: 0,
  110. preprocessed_squares: Default::default(),
  111. preprocessed_mt_c1: Default::default(),
  112. }
  113. }
  114. /// Create an instance from an existing Legendre PRF key.
  115. pub fn from_legendre_prf_key(legendre_prf_key: LegendrePrfKey<F>) -> Self {
  116. let mut new = Self::new(legendre_prf_key.keys.len());
  117. new.legendre_prf_key = Some(legendre_prf_key);
  118. new
  119. }
  120. /// Reset this instance.
  121. pub fn reset(&mut self) {
  122. *self = Self::new(self.output_bitsize)
  123. }
  124. /// Delete all preprocessed data.
  125. pub fn reset_preprocessing(&mut self) {
  126. self.num_preprocessed_invocations = 0;
  127. self.preprocessed_squares = Default::default();
  128. self.preprocessed_mt_c1 = Default::default();
  129. }
  130. /// Step 0 of the initialization protocol.
  131. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  132. assert!(!self.is_initialized);
  133. // sample and share a PRF key with Party 2
  134. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  135. (self.shared_prg_1_2.as_ref().unwrap().get_seed(), ())
  136. }
  137. /// Step 1 of the initialization protocol.
  138. pub fn init_round_1(&mut self, _: (), shared_prg_seed_1_3: SharedSeed) {
  139. assert!(!self.is_initialized);
  140. // receive shared PRF key from Party 3
  141. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(shared_prg_seed_1_3));
  142. if self.legendre_prf_key.is_none() {
  143. // generate Legendre PRF key
  144. self.legendre_prf_key = Some(LegendrePrf::key_gen(self.output_bitsize));
  145. }
  146. self.is_initialized = true;
  147. }
  148. /// Run the initialization protocol.
  149. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  150. let fut_3_1 = comm.receive_previous()?;
  151. let (msg_1_2, _) = self.init_round_0();
  152. comm.send_next(msg_1_2)?;
  153. self.init_round_1((), fut_3_1.get()?);
  154. Ok(())
  155. }
  156. /// Return the Legendre PRF key.
  157. pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
  158. assert!(self.legendre_prf_key.is_some());
  159. self.legendre_prf_key.as_ref().unwrap().clone()
  160. }
  161. /// Set the Legendre PRF key.
  162. ///
  163. /// Can only done before the initialization protocol is run.
  164. pub fn set_legendre_prf_key(&mut self, legendre_prf_key: LegendrePrfKey<F>) {
  165. assert!(!self.is_initialized);
  166. self.legendre_prf_key = Some(legendre_prf_key);
  167. }
  168. /// Step 0 of the preprocessing protocol.
  169. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  170. assert!(self.is_initialized);
  171. let n = num * self.output_bitsize;
  172. self.preprocessed_squares
  173. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()).square()));
  174. ((), ())
  175. }
  176. /// Step 1 of the preprocessing protocol.
  177. pub fn preprocess_round_1(&mut self, num: usize, preprocessed_mt_c1: Vec<F>, _: ()) {
  178. assert!(self.is_initialized);
  179. let n = num * self.output_bitsize;
  180. assert_eq!(preprocessed_mt_c1.len(), n);
  181. self.preprocessed_mt_c1.extend(preprocessed_mt_c1);
  182. self.num_preprocessed_invocations += num;
  183. }
  184. /// Run the preprocessing protocol for `num` evaluations.
  185. pub fn preprocess<C: AbstractCommunicator>(
  186. &mut self,
  187. comm: &mut C,
  188. num: usize,
  189. ) -> Result<(), Error>
  190. where
  191. F: Serializable,
  192. {
  193. let fut_2_1 = comm.receive_next()?;
  194. self.preprocess_round_0(num);
  195. self.preprocess_round_1(num, fut_2_1.get()?, ());
  196. Ok(())
  197. }
  198. /// Return the number of preprocessed invocations available.
  199. pub fn get_num_preprocessed_invocations(&self) -> usize {
  200. self.num_preprocessed_invocations
  201. }
  202. /// Return the preprocessed data.
  203. #[doc(hidden)]
  204. pub fn get_preprocessed_data(&self) -> (&[F], &[F]) {
  205. (&self.preprocessed_squares, &self.preprocessed_mt_c1)
  206. }
  207. /// Perform some self-test of the preprocessed data.
  208. #[doc(hidden)]
  209. pub fn check_preprocessing(&self) {
  210. let num = self.num_preprocessed_invocations;
  211. let n = num * self.output_bitsize;
  212. assert_eq!(self.preprocessed_squares.len(), n);
  213. assert_eq!(self.preprocessed_mt_c1.len(), n);
  214. }
  215. /// Step 1 of the evaluation protocol.
  216. pub fn eval_round_1(
  217. &mut self,
  218. num: usize,
  219. shares1: &[F],
  220. masked_shares2: &[F],
  221. mult_e: &[F],
  222. ) -> ((), Vec<F>) {
  223. assert!(num <= self.num_preprocessed_invocations);
  224. let n = num * self.output_bitsize;
  225. assert_eq!(shares1.len(), num);
  226. assert_eq!(masked_shares2.len(), num);
  227. assert_eq!(mult_e.len(), num);
  228. let k = &self.legendre_prf_key.as_ref().unwrap().keys;
  229. assert_eq!(k.len(), self.output_bitsize);
  230. let output_shares_z1: Vec<F> = izip!(
  231. shares1
  232. .iter()
  233. .flat_map(|s1i| repeat(s1i).take(self.output_bitsize)),
  234. masked_shares2
  235. .iter()
  236. .flat_map(|ms2i| repeat(ms2i).take(self.output_bitsize)),
  237. k.iter().cycle(),
  238. self.preprocessed_squares.drain(0..n),
  239. self.preprocessed_mt_c1.drain(0..n),
  240. mult_e
  241. .iter()
  242. .flat_map(|e| repeat(e).take(self.output_bitsize)),
  243. )
  244. .map(|(&s1_i, &ms2_i, &k_j, sq_ij, c1_ij, &e_ij)| {
  245. sq_ij * (k_j + s1_i + ms2_i) + e_ij * sq_ij + c1_ij
  246. })
  247. .collect();
  248. self.num_preprocessed_invocations -= num;
  249. ((), output_shares_z1)
  250. }
  251. /// Run the evaluation protocol.
  252. pub fn eval<C: AbstractCommunicator>(
  253. &mut self,
  254. comm: &mut C,
  255. num: usize,
  256. shares1: &[F],
  257. ) -> Result<(), Error>
  258. where
  259. F: Serializable,
  260. {
  261. assert_eq!(shares1.len(), num);
  262. let fut_2_1 = comm.receive_next::<Vec<_>>()?;
  263. let fut_3_1 = comm.receive_previous::<Vec<_>>()?;
  264. let (_, msg_1_3) = self.eval_round_1(num, shares1, &fut_2_1.get()?, &fut_3_1.get()?);
  265. comm.send_previous(msg_1_3)?;
  266. Ok(())
  267. }
  268. }
  269. /// Party 2 of the *unmasked* DOPRF protocol.
  270. pub struct DOPrfParty2<F: LegendreSymbol> {
  271. _phantom: PhantomData<F>,
  272. output_bitsize: usize,
  273. shared_prg_1_2: Option<ChaChaRng>,
  274. shared_prg_2_3: Option<ChaChaRng>,
  275. is_initialized: bool,
  276. num_preprocessed_invocations: usize,
  277. preprocessed_rerand_m2: Vec<F>,
  278. }
  279. impl<F> DOPrfParty2<F>
  280. where
  281. F: LegendreSymbol,
  282. {
  283. /// Create a new instance with the given `output_bitsize`.
  284. pub fn new(output_bitsize: usize) -> Self {
  285. assert!(output_bitsize > 0);
  286. Self {
  287. _phantom: PhantomData,
  288. output_bitsize,
  289. shared_prg_1_2: None,
  290. shared_prg_2_3: None,
  291. is_initialized: false,
  292. num_preprocessed_invocations: 0,
  293. preprocessed_rerand_m2: Default::default(),
  294. }
  295. }
  296. /// Reset this instance.
  297. pub fn reset(&mut self) {
  298. *self = Self::new(self.output_bitsize)
  299. }
  300. /// Delete all preprocessed data.
  301. pub fn reset_preprocessing(&mut self) {
  302. self.num_preprocessed_invocations = 0;
  303. self.preprocessed_rerand_m2 = Default::default();
  304. }
  305. /// Step 0 of the initialization protocol.
  306. pub fn init_round_0(&mut self) -> ((), SharedSeed) {
  307. assert!(!self.is_initialized);
  308. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  309. ((), self.shared_prg_2_3.as_ref().unwrap().get_seed())
  310. }
  311. /// Step 1 of the initialization protocol.
  312. pub fn init_round_1(&mut self, shared_prg_seed_1_2: SharedSeed, _: ()) {
  313. assert!(!self.is_initialized);
  314. // receive shared PRF key from Party 1
  315. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(shared_prg_seed_1_2));
  316. self.is_initialized = true;
  317. }
  318. /// Run the initialization protocol.
  319. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  320. let fut_1_2 = comm.receive_previous()?;
  321. let (_, msg_2_3) = self.init_round_0();
  322. comm.send_next(msg_2_3)?;
  323. self.init_round_1(fut_1_2.get()?, ());
  324. Ok(())
  325. }
  326. /// Step 0 of the preprocessing protocol.
  327. pub fn preprocess_round_0(&mut self, num: usize) -> (Vec<F>, ()) {
  328. assert!(self.is_initialized);
  329. let n = num * self.output_bitsize;
  330. let preprocessed_squares: Vec<F> = (0..n)
  331. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()).square())
  332. .collect();
  333. self.preprocessed_rerand_m2
  334. .extend((0..num).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  335. let preprocessed_mult_d: Vec<F> = (0..n)
  336. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()))
  337. .collect();
  338. let preprocessed_mt_b: Vec<F> = (0..num)
  339. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()))
  340. .collect();
  341. let preprocessed_mt_c3: Vec<F> = (0..n)
  342. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()))
  343. .collect();
  344. let preprocessed_c1: Vec<F> = izip!(
  345. preprocessed_squares.iter(),
  346. preprocessed_mult_d.iter(),
  347. preprocessed_mt_b
  348. .iter()
  349. .flat_map(|b| repeat(b).take(self.output_bitsize)),
  350. preprocessed_mt_c3.iter(),
  351. )
  352. .map(|(&s, &d, &b, &c3)| (s - d) * b - c3)
  353. .collect();
  354. self.num_preprocessed_invocations += num;
  355. (preprocessed_c1, ())
  356. }
  357. /// Step 1 of the preprocessing protocol.
  358. pub fn preprocess_round_1(&mut self, _: usize, _: (), _: ()) {
  359. assert!(self.is_initialized);
  360. }
  361. /// Run the preprocessing protocol for `num` evaluations.
  362. pub fn preprocess<C: AbstractCommunicator>(
  363. &mut self,
  364. comm: &mut C,
  365. num: usize,
  366. ) -> Result<(), Error>
  367. where
  368. F: Serializable,
  369. {
  370. let (msg_2_1, _) = self.preprocess_round_0(num);
  371. comm.send_previous(msg_2_1)?;
  372. self.preprocess_round_1(num, (), ());
  373. Ok(())
  374. }
  375. /// Return the number of preprocessed invocations available.
  376. pub fn get_num_preprocessed_invocations(&self) -> usize {
  377. self.num_preprocessed_invocations
  378. }
  379. /// Return the preprocessed data.
  380. #[doc(hidden)]
  381. pub fn get_preprocessed_data(&self) -> &[F] {
  382. &self.preprocessed_rerand_m2
  383. }
  384. /// Perform some self-test of the preprocessed data.
  385. #[doc(hidden)]
  386. pub fn check_preprocessing(&self) {
  387. let num = self.num_preprocessed_invocations;
  388. assert_eq!(self.preprocessed_rerand_m2.len(), num);
  389. }
  390. /// Step 0 of the evaluation protocol.
  391. pub fn eval_round_0(&mut self, num: usize, shares2: &[F]) -> (Vec<F>, ()) {
  392. assert!(num <= self.num_preprocessed_invocations);
  393. assert_eq!(shares2.len(), num);
  394. let masked_shares2: Vec<F> =
  395. izip!(shares2.iter(), self.preprocessed_rerand_m2.drain(0..num),)
  396. .map(|(&s2i, m2i)| s2i + m2i)
  397. .collect();
  398. self.num_preprocessed_invocations -= num;
  399. (masked_shares2, ())
  400. }
  401. /// Run the evaluation protocol.
  402. pub fn eval<C: AbstractCommunicator>(
  403. &mut self,
  404. comm: &mut C,
  405. num: usize,
  406. shares2: &[F],
  407. ) -> Result<(), Error>
  408. where
  409. F: Serializable,
  410. {
  411. assert_eq!(shares2.len(), num);
  412. let (msg_2_1, _) = self.eval_round_0(1, shares2);
  413. comm.send_previous(msg_2_1)?;
  414. Ok(())
  415. }
  416. }
  417. /// Party 3 of the *unmasked* DOPRF protocol.
  418. pub struct DOPrfParty3<F: LegendreSymbol> {
  419. _phantom: PhantomData<F>,
  420. output_bitsize: usize,
  421. shared_prg_1_3: Option<ChaChaRng>,
  422. shared_prg_2_3: Option<ChaChaRng>,
  423. is_initialized: bool,
  424. num_preprocessed_invocations: usize,
  425. preprocessed_rerand_m3: Vec<F>,
  426. preprocessed_mt_b: Vec<F>,
  427. preprocessed_mt_c3: Vec<F>,
  428. preprocessed_mult_d: Vec<F>,
  429. mult_e: Vec<F>,
  430. }
  431. impl<F> DOPrfParty3<F>
  432. where
  433. F: LegendreSymbol,
  434. {
  435. /// Create a new instance with the given `output_bitsize`.
  436. pub fn new(output_bitsize: usize) -> Self {
  437. assert!(output_bitsize > 0);
  438. Self {
  439. _phantom: PhantomData,
  440. output_bitsize,
  441. shared_prg_1_3: None,
  442. shared_prg_2_3: None,
  443. is_initialized: false,
  444. num_preprocessed_invocations: 0,
  445. preprocessed_rerand_m3: Default::default(),
  446. preprocessed_mt_b: Default::default(),
  447. preprocessed_mt_c3: Default::default(),
  448. preprocessed_mult_d: Default::default(),
  449. mult_e: Default::default(),
  450. }
  451. }
  452. /// Reset this instance.
  453. pub fn reset(&mut self) {
  454. *self = Self::new(self.output_bitsize)
  455. }
  456. /// Delete all preprocessed data.
  457. pub fn reset_preprocessing(&mut self) {
  458. self.num_preprocessed_invocations = 0;
  459. self.preprocessed_rerand_m3 = Default::default();
  460. self.preprocessed_mt_b = Default::default();
  461. self.preprocessed_mt_c3 = Default::default();
  462. self.preprocessed_mult_d = Default::default();
  463. self.mult_e = Default::default();
  464. }
  465. /// Step 0 of the initialization protocol.
  466. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  467. assert!(!self.is_initialized);
  468. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  469. (self.shared_prg_1_3.as_ref().unwrap().get_seed(), ())
  470. }
  471. /// Step 1 of the initialization protocol.
  472. pub fn init_round_1(&mut self, _: (), shared_prg_seed_2_3: SharedSeed) {
  473. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(shared_prg_seed_2_3));
  474. self.is_initialized = true;
  475. }
  476. /// Run the initialization protocol.
  477. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  478. let fut_2_3 = comm.receive_previous()?;
  479. let (msg_3_1, _) = self.init_round_0();
  480. comm.send_next(msg_3_1)?;
  481. self.init_round_1((), fut_2_3.get()?);
  482. Ok(())
  483. }
  484. /// Step 0 of the preprocessing protocol.
  485. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  486. assert!(self.is_initialized);
  487. let n = num * self.output_bitsize;
  488. self.preprocessed_rerand_m3
  489. .extend((0..num).map(|_| -F::random(self.shared_prg_2_3.as_mut().unwrap())));
  490. self.preprocessed_mult_d
  491. .extend((0..n).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  492. self.preprocessed_mt_b
  493. .extend((0..num).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  494. self.preprocessed_mt_c3
  495. .extend((0..n).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  496. ((), ())
  497. }
  498. /// Step 1 of the preprocessing protocol.
  499. pub fn preprocess_round_1(&mut self, num: usize, _: (), _: ()) {
  500. assert!(self.is_initialized);
  501. self.num_preprocessed_invocations += num;
  502. }
  503. /// Run the preprocessing protocol for `num` evaluations.
  504. pub fn preprocess<C: AbstractCommunicator>(
  505. &mut self,
  506. _comm: &mut C,
  507. num: usize,
  508. ) -> Result<(), Error>
  509. where
  510. F: Serializable,
  511. {
  512. self.preprocess_round_0(num);
  513. self.preprocess_round_1(num, (), ());
  514. Ok(())
  515. }
  516. /// Return the number of preprocessed invocations available.
  517. pub fn get_num_preprocessed_invocations(&self) -> usize {
  518. self.num_preprocessed_invocations
  519. }
  520. /// Return the preprocessed data.
  521. #[doc(hidden)]
  522. pub fn get_preprocessed_data(&self) -> (&[F], &[F], &[F], &[F]) {
  523. (
  524. &self.preprocessed_rerand_m3,
  525. &self.preprocessed_mt_b,
  526. &self.preprocessed_mt_c3,
  527. &self.preprocessed_mult_d,
  528. )
  529. }
  530. /// Perform some self-test of the preprocessed data.
  531. #[doc(hidden)]
  532. pub fn check_preprocessing(&self) {
  533. let num = self.num_preprocessed_invocations;
  534. let n = num * self.output_bitsize;
  535. assert_eq!(self.preprocessed_rerand_m3.len(), num);
  536. assert_eq!(self.preprocessed_mt_b.len(), num);
  537. assert_eq!(self.preprocessed_mt_c3.len(), n);
  538. assert_eq!(self.preprocessed_mult_d.len(), n);
  539. }
  540. /// Step 0 of the evaluation protocol.
  541. pub fn eval_round_0(&mut self, num: usize, shares3: &[F]) -> (Vec<F>, ()) {
  542. assert!(num <= self.num_preprocessed_invocations);
  543. assert_eq!(shares3.len(), num);
  544. self.mult_e = izip!(
  545. shares3.iter(),
  546. &self.preprocessed_rerand_m3[0..num],
  547. self.preprocessed_mt_b.drain(0..num),
  548. )
  549. .map(|(&s3_i, m3_i, b_i)| s3_i + m3_i - b_i)
  550. .collect();
  551. (self.mult_e.clone(), ())
  552. }
  553. /// Step 2 of the evaluation protocol.
  554. pub fn eval_round_2(
  555. &mut self,
  556. num: usize,
  557. shares3: &[F],
  558. output_shares_z1: Vec<F>,
  559. _: (),
  560. ) -> Vec<BitVec> {
  561. assert!(num <= self.num_preprocessed_invocations);
  562. let n = num * self.output_bitsize;
  563. assert_eq!(shares3.len(), num);
  564. assert_eq!(output_shares_z1.len(), n);
  565. let lprf_inputs: Vec<F> = izip!(
  566. shares3
  567. .iter()
  568. .flat_map(|s3| repeat(s3).take(self.output_bitsize)),
  569. self.preprocessed_rerand_m3
  570. .drain(0..num)
  571. .flat_map(|m3| repeat(m3).take(self.output_bitsize)),
  572. self.preprocessed_mult_d.drain(0..n),
  573. self.mult_e
  574. .drain(0..num)
  575. .flat_map(|e| repeat(e).take(self.output_bitsize)),
  576. self.preprocessed_mt_c3.drain(0..n),
  577. output_shares_z1.iter(),
  578. )
  579. .map(|(&s3_i, m3_i, d_ij, e_i, c3_ij, &z1_ij)| {
  580. d_ij * (s3_i + m3_i) + c3_ij + z1_ij - d_ij * e_i
  581. })
  582. .collect();
  583. assert_eq!(lprf_inputs.len(), n);
  584. let output: Vec<BitVec> = lprf_inputs
  585. .chunks_exact(self.output_bitsize)
  586. .map(|chunk| {
  587. let mut bv = BitVec::with_capacity(self.output_bitsize);
  588. for &x in chunk.iter() {
  589. let ls = F::legendre_symbol(x);
  590. debug_assert!(ls != 0, "unlikely");
  591. bv.push(ls == 1);
  592. }
  593. bv
  594. })
  595. .collect();
  596. self.num_preprocessed_invocations -= num;
  597. output
  598. }
  599. /// Run the evaluation protocol to obtain bit vectors.
  600. pub fn eval<C: AbstractCommunicator>(
  601. &mut self,
  602. comm: &mut C,
  603. num: usize,
  604. shares3: &[F],
  605. ) -> Result<Vec<BitVec>, Error>
  606. where
  607. F: Serializable,
  608. {
  609. assert_eq!(shares3.len(), num);
  610. let fut_1_3 = comm.receive_next()?;
  611. let (msg_3_1, _) = self.eval_round_0(num, shares3);
  612. comm.send_next(msg_3_1)?;
  613. let output = self.eval_round_2(num, shares3, fut_1_3.get()?, ());
  614. Ok(output)
  615. }
  616. /// Run the evaluation protocol to obtain integers.
  617. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  618. &mut self,
  619. comm: &mut C,
  620. num: usize,
  621. shares3: &[F],
  622. ) -> Result<Vec<T>, Error>
  623. where
  624. F: Serializable,
  625. {
  626. assert!(self.output_bitsize <= T::BITS as usize);
  627. Ok(to_uint(self.eval(comm, num, shares3)?))
  628. }
  629. }
  630. /// Combination of three instances of the *unmasked* DOPRF protocol such that this party acts as
  631. /// each role 1, 2, 3 in the different instances.
  632. pub struct JointDOPrf<F: LegendreSymbol> {
  633. output_bitsize: usize,
  634. doprf_p1_prev: DOPrfParty1<F>,
  635. doprf_p2_next: DOPrfParty2<F>,
  636. doprf_p3_mine: DOPrfParty3<F>,
  637. }
  638. impl<F: LegendreSymbol + Serializable> JointDOPrf<F> {
  639. /// Create a new instance with the given `output_bitsize`.
  640. pub fn new(output_bitsize: usize) -> Self {
  641. Self {
  642. output_bitsize,
  643. doprf_p1_prev: DOPrfParty1::new(output_bitsize),
  644. doprf_p2_next: DOPrfParty2::new(output_bitsize),
  645. doprf_p3_mine: DOPrfParty3::new(output_bitsize),
  646. }
  647. }
  648. /// Reset this instance.
  649. pub fn reset(&mut self) {
  650. *self = Self::new(self.output_bitsize);
  651. }
  652. /// Return the Legendre PRF key.
  653. pub fn get_legendre_prf_key_prev(&self) -> LegendrePrfKey<F> {
  654. self.doprf_p1_prev.get_legendre_prf_key()
  655. }
  656. /// Set the Legendre PRF key.
  657. ///
  658. /// Can only done before the initialization protocol is run.
  659. pub fn set_legendre_prf_key_prev(&mut self, legendre_prf_key: LegendrePrfKey<F>) {
  660. self.doprf_p1_prev.set_legendre_prf_key(legendre_prf_key)
  661. }
  662. /// Run the initialization protocol.
  663. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  664. let fut_prev = comm.receive_previous()?;
  665. let (msg_1_2, _) = self.doprf_p1_prev.init_round_0();
  666. let (_, msg_2_3) = self.doprf_p2_next.init_round_0();
  667. let (msg_3_1, _) = self.doprf_p3_mine.init_round_0();
  668. comm.send_next((msg_1_2, msg_2_3, msg_3_1))?;
  669. let (msg_1_2, msg_2_3, msg_3_1) = fut_prev.get()?;
  670. self.doprf_p1_prev.init_round_1((), msg_3_1);
  671. self.doprf_p2_next.init_round_1(msg_1_2, ());
  672. self.doprf_p3_mine.init_round_1((), msg_2_3);
  673. Ok(())
  674. }
  675. /// Run the preprocessing protocol for `num` evaluations.
  676. pub fn preprocess<C: AbstractCommunicator>(
  677. &mut self,
  678. comm: &mut C,
  679. num: usize,
  680. ) -> Result<(), Error> {
  681. let fut_2_1 = comm.receive_next()?;
  682. let (msg_2_1, _) = self.doprf_p2_next.preprocess_round_0(num);
  683. comm.send_previous(msg_2_1)?;
  684. self.doprf_p2_next.preprocess_round_1(num, (), ());
  685. self.doprf_p3_mine.preprocess_round_0(num);
  686. self.doprf_p3_mine.preprocess_round_1(num, (), ());
  687. self.doprf_p1_prev.preprocess_round_0(num);
  688. self.doprf_p1_prev
  689. .preprocess_round_1(num, fut_2_1.get()?, ());
  690. Ok(())
  691. }
  692. /// Run the evaluation protocol to obtain integers.
  693. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  694. &mut self,
  695. comm: &mut C,
  696. shares: &[F],
  697. ) -> Result<Vec<T>, Error> {
  698. let num = shares.len();
  699. let fut_2_1 = comm.receive_next::<Vec<_>>()?; // round 0
  700. let fut_3_1 = comm.receive_previous::<Vec<_>>()?; // round 0
  701. let fut_1_3 = comm.receive_next()?; // round 1
  702. let (msg_2_1, _) = self.doprf_p2_next.eval_round_0(num, shares);
  703. comm.send_previous(msg_2_1)?;
  704. let (msg_3_1, _) = self.doprf_p3_mine.eval_round_0(num, shares);
  705. comm.send_next(msg_3_1)?;
  706. let (_, msg_1_3) =
  707. self.doprf_p1_prev
  708. .eval_round_1(num, shares, &fut_2_1.get()?, &fut_3_1.get()?);
  709. comm.send_previous(msg_1_3)?;
  710. let output = self
  711. .doprf_p3_mine
  712. .eval_round_2(num, shares, fut_1_3.get()?, ());
  713. Ok(to_uint(output))
  714. }
  715. }
  716. /// Party 1 of the *masked* DOPRF protocol.
  717. pub struct MaskedDOPrfParty1<F: LegendreSymbol> {
  718. _phantom: PhantomData<F>,
  719. output_bitsize: usize,
  720. shared_prg_1_2: Option<ChaChaRng>,
  721. shared_prg_1_3: Option<ChaChaRng>,
  722. legendre_prf_key: Option<LegendrePrfKey<F>>,
  723. is_initialized: bool,
  724. num_preprocessed_invocations: usize,
  725. preprocessed_rerand_m1: Vec<F>,
  726. preprocessed_mt_a: Vec<F>,
  727. preprocessed_mt_c1: Vec<F>,
  728. preprocessed_mult_e: Vec<F>,
  729. mult_d: Vec<F>,
  730. }
  731. impl<F> MaskedDOPrfParty1<F>
  732. where
  733. F: LegendreSymbol,
  734. {
  735. /// Create a new instance with the given `output_bitsize`.
  736. pub fn new(output_bitsize: usize) -> Self {
  737. assert!(output_bitsize > 0);
  738. Self {
  739. _phantom: PhantomData,
  740. output_bitsize,
  741. shared_prg_1_2: None,
  742. shared_prg_1_3: None,
  743. legendre_prf_key: None,
  744. is_initialized: false,
  745. num_preprocessed_invocations: 0,
  746. preprocessed_rerand_m1: Default::default(),
  747. preprocessed_mt_a: Default::default(),
  748. preprocessed_mt_c1: Default::default(),
  749. preprocessed_mult_e: Default::default(),
  750. mult_d: Default::default(),
  751. }
  752. }
  753. /// Create an instance from an existing Legendre PRF key.
  754. pub fn from_legendre_prf_key(legendre_prf_key: LegendrePrfKey<F>) -> Self {
  755. let mut new = Self::new(legendre_prf_key.keys.len());
  756. new.legendre_prf_key = Some(legendre_prf_key);
  757. new
  758. }
  759. /// Reset this instance.
  760. pub fn reset(&mut self) {
  761. *self = Self::new(self.output_bitsize)
  762. }
  763. /// Delete all preprocessed data.
  764. pub fn reset_preprocessing(&mut self) {
  765. self.num_preprocessed_invocations = 0;
  766. self.preprocessed_rerand_m1 = Default::default();
  767. self.preprocessed_mt_a = Default::default();
  768. self.preprocessed_mt_c1 = Default::default();
  769. self.preprocessed_mult_e = Default::default();
  770. }
  771. /// Step 0 of the initialization protocol.
  772. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  773. assert!(!self.is_initialized);
  774. // sample and share a PRF key with Party 2
  775. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  776. (self.shared_prg_1_2.as_ref().unwrap().get_seed(), ())
  777. }
  778. /// Step 1 of the initialization protocol.
  779. pub fn init_round_1(&mut self, _: (), shared_prg_seed_1_3: SharedSeed) {
  780. assert!(!self.is_initialized);
  781. // receive shared PRF key from Party 3
  782. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(shared_prg_seed_1_3));
  783. if self.legendre_prf_key.is_none() {
  784. // generate Legendre PRF key
  785. self.legendre_prf_key = Some(LegendrePrf::key_gen(self.output_bitsize));
  786. }
  787. self.is_initialized = true;
  788. }
  789. /// Run the initialization protocol.
  790. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  791. let fut_3_1 = comm.receive_previous()?;
  792. let (msg_1_2, _) = self.init_round_0();
  793. comm.send_next(msg_1_2)?;
  794. self.init_round_1((), fut_3_1.get()?);
  795. Ok(())
  796. }
  797. /// Return the Legendre PRF key.
  798. pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
  799. assert!(self.is_initialized);
  800. self.legendre_prf_key.as_ref().unwrap().clone()
  801. }
  802. /// Step 0 of the preprocessing protocol.
  803. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  804. assert!(self.is_initialized);
  805. let n = num * self.output_bitsize;
  806. self.preprocessed_rerand_m1
  807. .extend((0..num).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  808. self.preprocessed_mt_a
  809. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  810. self.preprocessed_mt_c1
  811. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  812. self.preprocessed_mult_e
  813. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  814. ((), ())
  815. }
  816. /// Step 1 of the preprocessing protocol.
  817. pub fn preprocess_round_1(&mut self, num: usize, _: (), _: ()) {
  818. assert!(self.is_initialized);
  819. self.num_preprocessed_invocations += num;
  820. }
  821. /// Run the preprocessing protocol for `num` evaluations.
  822. pub fn preprocess<C: AbstractCommunicator>(
  823. &mut self,
  824. _comm: &mut C,
  825. num: usize,
  826. ) -> Result<(), Error> {
  827. self.preprocess_round_0(num);
  828. self.preprocess_round_1(num, (), ());
  829. Ok(())
  830. }
  831. /// Return the number of preprocessed invocations available.
  832. pub fn get_num_preprocessed_invocations(&self) -> usize {
  833. self.num_preprocessed_invocations
  834. }
  835. /// Return the preprocessed data.
  836. #[doc(hidden)]
  837. pub fn get_preprocessed_data(&self) -> (&[F], &[F], &[F], &[F]) {
  838. (
  839. &self.preprocessed_rerand_m1,
  840. &self.preprocessed_mt_a,
  841. &self.preprocessed_mt_c1,
  842. &self.preprocessed_mult_e,
  843. )
  844. }
  845. /// Perform some self-test of the preprocessed data.
  846. #[doc(hidden)]
  847. pub fn check_preprocessing(&self) {
  848. let num = self.num_preprocessed_invocations;
  849. let n = num * self.output_bitsize;
  850. assert_eq!(self.preprocessed_rerand_m1.len(), num);
  851. assert_eq!(self.preprocessed_mt_a.len(), n);
  852. assert_eq!(self.preprocessed_mt_c1.len(), n);
  853. assert_eq!(self.preprocessed_mult_e.len(), n);
  854. }
  855. /// Step 0 of the evaluation protocol.
  856. pub fn eval_round_0(&mut self, num: usize, shares1: &[F]) -> ((), Vec<F>) {
  857. assert!(num <= self.num_preprocessed_invocations);
  858. assert_eq!(shares1.len(), num);
  859. let n = num * self.output_bitsize;
  860. let k = &self.legendre_prf_key.as_ref().unwrap().keys;
  861. self.mult_d = izip!(
  862. k.iter().cycle(),
  863. shares1
  864. .iter()
  865. .flat_map(|s1| repeat(s1).take(self.output_bitsize)),
  866. self.preprocessed_rerand_m1
  867. .iter()
  868. .take(num)
  869. .flat_map(|m1| repeat(m1).take(self.output_bitsize)),
  870. self.preprocessed_mt_a.drain(0..n),
  871. )
  872. .map(|(&k_i, &s1_i, m1_i, a_i)| k_i + s1_i + m1_i - a_i)
  873. .collect();
  874. assert_eq!(self.mult_d.len(), n);
  875. ((), self.mult_d.clone())
  876. }
  877. /// Step 2 of the evaluation protocol.
  878. pub fn eval_round_2(
  879. &mut self,
  880. num: usize,
  881. shares1: &[F],
  882. _: (),
  883. output_shares_z3: Vec<F>,
  884. ) -> Vec<BitVec> {
  885. assert!(num <= self.num_preprocessed_invocations);
  886. let n = num * self.output_bitsize;
  887. assert_eq!(shares1.len(), num);
  888. assert_eq!(output_shares_z3.len(), n);
  889. let k = &self.legendre_prf_key.as_ref().unwrap().keys;
  890. let lprf_inputs: Vec<F> = izip!(
  891. k.iter().cycle(),
  892. shares1
  893. .iter()
  894. .flat_map(|s1| repeat(s1).take(self.output_bitsize)),
  895. self.preprocessed_rerand_m1
  896. .drain(0..num)
  897. .flat_map(|m1| repeat(m1).take(self.output_bitsize)),
  898. self.preprocessed_mult_e.drain(0..n),
  899. self.mult_d.drain(..),
  900. self.preprocessed_mt_c1.drain(0..n),
  901. output_shares_z3.iter(),
  902. )
  903. .map(|(&k_j, &s1_i, m1_i, e_ij, d_ij, c1_ij, &z3_ij)| {
  904. e_ij * (k_j + s1_i + m1_i) + c1_ij + z3_ij - d_ij * e_ij
  905. })
  906. .collect();
  907. assert_eq!(lprf_inputs.len(), n);
  908. let output: Vec<BitVec> = lprf_inputs
  909. .chunks_exact(self.output_bitsize)
  910. .map(|chunk| {
  911. let mut bv = BitVec::with_capacity(self.output_bitsize);
  912. for &x in chunk.iter() {
  913. let ls = F::legendre_symbol(x);
  914. debug_assert!(ls != 0, "unlikely");
  915. bv.push(ls == 1);
  916. }
  917. bv
  918. })
  919. .collect();
  920. self.num_preprocessed_invocations -= num;
  921. output
  922. }
  923. /// Run the evaluation protocol to obtain bit vectors.
  924. pub fn eval<C: AbstractCommunicator>(
  925. &mut self,
  926. comm: &mut C,
  927. num: usize,
  928. shares1: &[F],
  929. ) -> Result<Vec<BitVec>, Error>
  930. where
  931. F: Serializable,
  932. {
  933. assert_eq!(shares1.len(), num);
  934. let fut_3_1 = comm.receive_previous()?;
  935. let (_, msg_1_3) = self.eval_round_0(num, shares1);
  936. comm.send_previous(msg_1_3)?;
  937. let output = self.eval_round_2(1, shares1, (), fut_3_1.get()?);
  938. Ok(output)
  939. }
  940. /// Run the evaluation protocol to obtain integers.
  941. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  942. &mut self,
  943. comm: &mut C,
  944. num: usize,
  945. shares1: &[F],
  946. ) -> Result<Vec<T>, Error>
  947. where
  948. F: Serializable,
  949. {
  950. assert!(self.output_bitsize <= T::BITS as usize);
  951. Ok(to_uint(self.eval(comm, num, shares1)?))
  952. }
  953. }
  954. /// Party 2 of the *masked* DOPRF protocol.
  955. pub struct MaskedDOPrfParty2<F: LegendreSymbol> {
  956. _phantom: PhantomData<F>,
  957. output_bitsize: usize,
  958. shared_prg_1_2: Option<ChaChaRng>,
  959. shared_prg_2_3: Option<ChaChaRng>,
  960. is_initialized: bool,
  961. num_preprocessed_invocations: usize,
  962. preprocessed_rerand_m2: Vec<F>,
  963. preprocessed_r: BitVec,
  964. }
  965. impl<F> MaskedDOPrfParty2<F>
  966. where
  967. F: LegendreSymbol,
  968. {
  969. /// Create a new instance with the given `output_bitsize`.
  970. pub fn new(output_bitsize: usize) -> Self {
  971. assert!(output_bitsize > 0);
  972. Self {
  973. _phantom: PhantomData,
  974. output_bitsize,
  975. shared_prg_1_2: None,
  976. shared_prg_2_3: None,
  977. is_initialized: false,
  978. num_preprocessed_invocations: 0,
  979. preprocessed_rerand_m2: Default::default(),
  980. preprocessed_r: Default::default(),
  981. }
  982. }
  983. /// Reset this instance.
  984. pub fn reset(&mut self) {
  985. *self = Self::new(self.output_bitsize)
  986. }
  987. /// Delete all preprocessed data.
  988. pub fn reset_preprocessing(&mut self) {
  989. self.num_preprocessed_invocations = 0;
  990. self.preprocessed_rerand_m2 = Default::default();
  991. }
  992. /// Step 0 of the initialization protocol.
  993. pub fn init_round_0(&mut self) -> ((), SharedSeed) {
  994. assert!(!self.is_initialized);
  995. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  996. ((), self.shared_prg_2_3.as_ref().unwrap().get_seed())
  997. }
  998. /// Step 1 of the initialization protocol.
  999. pub fn init_round_1(&mut self, shared_prg_seed_1_2: SharedSeed, _: ()) {
  1000. assert!(!self.is_initialized);
  1001. // receive shared PRF key from Party 1
  1002. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(shared_prg_seed_1_2));
  1003. self.is_initialized = true;
  1004. }
  1005. /// Run the initialization protocol.
  1006. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  1007. let fut_1_2 = comm.receive_previous()?;
  1008. let (_, msg_2_3) = self.init_round_0();
  1009. comm.send_next(msg_2_3)?;
  1010. self.init_round_1(fut_1_2.get()?, ());
  1011. Ok(())
  1012. }
  1013. /// Step 0 of the preprocessing protocol.
  1014. pub fn preprocess_round_0(&mut self, num: usize) -> ((), Vec<F>) {
  1015. assert!(self.is_initialized);
  1016. let n = num * self.output_bitsize;
  1017. let mut preprocessed_t: Vec<_> = (0..n)
  1018. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()).square())
  1019. .collect();
  1020. debug_assert!(!preprocessed_t.contains(&F::ZERO));
  1021. {
  1022. let mut random_bytes = vec![0u8; (n + 7) / 8];
  1023. self.shared_prg_2_3
  1024. .as_mut()
  1025. .unwrap()
  1026. .fill_bytes(&mut random_bytes);
  1027. let new_r_slice = BitSlice::from_slice(&random_bytes);
  1028. self.preprocessed_r.extend(&new_r_slice[..n]);
  1029. for (i, r_i) in new_r_slice.iter().by_vals().take(n).enumerate() {
  1030. if r_i {
  1031. preprocessed_t[i] *= F::get_non_random_qnr();
  1032. }
  1033. }
  1034. }
  1035. self.preprocessed_rerand_m2
  1036. .extend((0..num).map(|_| -F::random(self.shared_prg_1_2.as_mut().unwrap())));
  1037. let preprocessed_mt_a: Vec<F> = (0..n)
  1038. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()))
  1039. .collect();
  1040. let preprocessed_mt_c1: Vec<F> = (0..n)
  1041. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()))
  1042. .collect();
  1043. let preprocessed_mult_e: Vec<F> = (0..n)
  1044. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()))
  1045. .collect();
  1046. let preprocessed_c3: Vec<F> = izip!(
  1047. preprocessed_t.iter(),
  1048. preprocessed_mult_e.iter(),
  1049. preprocessed_mt_a.iter(),
  1050. preprocessed_mt_c1.iter(),
  1051. )
  1052. .map(|(&t, &e, &a, &c1)| a * (t - e) - c1)
  1053. .collect();
  1054. self.num_preprocessed_invocations += num;
  1055. ((), preprocessed_c3)
  1056. }
  1057. /// Step 1 of the preprocessing protocol.
  1058. pub fn preprocess_round_1(&mut self, _: usize, _: (), _: ()) {
  1059. assert!(self.is_initialized);
  1060. }
  1061. /// Run the preprocessing protocol for `num` evaluations.
  1062. pub fn preprocess<C: AbstractCommunicator>(
  1063. &mut self,
  1064. comm: &mut C,
  1065. num: usize,
  1066. ) -> Result<(), Error>
  1067. where
  1068. F: Serializable,
  1069. {
  1070. let (_, msg_2_3) = self.preprocess_round_0(num);
  1071. comm.send_next(msg_2_3)?;
  1072. self.preprocess_round_1(num, (), ());
  1073. Ok(())
  1074. }
  1075. /// Return the number of preprocessed invocations available.
  1076. pub fn get_num_preprocessed_invocations(&self) -> usize {
  1077. self.num_preprocessed_invocations
  1078. }
  1079. /// Return the preprocessed data.
  1080. #[doc(hidden)]
  1081. pub fn get_preprocessed_data(&self) -> (&BitSlice, &[F]) {
  1082. (&self.preprocessed_r, &self.preprocessed_rerand_m2)
  1083. }
  1084. /// Perform some self-test of the preprocessed data.
  1085. #[doc(hidden)]
  1086. pub fn check_preprocessing(&self) {
  1087. let num = self.num_preprocessed_invocations;
  1088. assert_eq!(self.preprocessed_rerand_m2.len(), num);
  1089. }
  1090. /// Step 0 of the evaluation protocol.
  1091. pub fn eval_round_0(&mut self, num: usize, shares2: &[F]) -> ((), Vec<F>) {
  1092. assert!(num <= self.num_preprocessed_invocations);
  1093. assert_eq!(shares2.len(), num);
  1094. let masked_shares2: Vec<F> =
  1095. izip!(shares2.iter(), self.preprocessed_rerand_m2.drain(0..num),)
  1096. .map(|(&s2i, m2i)| s2i + m2i)
  1097. .collect();
  1098. assert_eq!(masked_shares2.len(), num);
  1099. ((), masked_shares2)
  1100. }
  1101. /// Final step of the evaluation protocol.
  1102. pub fn eval_get_output(&mut self, num: usize) -> Vec<BitVec> {
  1103. assert!(num <= self.num_preprocessed_invocations);
  1104. let n = num * self.output_bitsize;
  1105. let mut output = Vec::with_capacity(num);
  1106. for chunk in self
  1107. .preprocessed_r
  1108. .chunks_exact(self.output_bitsize)
  1109. .take(num)
  1110. {
  1111. output.push(chunk.to_bitvec());
  1112. }
  1113. let (_, last_r) = self.preprocessed_r.split_at(n);
  1114. self.preprocessed_r = last_r.to_bitvec();
  1115. self.num_preprocessed_invocations -= num;
  1116. output
  1117. }
  1118. /// Run the evaluation protocol to obtain bit vectors.
  1119. pub fn eval<C: AbstractCommunicator>(
  1120. &mut self,
  1121. comm: &mut C,
  1122. num: usize,
  1123. shares2: &[F],
  1124. ) -> Result<Vec<BitVec>, Error>
  1125. where
  1126. F: Serializable,
  1127. {
  1128. assert_eq!(shares2.len(), num);
  1129. let (_, msg_2_3) = self.eval_round_0(num, shares2);
  1130. comm.send_next(msg_2_3)?;
  1131. let output = self.eval_get_output(num);
  1132. Ok(output)
  1133. }
  1134. /// Run the evaluation protocol to obtain integers.
  1135. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  1136. &mut self,
  1137. comm: &mut C,
  1138. num: usize,
  1139. shares2: &[F],
  1140. ) -> Result<Vec<T>, Error>
  1141. where
  1142. F: Serializable,
  1143. {
  1144. assert!(self.output_bitsize <= T::BITS as usize);
  1145. Ok(to_uint(self.eval(comm, num, shares2)?))
  1146. }
  1147. }
  1148. /// Party 3 of the *masked* DOPRF protocol.
  1149. pub struct MaskedDOPrfParty3<F: LegendreSymbol> {
  1150. _phantom: PhantomData<F>,
  1151. output_bitsize: usize,
  1152. shared_prg_1_3: Option<ChaChaRng>,
  1153. shared_prg_2_3: Option<ChaChaRng>,
  1154. is_initialized: bool,
  1155. num_preprocessed_invocations: usize,
  1156. preprocessed_r: BitVec,
  1157. preprocessed_t: Vec<F>,
  1158. preprocessed_mt_c3: Vec<F>,
  1159. }
  1160. impl<F> MaskedDOPrfParty3<F>
  1161. where
  1162. F: LegendreSymbol,
  1163. {
  1164. /// Create a new instance with the given `output_bitsize`.
  1165. pub fn new(output_bitsize: usize) -> Self {
  1166. assert!(output_bitsize > 0);
  1167. Self {
  1168. _phantom: PhantomData,
  1169. output_bitsize,
  1170. shared_prg_1_3: None,
  1171. shared_prg_2_3: None,
  1172. is_initialized: false,
  1173. num_preprocessed_invocations: 0,
  1174. preprocessed_r: Default::default(),
  1175. preprocessed_t: Default::default(),
  1176. preprocessed_mt_c3: Default::default(),
  1177. }
  1178. }
  1179. /// Reset this instance.
  1180. pub fn reset(&mut self) {
  1181. *self = Self::new(self.output_bitsize)
  1182. }
  1183. /// Delete all preprocessed data.
  1184. pub fn reset_preprocessing(&mut self) {
  1185. self.num_preprocessed_invocations = 0;
  1186. self.preprocessed_t = Default::default();
  1187. self.preprocessed_mt_c3 = Default::default();
  1188. }
  1189. /// Step 0 of the initialization protocol.
  1190. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  1191. assert!(!self.is_initialized);
  1192. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  1193. (self.shared_prg_1_3.as_ref().unwrap().get_seed(), ())
  1194. }
  1195. /// Step 1 of the initialization protocol.
  1196. pub fn init_round_1(&mut self, _: (), shared_prg_seed_2_3: SharedSeed) {
  1197. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(shared_prg_seed_2_3));
  1198. self.is_initialized = true;
  1199. }
  1200. /// Run the initialization protocol.
  1201. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  1202. let fut_2_3 = comm.receive_previous()?;
  1203. let (msg_3_1, _) = self.init_round_0();
  1204. comm.send_next(msg_3_1)?;
  1205. self.init_round_1((), fut_2_3.get()?);
  1206. Ok(())
  1207. }
  1208. /// Step 0 of the preprocessing protocol.
  1209. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  1210. assert!(self.is_initialized);
  1211. let n = num * self.output_bitsize;
  1212. let start_index = self.num_preprocessed_invocations * self.output_bitsize;
  1213. self.preprocessed_t
  1214. .extend((0..n).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()).square()));
  1215. debug_assert!(!self.preprocessed_t[start_index..].contains(&F::ZERO));
  1216. {
  1217. let mut random_bytes = vec![0u8; (n + 7) / 8];
  1218. self.shared_prg_2_3
  1219. .as_mut()
  1220. .unwrap()
  1221. .fill_bytes(&mut random_bytes);
  1222. let new_r_slice = BitSlice::from_slice(&random_bytes);
  1223. self.preprocessed_r.extend(&new_r_slice[..n]);
  1224. for (i, r_i) in new_r_slice.iter().by_vals().take(n).enumerate() {
  1225. if r_i {
  1226. self.preprocessed_t[start_index + i] *= F::get_non_random_qnr();
  1227. }
  1228. }
  1229. }
  1230. ((), ())
  1231. }
  1232. /// Step 1 of the preprocessing protocol.
  1233. pub fn preprocess_round_1(&mut self, num: usize, _: (), preprocessed_mt_c3: Vec<F>) {
  1234. assert!(self.is_initialized);
  1235. let n = num * self.output_bitsize;
  1236. assert_eq!(preprocessed_mt_c3.len(), n);
  1237. self.preprocessed_mt_c3.extend(preprocessed_mt_c3);
  1238. self.num_preprocessed_invocations += num;
  1239. }
  1240. /// Run the preprocessing protocol for `num` evaluations.
  1241. pub fn preprocess<C: AbstractCommunicator>(
  1242. &mut self,
  1243. comm: &mut C,
  1244. num: usize,
  1245. ) -> Result<(), Error>
  1246. where
  1247. F: Serializable,
  1248. {
  1249. let fut_2_3 = comm.receive_previous()?;
  1250. self.preprocess_round_0(num);
  1251. self.preprocess_round_1(num, (), fut_2_3.get()?);
  1252. Ok(())
  1253. }
  1254. /// Return the number of preprocessed invocations available.
  1255. pub fn get_num_preprocessed_invocations(&self) -> usize {
  1256. self.num_preprocessed_invocations
  1257. }
  1258. /// Return the preprocessed data.
  1259. #[doc(hidden)]
  1260. pub fn get_preprocessed_data(&self) -> (&BitSlice, &[F], &[F]) {
  1261. (
  1262. &self.preprocessed_r,
  1263. &self.preprocessed_t,
  1264. &self.preprocessed_mt_c3,
  1265. )
  1266. }
  1267. /// Perform some self-test of the preprocessed data.
  1268. #[doc(hidden)]
  1269. pub fn check_preprocessing(&self) {
  1270. let num = self.num_preprocessed_invocations;
  1271. let n = num * self.output_bitsize;
  1272. assert_eq!(self.preprocessed_t.len(), n);
  1273. assert_eq!(self.preprocessed_mt_c3.len(), n);
  1274. }
  1275. /// Step 1 of the evaluation protocol.
  1276. pub fn eval_round_1(
  1277. &mut self,
  1278. num: usize,
  1279. shares3: &[F],
  1280. mult_d: &[F],
  1281. masked_shares2: &[F],
  1282. ) -> (Vec<F>, ()) {
  1283. assert!(num <= self.num_preprocessed_invocations);
  1284. let n = num * self.output_bitsize;
  1285. assert_eq!(shares3.len(), num);
  1286. assert_eq!(masked_shares2.len(), num);
  1287. assert_eq!(mult_d.len(), n);
  1288. let output_shares_z3: Vec<F> = izip!(
  1289. shares3
  1290. .iter()
  1291. .flat_map(|s1i| repeat(s1i).take(self.output_bitsize)),
  1292. masked_shares2
  1293. .iter()
  1294. .flat_map(|ms2i| repeat(ms2i).take(self.output_bitsize)),
  1295. self.preprocessed_t.drain(0..n),
  1296. self.preprocessed_mt_c3.drain(0..n),
  1297. mult_d,
  1298. )
  1299. .map(|(&s3_i, &ms2_i, t_ij, c3_ij, &d_ij)| t_ij * (s3_i + ms2_i) + d_ij * t_ij + c3_ij)
  1300. .collect();
  1301. (output_shares_z3, ())
  1302. }
  1303. /// Final step of the evaluation protocol.
  1304. pub fn eval_get_output(&mut self, num: usize) -> Vec<BitVec> {
  1305. assert!(num <= self.num_preprocessed_invocations);
  1306. let n = num * self.output_bitsize;
  1307. let mut output = Vec::with_capacity(num);
  1308. for chunk in self
  1309. .preprocessed_r
  1310. .chunks_exact(self.output_bitsize)
  1311. .take(num)
  1312. {
  1313. output.push(chunk.to_bitvec());
  1314. }
  1315. let (_, last_r) = self.preprocessed_r.split_at(n);
  1316. self.preprocessed_r = last_r.to_bitvec();
  1317. self.num_preprocessed_invocations -= num;
  1318. output
  1319. }
  1320. /// Run the evaluation protocol to obtain bit vectors.
  1321. pub fn eval<C: AbstractCommunicator>(
  1322. &mut self,
  1323. comm: &mut C,
  1324. num: usize,
  1325. shares3: &[F],
  1326. ) -> Result<Vec<BitVec>, Error>
  1327. where
  1328. F: Serializable,
  1329. {
  1330. assert_eq!(shares3.len(), num);
  1331. let fut_1_3 = comm.receive_next::<Vec<_>>()?;
  1332. let fut_2_3 = comm.receive_previous::<Vec<_>>()?;
  1333. let (msg_3_1, _) = self.eval_round_1(1, shares3, &fut_1_3.get()?, &fut_2_3.get()?);
  1334. comm.send_next(msg_3_1)?;
  1335. let output = self.eval_get_output(num);
  1336. Ok(output)
  1337. }
  1338. /// Run the evaluation protocol to obtain integers.
  1339. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  1340. &mut self,
  1341. comm: &mut C,
  1342. num: usize,
  1343. shares3: &[F],
  1344. ) -> Result<Vec<T>, Error>
  1345. where
  1346. F: Serializable,
  1347. {
  1348. assert!(self.output_bitsize <= T::BITS as usize);
  1349. Ok(to_uint(self.eval(comm, num, shares3)?))
  1350. }
  1351. }
  1352. #[cfg(test)]
  1353. mod tests {
  1354. use super::*;
  1355. use bincode;
  1356. use ff::Field;
  1357. use utils::field::Fp;
  1358. fn doprf_init(
  1359. party_1: &mut DOPrfParty1<Fp>,
  1360. party_2: &mut DOPrfParty2<Fp>,
  1361. party_3: &mut DOPrfParty3<Fp>,
  1362. ) {
  1363. let (msg_1_2, msg_1_3) = party_1.init_round_0();
  1364. let (msg_2_1, msg_2_3) = party_2.init_round_0();
  1365. let (msg_3_1, msg_3_2) = party_3.init_round_0();
  1366. party_1.init_round_1(msg_2_1, msg_3_1);
  1367. party_2.init_round_1(msg_1_2, msg_3_2);
  1368. party_3.init_round_1(msg_1_3, msg_2_3);
  1369. }
  1370. fn doprf_preprocess(
  1371. party_1: &mut DOPrfParty1<Fp>,
  1372. party_2: &mut DOPrfParty2<Fp>,
  1373. party_3: &mut DOPrfParty3<Fp>,
  1374. num: usize,
  1375. ) {
  1376. let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
  1377. let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
  1378. let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
  1379. party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
  1380. party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
  1381. party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
  1382. }
  1383. fn doprf_eval(
  1384. party_1: &mut DOPrfParty1<Fp>,
  1385. party_2: &mut DOPrfParty2<Fp>,
  1386. party_3: &mut DOPrfParty3<Fp>,
  1387. shares_1: &[Fp],
  1388. shares_2: &[Fp],
  1389. shares_3: &[Fp],
  1390. num: usize,
  1391. ) -> Vec<BitVec> {
  1392. assert_eq!(shares_1.len(), num);
  1393. assert_eq!(shares_2.len(), num);
  1394. assert_eq!(shares_3.len(), num);
  1395. let (msg_2_1, msg_2_3) = party_2.eval_round_0(num, &shares_2);
  1396. let (msg_3_1, _) = party_3.eval_round_0(num, &shares_3);
  1397. let (_, msg_1_3) = party_1.eval_round_1(num, &shares_1, &msg_2_1, &msg_3_1);
  1398. let output = party_3.eval_round_2(num, &shares_3, msg_1_3, msg_2_3);
  1399. output
  1400. }
  1401. #[test]
  1402. fn test_doprf() {
  1403. let output_bitsize = 42;
  1404. let mut party_1 = DOPrfParty1::<Fp>::new(output_bitsize);
  1405. let mut party_2 = DOPrfParty2::<Fp>::new(output_bitsize);
  1406. let mut party_3 = DOPrfParty3::<Fp>::new(output_bitsize);
  1407. doprf_init(&mut party_1, &mut party_2, &mut party_3);
  1408. // preprocess num invocations
  1409. let num = 20;
  1410. doprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1411. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1412. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1413. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1414. party_1.check_preprocessing();
  1415. party_2.check_preprocessing();
  1416. party_3.check_preprocessing();
  1417. // preprocess another n invocations
  1418. doprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1419. let num = 2 * num;
  1420. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1421. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1422. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1423. party_1.check_preprocessing();
  1424. party_2.check_preprocessing();
  1425. party_3.check_preprocessing();
  1426. // verify preprocessed data
  1427. {
  1428. let n = num * output_bitsize;
  1429. let (squares, mt_c1) = party_1.get_preprocessed_data();
  1430. let rerand_m2 = party_2.get_preprocessed_data();
  1431. let (rerand_m3, mt_b, mt_c3, mult_d) = party_3.get_preprocessed_data();
  1432. assert_eq!(squares.len(), n);
  1433. assert!(squares.iter().all(|&x| Fp::legendre_symbol(x) == 1));
  1434. assert_eq!(rerand_m2.len(), num);
  1435. assert_eq!(rerand_m3.len(), num);
  1436. assert!(izip!(rerand_m2.iter(), rerand_m3.iter()).all(|(&m2, &m3)| m2 + m3 == Fp::ZERO));
  1437. let mt_a: Vec<Fp> = squares
  1438. .iter()
  1439. .zip(mult_d.iter())
  1440. .map(|(&s, &d)| s - d)
  1441. .collect();
  1442. assert_eq!(mult_d.len(), n);
  1443. assert_eq!(mt_a.len(), n);
  1444. assert_eq!(mt_b.len(), num);
  1445. assert_eq!(mt_c1.len(), n);
  1446. assert_eq!(mt_c3.len(), n);
  1447. let mut triple_it = izip!(
  1448. mt_a.iter(),
  1449. mt_b.iter().flat_map(|b| repeat(b).take(output_bitsize)),
  1450. mt_c1.iter(),
  1451. mt_c3.iter()
  1452. );
  1453. assert_eq!(triple_it.clone().count(), n);
  1454. assert!(triple_it.all(|(&a, &b, &c1, &c3)| a * b == c1 + c3));
  1455. }
  1456. // perform n evaluations
  1457. let num = 15;
  1458. let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1459. let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1460. let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1461. let output = doprf_eval(
  1462. &mut party_1,
  1463. &mut party_2,
  1464. &mut party_3,
  1465. &shares_1,
  1466. &shares_2,
  1467. &shares_3,
  1468. num,
  1469. );
  1470. assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
  1471. assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
  1472. assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
  1473. party_1.check_preprocessing();
  1474. party_2.check_preprocessing();
  1475. party_3.check_preprocessing();
  1476. assert_eq!(output.len(), num);
  1477. assert!(output.iter().all(|bv| bv.len() == output_bitsize));
  1478. // check that the output matches the non-distributed version
  1479. let legendre_prf_key = party_1.get_legendre_prf_key();
  1480. for i in 0..num {
  1481. let input_i = shares_1[i] + shares_2[i] + shares_3[i];
  1482. let output_i = LegendrePrf::<Fp>::eval_bits(&legendre_prf_key, input_i);
  1483. assert_eq!(output[i], output_i);
  1484. }
  1485. }
  1486. fn mdoprf_init(
  1487. party_1: &mut MaskedDOPrfParty1<Fp>,
  1488. party_2: &mut MaskedDOPrfParty2<Fp>,
  1489. party_3: &mut MaskedDOPrfParty3<Fp>,
  1490. ) {
  1491. let (msg_1_2, msg_1_3) = party_1.init_round_0();
  1492. let (msg_2_1, msg_2_3) = party_2.init_round_0();
  1493. let (msg_3_1, msg_3_2) = party_3.init_round_0();
  1494. party_1.init_round_1(msg_2_1, msg_3_1);
  1495. party_2.init_round_1(msg_1_2, msg_3_2);
  1496. party_3.init_round_1(msg_1_3, msg_2_3);
  1497. }
  1498. fn mdoprf_preprocess(
  1499. party_1: &mut MaskedDOPrfParty1<Fp>,
  1500. party_2: &mut MaskedDOPrfParty2<Fp>,
  1501. party_3: &mut MaskedDOPrfParty3<Fp>,
  1502. num: usize,
  1503. ) {
  1504. let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
  1505. let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
  1506. let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
  1507. party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
  1508. party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
  1509. party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
  1510. }
  1511. fn mdoprf_eval(
  1512. party_1: &mut MaskedDOPrfParty1<Fp>,
  1513. party_2: &mut MaskedDOPrfParty2<Fp>,
  1514. party_3: &mut MaskedDOPrfParty3<Fp>,
  1515. shares_1: &[Fp],
  1516. shares_2: &[Fp],
  1517. shares_3: &[Fp],
  1518. num: usize,
  1519. ) -> (Vec<BitVec>, Vec<BitVec>, Vec<BitVec>) {
  1520. assert_eq!(shares_1.len(), num);
  1521. assert_eq!(shares_2.len(), num);
  1522. assert_eq!(shares_3.len(), num);
  1523. let (_, msg_1_3) = party_1.eval_round_0(num, &shares_1);
  1524. let (_, msg_2_3) = party_2.eval_round_0(num, &shares_2);
  1525. let (msg_3_1, _) = party_3.eval_round_1(num, &shares_3, &msg_1_3, &msg_2_3);
  1526. let masked_output = party_1.eval_round_2(num, &shares_1, (), msg_3_1);
  1527. let mask2 = party_2.eval_get_output(num);
  1528. let mask3 = party_3.eval_get_output(num);
  1529. (masked_output, mask2, mask3)
  1530. }
  1531. #[test]
  1532. fn test_masked_doprf() {
  1533. let output_bitsize = 42;
  1534. let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
  1535. let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
  1536. let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
  1537. mdoprf_init(&mut party_1, &mut party_2, &mut party_3);
  1538. // preprocess num invocations
  1539. let num = 20;
  1540. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1541. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1542. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1543. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1544. party_1.check_preprocessing();
  1545. party_2.check_preprocessing();
  1546. party_3.check_preprocessing();
  1547. // preprocess another n invocations
  1548. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1549. let num = 2 * num;
  1550. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1551. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1552. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1553. party_1.check_preprocessing();
  1554. party_2.check_preprocessing();
  1555. party_3.check_preprocessing();
  1556. // verify preprocessed data
  1557. {
  1558. let n = num * output_bitsize;
  1559. let (rerand_m1, mt_a, mt_c1, mult_e) = party_1.get_preprocessed_data();
  1560. let (r2, rerand_m2) = party_2.get_preprocessed_data();
  1561. let (r3, ts, mt_c3) = party_3.get_preprocessed_data();
  1562. assert_eq!(r2.len(), n);
  1563. assert_eq!(r2, r3);
  1564. assert_eq!(ts.len(), n);
  1565. assert!(r2.iter().by_vals().zip(ts.iter()).all(|(r_i, &t_i)| {
  1566. if r_i {
  1567. Fp::legendre_symbol(t_i) == -1
  1568. } else {
  1569. Fp::legendre_symbol(t_i) == 1
  1570. }
  1571. }));
  1572. assert_eq!(rerand_m1.len(), num);
  1573. assert_eq!(rerand_m2.len(), num);
  1574. assert!(izip!(rerand_m1.iter(), rerand_m2.iter()).all(|(&m1, &m2)| m1 + m2 == Fp::ZERO));
  1575. let mt_b: Vec<Fp> = ts.iter().zip(mult_e.iter()).map(|(&t, &e)| t - e).collect();
  1576. assert_eq!(mult_e.len(), n);
  1577. assert_eq!(mt_a.len(), n);
  1578. assert_eq!(mt_b.len(), n);
  1579. assert_eq!(mt_c1.len(), n);
  1580. assert_eq!(mt_c3.len(), n);
  1581. let mut triple_it = izip!(mt_a.iter(), mt_b.iter(), mt_c1.iter(), mt_c3.iter());
  1582. assert_eq!(triple_it.clone().count(), n);
  1583. assert!(triple_it.all(|(&a, &b, &c1, &c3)| a * b == c1 + c3));
  1584. }
  1585. // perform n evaluations
  1586. let num = 15;
  1587. let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1588. let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1589. let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1590. let (masked_output, mask2, mask3) = mdoprf_eval(
  1591. &mut party_1,
  1592. &mut party_2,
  1593. &mut party_3,
  1594. &shares_1,
  1595. &shares_2,
  1596. &shares_3,
  1597. num,
  1598. );
  1599. assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
  1600. assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
  1601. assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
  1602. party_1.check_preprocessing();
  1603. party_2.check_preprocessing();
  1604. party_3.check_preprocessing();
  1605. assert_eq!(masked_output.len(), num);
  1606. assert!(masked_output.iter().all(|bv| bv.len() == output_bitsize));
  1607. assert_eq!(mask2.len(), num);
  1608. assert_eq!(mask2, mask3);
  1609. assert!(mask2.iter().all(|bv| bv.len() == output_bitsize));
  1610. // check that the output matches the non-distributed version
  1611. let legendre_prf_key = party_1.get_legendre_prf_key();
  1612. for i in 0..num {
  1613. let input_i = shares_1[i] + shares_2[i] + shares_3[i];
  1614. let expected_output_i = LegendrePrf::<Fp>::eval_bits(&legendre_prf_key, input_i);
  1615. let output_i = masked_output[i].clone() ^ &mask2[i];
  1616. assert_eq!(output_i, expected_output_i);
  1617. }
  1618. // preprocess another n invocations
  1619. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1620. // perform another n evaluations on the same inputs
  1621. let num = 15;
  1622. let (new_masked_output, new_mask2, new_mask3) = mdoprf_eval(
  1623. &mut party_1,
  1624. &mut party_2,
  1625. &mut party_3,
  1626. &shares_1,
  1627. &shares_2,
  1628. &shares_3,
  1629. num,
  1630. );
  1631. assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
  1632. assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
  1633. assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
  1634. party_1.check_preprocessing();
  1635. party_2.check_preprocessing();
  1636. party_3.check_preprocessing();
  1637. assert_eq!(new_masked_output.len(), num);
  1638. assert!(new_masked_output
  1639. .iter()
  1640. .all(|bv| bv.len() == output_bitsize));
  1641. assert_eq!(new_mask2.len(), num);
  1642. assert_eq!(new_mask2, new_mask3);
  1643. assert!(new_mask2.iter().all(|bv| bv.len() == output_bitsize));
  1644. // check that the new output matches the previous one
  1645. for i in 0..num {
  1646. let expected_output_i = masked_output[i].clone() ^ &mask2[i];
  1647. let output_i = new_masked_output[i].clone() ^ &new_mask2[i];
  1648. assert_eq!(output_i, expected_output_i);
  1649. }
  1650. }
  1651. #[test]
  1652. fn test_masked_doprf_single() {
  1653. let output_bitsize = 42;
  1654. let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
  1655. let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
  1656. let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
  1657. mdoprf_init(&mut party_1, &mut party_2, &mut party_3);
  1658. let share_1 = Fp::random(thread_rng());
  1659. let share_2 = Fp::random(thread_rng());
  1660. let share_3 = Fp::random(thread_rng());
  1661. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
  1662. let (masked_output_1, mask2_1, mask3_1) = mdoprf_eval(
  1663. &mut party_1,
  1664. &mut party_2,
  1665. &mut party_3,
  1666. &[share_1],
  1667. &[share_2],
  1668. &[share_3],
  1669. 1,
  1670. );
  1671. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
  1672. let (masked_output_2, mask2_2, mask3_2) = mdoprf_eval(
  1673. &mut party_1,
  1674. &mut party_2,
  1675. &mut party_3,
  1676. &[share_1],
  1677. &[share_2],
  1678. &[share_3],
  1679. 1,
  1680. );
  1681. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
  1682. let (masked_output_3, mask2_3, mask3_3) = mdoprf_eval(
  1683. &mut party_1,
  1684. &mut party_2,
  1685. &mut party_3,
  1686. &[share_1],
  1687. &[share_2],
  1688. &[share_3],
  1689. 1,
  1690. );
  1691. assert_eq!(mask2_1, mask3_1);
  1692. assert_eq!(mask2_2, mask3_2);
  1693. assert_eq!(mask2_3, mask3_3);
  1694. let plain_output = masked_output_1[0].clone() ^ mask2_1[0].clone();
  1695. assert_eq!(
  1696. masked_output_2[0].clone() ^ mask2_2[0].clone(),
  1697. plain_output
  1698. );
  1699. assert_eq!(
  1700. masked_output_3[0].clone() ^ mask2_3[0].clone(),
  1701. plain_output
  1702. );
  1703. }
  1704. #[test]
  1705. fn test_serialization() {
  1706. let original_key = LegendrePrf::<Fp>::key_gen(42);
  1707. let encoded_key =
  1708. bincode::encode_to_vec(&original_key, bincode::config::standard()).unwrap();
  1709. let (decoded_key, _size): (LegendrePrfKey<Fp>, usize) =
  1710. bincode::decode_from_slice(&encoded_key, bincode::config::standard()).unwrap();
  1711. assert_eq!(decoded_key, original_key);
  1712. }
  1713. }