doprf.rs 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768
  1. use crate::common::Error;
  2. use bincode;
  3. use bitvec;
  4. use communicator::{AbstractCommunicator, Fut, Serializable};
  5. use core::marker::PhantomData;
  6. use funty::Unsigned;
  7. use itertools::izip;
  8. use rand::{thread_rng, Rng, RngCore, SeedableRng};
  9. use rand_chacha::ChaChaRng;
  10. use std::iter::repeat;
  11. use utils::field::LegendreSymbol;
  12. pub type BitVec = bitvec::vec::BitVec<u8>;
  13. type BitSlice = bitvec::slice::BitSlice<u8>;
  14. #[derive(Clone, Debug, Eq, PartialEq, bincode::Encode, bincode::Decode)]
  15. pub struct LegendrePrfKey<F: LegendreSymbol> {
  16. pub keys: Vec<F>,
  17. }
  18. impl<F: LegendreSymbol> LegendrePrfKey<F> {
  19. pub fn get_output_bitsize(&self) -> usize {
  20. self.keys.len()
  21. }
  22. }
  23. /// Legendre PRF: F x F -> {0,1}^k
  24. pub struct LegendrePrf<F> {
  25. _phantom: PhantomData<F>,
  26. }
  27. impl<F: LegendreSymbol> LegendrePrf<F> {
  28. pub fn key_gen(output_bitsize: usize) -> LegendrePrfKey<F> {
  29. LegendrePrfKey {
  30. keys: (0..output_bitsize)
  31. .map(|_| F::random(thread_rng()))
  32. .collect(),
  33. }
  34. }
  35. pub fn eval<'a>(key: &'a LegendrePrfKey<F>, input: F) -> impl Iterator<Item = bool> + 'a {
  36. key.keys.iter().map(move |&k| {
  37. let ls = F::legendre_symbol(k + input);
  38. debug_assert!(ls != 0, "unlikely");
  39. ls == 1
  40. })
  41. }
  42. pub fn eval_bits(key: &LegendrePrfKey<F>, input: F) -> BitVec {
  43. let mut output = BitVec::with_capacity(key.keys.len());
  44. output.extend(Self::eval(key, input));
  45. output
  46. }
  47. pub fn eval_to_uint<T: Unsigned>(key: &LegendrePrfKey<F>, input: F) -> T {
  48. assert!(key.keys.len() <= T::BITS as usize);
  49. let mut output = T::ZERO;
  50. for (i, b) in Self::eval(key, input).enumerate() {
  51. if b {
  52. output |= T::ONE << i;
  53. }
  54. }
  55. output
  56. }
  57. }
  58. fn to_uint<T: Unsigned>(vs: impl IntoIterator<Item = impl IntoIterator<Item = bool>>) -> Vec<T> {
  59. vs.into_iter()
  60. .map(|v| {
  61. let mut output = T::ZERO;
  62. for (i, b) in v.into_iter().enumerate() {
  63. if b {
  64. output |= T::ONE << i;
  65. }
  66. }
  67. output
  68. })
  69. .collect()
  70. }
  71. type SharedSeed = [u8; 32];
  72. pub struct DOPrfParty1<F: LegendreSymbol> {
  73. _phantom: PhantomData<F>,
  74. output_bitsize: usize,
  75. shared_prg_1_2: Option<ChaChaRng>,
  76. shared_prg_1_3: Option<ChaChaRng>,
  77. legendre_prf_key: Option<LegendrePrfKey<F>>,
  78. is_initialized: bool,
  79. num_preprocessed_invocations: usize,
  80. preprocessed_squares: Vec<F>,
  81. preprocessed_mt_c1: Vec<F>,
  82. }
  83. impl<F> DOPrfParty1<F>
  84. where
  85. F: LegendreSymbol,
  86. {
  87. pub fn new(output_bitsize: usize) -> Self {
  88. assert!(output_bitsize > 0);
  89. Self {
  90. _phantom: PhantomData,
  91. output_bitsize,
  92. shared_prg_1_2: None,
  93. shared_prg_1_3: None,
  94. legendre_prf_key: None,
  95. is_initialized: false,
  96. num_preprocessed_invocations: 0,
  97. preprocessed_squares: Default::default(),
  98. preprocessed_mt_c1: Default::default(),
  99. }
  100. }
  101. pub fn from_legendre_prf_key(legendre_prf_key: LegendrePrfKey<F>) -> Self {
  102. let mut new = Self::new(legendre_prf_key.keys.len());
  103. new.legendre_prf_key = Some(legendre_prf_key);
  104. new
  105. }
  106. pub fn reset(&mut self) {
  107. *self = Self::new(self.output_bitsize)
  108. }
  109. pub fn reset_preprocessing(&mut self) {
  110. self.num_preprocessed_invocations = 0;
  111. self.preprocessed_squares = Default::default();
  112. self.preprocessed_mt_c1 = Default::default();
  113. }
  114. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  115. assert!(!self.is_initialized);
  116. // sample and share a PRF key with Party 2
  117. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  118. (self.shared_prg_1_2.as_ref().unwrap().get_seed(), ())
  119. }
  120. pub fn init_round_1(&mut self, _: (), shared_prg_seed_1_3: SharedSeed) {
  121. assert!(!self.is_initialized);
  122. // receive shared PRF key from Party 3
  123. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(shared_prg_seed_1_3));
  124. if self.legendre_prf_key.is_none() {
  125. // generate Legendre PRF key
  126. self.legendre_prf_key = Some(LegendrePrf::key_gen(self.output_bitsize));
  127. }
  128. self.is_initialized = true;
  129. }
  130. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  131. let fut_3_1 = comm.receive_previous()?;
  132. let (msg_1_2, _) = self.init_round_0();
  133. comm.send_next(msg_1_2)?;
  134. self.init_round_1((), fut_3_1.get()?);
  135. Ok(())
  136. }
  137. pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
  138. assert!(self.legendre_prf_key.is_some());
  139. self.legendre_prf_key.as_ref().unwrap().clone()
  140. }
  141. pub fn set_legendre_prf_key(&mut self, legendre_prf_key: LegendrePrfKey<F>) {
  142. assert!(!self.is_initialized);
  143. self.legendre_prf_key = Some(legendre_prf_key);
  144. }
  145. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  146. assert!(self.is_initialized);
  147. let n = num * self.output_bitsize;
  148. self.preprocessed_squares
  149. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()).square()));
  150. ((), ())
  151. }
  152. pub fn preprocess_round_1(&mut self, num: usize, preprocessed_mt_c1: Vec<F>, _: ()) {
  153. assert!(self.is_initialized);
  154. let n = num * self.output_bitsize;
  155. assert_eq!(preprocessed_mt_c1.len(), n);
  156. self.preprocessed_mt_c1.extend(preprocessed_mt_c1);
  157. self.num_preprocessed_invocations += num;
  158. }
  159. pub fn preprocess<C: AbstractCommunicator>(
  160. &mut self,
  161. comm: &mut C,
  162. num: usize,
  163. ) -> Result<(), Error>
  164. where
  165. F: Serializable,
  166. {
  167. let fut_2_1 = comm.receive_next()?;
  168. self.preprocess_round_0(num);
  169. self.preprocess_round_1(num, fut_2_1.get()?, ());
  170. Ok(())
  171. }
  172. pub fn get_num_preprocessed_invocations(&self) -> usize {
  173. self.num_preprocessed_invocations
  174. }
  175. pub fn get_preprocessed_data(&self) -> (&[F], &[F]) {
  176. (&self.preprocessed_squares, &self.preprocessed_mt_c1)
  177. }
  178. pub fn check_preprocessing(&self) {
  179. let num = self.num_preprocessed_invocations;
  180. let n = num * self.output_bitsize;
  181. assert_eq!(self.preprocessed_squares.len(), n);
  182. assert_eq!(self.preprocessed_mt_c1.len(), n);
  183. }
  184. pub fn eval_round_1(
  185. &mut self,
  186. num: usize,
  187. shares1: &[F],
  188. masked_shares2: &[F],
  189. mult_e: &[F],
  190. ) -> ((), Vec<F>) {
  191. assert!(num <= self.num_preprocessed_invocations);
  192. let n = num * self.output_bitsize;
  193. assert_eq!(shares1.len(), num);
  194. assert_eq!(masked_shares2.len(), num);
  195. assert_eq!(mult_e.len(), num);
  196. let k = &self.legendre_prf_key.as_ref().unwrap().keys;
  197. assert_eq!(k.len(), self.output_bitsize);
  198. let output_shares_z1: Vec<F> = izip!(
  199. shares1
  200. .iter()
  201. .flat_map(|s1i| repeat(s1i).take(self.output_bitsize)),
  202. masked_shares2
  203. .iter()
  204. .flat_map(|ms2i| repeat(ms2i).take(self.output_bitsize)),
  205. k.iter().cycle(),
  206. self.preprocessed_squares.drain(0..n),
  207. self.preprocessed_mt_c1.drain(0..n),
  208. mult_e
  209. .iter()
  210. .flat_map(|e| repeat(e).take(self.output_bitsize)),
  211. )
  212. .map(|(&s1_i, &ms2_i, &k_j, sq_ij, c1_ij, &e_ij)| {
  213. sq_ij * (k_j + s1_i + ms2_i) + e_ij * sq_ij + c1_ij
  214. })
  215. .collect();
  216. self.num_preprocessed_invocations -= num;
  217. ((), output_shares_z1)
  218. }
  219. pub fn eval<C: AbstractCommunicator>(
  220. &mut self,
  221. comm: &mut C,
  222. num: usize,
  223. shares1: &[F],
  224. ) -> Result<(), Error>
  225. where
  226. F: Serializable,
  227. {
  228. assert_eq!(shares1.len(), num);
  229. let fut_2_1 = comm.receive_next::<Vec<_>>()?;
  230. let fut_3_1 = comm.receive_previous::<Vec<_>>()?;
  231. let (_, msg_1_3) = self.eval_round_1(num, shares1, &fut_2_1.get()?, &fut_3_1.get()?);
  232. comm.send_previous(msg_1_3)?;
  233. Ok(())
  234. }
  235. }
  236. pub struct DOPrfParty2<F: LegendreSymbol> {
  237. _phantom: PhantomData<F>,
  238. output_bitsize: usize,
  239. shared_prg_1_2: Option<ChaChaRng>,
  240. shared_prg_2_3: Option<ChaChaRng>,
  241. is_initialized: bool,
  242. num_preprocessed_invocations: usize,
  243. preprocessed_rerand_m2: Vec<F>,
  244. }
  245. impl<F> DOPrfParty2<F>
  246. where
  247. F: LegendreSymbol,
  248. {
  249. pub fn new(output_bitsize: usize) -> Self {
  250. assert!(output_bitsize > 0);
  251. Self {
  252. _phantom: PhantomData,
  253. output_bitsize,
  254. shared_prg_1_2: None,
  255. shared_prg_2_3: None,
  256. is_initialized: false,
  257. num_preprocessed_invocations: 0,
  258. preprocessed_rerand_m2: Default::default(),
  259. }
  260. }
  261. pub fn reset(&mut self) {
  262. *self = Self::new(self.output_bitsize)
  263. }
  264. pub fn reset_preprocessing(&mut self) {
  265. self.num_preprocessed_invocations = 0;
  266. self.preprocessed_rerand_m2 = Default::default();
  267. }
  268. pub fn init_round_0(&mut self) -> ((), SharedSeed) {
  269. assert!(!self.is_initialized);
  270. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  271. ((), self.shared_prg_2_3.as_ref().unwrap().get_seed())
  272. }
  273. pub fn init_round_1(&mut self, shared_prg_seed_1_2: SharedSeed, _: ()) {
  274. assert!(!self.is_initialized);
  275. // receive shared PRF key from Party 1
  276. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(shared_prg_seed_1_2));
  277. self.is_initialized = true;
  278. }
  279. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  280. let fut_1_2 = comm.receive_previous()?;
  281. let (_, msg_2_3) = self.init_round_0();
  282. comm.send_next(msg_2_3)?;
  283. self.init_round_1(fut_1_2.get()?, ());
  284. Ok(())
  285. }
  286. pub fn preprocess_round_0(&mut self, num: usize) -> (Vec<F>, ()) {
  287. assert!(self.is_initialized);
  288. let n = num * self.output_bitsize;
  289. let preprocessed_squares: Vec<F> = (0..n)
  290. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()).square())
  291. .collect();
  292. self.preprocessed_rerand_m2
  293. .extend((0..num).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  294. let preprocessed_mult_d: Vec<F> = (0..n)
  295. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()))
  296. .collect();
  297. let preprocessed_mt_b: Vec<F> = (0..num)
  298. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()))
  299. .collect();
  300. let preprocessed_mt_c3: Vec<F> = (0..n)
  301. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()))
  302. .collect();
  303. let preprocessed_c1: Vec<F> = izip!(
  304. preprocessed_squares.iter(),
  305. preprocessed_mult_d.iter(),
  306. preprocessed_mt_b
  307. .iter()
  308. .flat_map(|b| repeat(b).take(self.output_bitsize)),
  309. preprocessed_mt_c3.iter(),
  310. )
  311. .map(|(&s, &d, &b, &c3)| (s - d) * b - c3)
  312. .collect();
  313. self.num_preprocessed_invocations += num;
  314. (preprocessed_c1, ())
  315. }
  316. pub fn preprocess_round_1(&mut self, _: usize, _: (), _: ()) {
  317. assert!(self.is_initialized);
  318. }
  319. pub fn preprocess<C: AbstractCommunicator>(
  320. &mut self,
  321. comm: &mut C,
  322. num: usize,
  323. ) -> Result<(), Error>
  324. where
  325. F: Serializable,
  326. {
  327. let (msg_2_1, _) = self.preprocess_round_0(num);
  328. comm.send_previous(msg_2_1)?;
  329. self.preprocess_round_1(num, (), ());
  330. Ok(())
  331. }
  332. pub fn get_num_preprocessed_invocations(&self) -> usize {
  333. self.num_preprocessed_invocations
  334. }
  335. pub fn get_preprocessed_data(&self) -> &[F] {
  336. &self.preprocessed_rerand_m2
  337. }
  338. pub fn check_preprocessing(&self) {
  339. let num = self.num_preprocessed_invocations;
  340. assert_eq!(self.preprocessed_rerand_m2.len(), num);
  341. }
  342. pub fn eval_round_0(&mut self, num: usize, shares2: &[F]) -> (Vec<F>, ()) {
  343. assert!(num <= self.num_preprocessed_invocations);
  344. assert_eq!(shares2.len(), num);
  345. let masked_shares2: Vec<F> =
  346. izip!(shares2.iter(), self.preprocessed_rerand_m2.drain(0..num),)
  347. .map(|(&s2i, m2i)| s2i + m2i)
  348. .collect();
  349. self.num_preprocessed_invocations -= num;
  350. (masked_shares2, ())
  351. }
  352. pub fn eval<C: AbstractCommunicator>(
  353. &mut self,
  354. comm: &mut C,
  355. num: usize,
  356. shares2: &[F],
  357. ) -> Result<(), Error>
  358. where
  359. F: Serializable,
  360. {
  361. assert_eq!(shares2.len(), num);
  362. let (msg_2_1, _) = self.eval_round_0(1, shares2);
  363. comm.send_previous(msg_2_1)?;
  364. Ok(())
  365. }
  366. }
  367. pub struct DOPrfParty3<F: LegendreSymbol> {
  368. _phantom: PhantomData<F>,
  369. output_bitsize: usize,
  370. shared_prg_1_3: Option<ChaChaRng>,
  371. shared_prg_2_3: Option<ChaChaRng>,
  372. is_initialized: bool,
  373. num_preprocessed_invocations: usize,
  374. preprocessed_rerand_m3: Vec<F>,
  375. preprocessed_mt_b: Vec<F>,
  376. preprocessed_mt_c3: Vec<F>,
  377. preprocessed_mult_d: Vec<F>,
  378. mult_e: Vec<F>,
  379. }
  380. impl<F> DOPrfParty3<F>
  381. where
  382. F: LegendreSymbol,
  383. {
  384. pub fn new(output_bitsize: usize) -> Self {
  385. assert!(output_bitsize > 0);
  386. Self {
  387. _phantom: PhantomData,
  388. output_bitsize,
  389. shared_prg_1_3: None,
  390. shared_prg_2_3: None,
  391. is_initialized: false,
  392. num_preprocessed_invocations: 0,
  393. preprocessed_rerand_m3: Default::default(),
  394. preprocessed_mt_b: Default::default(),
  395. preprocessed_mt_c3: Default::default(),
  396. preprocessed_mult_d: Default::default(),
  397. mult_e: Default::default(),
  398. }
  399. }
  400. pub fn reset(&mut self) {
  401. *self = Self::new(self.output_bitsize)
  402. }
  403. pub fn reset_preprocessing(&mut self) {
  404. self.num_preprocessed_invocations = 0;
  405. self.preprocessed_rerand_m3 = Default::default();
  406. self.preprocessed_mt_b = Default::default();
  407. self.preprocessed_mt_c3 = Default::default();
  408. self.preprocessed_mult_d = Default::default();
  409. self.mult_e = Default::default();
  410. }
  411. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  412. assert!(!self.is_initialized);
  413. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  414. (self.shared_prg_1_3.as_ref().unwrap().get_seed(), ())
  415. }
  416. pub fn init_round_1(&mut self, _: (), shared_prg_seed_2_3: SharedSeed) {
  417. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(shared_prg_seed_2_3));
  418. self.is_initialized = true;
  419. }
  420. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  421. let fut_2_3 = comm.receive_previous()?;
  422. let (msg_3_1, _) = self.init_round_0();
  423. comm.send_next(msg_3_1)?;
  424. self.init_round_1((), fut_2_3.get()?);
  425. Ok(())
  426. }
  427. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  428. assert!(self.is_initialized);
  429. let n = num * self.output_bitsize;
  430. self.preprocessed_rerand_m3
  431. .extend((0..num).map(|_| -F::random(self.shared_prg_2_3.as_mut().unwrap())));
  432. self.preprocessed_mult_d
  433. .extend((0..n).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  434. self.preprocessed_mt_b
  435. .extend((0..num).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  436. self.preprocessed_mt_c3
  437. .extend((0..n).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap())));
  438. ((), ())
  439. }
  440. pub fn preprocess_round_1(&mut self, num: usize, _: (), _: ()) {
  441. assert!(self.is_initialized);
  442. self.num_preprocessed_invocations += num;
  443. }
  444. pub fn preprocess<C: AbstractCommunicator>(
  445. &mut self,
  446. _comm: &mut C,
  447. num: usize,
  448. ) -> Result<(), Error>
  449. where
  450. F: Serializable,
  451. {
  452. self.preprocess_round_0(num);
  453. self.preprocess_round_1(num, (), ());
  454. Ok(())
  455. }
  456. pub fn get_num_preprocessed_invocations(&self) -> usize {
  457. self.num_preprocessed_invocations
  458. }
  459. pub fn get_preprocessed_data(&self) -> (&[F], &[F], &[F], &[F]) {
  460. (
  461. &self.preprocessed_rerand_m3,
  462. &self.preprocessed_mt_b,
  463. &self.preprocessed_mt_c3,
  464. &self.preprocessed_mult_d,
  465. )
  466. }
  467. pub fn check_preprocessing(&self) {
  468. let num = self.num_preprocessed_invocations;
  469. let n = num * self.output_bitsize;
  470. assert_eq!(self.preprocessed_rerand_m3.len(), num);
  471. assert_eq!(self.preprocessed_mt_b.len(), num);
  472. assert_eq!(self.preprocessed_mt_c3.len(), n);
  473. assert_eq!(self.preprocessed_mult_d.len(), n);
  474. }
  475. pub fn eval_round_0(&mut self, num: usize, shares3: &[F]) -> (Vec<F>, ()) {
  476. assert!(num <= self.num_preprocessed_invocations);
  477. assert_eq!(shares3.len(), num);
  478. self.mult_e = izip!(
  479. shares3.iter(),
  480. &self.preprocessed_rerand_m3[0..num],
  481. self.preprocessed_mt_b.drain(0..num),
  482. )
  483. .map(|(&s3_i, m3_i, b_i)| s3_i + m3_i - b_i)
  484. .collect();
  485. (self.mult_e.clone(), ())
  486. }
  487. pub fn eval_round_2(
  488. &mut self,
  489. num: usize,
  490. shares3: &[F],
  491. output_shares_z1: Vec<F>,
  492. _: (),
  493. ) -> Vec<BitVec> {
  494. assert!(num <= self.num_preprocessed_invocations);
  495. let n = num * self.output_bitsize;
  496. assert_eq!(shares3.len(), num);
  497. assert_eq!(output_shares_z1.len(), n);
  498. let lprf_inputs: Vec<F> = izip!(
  499. shares3
  500. .iter()
  501. .flat_map(|s3| repeat(s3).take(self.output_bitsize)),
  502. self.preprocessed_rerand_m3
  503. .drain(0..num)
  504. .flat_map(|m3| repeat(m3).take(self.output_bitsize)),
  505. self.preprocessed_mult_d.drain(0..n),
  506. self.mult_e
  507. .drain(0..num)
  508. .flat_map(|e| repeat(e).take(self.output_bitsize)),
  509. self.preprocessed_mt_c3.drain(0..n),
  510. output_shares_z1.iter(),
  511. )
  512. .map(|(&s3_i, m3_i, d_ij, e_i, c3_ij, &z1_ij)| {
  513. d_ij * (s3_i + m3_i) + c3_ij + z1_ij - d_ij * e_i
  514. })
  515. .collect();
  516. assert_eq!(lprf_inputs.len(), n);
  517. let output: Vec<BitVec> = lprf_inputs
  518. .chunks_exact(self.output_bitsize)
  519. .map(|chunk| {
  520. let mut bv = BitVec::with_capacity(self.output_bitsize);
  521. for &x in chunk.iter() {
  522. let ls = F::legendre_symbol(x);
  523. debug_assert!(ls != 0, "unlikely");
  524. bv.push(ls == 1);
  525. }
  526. bv
  527. })
  528. .collect();
  529. self.num_preprocessed_invocations -= num;
  530. output
  531. }
  532. pub fn eval<C: AbstractCommunicator>(
  533. &mut self,
  534. comm: &mut C,
  535. num: usize,
  536. shares3: &[F],
  537. ) -> Result<Vec<BitVec>, Error>
  538. where
  539. F: Serializable,
  540. {
  541. assert_eq!(shares3.len(), num);
  542. let fut_1_3 = comm.receive_next()?;
  543. let (msg_3_1, _) = self.eval_round_0(num, shares3);
  544. comm.send_next(msg_3_1)?;
  545. let output = self.eval_round_2(num, shares3, fut_1_3.get()?, ());
  546. Ok(output)
  547. }
  548. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  549. &mut self,
  550. comm: &mut C,
  551. num: usize,
  552. shares3: &[F],
  553. ) -> Result<Vec<T>, Error>
  554. where
  555. F: Serializable,
  556. {
  557. assert!(self.output_bitsize <= T::BITS as usize);
  558. Ok(to_uint(self.eval(comm, num, shares3)?))
  559. }
  560. }
  561. pub struct JointDOPrf<F: LegendreSymbol> {
  562. output_bitsize: usize,
  563. doprf_p1_prev: DOPrfParty1<F>,
  564. doprf_p2_next: DOPrfParty2<F>,
  565. doprf_p3_mine: DOPrfParty3<F>,
  566. }
  567. impl<F: LegendreSymbol + Serializable> JointDOPrf<F> {
  568. pub fn new(output_bitsize: usize) -> Self {
  569. Self {
  570. output_bitsize,
  571. doprf_p1_prev: DOPrfParty1::new(output_bitsize),
  572. doprf_p2_next: DOPrfParty2::new(output_bitsize),
  573. doprf_p3_mine: DOPrfParty3::new(output_bitsize),
  574. }
  575. }
  576. pub fn reset(&mut self) {
  577. *self = Self::new(self.output_bitsize);
  578. }
  579. pub fn get_legendre_prf_key_prev(&self) -> LegendrePrfKey<F> {
  580. self.doprf_p1_prev.get_legendre_prf_key()
  581. }
  582. pub fn set_legendre_prf_key_prev(&mut self, legendre_prf_key: LegendrePrfKey<F>) {
  583. self.doprf_p1_prev.set_legendre_prf_key(legendre_prf_key)
  584. }
  585. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  586. let fut_prev = comm.receive_previous()?;
  587. let (msg_1_2, _) = self.doprf_p1_prev.init_round_0();
  588. let (_, msg_2_3) = self.doprf_p2_next.init_round_0();
  589. let (msg_3_1, _) = self.doprf_p3_mine.init_round_0();
  590. comm.send_next((msg_1_2, msg_2_3, msg_3_1))?;
  591. let (msg_1_2, msg_2_3, msg_3_1) = fut_prev.get()?;
  592. self.doprf_p1_prev.init_round_1((), msg_3_1);
  593. self.doprf_p2_next.init_round_1(msg_1_2, ());
  594. self.doprf_p3_mine.init_round_1((), msg_2_3);
  595. Ok(())
  596. }
  597. pub fn preprocess<C: AbstractCommunicator>(
  598. &mut self,
  599. comm: &mut C,
  600. num: usize,
  601. ) -> Result<(), Error> {
  602. let fut_2_1 = comm.receive_next()?;
  603. let (msg_2_1, _) = self.doprf_p2_next.preprocess_round_0(num);
  604. comm.send_previous(msg_2_1)?;
  605. self.doprf_p2_next.preprocess_round_1(num, (), ());
  606. self.doprf_p3_mine.preprocess_round_0(num);
  607. self.doprf_p3_mine.preprocess_round_1(num, (), ());
  608. self.doprf_p1_prev.preprocess_round_0(num);
  609. self.doprf_p1_prev
  610. .preprocess_round_1(num, fut_2_1.get()?, ());
  611. Ok(())
  612. }
  613. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  614. &mut self,
  615. comm: &mut C,
  616. shares: &[F],
  617. ) -> Result<Vec<T>, Error> {
  618. let num = shares.len();
  619. let fut_2_1 = comm.receive_next::<Vec<_>>()?; // round 0
  620. let fut_3_1 = comm.receive_previous::<Vec<_>>()?; // round 0
  621. let fut_1_3 = comm.receive_next()?; // round 1
  622. let (msg_2_1, _) = self.doprf_p2_next.eval_round_0(num, shares);
  623. comm.send_previous(msg_2_1)?;
  624. let (msg_3_1, _) = self.doprf_p3_mine.eval_round_0(num, shares);
  625. comm.send_next(msg_3_1)?;
  626. let (_, msg_1_3) =
  627. self.doprf_p1_prev
  628. .eval_round_1(num, shares, &fut_2_1.get()?, &fut_3_1.get()?);
  629. comm.send_previous(msg_1_3)?;
  630. let output = self
  631. .doprf_p3_mine
  632. .eval_round_2(num, shares, fut_1_3.get()?, ());
  633. Ok(to_uint(output))
  634. }
  635. }
  636. pub struct MaskedDOPrfParty1<F: LegendreSymbol> {
  637. _phantom: PhantomData<F>,
  638. output_bitsize: usize,
  639. shared_prg_1_2: Option<ChaChaRng>,
  640. shared_prg_1_3: Option<ChaChaRng>,
  641. legendre_prf_key: Option<LegendrePrfKey<F>>,
  642. is_initialized: bool,
  643. num_preprocessed_invocations: usize,
  644. preprocessed_rerand_m1: Vec<F>,
  645. preprocessed_mt_a: Vec<F>,
  646. preprocessed_mt_c1: Vec<F>,
  647. preprocessed_mult_e: Vec<F>,
  648. mult_d: Vec<F>,
  649. }
  650. impl<F> MaskedDOPrfParty1<F>
  651. where
  652. F: LegendreSymbol,
  653. {
  654. pub fn new(output_bitsize: usize) -> Self {
  655. assert!(output_bitsize > 0);
  656. Self {
  657. _phantom: PhantomData,
  658. output_bitsize,
  659. shared_prg_1_2: None,
  660. shared_prg_1_3: None,
  661. legendre_prf_key: None,
  662. is_initialized: false,
  663. num_preprocessed_invocations: 0,
  664. preprocessed_rerand_m1: Default::default(),
  665. preprocessed_mt_a: Default::default(),
  666. preprocessed_mt_c1: Default::default(),
  667. preprocessed_mult_e: Default::default(),
  668. mult_d: Default::default(),
  669. }
  670. }
  671. pub fn from_legendre_prf_key(legendre_prf_key: LegendrePrfKey<F>) -> Self {
  672. let mut new = Self::new(legendre_prf_key.keys.len());
  673. new.legendre_prf_key = Some(legendre_prf_key);
  674. new
  675. }
  676. pub fn reset(&mut self) {
  677. *self = Self::new(self.output_bitsize)
  678. }
  679. pub fn reset_preprocessing(&mut self) {
  680. self.num_preprocessed_invocations = 0;
  681. self.preprocessed_rerand_m1 = Default::default();
  682. self.preprocessed_mt_a = Default::default();
  683. self.preprocessed_mt_c1 = Default::default();
  684. self.preprocessed_mult_e = Default::default();
  685. }
  686. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  687. assert!(!self.is_initialized);
  688. // sample and share a PRF key with Party 2
  689. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  690. (self.shared_prg_1_2.as_ref().unwrap().get_seed(), ())
  691. }
  692. pub fn init_round_1(&mut self, _: (), shared_prg_seed_1_3: SharedSeed) {
  693. assert!(!self.is_initialized);
  694. // receive shared PRF key from Party 3
  695. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(shared_prg_seed_1_3));
  696. if self.legendre_prf_key.is_none() {
  697. // generate Legendre PRF key
  698. self.legendre_prf_key = Some(LegendrePrf::key_gen(self.output_bitsize));
  699. }
  700. self.is_initialized = true;
  701. }
  702. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  703. let fut_3_1 = comm.receive_previous()?;
  704. let (msg_1_2, _) = self.init_round_0();
  705. comm.send_next(msg_1_2)?;
  706. self.init_round_1((), fut_3_1.get()?);
  707. Ok(())
  708. }
  709. pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
  710. assert!(self.is_initialized);
  711. self.legendre_prf_key.as_ref().unwrap().clone()
  712. }
  713. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  714. assert!(self.is_initialized);
  715. let n = num * self.output_bitsize;
  716. self.preprocessed_rerand_m1
  717. .extend((0..num).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  718. self.preprocessed_mt_a
  719. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  720. self.preprocessed_mt_c1
  721. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  722. self.preprocessed_mult_e
  723. .extend((0..n).map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap())));
  724. ((), ())
  725. }
  726. pub fn preprocess_round_1(&mut self, num: usize, _: (), _: ()) {
  727. assert!(self.is_initialized);
  728. self.num_preprocessed_invocations += num;
  729. }
  730. pub fn preprocess<C: AbstractCommunicator>(
  731. &mut self,
  732. _comm: &mut C,
  733. num: usize,
  734. ) -> Result<(), Error> {
  735. self.preprocess_round_0(num);
  736. self.preprocess_round_1(num, (), ());
  737. Ok(())
  738. }
  739. pub fn get_num_preprocessed_invocations(&self) -> usize {
  740. self.num_preprocessed_invocations
  741. }
  742. pub fn get_preprocessed_data(&self) -> (&[F], &[F], &[F], &[F]) {
  743. (
  744. &self.preprocessed_rerand_m1,
  745. &self.preprocessed_mt_a,
  746. &self.preprocessed_mt_c1,
  747. &self.preprocessed_mult_e,
  748. )
  749. }
  750. pub fn check_preprocessing(&self) {
  751. let num = self.num_preprocessed_invocations;
  752. let n = num * self.output_bitsize;
  753. assert_eq!(self.preprocessed_rerand_m1.len(), num);
  754. assert_eq!(self.preprocessed_mt_a.len(), n);
  755. assert_eq!(self.preprocessed_mt_c1.len(), n);
  756. assert_eq!(self.preprocessed_mult_e.len(), n);
  757. }
  758. pub fn eval_round_0(&mut self, num: usize, shares1: &[F]) -> ((), Vec<F>) {
  759. assert!(num <= self.num_preprocessed_invocations);
  760. assert_eq!(shares1.len(), num);
  761. let n = num * self.output_bitsize;
  762. let k = &self.legendre_prf_key.as_ref().unwrap().keys;
  763. self.mult_d = izip!(
  764. k.iter().cycle(),
  765. shares1
  766. .iter()
  767. .flat_map(|s1| repeat(s1).take(self.output_bitsize)),
  768. self.preprocessed_rerand_m1
  769. .iter()
  770. .take(num)
  771. .flat_map(|m1| repeat(m1).take(self.output_bitsize)),
  772. self.preprocessed_mt_a.drain(0..n),
  773. )
  774. .map(|(&k_i, &s1_i, m1_i, a_i)| k_i + s1_i + m1_i - a_i)
  775. .collect();
  776. assert_eq!(self.mult_d.len(), n);
  777. ((), self.mult_d.clone())
  778. }
  779. pub fn eval_round_2(
  780. &mut self,
  781. num: usize,
  782. shares1: &[F],
  783. _: (),
  784. output_shares_z3: Vec<F>,
  785. ) -> Vec<BitVec> {
  786. assert!(num <= self.num_preprocessed_invocations);
  787. let n = num * self.output_bitsize;
  788. assert_eq!(shares1.len(), num);
  789. assert_eq!(output_shares_z3.len(), n);
  790. let k = &self.legendre_prf_key.as_ref().unwrap().keys;
  791. let lprf_inputs: Vec<F> = izip!(
  792. k.iter().cycle(),
  793. shares1
  794. .iter()
  795. .flat_map(|s1| repeat(s1).take(self.output_bitsize)),
  796. self.preprocessed_rerand_m1
  797. .drain(0..num)
  798. .flat_map(|m1| repeat(m1).take(self.output_bitsize)),
  799. self.preprocessed_mult_e.drain(0..n),
  800. self.mult_d.drain(..),
  801. self.preprocessed_mt_c1.drain(0..n),
  802. output_shares_z3.iter(),
  803. )
  804. .map(|(&k_j, &s1_i, m1_i, e_ij, d_ij, c1_ij, &z3_ij)| {
  805. e_ij * (k_j + s1_i + m1_i) + c1_ij + z3_ij - d_ij * e_ij
  806. })
  807. .collect();
  808. assert_eq!(lprf_inputs.len(), n);
  809. let output: Vec<BitVec> = lprf_inputs
  810. .chunks_exact(self.output_bitsize)
  811. .map(|chunk| {
  812. let mut bv = BitVec::with_capacity(self.output_bitsize);
  813. for &x in chunk.iter() {
  814. let ls = F::legendre_symbol(x);
  815. debug_assert!(ls != 0, "unlikely");
  816. bv.push(ls == 1);
  817. }
  818. bv
  819. })
  820. .collect();
  821. self.num_preprocessed_invocations -= num;
  822. output
  823. }
  824. pub fn eval<C: AbstractCommunicator>(
  825. &mut self,
  826. comm: &mut C,
  827. num: usize,
  828. shares1: &[F],
  829. ) -> Result<Vec<BitVec>, Error>
  830. where
  831. F: Serializable,
  832. {
  833. assert_eq!(shares1.len(), num);
  834. let fut_3_1 = comm.receive_previous()?;
  835. let (_, msg_1_3) = self.eval_round_0(num, shares1);
  836. comm.send_previous(msg_1_3)?;
  837. let output = self.eval_round_2(1, shares1, (), fut_3_1.get()?);
  838. Ok(output)
  839. }
  840. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  841. &mut self,
  842. comm: &mut C,
  843. num: usize,
  844. shares1: &[F],
  845. ) -> Result<Vec<T>, Error>
  846. where
  847. F: Serializable,
  848. {
  849. assert!(self.output_bitsize <= T::BITS as usize);
  850. Ok(to_uint(self.eval(comm, num, shares1)?))
  851. }
  852. }
  853. pub struct MaskedDOPrfParty2<F: LegendreSymbol> {
  854. _phantom: PhantomData<F>,
  855. output_bitsize: usize,
  856. shared_prg_1_2: Option<ChaChaRng>,
  857. shared_prg_2_3: Option<ChaChaRng>,
  858. is_initialized: bool,
  859. num_preprocessed_invocations: usize,
  860. preprocessed_rerand_m2: Vec<F>,
  861. preprocessed_r: BitVec,
  862. }
  863. impl<F> MaskedDOPrfParty2<F>
  864. where
  865. F: LegendreSymbol,
  866. {
  867. pub fn new(output_bitsize: usize) -> Self {
  868. assert!(output_bitsize > 0);
  869. Self {
  870. _phantom: PhantomData,
  871. output_bitsize,
  872. shared_prg_1_2: None,
  873. shared_prg_2_3: None,
  874. is_initialized: false,
  875. num_preprocessed_invocations: 0,
  876. preprocessed_rerand_m2: Default::default(),
  877. preprocessed_r: Default::default(),
  878. }
  879. }
  880. pub fn reset(&mut self) {
  881. *self = Self::new(self.output_bitsize)
  882. }
  883. pub fn reset_preprocessing(&mut self) {
  884. self.num_preprocessed_invocations = 0;
  885. self.preprocessed_rerand_m2 = Default::default();
  886. }
  887. pub fn init_round_0(&mut self) -> ((), SharedSeed) {
  888. assert!(!self.is_initialized);
  889. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  890. ((), self.shared_prg_2_3.as_ref().unwrap().get_seed())
  891. }
  892. pub fn init_round_1(&mut self, shared_prg_seed_1_2: SharedSeed, _: ()) {
  893. assert!(!self.is_initialized);
  894. // receive shared PRF key from Party 1
  895. self.shared_prg_1_2 = Some(ChaChaRng::from_seed(shared_prg_seed_1_2));
  896. self.is_initialized = true;
  897. }
  898. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  899. let fut_1_2 = comm.receive_previous()?;
  900. let (_, msg_2_3) = self.init_round_0();
  901. comm.send_next(msg_2_3)?;
  902. self.init_round_1(fut_1_2.get()?, ());
  903. Ok(())
  904. }
  905. pub fn preprocess_round_0(&mut self, num: usize) -> ((), Vec<F>) {
  906. assert!(self.is_initialized);
  907. let n = num * self.output_bitsize;
  908. let mut preprocessed_t: Vec<_> = (0..n)
  909. .map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()).square())
  910. .collect();
  911. debug_assert!(!preprocessed_t.contains(&F::ZERO));
  912. {
  913. let mut random_bytes = vec![0u8; (n + 7) / 8];
  914. self.shared_prg_2_3
  915. .as_mut()
  916. .unwrap()
  917. .fill_bytes(&mut random_bytes);
  918. let new_r_slice = BitSlice::from_slice(&random_bytes);
  919. self.preprocessed_r.extend(&new_r_slice[..n]);
  920. for (i, r_i) in new_r_slice.iter().by_vals().take(n).enumerate() {
  921. if r_i {
  922. preprocessed_t[i] *= F::get_non_random_qnr();
  923. }
  924. }
  925. }
  926. self.preprocessed_rerand_m2
  927. .extend((0..num).map(|_| -F::random(self.shared_prg_1_2.as_mut().unwrap())));
  928. let preprocessed_mt_a: Vec<F> = (0..n)
  929. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()))
  930. .collect();
  931. let preprocessed_mt_c1: Vec<F> = (0..n)
  932. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()))
  933. .collect();
  934. let preprocessed_mult_e: Vec<F> = (0..n)
  935. .map(|_| F::random(self.shared_prg_1_2.as_mut().unwrap()))
  936. .collect();
  937. let preprocessed_c3: Vec<F> = izip!(
  938. preprocessed_t.iter(),
  939. preprocessed_mult_e.iter(),
  940. preprocessed_mt_a.iter(),
  941. preprocessed_mt_c1.iter(),
  942. )
  943. .map(|(&t, &e, &a, &c1)| a * (t - e) - c1)
  944. .collect();
  945. self.num_preprocessed_invocations += num;
  946. ((), preprocessed_c3)
  947. }
  948. pub fn preprocess_round_1(&mut self, _: usize, _: (), _: ()) {
  949. assert!(self.is_initialized);
  950. }
  951. pub fn preprocess<C: AbstractCommunicator>(
  952. &mut self,
  953. comm: &mut C,
  954. num: usize,
  955. ) -> Result<(), Error>
  956. where
  957. F: Serializable,
  958. {
  959. let (_, msg_2_3) = self.preprocess_round_0(num);
  960. comm.send_next(msg_2_3)?;
  961. self.preprocess_round_1(num, (), ());
  962. Ok(())
  963. }
  964. pub fn get_num_preprocessed_invocations(&self) -> usize {
  965. self.num_preprocessed_invocations
  966. }
  967. pub fn get_preprocessed_data(&self) -> (&BitSlice, &[F]) {
  968. (&self.preprocessed_r, &self.preprocessed_rerand_m2)
  969. }
  970. pub fn check_preprocessing(&self) {
  971. let num = self.num_preprocessed_invocations;
  972. assert_eq!(self.preprocessed_rerand_m2.len(), num);
  973. }
  974. pub fn eval_round_0(&mut self, num: usize, shares2: &[F]) -> ((), Vec<F>) {
  975. assert!(num <= self.num_preprocessed_invocations);
  976. assert_eq!(shares2.len(), num);
  977. let masked_shares2: Vec<F> =
  978. izip!(shares2.iter(), self.preprocessed_rerand_m2.drain(0..num),)
  979. .map(|(&s2i, m2i)| s2i + m2i)
  980. .collect();
  981. assert_eq!(masked_shares2.len(), num);
  982. ((), masked_shares2)
  983. }
  984. pub fn eval_get_output(&mut self, num: usize) -> Vec<BitVec> {
  985. assert!(num <= self.num_preprocessed_invocations);
  986. let n = num * self.output_bitsize;
  987. let mut output = Vec::with_capacity(num);
  988. for chunk in self
  989. .preprocessed_r
  990. .chunks_exact(self.output_bitsize)
  991. .take(num)
  992. {
  993. output.push(chunk.to_bitvec());
  994. }
  995. let (_, last_r) = self.preprocessed_r.split_at(n);
  996. self.preprocessed_r = last_r.to_bitvec();
  997. self.num_preprocessed_invocations -= num;
  998. output
  999. }
  1000. pub fn eval<C: AbstractCommunicator>(
  1001. &mut self,
  1002. comm: &mut C,
  1003. num: usize,
  1004. shares2: &[F],
  1005. ) -> Result<Vec<BitVec>, Error>
  1006. where
  1007. F: Serializable,
  1008. {
  1009. assert_eq!(shares2.len(), num);
  1010. let (_, msg_2_3) = self.eval_round_0(num, shares2);
  1011. comm.send_next(msg_2_3)?;
  1012. let output = self.eval_get_output(num);
  1013. Ok(output)
  1014. }
  1015. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  1016. &mut self,
  1017. comm: &mut C,
  1018. num: usize,
  1019. shares2: &[F],
  1020. ) -> Result<Vec<T>, Error>
  1021. where
  1022. F: Serializable,
  1023. {
  1024. assert!(self.output_bitsize <= T::BITS as usize);
  1025. Ok(to_uint(self.eval(comm, num, shares2)?))
  1026. }
  1027. }
  1028. pub struct MaskedDOPrfParty3<F: LegendreSymbol> {
  1029. _phantom: PhantomData<F>,
  1030. output_bitsize: usize,
  1031. shared_prg_1_3: Option<ChaChaRng>,
  1032. shared_prg_2_3: Option<ChaChaRng>,
  1033. is_initialized: bool,
  1034. num_preprocessed_invocations: usize,
  1035. preprocessed_r: BitVec,
  1036. preprocessed_t: Vec<F>,
  1037. preprocessed_mt_c3: Vec<F>,
  1038. }
  1039. impl<F> MaskedDOPrfParty3<F>
  1040. where
  1041. F: LegendreSymbol,
  1042. {
  1043. pub fn new(output_bitsize: usize) -> Self {
  1044. assert!(output_bitsize > 0);
  1045. Self {
  1046. _phantom: PhantomData,
  1047. output_bitsize,
  1048. shared_prg_1_3: None,
  1049. shared_prg_2_3: None,
  1050. is_initialized: false,
  1051. num_preprocessed_invocations: 0,
  1052. preprocessed_r: Default::default(),
  1053. preprocessed_t: Default::default(),
  1054. preprocessed_mt_c3: Default::default(),
  1055. }
  1056. }
  1057. pub fn reset(&mut self) {
  1058. *self = Self::new(self.output_bitsize)
  1059. }
  1060. pub fn reset_preprocessing(&mut self) {
  1061. self.num_preprocessed_invocations = 0;
  1062. self.preprocessed_t = Default::default();
  1063. self.preprocessed_mt_c3 = Default::default();
  1064. }
  1065. pub fn init_round_0(&mut self) -> (SharedSeed, ()) {
  1066. assert!(!self.is_initialized);
  1067. self.shared_prg_1_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  1068. (self.shared_prg_1_3.as_ref().unwrap().get_seed(), ())
  1069. }
  1070. pub fn init_round_1(&mut self, _: (), shared_prg_seed_2_3: SharedSeed) {
  1071. self.shared_prg_2_3 = Some(ChaChaRng::from_seed(shared_prg_seed_2_3));
  1072. self.is_initialized = true;
  1073. }
  1074. pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  1075. let fut_2_3 = comm.receive_previous()?;
  1076. let (msg_3_1, _) = self.init_round_0();
  1077. comm.send_next(msg_3_1)?;
  1078. self.init_round_1((), fut_2_3.get()?);
  1079. Ok(())
  1080. }
  1081. pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
  1082. assert!(self.is_initialized);
  1083. let n = num * self.output_bitsize;
  1084. let start_index = self.num_preprocessed_invocations * self.output_bitsize;
  1085. self.preprocessed_t
  1086. .extend((0..n).map(|_| F::random(self.shared_prg_2_3.as_mut().unwrap()).square()));
  1087. debug_assert!(!self.preprocessed_t[start_index..].contains(&F::ZERO));
  1088. {
  1089. let mut random_bytes = vec![0u8; (n + 7) / 8];
  1090. self.shared_prg_2_3
  1091. .as_mut()
  1092. .unwrap()
  1093. .fill_bytes(&mut random_bytes);
  1094. let new_r_slice = BitSlice::from_slice(&random_bytes);
  1095. self.preprocessed_r.extend(&new_r_slice[..n]);
  1096. for (i, r_i) in new_r_slice.iter().by_vals().take(n).enumerate() {
  1097. if r_i {
  1098. self.preprocessed_t[start_index + i] *= F::get_non_random_qnr();
  1099. }
  1100. }
  1101. }
  1102. ((), ())
  1103. }
  1104. pub fn preprocess_round_1(&mut self, num: usize, _: (), preprocessed_mt_c3: Vec<F>) {
  1105. assert!(self.is_initialized);
  1106. let n = num * self.output_bitsize;
  1107. assert_eq!(preprocessed_mt_c3.len(), n);
  1108. self.preprocessed_mt_c3.extend(preprocessed_mt_c3);
  1109. self.num_preprocessed_invocations += num;
  1110. }
  1111. pub fn preprocess<C: AbstractCommunicator>(
  1112. &mut self,
  1113. comm: &mut C,
  1114. num: usize,
  1115. ) -> Result<(), Error>
  1116. where
  1117. F: Serializable,
  1118. {
  1119. let fut_2_3 = comm.receive_previous()?;
  1120. self.preprocess_round_0(num);
  1121. self.preprocess_round_1(num, (), fut_2_3.get()?);
  1122. Ok(())
  1123. }
  1124. pub fn get_num_preprocessed_invocations(&self) -> usize {
  1125. self.num_preprocessed_invocations
  1126. }
  1127. pub fn get_preprocessed_data(&self) -> (&BitSlice, &[F], &[F]) {
  1128. (
  1129. &self.preprocessed_r,
  1130. &self.preprocessed_t,
  1131. &self.preprocessed_mt_c3,
  1132. )
  1133. }
  1134. pub fn check_preprocessing(&self) {
  1135. let num = self.num_preprocessed_invocations;
  1136. let n = num * self.output_bitsize;
  1137. assert_eq!(self.preprocessed_t.len(), n);
  1138. assert_eq!(self.preprocessed_mt_c3.len(), n);
  1139. }
  1140. pub fn eval_round_1(
  1141. &mut self,
  1142. num: usize,
  1143. shares3: &[F],
  1144. mult_d: &[F],
  1145. masked_shares2: &[F],
  1146. ) -> (Vec<F>, ()) {
  1147. assert!(num <= self.num_preprocessed_invocations);
  1148. let n = num * self.output_bitsize;
  1149. assert_eq!(shares3.len(), num);
  1150. assert_eq!(masked_shares2.len(), num);
  1151. assert_eq!(mult_d.len(), n);
  1152. let output_shares_z3: Vec<F> = izip!(
  1153. shares3
  1154. .iter()
  1155. .flat_map(|s1i| repeat(s1i).take(self.output_bitsize)),
  1156. masked_shares2
  1157. .iter()
  1158. .flat_map(|ms2i| repeat(ms2i).take(self.output_bitsize)),
  1159. self.preprocessed_t.drain(0..n),
  1160. self.preprocessed_mt_c3.drain(0..n),
  1161. mult_d,
  1162. )
  1163. .map(|(&s3_i, &ms2_i, t_ij, c3_ij, &d_ij)| t_ij * (s3_i + ms2_i) + d_ij * t_ij + c3_ij)
  1164. .collect();
  1165. (output_shares_z3, ())
  1166. }
  1167. pub fn eval_get_output(&mut self, num: usize) -> Vec<BitVec> {
  1168. assert!(num <= self.num_preprocessed_invocations);
  1169. let n = num * self.output_bitsize;
  1170. let mut output = Vec::with_capacity(num);
  1171. for chunk in self
  1172. .preprocessed_r
  1173. .chunks_exact(self.output_bitsize)
  1174. .take(num)
  1175. {
  1176. output.push(chunk.to_bitvec());
  1177. }
  1178. let (_, last_r) = self.preprocessed_r.split_at(n);
  1179. self.preprocessed_r = last_r.to_bitvec();
  1180. self.num_preprocessed_invocations -= num;
  1181. output
  1182. }
  1183. pub fn eval<C: AbstractCommunicator>(
  1184. &mut self,
  1185. comm: &mut C,
  1186. num: usize,
  1187. shares3: &[F],
  1188. ) -> Result<Vec<BitVec>, Error>
  1189. where
  1190. F: Serializable,
  1191. {
  1192. assert_eq!(shares3.len(), num);
  1193. let fut_1_3 = comm.receive_next::<Vec<_>>()?;
  1194. let fut_2_3 = comm.receive_previous::<Vec<_>>()?;
  1195. let (msg_3_1, _) = self.eval_round_1(1, shares3, &fut_1_3.get()?, &fut_2_3.get()?);
  1196. comm.send_next(msg_3_1)?;
  1197. let output = self.eval_get_output(num);
  1198. Ok(output)
  1199. }
  1200. pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
  1201. &mut self,
  1202. comm: &mut C,
  1203. num: usize,
  1204. shares3: &[F],
  1205. ) -> Result<Vec<T>, Error>
  1206. where
  1207. F: Serializable,
  1208. {
  1209. assert!(self.output_bitsize <= T::BITS as usize);
  1210. Ok(to_uint(self.eval(comm, num, shares3)?))
  1211. }
  1212. }
  1213. #[cfg(test)]
  1214. mod tests {
  1215. use super::*;
  1216. use bincode;
  1217. use ff::Field;
  1218. use utils::field::Fp;
  1219. fn doprf_init(
  1220. party_1: &mut DOPrfParty1<Fp>,
  1221. party_2: &mut DOPrfParty2<Fp>,
  1222. party_3: &mut DOPrfParty3<Fp>,
  1223. ) {
  1224. let (msg_1_2, msg_1_3) = party_1.init_round_0();
  1225. let (msg_2_1, msg_2_3) = party_2.init_round_0();
  1226. let (msg_3_1, msg_3_2) = party_3.init_round_0();
  1227. party_1.init_round_1(msg_2_1, msg_3_1);
  1228. party_2.init_round_1(msg_1_2, msg_3_2);
  1229. party_3.init_round_1(msg_1_3, msg_2_3);
  1230. }
  1231. fn doprf_preprocess(
  1232. party_1: &mut DOPrfParty1<Fp>,
  1233. party_2: &mut DOPrfParty2<Fp>,
  1234. party_3: &mut DOPrfParty3<Fp>,
  1235. num: usize,
  1236. ) {
  1237. let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
  1238. let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
  1239. let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
  1240. party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
  1241. party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
  1242. party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
  1243. }
  1244. fn doprf_eval(
  1245. party_1: &mut DOPrfParty1<Fp>,
  1246. party_2: &mut DOPrfParty2<Fp>,
  1247. party_3: &mut DOPrfParty3<Fp>,
  1248. shares_1: &[Fp],
  1249. shares_2: &[Fp],
  1250. shares_3: &[Fp],
  1251. num: usize,
  1252. ) -> Vec<BitVec> {
  1253. assert_eq!(shares_1.len(), num);
  1254. assert_eq!(shares_2.len(), num);
  1255. assert_eq!(shares_3.len(), num);
  1256. let (msg_2_1, msg_2_3) = party_2.eval_round_0(num, &shares_2);
  1257. let (msg_3_1, _) = party_3.eval_round_0(num, &shares_3);
  1258. let (_, msg_1_3) = party_1.eval_round_1(num, &shares_1, &msg_2_1, &msg_3_1);
  1259. let output = party_3.eval_round_2(num, &shares_3, msg_1_3, msg_2_3);
  1260. output
  1261. }
  1262. #[test]
  1263. fn test_doprf() {
  1264. let output_bitsize = 42;
  1265. let mut party_1 = DOPrfParty1::<Fp>::new(output_bitsize);
  1266. let mut party_2 = DOPrfParty2::<Fp>::new(output_bitsize);
  1267. let mut party_3 = DOPrfParty3::<Fp>::new(output_bitsize);
  1268. doprf_init(&mut party_1, &mut party_2, &mut party_3);
  1269. // preprocess num invocations
  1270. let num = 20;
  1271. doprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1272. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1273. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1274. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1275. party_1.check_preprocessing();
  1276. party_2.check_preprocessing();
  1277. party_3.check_preprocessing();
  1278. // preprocess another n invocations
  1279. doprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1280. let num = 2 * num;
  1281. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1282. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1283. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1284. party_1.check_preprocessing();
  1285. party_2.check_preprocessing();
  1286. party_3.check_preprocessing();
  1287. // verify preprocessed data
  1288. {
  1289. let n = num * output_bitsize;
  1290. let (squares, mt_c1) = party_1.get_preprocessed_data();
  1291. let rerand_m2 = party_2.get_preprocessed_data();
  1292. let (rerand_m3, mt_b, mt_c3, mult_d) = party_3.get_preprocessed_data();
  1293. assert_eq!(squares.len(), n);
  1294. assert!(squares.iter().all(|&x| Fp::legendre_symbol(x) == 1));
  1295. assert_eq!(rerand_m2.len(), num);
  1296. assert_eq!(rerand_m3.len(), num);
  1297. assert!(izip!(rerand_m2.iter(), rerand_m3.iter()).all(|(&m2, &m3)| m2 + m3 == Fp::ZERO));
  1298. let mt_a: Vec<Fp> = squares
  1299. .iter()
  1300. .zip(mult_d.iter())
  1301. .map(|(&s, &d)| s - d)
  1302. .collect();
  1303. assert_eq!(mult_d.len(), n);
  1304. assert_eq!(mt_a.len(), n);
  1305. assert_eq!(mt_b.len(), num);
  1306. assert_eq!(mt_c1.len(), n);
  1307. assert_eq!(mt_c3.len(), n);
  1308. let mut triple_it = izip!(
  1309. mt_a.iter(),
  1310. mt_b.iter().flat_map(|b| repeat(b).take(output_bitsize)),
  1311. mt_c1.iter(),
  1312. mt_c3.iter()
  1313. );
  1314. assert_eq!(triple_it.clone().count(), n);
  1315. assert!(triple_it.all(|(&a, &b, &c1, &c3)| a * b == c1 + c3));
  1316. }
  1317. // perform n evaluations
  1318. let num = 15;
  1319. let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1320. let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1321. let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1322. let output = doprf_eval(
  1323. &mut party_1,
  1324. &mut party_2,
  1325. &mut party_3,
  1326. &shares_1,
  1327. &shares_2,
  1328. &shares_3,
  1329. num,
  1330. );
  1331. assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
  1332. assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
  1333. assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
  1334. party_1.check_preprocessing();
  1335. party_2.check_preprocessing();
  1336. party_3.check_preprocessing();
  1337. assert_eq!(output.len(), num);
  1338. assert!(output.iter().all(|bv| bv.len() == output_bitsize));
  1339. // check that the output matches the non-distributed version
  1340. let legendre_prf_key = party_1.get_legendre_prf_key();
  1341. for i in 0..num {
  1342. let input_i = shares_1[i] + shares_2[i] + shares_3[i];
  1343. let output_i = LegendrePrf::<Fp>::eval_bits(&legendre_prf_key, input_i);
  1344. assert_eq!(output[i], output_i);
  1345. }
  1346. }
  1347. fn mdoprf_init(
  1348. party_1: &mut MaskedDOPrfParty1<Fp>,
  1349. party_2: &mut MaskedDOPrfParty2<Fp>,
  1350. party_3: &mut MaskedDOPrfParty3<Fp>,
  1351. ) {
  1352. let (msg_1_2, msg_1_3) = party_1.init_round_0();
  1353. let (msg_2_1, msg_2_3) = party_2.init_round_0();
  1354. let (msg_3_1, msg_3_2) = party_3.init_round_0();
  1355. party_1.init_round_1(msg_2_1, msg_3_1);
  1356. party_2.init_round_1(msg_1_2, msg_3_2);
  1357. party_3.init_round_1(msg_1_3, msg_2_3);
  1358. }
  1359. fn mdoprf_preprocess(
  1360. party_1: &mut MaskedDOPrfParty1<Fp>,
  1361. party_2: &mut MaskedDOPrfParty2<Fp>,
  1362. party_3: &mut MaskedDOPrfParty3<Fp>,
  1363. num: usize,
  1364. ) {
  1365. let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
  1366. let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
  1367. let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
  1368. party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
  1369. party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
  1370. party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
  1371. }
  1372. fn mdoprf_eval(
  1373. party_1: &mut MaskedDOPrfParty1<Fp>,
  1374. party_2: &mut MaskedDOPrfParty2<Fp>,
  1375. party_3: &mut MaskedDOPrfParty3<Fp>,
  1376. shares_1: &[Fp],
  1377. shares_2: &[Fp],
  1378. shares_3: &[Fp],
  1379. num: usize,
  1380. ) -> (Vec<BitVec>, Vec<BitVec>, Vec<BitVec>) {
  1381. assert_eq!(shares_1.len(), num);
  1382. assert_eq!(shares_2.len(), num);
  1383. assert_eq!(shares_3.len(), num);
  1384. let (_, msg_1_3) = party_1.eval_round_0(num, &shares_1);
  1385. let (_, msg_2_3) = party_2.eval_round_0(num, &shares_2);
  1386. let (msg_3_1, _) = party_3.eval_round_1(num, &shares_3, &msg_1_3, &msg_2_3);
  1387. let masked_output = party_1.eval_round_2(num, &shares_1, (), msg_3_1);
  1388. let mask2 = party_2.eval_get_output(num);
  1389. let mask3 = party_3.eval_get_output(num);
  1390. (masked_output, mask2, mask3)
  1391. }
  1392. #[test]
  1393. fn test_masked_doprf() {
  1394. let output_bitsize = 42;
  1395. let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
  1396. let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
  1397. let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
  1398. mdoprf_init(&mut party_1, &mut party_2, &mut party_3);
  1399. // preprocess num invocations
  1400. let num = 20;
  1401. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1402. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1403. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1404. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1405. party_1.check_preprocessing();
  1406. party_2.check_preprocessing();
  1407. party_3.check_preprocessing();
  1408. // preprocess another n invocations
  1409. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1410. let num = 2 * num;
  1411. assert_eq!(party_1.get_num_preprocessed_invocations(), num);
  1412. assert_eq!(party_2.get_num_preprocessed_invocations(), num);
  1413. assert_eq!(party_3.get_num_preprocessed_invocations(), num);
  1414. party_1.check_preprocessing();
  1415. party_2.check_preprocessing();
  1416. party_3.check_preprocessing();
  1417. // verify preprocessed data
  1418. {
  1419. let n = num * output_bitsize;
  1420. let (rerand_m1, mt_a, mt_c1, mult_e) = party_1.get_preprocessed_data();
  1421. let (r2, rerand_m2) = party_2.get_preprocessed_data();
  1422. let (r3, ts, mt_c3) = party_3.get_preprocessed_data();
  1423. assert_eq!(r2.len(), n);
  1424. assert_eq!(r2, r3);
  1425. assert_eq!(ts.len(), n);
  1426. assert!(r2.iter().by_vals().zip(ts.iter()).all(|(r_i, &t_i)| {
  1427. if r_i {
  1428. Fp::legendre_symbol(t_i) == -1
  1429. } else {
  1430. Fp::legendre_symbol(t_i) == 1
  1431. }
  1432. }));
  1433. assert_eq!(rerand_m1.len(), num);
  1434. assert_eq!(rerand_m2.len(), num);
  1435. assert!(izip!(rerand_m1.iter(), rerand_m2.iter()).all(|(&m1, &m2)| m1 + m2 == Fp::ZERO));
  1436. let mt_b: Vec<Fp> = ts.iter().zip(mult_e.iter()).map(|(&t, &e)| t - e).collect();
  1437. assert_eq!(mult_e.len(), n);
  1438. assert_eq!(mt_a.len(), n);
  1439. assert_eq!(mt_b.len(), n);
  1440. assert_eq!(mt_c1.len(), n);
  1441. assert_eq!(mt_c3.len(), n);
  1442. let mut triple_it = izip!(mt_a.iter(), mt_b.iter(), mt_c1.iter(), mt_c3.iter());
  1443. assert_eq!(triple_it.clone().count(), n);
  1444. assert!(triple_it.all(|(&a, &b, &c1, &c3)| a * b == c1 + c3));
  1445. }
  1446. // perform n evaluations
  1447. let num = 15;
  1448. let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1449. let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1450. let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
  1451. let (masked_output, mask2, mask3) = mdoprf_eval(
  1452. &mut party_1,
  1453. &mut party_2,
  1454. &mut party_3,
  1455. &shares_1,
  1456. &shares_2,
  1457. &shares_3,
  1458. num,
  1459. );
  1460. assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
  1461. assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
  1462. assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
  1463. party_1.check_preprocessing();
  1464. party_2.check_preprocessing();
  1465. party_3.check_preprocessing();
  1466. assert_eq!(masked_output.len(), num);
  1467. assert!(masked_output.iter().all(|bv| bv.len() == output_bitsize));
  1468. assert_eq!(mask2.len(), num);
  1469. assert_eq!(mask2, mask3);
  1470. assert!(mask2.iter().all(|bv| bv.len() == output_bitsize));
  1471. // check that the output matches the non-distributed version
  1472. let legendre_prf_key = party_1.get_legendre_prf_key();
  1473. for i in 0..num {
  1474. let input_i = shares_1[i] + shares_2[i] + shares_3[i];
  1475. let expected_output_i = LegendrePrf::<Fp>::eval_bits(&legendre_prf_key, input_i);
  1476. let output_i = masked_output[i].clone() ^ &mask2[i];
  1477. assert_eq!(output_i, expected_output_i);
  1478. }
  1479. // preprocess another n invocations
  1480. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
  1481. // perform another n evaluations on the same inputs
  1482. let num = 15;
  1483. let (new_masked_output, new_mask2, new_mask3) = mdoprf_eval(
  1484. &mut party_1,
  1485. &mut party_2,
  1486. &mut party_3,
  1487. &shares_1,
  1488. &shares_2,
  1489. &shares_3,
  1490. num,
  1491. );
  1492. assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
  1493. assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
  1494. assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
  1495. party_1.check_preprocessing();
  1496. party_2.check_preprocessing();
  1497. party_3.check_preprocessing();
  1498. assert_eq!(new_masked_output.len(), num);
  1499. assert!(new_masked_output
  1500. .iter()
  1501. .all(|bv| bv.len() == output_bitsize));
  1502. assert_eq!(new_mask2.len(), num);
  1503. assert_eq!(new_mask2, new_mask3);
  1504. assert!(new_mask2.iter().all(|bv| bv.len() == output_bitsize));
  1505. // check that the new output matches the previous one
  1506. for i in 0..num {
  1507. let expected_output_i = masked_output[i].clone() ^ &mask2[i];
  1508. let output_i = new_masked_output[i].clone() ^ &new_mask2[i];
  1509. assert_eq!(output_i, expected_output_i);
  1510. }
  1511. }
  1512. #[test]
  1513. fn test_masked_doprf_single() {
  1514. let output_bitsize = 42;
  1515. let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
  1516. let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
  1517. let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
  1518. mdoprf_init(&mut party_1, &mut party_2, &mut party_3);
  1519. let share_1 = Fp::random(thread_rng());
  1520. let share_2 = Fp::random(thread_rng());
  1521. let share_3 = Fp::random(thread_rng());
  1522. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
  1523. let (masked_output_1, mask2_1, mask3_1) = mdoprf_eval(
  1524. &mut party_1,
  1525. &mut party_2,
  1526. &mut party_3,
  1527. &[share_1],
  1528. &[share_2],
  1529. &[share_3],
  1530. 1,
  1531. );
  1532. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
  1533. let (masked_output_2, mask2_2, mask3_2) = mdoprf_eval(
  1534. &mut party_1,
  1535. &mut party_2,
  1536. &mut party_3,
  1537. &[share_1],
  1538. &[share_2],
  1539. &[share_3],
  1540. 1,
  1541. );
  1542. mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
  1543. let (masked_output_3, mask2_3, mask3_3) = mdoprf_eval(
  1544. &mut party_1,
  1545. &mut party_2,
  1546. &mut party_3,
  1547. &[share_1],
  1548. &[share_2],
  1549. &[share_3],
  1550. 1,
  1551. );
  1552. assert_eq!(mask2_1, mask3_1);
  1553. assert_eq!(mask2_2, mask3_2);
  1554. assert_eq!(mask2_3, mask3_3);
  1555. let plain_output = masked_output_1[0].clone() ^ mask2_1[0].clone();
  1556. assert_eq!(
  1557. masked_output_2[0].clone() ^ mask2_2[0].clone(),
  1558. plain_output
  1559. );
  1560. assert_eq!(
  1561. masked_output_3[0].clone() ^ mask2_3[0].clone(),
  1562. plain_output
  1563. );
  1564. }
  1565. #[test]
  1566. fn test_serialization() {
  1567. let original_key = LegendrePrf::<Fp>::key_gen(42);
  1568. let encoded_key =
  1569. bincode::encode_to_vec(&original_key, bincode::config::standard()).unwrap();
  1570. let (decoded_key, _size): (LegendrePrfKey<Fp>, usize) =
  1571. bincode::decode_from_slice(&encoded_key, bincode::config::standard()).unwrap();
  1572. assert_eq!(decoded_key, original_key);
  1573. }
  1574. }