stash.rs 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162
  1. //! Stash protocol implementation.
  2. use crate::common::{Error, InstructionShare};
  3. use crate::doprf::{
  4. DOPrfParty1, DOPrfParty2, DOPrfParty3, LegendrePrf, MaskedDOPrfParty1, MaskedDOPrfParty2,
  5. MaskedDOPrfParty3,
  6. };
  7. use crate::mask_index::{MaskIndex, MaskIndexProtocol};
  8. use crate::select::{Select, SelectProtocol};
  9. use communicator::{AbstractCommunicator, Fut, Serializable};
  10. use dpf::spdpf::SinglePointDpf;
  11. use ff::PrimeField;
  12. use rand::thread_rng;
  13. use rayon::prelude::*;
  14. use std::marker::PhantomData;
  15. use std::time::{Duration, Instant};
  16. use utils::field::LegendreSymbol;
  17. /// Result of a stash read.
  18. ///
  19. /// All values are shared.
  20. #[derive(Clone, Copy, Debug, Default)]
  21. pub struct StashStateShare<F: PrimeField> {
  22. /// Share of 1 if the searched address was present in the stash, and share of 0 otherwise.
  23. pub flag: F,
  24. /// Possible location of the found entry in the stash.
  25. pub location: F,
  26. /// Possible value of the found entry.
  27. pub value: F,
  28. }
  29. /// State of the stash protocol.
  30. #[derive(Clone, Copy, Debug, PartialEq, Eq)]
  31. enum State {
  32. New,
  33. AwaitingRead,
  34. AwaitingWrite,
  35. AccessesExhausted,
  36. }
  37. const PARTY_1: usize = 0;
  38. const PARTY_2: usize = 1;
  39. const PARTY_3: usize = 2;
  40. /// Definition of the stash interface.
  41. pub trait Stash<F: PrimeField> {
  42. /// Return ID of the current party.
  43. fn get_party_id(&self) -> usize;
  44. /// Return capacity of the stash.
  45. fn get_stash_size(&self) -> usize;
  46. /// Return current access counter.
  47. fn get_access_counter(&self) -> usize;
  48. /// Reset the data structure to be used again.
  49. fn reset(&mut self);
  50. /// Initialize the stash.
  51. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error>;
  52. /// Perform a read from the stash.
  53. fn read<C: AbstractCommunicator>(
  54. &mut self,
  55. comm: &mut C,
  56. instruction: InstructionShare<F>,
  57. ) -> Result<StashStateShare<F>, Error>;
  58. /// Perform a write into the stash.
  59. fn write<C: AbstractCommunicator>(
  60. &mut self,
  61. comm: &mut C,
  62. instruction: InstructionShare<F>,
  63. stash_state: StashStateShare<F>,
  64. db_address_share: F,
  65. db_value_share: F,
  66. ) -> Result<(), Error>;
  67. /// Get an additive share of the stash.
  68. fn get_stash_share(&self) -> (&[F], &[F], &[F]);
  69. }
  70. fn compute_stash_prf_output_bitsize(stash_size: usize) -> usize {
  71. (usize::BITS - stash_size.leading_zeros()) as usize + 40
  72. }
  73. /// Protocol steps of the stash initialization, read, and write.
  74. #[allow(missing_docs)]
  75. #[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter, strum_macros::Display)]
  76. pub enum ProtocolStep {
  77. Init = 0,
  78. ReadMaskedAddressTag,
  79. ReadDpfKeyGen,
  80. ReadLookupFlagLocation,
  81. ReadComputeLocation,
  82. ReadReshareFlag,
  83. ReadConvertToReplicated,
  84. ReadComputeMaskedIndex,
  85. ReadDpfKeyDistribution,
  86. ReadDpfEvaluations,
  87. WriteAddressTag,
  88. WriteStoreTriple,
  89. WriteSelectPreviousValue,
  90. WriteSelectValue,
  91. WriteComputeMaskedIndex,
  92. WriteDpfKeyDistribution,
  93. WriteDpfEvaluations,
  94. }
  95. /// Collection of accumulated runtimes for the protocol steps.
  96. #[derive(Debug, Default, Clone, Copy)]
  97. pub struct Runtimes {
  98. durations: [Duration; 17],
  99. }
  100. impl Runtimes {
  101. /// Add another duration to the accumulated runtimes for a protocol step.
  102. #[inline(always)]
  103. pub fn record(&mut self, id: ProtocolStep, duration: Duration) {
  104. self.durations[id as usize] += duration;
  105. }
  106. /// Get the accumulated durations for a protocol step.
  107. pub fn get(&self, id: ProtocolStep) -> Duration {
  108. self.durations[id as usize]
  109. }
  110. }
  111. /// Implementation of the stash protocol.
  112. pub struct StashProtocol<F, SPDPF>
  113. where
  114. F: PrimeField + LegendreSymbol + Serializable,
  115. SPDPF: SinglePointDpf<Value = F>,
  116. {
  117. party_id: usize,
  118. stash_size: usize,
  119. access_counter: usize,
  120. state: State,
  121. stash_addresses_share: Vec<F>,
  122. stash_values_share: Vec<F>,
  123. stash_old_values_share: Vec<F>,
  124. address_tag_list: Vec<u64>,
  125. select_party: Option<SelectProtocol<F>>,
  126. doprf_party_1: Option<DOPrfParty1<F>>,
  127. doprf_party_2: Option<DOPrfParty2<F>>,
  128. doprf_party_3: Option<DOPrfParty3<F>>,
  129. masked_doprf_party_1: Option<MaskedDOPrfParty1<F>>,
  130. masked_doprf_party_2: Option<MaskedDOPrfParty2<F>>,
  131. masked_doprf_party_3: Option<MaskedDOPrfParty3<F>>,
  132. _phantom: PhantomData<SPDPF>,
  133. }
  134. impl<F, SPDPF> StashProtocol<F, SPDPF>
  135. where
  136. F: PrimeField + LegendreSymbol + Serializable,
  137. SPDPF: SinglePointDpf<Value = F>,
  138. SPDPF::Key: Serializable + Sync,
  139. {
  140. /// Create new instance of the stash protocol for a party `{0, 1, 2}` and given size.
  141. pub fn new(party_id: usize, stash_size: usize) -> Self {
  142. assert!(party_id < 3);
  143. assert!(stash_size > 0);
  144. assert!(compute_stash_prf_output_bitsize(stash_size) <= 64);
  145. Self {
  146. party_id,
  147. stash_size,
  148. access_counter: 0,
  149. state: State::New,
  150. stash_addresses_share: Vec::with_capacity(stash_size),
  151. stash_values_share: Vec::with_capacity(stash_size),
  152. stash_old_values_share: Vec::with_capacity(stash_size),
  153. address_tag_list: if party_id == PARTY_1 {
  154. Default::default()
  155. } else {
  156. Vec::with_capacity(stash_size)
  157. },
  158. select_party: None,
  159. doprf_party_1: None,
  160. doprf_party_2: None,
  161. doprf_party_3: None,
  162. masked_doprf_party_1: None,
  163. masked_doprf_party_2: None,
  164. masked_doprf_party_3: None,
  165. _phantom: PhantomData,
  166. }
  167. }
  168. fn init_with_runtimes<C: AbstractCommunicator>(
  169. &mut self,
  170. comm: &mut C,
  171. runtimes: Option<Runtimes>,
  172. ) -> Result<Option<Runtimes>, Error> {
  173. assert_eq!(self.state, State::New);
  174. let t_start = Instant::now();
  175. let prf_output_bitsize = compute_stash_prf_output_bitsize(self.stash_size);
  176. let legendre_prf_key = LegendrePrf::<F>::key_gen(prf_output_bitsize);
  177. // run DOPRF initilization
  178. match self.party_id {
  179. PARTY_1 => {
  180. let mut doprf_p1 = DOPrfParty1::from_legendre_prf_key(legendre_prf_key.clone());
  181. let mut mdoprf_p1 = MaskedDOPrfParty1::from_legendre_prf_key(legendre_prf_key);
  182. doprf_p1.init(comm)?;
  183. mdoprf_p1.init(comm)?;
  184. doprf_p1.preprocess(comm, self.stash_size)?;
  185. mdoprf_p1.preprocess(comm, self.stash_size)?;
  186. self.doprf_party_1 = Some(doprf_p1);
  187. self.masked_doprf_party_1 = Some(mdoprf_p1);
  188. }
  189. PARTY_2 => {
  190. let mut doprf_p2 = DOPrfParty2::new(prf_output_bitsize);
  191. let mut mdoprf_p2 = MaskedDOPrfParty2::new(prf_output_bitsize);
  192. doprf_p2.init(comm)?;
  193. mdoprf_p2.init(comm)?;
  194. doprf_p2.preprocess(comm, self.stash_size)?;
  195. mdoprf_p2.preprocess(comm, self.stash_size)?;
  196. self.doprf_party_2 = Some(doprf_p2);
  197. self.masked_doprf_party_2 = Some(mdoprf_p2);
  198. }
  199. PARTY_3 => {
  200. let mut doprf_p3 = DOPrfParty3::new(prf_output_bitsize);
  201. let mut mdoprf_p3 = MaskedDOPrfParty3::new(prf_output_bitsize);
  202. doprf_p3.init(comm)?;
  203. mdoprf_p3.init(comm)?;
  204. doprf_p3.preprocess(comm, self.stash_size)?;
  205. mdoprf_p3.preprocess(comm, self.stash_size)?;
  206. self.doprf_party_3 = Some(doprf_p3);
  207. self.masked_doprf_party_3 = Some(mdoprf_p3);
  208. }
  209. _ => panic!("invalid party id"),
  210. }
  211. // run Select initialiation and preprocessing
  212. {
  213. let mut select_party = SelectProtocol::default();
  214. select_party.init(comm)?;
  215. select_party.preprocess(comm, 3 * self.stash_size)?;
  216. self.select_party = Some(select_party);
  217. }
  218. let t_end = Instant::now();
  219. let runtimes = runtimes.map(|mut r| {
  220. r.record(ProtocolStep::Init, t_end - t_start);
  221. r
  222. });
  223. self.state = State::AwaitingRead;
  224. Ok(runtimes)
  225. }
  226. /// Perform a stash read and collect runtime data.
  227. pub fn read_with_runtimes<C: AbstractCommunicator>(
  228. &mut self,
  229. comm: &mut C,
  230. instruction: InstructionShare<F>,
  231. runtimes: Option<Runtimes>,
  232. ) -> Result<(StashStateShare<F>, Option<Runtimes>), Error> {
  233. assert_eq!(self.state, State::AwaitingRead);
  234. assert!(self.access_counter < self.stash_size);
  235. // 0. If the stash is empty, we are done
  236. if self.access_counter == 0 {
  237. self.state = State::AwaitingWrite;
  238. return Ok((
  239. StashStateShare {
  240. flag: F::ZERO,
  241. location: F::ZERO,
  242. value: F::ZERO,
  243. },
  244. runtimes,
  245. ));
  246. }
  247. let t_start = Instant::now();
  248. let (
  249. flag_share,
  250. location_share,
  251. t_after_masked_address_tag,
  252. t_after_dpf_keygen,
  253. t_after_compute_flag_loc,
  254. ) = match self.party_id {
  255. PARTY_1 => {
  256. // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
  257. let masked_address_tag: u64 = {
  258. let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
  259. mdoprf_p1.eval_to_uint(comm, 1, &[instruction.address])?[0]
  260. };
  261. let t_after_masked_address_tag = Instant::now();
  262. // 2. Create and send DPF keys for the function f(x) = if x = y { 1 } else { 0 }
  263. {
  264. let domain_size = 1 << compute_stash_prf_output_bitsize(self.stash_size);
  265. let (dpf_key_2, dpf_key_3) =
  266. SPDPF::generate_keys(domain_size, masked_address_tag, F::ONE);
  267. comm.send(PARTY_2, dpf_key_2)?;
  268. comm.send(PARTY_3, dpf_key_3)?;
  269. }
  270. let t_after_dpf_keygen = Instant::now();
  271. // 3. The other parties compute shares of <flag>, <loc>, i.e., if the address is present in
  272. // the stash and if so, where it is. We just take 0s as our shares.
  273. (
  274. F::ZERO,
  275. F::ZERO,
  276. t_after_masked_address_tag,
  277. t_after_dpf_keygen,
  278. t_after_dpf_keygen,
  279. )
  280. }
  281. PARTY_2 | PARTY_3 => {
  282. // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
  283. let address_tag_mask: u64 = match self.party_id {
  284. PARTY_2 => {
  285. let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
  286. mdoprf_p2.eval_to_uint(comm, 1, &[instruction.address])?[0]
  287. }
  288. PARTY_3 => {
  289. let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
  290. mdoprf_p3.eval_to_uint(comm, 1, &[instruction.address])?[0]
  291. }
  292. _ => panic!("invalid party id"),
  293. };
  294. let t_after_masked_address_tag = Instant::now();
  295. // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
  296. let dpf_key_i: SPDPF::Key = {
  297. let fut = comm.receive(PARTY_1)?;
  298. fut.get()?
  299. };
  300. let t_after_dpf_keygen = Instant::now();
  301. // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
  302. // so, where it is
  303. {
  304. let (flag_share, location_share) = self
  305. .address_tag_list
  306. .par_iter()
  307. .enumerate()
  308. .map(|(j, tag_j)| {
  309. let dpf_value_j =
  310. SPDPF::evaluate_at(&dpf_key_i, tag_j ^ address_tag_mask);
  311. (dpf_value_j, F::from_u128(j as u128) * dpf_value_j)
  312. })
  313. .reduce(|| (F::ZERO, F::ZERO), |(a, b), (c, d)| (a + c, b + d));
  314. let t_after_compute_flag_loc = Instant::now();
  315. (
  316. flag_share,
  317. location_share,
  318. t_after_masked_address_tag,
  319. t_after_dpf_keygen,
  320. t_after_compute_flag_loc,
  321. )
  322. }
  323. }
  324. _ => panic!("invalid party id"),
  325. };
  326. // 4. Compute <loc> = if <flag> { <loc> } else { access_counter - 1 }
  327. let location_share = {
  328. let access_counter_share = if self.party_id == PARTY_1 {
  329. F::from_u128(self.access_counter as u128)
  330. } else {
  331. F::ZERO
  332. };
  333. self.select_party.as_mut().unwrap().select(
  334. comm,
  335. flag_share,
  336. location_share,
  337. access_counter_share,
  338. )?
  339. };
  340. let t_after_location_share = Instant::now();
  341. // 5. Reshare <flag> among all three parties
  342. let flag_share = match self.party_id {
  343. PARTY_1 => {
  344. let flag_share = F::random(thread_rng());
  345. comm.send(PARTY_2, flag_share)?;
  346. flag_share
  347. }
  348. PARTY_2 => {
  349. let fut_1_2 = comm.receive::<F>(PARTY_1)?;
  350. flag_share - fut_1_2.get()?
  351. }
  352. _ => flag_share,
  353. };
  354. let t_after_flag_share = Instant::now();
  355. // 6. Read the value <val> from the stash (if <flag>) or read a zero value
  356. let (
  357. value_share,
  358. t_after_convert_to_replicated,
  359. t_after_masked_index,
  360. t_after_dpf_key_distr,
  361. ) = {
  362. // a) convert the stash into replicated secret sharing
  363. let fut_prev = comm.receive_previous::<Vec<F>>()?;
  364. comm.send_slice_next(self.stash_values_share.as_ref())?;
  365. let stash_values_share_prev = fut_prev.get()?;
  366. let t_after_convert_to_replicated = Instant::now();
  367. // b) mask and reconstruct the stash index <loc>
  368. let index_bits = (self.access_counter as f64).log2().ceil() as u32;
  369. assert!(index_bits <= 16);
  370. let bit_mask = ((1 << index_bits) - 1) as u16;
  371. let (masked_loc, r_prev, r_next) =
  372. MaskIndexProtocol::mask_index(comm, index_bits, location_share)?;
  373. let t_after_masked_index = Instant::now();
  374. // c) use DPFs to read the stash value
  375. let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
  376. let fut_next = comm.receive_next::<SPDPF::Key>()?;
  377. {
  378. let (dpf_key_prev, dpf_key_next) =
  379. SPDPF::generate_keys(1 << index_bits, masked_loc as u64, F::ONE);
  380. comm.send_previous(dpf_key_prev)?;
  381. comm.send_next(dpf_key_next)?;
  382. }
  383. let dpf_key_prev = fut_prev.get()?;
  384. let dpf_key_next = fut_next.get()?;
  385. let t_after_dpf_key_distr = Instant::now();
  386. let value_share: F = (0..self.access_counter)
  387. .into_par_iter()
  388. .map(|j| {
  389. let index_prev = ((j as u16 + r_prev) & bit_mask) as u64;
  390. let index_next = ((j as u16 + r_next) & bit_mask) as u64;
  391. SPDPF::evaluate_at(&dpf_key_prev, index_prev) * self.stash_values_share[j]
  392. + SPDPF::evaluate_at(&dpf_key_next, index_next) * stash_values_share_prev[j]
  393. })
  394. .sum();
  395. (
  396. value_share,
  397. t_after_convert_to_replicated,
  398. t_after_masked_index,
  399. t_after_dpf_key_distr,
  400. )
  401. };
  402. let t_after_dpf_eval = Instant::now();
  403. let runtimes = runtimes.map(|mut r| {
  404. r.record(
  405. ProtocolStep::ReadMaskedAddressTag,
  406. t_after_masked_address_tag - t_start,
  407. );
  408. r.record(
  409. ProtocolStep::ReadDpfKeyGen,
  410. t_after_dpf_keygen - t_after_masked_address_tag,
  411. );
  412. r.record(
  413. ProtocolStep::ReadLookupFlagLocation,
  414. t_after_compute_flag_loc - t_after_dpf_keygen,
  415. );
  416. r.record(
  417. ProtocolStep::ReadComputeLocation,
  418. t_after_location_share - t_after_compute_flag_loc,
  419. );
  420. r.record(
  421. ProtocolStep::ReadReshareFlag,
  422. t_after_flag_share - t_after_location_share,
  423. );
  424. r.record(
  425. ProtocolStep::ReadConvertToReplicated,
  426. t_after_convert_to_replicated - t_after_flag_share,
  427. );
  428. r.record(
  429. ProtocolStep::ReadComputeMaskedIndex,
  430. t_after_masked_index - t_after_convert_to_replicated,
  431. );
  432. r.record(
  433. ProtocolStep::ReadDpfKeyDistribution,
  434. t_after_dpf_key_distr - t_after_masked_index,
  435. );
  436. r.record(
  437. ProtocolStep::ReadDpfEvaluations,
  438. t_after_dpf_eval - t_after_dpf_key_distr,
  439. );
  440. r
  441. });
  442. self.state = State::AwaitingWrite;
  443. Ok((
  444. StashStateShare {
  445. flag: flag_share,
  446. location: location_share,
  447. value: value_share,
  448. },
  449. runtimes,
  450. ))
  451. }
  452. /// Perform a stash write and collect runtime data.
  453. pub fn write_with_runtimes<C: AbstractCommunicator>(
  454. &mut self,
  455. comm: &mut C,
  456. instruction: InstructionShare<F>,
  457. stash_state: StashStateShare<F>,
  458. db_address_share: F,
  459. db_value_share: F,
  460. runtimes: Option<Runtimes>,
  461. ) -> Result<Option<Runtimes>, Error> {
  462. assert_eq!(self.state, State::AwaitingWrite);
  463. assert!(self.access_counter < self.stash_size);
  464. let t_start = Instant::now();
  465. // 1. Compute tag y := PRF(k, <db_adr>) such that P2, P3 obtain y.
  466. match self.party_id {
  467. PARTY_1 => {
  468. let doprf_p1 = self.doprf_party_1.as_mut().unwrap();
  469. doprf_p1.eval(comm, 1, &[db_address_share])?;
  470. }
  471. PARTY_2 => {
  472. let address_tag: u64 = {
  473. let doprf_p2 = self.doprf_party_2.as_mut().unwrap();
  474. let fut_3_2 = comm.receive(PARTY_3)?;
  475. doprf_p2.eval(comm, 1, &[db_address_share])?;
  476. fut_3_2.get()?
  477. };
  478. self.address_tag_list.push(address_tag);
  479. }
  480. PARTY_3 => {
  481. let address_tag: u64 = {
  482. let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
  483. let tag = doprf_p3.eval_to_uint(comm, 1, &[db_address_share])?[0];
  484. comm.send(PARTY_2, tag)?;
  485. tag
  486. };
  487. self.address_tag_list.push(address_tag);
  488. }
  489. _ => panic!("invalid party id"),
  490. }
  491. let t_after_address_tag = Instant::now();
  492. // 2. Insert new triple (<db_adr>, <db_val>, <db_val> into stash.
  493. self.stash_addresses_share.push(db_address_share);
  494. self.stash_values_share.push(db_value_share);
  495. self.stash_old_values_share.push(db_value_share);
  496. let t_after_store_triple = Instant::now();
  497. // 3. Update stash
  498. let previous_value_share = self.select_party.as_mut().unwrap().select(
  499. comm,
  500. stash_state.flag,
  501. stash_state.value,
  502. db_value_share,
  503. )?;
  504. let t_after_select_previous_value = Instant::now();
  505. let value_share = self.select_party.as_mut().unwrap().select(
  506. comm,
  507. instruction.operation,
  508. instruction.value - previous_value_share,
  509. F::ZERO,
  510. )?;
  511. let t_after_select_value = Instant::now();
  512. let (t_after_masked_index, t_after_dpf_key_distr) = {
  513. // a) mask and reconstruct the stash index <loc>
  514. let index_bits = {
  515. let bits = usize::BITS - self.access_counter.leading_zeros();
  516. if bits > 0 {
  517. bits
  518. } else {
  519. 1
  520. }
  521. };
  522. assert!(index_bits <= 16);
  523. let bit_mask = ((1 << index_bits) - 1) as u16;
  524. let (masked_loc, r_prev, r_next) =
  525. MaskIndexProtocol::mask_index(comm, index_bits, stash_state.location)?;
  526. let t_after_masked_index = Instant::now();
  527. // b) use DPFs to read the stash value
  528. let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
  529. let fut_next = comm.receive_next::<SPDPF::Key>()?;
  530. {
  531. let (dpf_key_prev, dpf_key_next) =
  532. SPDPF::generate_keys(1 << index_bits, masked_loc as u64, value_share);
  533. comm.send_previous(dpf_key_prev)?;
  534. comm.send_next(dpf_key_next)?;
  535. }
  536. let dpf_key_prev = fut_prev.get()?;
  537. let dpf_key_next = fut_next.get()?;
  538. let t_after_dpf_key_distr = Instant::now();
  539. self.stash_values_share
  540. .par_iter_mut()
  541. .enumerate()
  542. .for_each(|(j, svs_j)| {
  543. let index_prev = ((j as u16).wrapping_add(r_prev) & bit_mask) as u64;
  544. let index_next = ((j as u16).wrapping_add(r_next) & bit_mask) as u64;
  545. *svs_j += SPDPF::evaluate_at(&dpf_key_prev, index_prev)
  546. + SPDPF::evaluate_at(&dpf_key_next, index_next);
  547. });
  548. (t_after_masked_index, t_after_dpf_key_distr)
  549. };
  550. let t_after_dpf_eval = Instant::now();
  551. self.access_counter += 1;
  552. self.state = if self.access_counter == self.stash_size {
  553. State::AccessesExhausted
  554. } else {
  555. State::AwaitingRead
  556. };
  557. let runtimes = runtimes.map(|mut r| {
  558. r.record(ProtocolStep::WriteAddressTag, t_after_address_tag - t_start);
  559. r.record(
  560. ProtocolStep::WriteStoreTriple,
  561. t_after_store_triple - t_after_address_tag,
  562. );
  563. r.record(
  564. ProtocolStep::WriteSelectPreviousValue,
  565. t_after_select_previous_value - t_after_store_triple,
  566. );
  567. r.record(
  568. ProtocolStep::WriteSelectValue,
  569. t_after_select_value - t_after_select_previous_value,
  570. );
  571. r.record(
  572. ProtocolStep::WriteComputeMaskedIndex,
  573. t_after_masked_index - t_after_select_value,
  574. );
  575. r.record(
  576. ProtocolStep::WriteDpfKeyDistribution,
  577. t_after_dpf_key_distr - t_after_masked_index,
  578. );
  579. r.record(
  580. ProtocolStep::WriteDpfEvaluations,
  581. t_after_dpf_eval - t_after_dpf_key_distr,
  582. );
  583. r
  584. });
  585. Ok(runtimes)
  586. }
  587. }
  588. impl<F, SPDPF> Stash<F> for StashProtocol<F, SPDPF>
  589. where
  590. F: PrimeField + LegendreSymbol + Serializable,
  591. SPDPF: SinglePointDpf<Value = F>,
  592. SPDPF::Key: Serializable + Sync,
  593. {
  594. fn get_party_id(&self) -> usize {
  595. self.party_id
  596. }
  597. fn get_stash_size(&self) -> usize {
  598. self.stash_size
  599. }
  600. fn get_access_counter(&self) -> usize {
  601. self.access_counter
  602. }
  603. fn reset(&mut self) {
  604. *self = Self::new(self.party_id, self.stash_size);
  605. }
  606. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  607. self.init_with_runtimes(comm, None).map(|_| ())
  608. }
  609. fn read<C: AbstractCommunicator>(
  610. &mut self,
  611. comm: &mut C,
  612. instruction: InstructionShare<F>,
  613. ) -> Result<StashStateShare<F>, Error> {
  614. self.read_with_runtimes(comm, instruction, None)
  615. .map(|x| x.0)
  616. }
  617. fn write<C: AbstractCommunicator>(
  618. &mut self,
  619. comm: &mut C,
  620. instruction: InstructionShare<F>,
  621. stash_state: StashStateShare<F>,
  622. db_address_share: F,
  623. db_value_share: F,
  624. ) -> Result<(), Error> {
  625. self.write_with_runtimes(
  626. comm,
  627. instruction,
  628. stash_state,
  629. db_address_share,
  630. db_value_share,
  631. None,
  632. )
  633. .map(|_| ())
  634. }
  635. fn get_stash_share(&self) -> (&[F], &[F], &[F]) {
  636. (
  637. &self.stash_addresses_share,
  638. &self.stash_values_share,
  639. &self.stash_old_values_share,
  640. )
  641. }
  642. }
  643. #[cfg(test)]
  644. mod tests {
  645. use super::*;
  646. use crate::common::Operation;
  647. use communicator::unix::make_unix_communicators;
  648. use dpf::spdpf::DummySpDpf;
  649. use ff::Field;
  650. use std::thread;
  651. use utils::field::Fp;
  652. fn run_init<F>(
  653. mut stash_party: impl Stash<F> + Send + 'static,
  654. mut comm: impl AbstractCommunicator + Send + 'static,
  655. ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator)>
  656. where
  657. F: PrimeField + LegendreSymbol,
  658. {
  659. thread::spawn(move || {
  660. stash_party.init(&mut comm).unwrap();
  661. (stash_party, comm)
  662. })
  663. }
  664. fn run_read<F>(
  665. mut stash_party: impl Stash<F> + Send + 'static,
  666. mut comm: impl AbstractCommunicator + Send + 'static,
  667. instruction: InstructionShare<F>,
  668. ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator, StashStateShare<F>)>
  669. where
  670. F: PrimeField + LegendreSymbol,
  671. {
  672. thread::spawn(move || {
  673. let result = stash_party.read(&mut comm, instruction);
  674. (stash_party, comm, result.unwrap())
  675. })
  676. }
  677. fn run_write<F>(
  678. mut stash_party: impl Stash<F> + Send + 'static,
  679. mut comm: impl AbstractCommunicator + Send + 'static,
  680. instruction: InstructionShare<F>,
  681. stash_state: StashStateShare<F>,
  682. db_address_share: F,
  683. db_value_share: F,
  684. ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator)>
  685. where
  686. F: PrimeField + LegendreSymbol,
  687. {
  688. thread::spawn(move || {
  689. stash_party
  690. .write(
  691. &mut comm,
  692. instruction,
  693. stash_state,
  694. db_address_share,
  695. db_value_share,
  696. )
  697. .unwrap();
  698. (stash_party, comm)
  699. })
  700. }
  701. #[test]
  702. fn test_stash() {
  703. type SPDPF = DummySpDpf<Fp>;
  704. let stash_size = 128;
  705. let mut num_accesses = 0;
  706. let party_1 = StashProtocol::<Fp, SPDPF>::new(PARTY_1, stash_size);
  707. let party_2 = StashProtocol::<Fp, SPDPF>::new(PARTY_2, stash_size);
  708. let party_3 = StashProtocol::<Fp, SPDPF>::new(PARTY_3, stash_size);
  709. assert_eq!(party_1.get_party_id(), PARTY_1);
  710. assert_eq!(party_2.get_party_id(), PARTY_2);
  711. assert_eq!(party_3.get_party_id(), PARTY_3);
  712. assert_eq!(party_1.get_stash_size(), stash_size);
  713. assert_eq!(party_2.get_stash_size(), stash_size);
  714. assert_eq!(party_3.get_stash_size(), stash_size);
  715. let (comm_3, comm_2, comm_1) = {
  716. let mut comms = make_unix_communicators(3);
  717. (
  718. comms.pop().unwrap(),
  719. comms.pop().unwrap(),
  720. comms.pop().unwrap(),
  721. )
  722. };
  723. let h1 = run_init(party_1, comm_1);
  724. let h2 = run_init(party_2, comm_2);
  725. let h3 = run_init(party_3, comm_3);
  726. let (party_1, comm_1) = h1.join().unwrap();
  727. let (party_2, comm_2) = h2.join().unwrap();
  728. let (party_3, comm_3) = h3.join().unwrap();
  729. assert_eq!(party_1.get_access_counter(), 0);
  730. assert_eq!(party_2.get_access_counter(), 0);
  731. assert_eq!(party_3.get_access_counter(), 0);
  732. // write a value 42 to address adr = 3
  733. let value = 42;
  734. let address = 3;
  735. let inst_w3_1 = InstructionShare {
  736. operation: Operation::Write.encode(),
  737. address: Fp::from_u128(address),
  738. value: Fp::from_u128(value),
  739. };
  740. let inst_w3_2 = InstructionShare {
  741. operation: Fp::ZERO,
  742. address: Fp::ZERO,
  743. value: Fp::ZERO,
  744. };
  745. let inst_w3_3 = inst_w3_2.clone();
  746. let h1 = run_read(party_1, comm_1, inst_w3_1);
  747. let h2 = run_read(party_2, comm_2, inst_w3_2);
  748. let h3 = run_read(party_3, comm_3, inst_w3_3);
  749. let (party_1, comm_1, state_1) = h1.join().unwrap();
  750. let (party_2, comm_2, state_2) = h2.join().unwrap();
  751. let (party_3, comm_3, state_3) = h3.join().unwrap();
  752. // since the stash is empty, st.flag must be zero
  753. assert_eq!(state_1.flag + state_2.flag + state_3.flag, Fp::ZERO);
  754. assert_eq!(
  755. state_1.location + state_2.location + state_3.location,
  756. Fp::ZERO
  757. );
  758. let h1 = run_write(
  759. party_1,
  760. comm_1,
  761. inst_w3_1,
  762. state_1,
  763. inst_w3_1.address,
  764. Fp::from_u128(0x71),
  765. );
  766. let h2 = run_write(
  767. party_2,
  768. comm_2,
  769. inst_w3_2,
  770. state_1,
  771. inst_w3_2.address,
  772. Fp::from_u128(0x72),
  773. );
  774. let h3 = run_write(
  775. party_3,
  776. comm_3,
  777. inst_w3_3,
  778. state_1,
  779. inst_w3_3.address,
  780. Fp::from_u128(0x73),
  781. );
  782. let (party_1, comm_1) = h1.join().unwrap();
  783. let (party_2, comm_2) = h2.join().unwrap();
  784. let (party_3, comm_3) = h3.join().unwrap();
  785. num_accesses += 1;
  786. assert_eq!(party_1.get_access_counter(), num_accesses);
  787. assert_eq!(party_2.get_access_counter(), num_accesses);
  788. assert_eq!(party_3.get_access_counter(), num_accesses);
  789. {
  790. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  791. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  792. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  793. assert_eq!(st_adrs_1.len(), num_accesses);
  794. assert_eq!(st_vals_1.len(), num_accesses);
  795. assert_eq!(st_old_vals_1.len(), num_accesses);
  796. assert_eq!(st_adrs_2.len(), num_accesses);
  797. assert_eq!(st_vals_2.len(), num_accesses);
  798. assert_eq!(st_old_vals_2.len(), num_accesses);
  799. assert_eq!(st_adrs_3.len(), num_accesses);
  800. assert_eq!(st_vals_3.len(), num_accesses);
  801. assert_eq!(st_old_vals_3.len(), num_accesses);
  802. assert_eq!(
  803. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  804. Fp::from_u128(address)
  805. );
  806. assert_eq!(
  807. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  808. Fp::from_u128(value)
  809. );
  810. }
  811. // read again from address adr = 3, we should get the value 42 back
  812. let inst_r3_1 = InstructionShare {
  813. operation: Operation::Read.encode(),
  814. address: Fp::from_u128(3),
  815. value: Fp::ZERO,
  816. };
  817. let inst_r3_2 = InstructionShare {
  818. operation: Fp::ZERO,
  819. address: Fp::ZERO,
  820. value: Fp::ZERO,
  821. };
  822. let inst_r3_3 = inst_r3_2.clone();
  823. let h1 = run_read(party_1, comm_1, inst_r3_1);
  824. let h2 = run_read(party_2, comm_2, inst_r3_2);
  825. let h3 = run_read(party_3, comm_3, inst_r3_3);
  826. let (party_1, comm_1, state_1) = h1.join().unwrap();
  827. let (party_2, comm_2, state_2) = h2.join().unwrap();
  828. let (party_3, comm_3, state_3) = h3.join().unwrap();
  829. let st_flag = state_1.flag + state_2.flag + state_3.flag;
  830. let st_location = state_1.location + state_2.location + state_3.location;
  831. let st_value = state_1.value + state_2.value + state_3.value;
  832. assert_eq!(st_flag, Fp::ONE);
  833. assert_eq!(st_location, Fp::from_u128(0));
  834. assert_eq!(st_value, Fp::from_u128(value));
  835. let h1 = run_write(
  836. party_1,
  837. comm_1,
  838. inst_r3_1,
  839. state_1,
  840. Fp::from_u128(0x83),
  841. Fp::from_u128(0x93),
  842. );
  843. let h2 = run_write(
  844. party_2,
  845. comm_2,
  846. inst_r3_2,
  847. state_1,
  848. Fp::from_u128(0x83),
  849. Fp::from_u128(0x93),
  850. );
  851. let h3 = run_write(
  852. party_3,
  853. comm_3,
  854. inst_r3_3,
  855. state_1,
  856. Fp::from_u128(0x83),
  857. Fp::from_u128(0x93),
  858. );
  859. let (party_1, comm_1) = h1.join().unwrap();
  860. let (party_2, comm_2) = h2.join().unwrap();
  861. let (party_3, comm_3) = h3.join().unwrap();
  862. num_accesses += 1;
  863. assert_eq!(party_1.get_access_counter(), num_accesses);
  864. assert_eq!(party_2.get_access_counter(), num_accesses);
  865. assert_eq!(party_3.get_access_counter(), num_accesses);
  866. {
  867. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  868. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  869. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  870. assert_eq!(st_adrs_1.len(), num_accesses);
  871. assert_eq!(st_vals_1.len(), num_accesses);
  872. assert_eq!(st_old_vals_1.len(), num_accesses);
  873. assert_eq!(st_adrs_2.len(), num_accesses);
  874. assert_eq!(st_vals_2.len(), num_accesses);
  875. assert_eq!(st_old_vals_2.len(), num_accesses);
  876. assert_eq!(st_adrs_3.len(), num_accesses);
  877. assert_eq!(st_vals_3.len(), num_accesses);
  878. assert_eq!(st_old_vals_3.len(), num_accesses);
  879. assert_eq!(
  880. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  881. Fp::from_u128(address)
  882. );
  883. assert_eq!(
  884. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  885. Fp::from_u128(value)
  886. );
  887. }
  888. // now write a value 0x1337 to address adr = 3
  889. let old_value = value;
  890. let value = 0x1337;
  891. let address = 3;
  892. let inst_w3_1 = InstructionShare {
  893. operation: Operation::Write.encode(),
  894. address: Fp::from_u128(address),
  895. value: Fp::from_u128(value),
  896. };
  897. let inst_w3_2 = InstructionShare {
  898. operation: Fp::ZERO,
  899. address: Fp::ZERO,
  900. value: Fp::ZERO,
  901. };
  902. let inst_w3_3 = inst_w3_2.clone();
  903. let h1 = run_read(party_1, comm_1, inst_w3_1);
  904. let h2 = run_read(party_2, comm_2, inst_w3_2);
  905. let h3 = run_read(party_3, comm_3, inst_w3_3);
  906. let (party_1, comm_1, state_1) = h1.join().unwrap();
  907. let (party_2, comm_2, state_2) = h2.join().unwrap();
  908. let (party_3, comm_3, state_3) = h3.join().unwrap();
  909. // since we already wrote to the address, it should be present in the stash
  910. assert_eq!(state_1.flag + state_2.flag + state_3.flag, Fp::ONE);
  911. assert_eq!(
  912. state_1.location + state_2.location + state_3.location,
  913. Fp::ZERO
  914. );
  915. assert_eq!(
  916. state_1.value + state_2.value + state_3.value,
  917. Fp::from_u128(old_value)
  918. );
  919. let h1 = run_write(
  920. party_1,
  921. comm_1,
  922. inst_w3_1,
  923. state_1,
  924. // inst_w3_1.address,
  925. Fp::from_u128(0x61),
  926. Fp::from_u128(0x71),
  927. );
  928. let h2 = run_write(
  929. party_2,
  930. comm_2,
  931. inst_w3_2,
  932. state_2,
  933. // inst_w3_2.address,
  934. Fp::from_u128(0x62),
  935. Fp::from_u128(0x72),
  936. );
  937. let h3 = run_write(
  938. party_3,
  939. comm_3,
  940. inst_w3_3,
  941. state_3,
  942. // inst_w3_3.address,
  943. Fp::from_u128(0x63),
  944. Fp::from_u128(0x73),
  945. );
  946. let (party_1, comm_1) = h1.join().unwrap();
  947. let (party_2, comm_2) = h2.join().unwrap();
  948. let (party_3, comm_3) = h3.join().unwrap();
  949. num_accesses += 1;
  950. assert_eq!(party_1.get_access_counter(), num_accesses);
  951. assert_eq!(party_2.get_access_counter(), num_accesses);
  952. assert_eq!(party_3.get_access_counter(), num_accesses);
  953. {
  954. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  955. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  956. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  957. assert_eq!(st_adrs_1.len(), num_accesses);
  958. assert_eq!(st_vals_1.len(), num_accesses);
  959. assert_eq!(st_old_vals_1.len(), num_accesses);
  960. assert_eq!(st_adrs_2.len(), num_accesses);
  961. assert_eq!(st_vals_2.len(), num_accesses);
  962. assert_eq!(st_old_vals_2.len(), num_accesses);
  963. assert_eq!(st_adrs_3.len(), num_accesses);
  964. assert_eq!(st_vals_3.len(), num_accesses);
  965. assert_eq!(st_old_vals_3.len(), num_accesses);
  966. assert_eq!(
  967. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  968. Fp::from_u128(address)
  969. );
  970. assert_eq!(
  971. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  972. Fp::from_u128(value)
  973. );
  974. }
  975. // read again from address adr = 3, we should get the value 0x1337 back
  976. let inst_r3_1 = InstructionShare {
  977. operation: Operation::Read.encode(),
  978. address: Fp::from_u128(3),
  979. value: Fp::ZERO,
  980. };
  981. let inst_r3_2 = InstructionShare {
  982. operation: Fp::ZERO,
  983. address: Fp::ZERO,
  984. value: Fp::ZERO,
  985. };
  986. let inst_r3_3 = inst_r3_2.clone();
  987. let h1 = run_read(party_1, comm_1, inst_r3_1);
  988. let h2 = run_read(party_2, comm_2, inst_r3_2);
  989. let h3 = run_read(party_3, comm_3, inst_r3_3);
  990. let (party_1, comm_1, state_1) = h1.join().unwrap();
  991. let (party_2, comm_2, state_2) = h2.join().unwrap();
  992. let (party_3, comm_3, state_3) = h3.join().unwrap();
  993. let st_flag = state_1.flag + state_2.flag + state_3.flag;
  994. let st_location = state_1.location + state_2.location + state_3.location;
  995. let st_value = state_1.value + state_2.value + state_3.value;
  996. assert_eq!(st_flag, Fp::ONE);
  997. assert_eq!(st_location, Fp::from_u128(0));
  998. assert_eq!(st_value, Fp::from_u128(value));
  999. let h1 = run_write(
  1000. party_1,
  1001. comm_1,
  1002. inst_r3_1,
  1003. state_1,
  1004. Fp::from_u128(0x83),
  1005. Fp::from_u128(0x93),
  1006. );
  1007. let h2 = run_write(
  1008. party_2,
  1009. comm_2,
  1010. inst_r3_2,
  1011. state_2,
  1012. Fp::from_u128(0x83),
  1013. Fp::from_u128(0x93),
  1014. );
  1015. let h3 = run_write(
  1016. party_3,
  1017. comm_3,
  1018. inst_r3_3,
  1019. state_3,
  1020. Fp::from_u128(0x83),
  1021. Fp::from_u128(0x93),
  1022. );
  1023. let (party_1, _comm_1) = h1.join().unwrap();
  1024. let (party_2, _comm_2) = h2.join().unwrap();
  1025. let (party_3, _comm_3) = h3.join().unwrap();
  1026. num_accesses += 1;
  1027. assert_eq!(party_1.get_access_counter(), num_accesses);
  1028. assert_eq!(party_2.get_access_counter(), num_accesses);
  1029. assert_eq!(party_3.get_access_counter(), num_accesses);
  1030. {
  1031. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  1032. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  1033. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  1034. assert_eq!(st_adrs_1.len(), num_accesses);
  1035. assert_eq!(st_vals_1.len(), num_accesses);
  1036. assert_eq!(st_old_vals_1.len(), num_accesses);
  1037. assert_eq!(st_adrs_2.len(), num_accesses);
  1038. assert_eq!(st_vals_2.len(), num_accesses);
  1039. assert_eq!(st_old_vals_2.len(), num_accesses);
  1040. assert_eq!(st_adrs_3.len(), num_accesses);
  1041. assert_eq!(st_vals_3.len(), num_accesses);
  1042. assert_eq!(st_old_vals_3.len(), num_accesses);
  1043. assert_eq!(
  1044. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  1045. Fp::from_u128(address)
  1046. );
  1047. assert_eq!(
  1048. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  1049. Fp::from_u128(value)
  1050. );
  1051. }
  1052. }
  1053. }