stash.rs 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. use crate::common::{Error, InstructionShare};
  2. use crate::doprf::{
  3. DOPrfParty1, DOPrfParty2, DOPrfParty3, LegendrePrf, MaskedDOPrfParty1, MaskedDOPrfParty2,
  4. MaskedDOPrfParty3,
  5. };
  6. use crate::mask_index::{MaskIndex, MaskIndexProtocol};
  7. use crate::select::{Select, SelectProtocol};
  8. use communicator::{AbstractCommunicator, Fut, Serializable};
  9. use dpf::spdpf::SinglePointDpf;
  10. use ff::PrimeField;
  11. use rand::thread_rng;
  12. use std::marker::PhantomData;
  13. use std::time::{Duration, Instant};
  14. use utils::field::LegendreSymbol;
  15. #[derive(Clone, Copy, Debug, Default)]
  16. pub struct StashEntryShare<F: PrimeField> {
  17. pub address: F,
  18. pub value: F,
  19. pub old_value: F,
  20. }
  21. #[derive(Clone, Copy, Debug, Default)]
  22. pub struct StashStateShare<F: PrimeField> {
  23. pub flag: F,
  24. pub location: F,
  25. pub value: F,
  26. }
  27. #[derive(Clone, Copy, Debug, PartialEq, Eq)]
  28. enum State {
  29. New,
  30. AwaitingRead,
  31. AwaitingWrite,
  32. AccessesExhausted,
  33. }
  34. const PARTY_1: usize = 0;
  35. const PARTY_2: usize = 1;
  36. const PARTY_3: usize = 2;
  37. pub trait Stash<F: PrimeField> {
  38. fn get_party_id(&self) -> usize;
  39. fn get_stash_size(&self) -> usize;
  40. fn get_access_counter(&self) -> usize;
  41. fn reset(&mut self);
  42. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error>;
  43. fn read<C: AbstractCommunicator>(
  44. &mut self,
  45. comm: &mut C,
  46. instruction: InstructionShare<F>,
  47. ) -> Result<StashStateShare<F>, Error>;
  48. fn write<C: AbstractCommunicator>(
  49. &mut self,
  50. comm: &mut C,
  51. instruction: InstructionShare<F>,
  52. stash_state: StashStateShare<F>,
  53. db_address_share: F,
  54. db_value_share: F,
  55. ) -> Result<(), Error>;
  56. fn get_stash_share(&self) -> (&[F], &[F], &[F]);
  57. }
  58. fn compute_stash_prf_output_bitsize(stash_size: usize) -> usize {
  59. (usize::BITS - stash_size.leading_zeros()) as usize + 40
  60. }
  61. #[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter, strum_macros::Display)]
  62. pub enum ProtocolStep {
  63. Init = 0,
  64. ReadMaskedAddressTag,
  65. ReadDpfKeyGen,
  66. ReadLookupFlagLocation,
  67. ReadComputeLocation,
  68. ReadReshareFlag,
  69. ReadConvertToReplicated,
  70. ReadComputeMaskedIndex,
  71. ReadDpfKeyDistribution,
  72. ReadDpfEvaluations,
  73. WriteAddressTag,
  74. WriteStoreTriple,
  75. WriteSelectPreviousValue,
  76. WriteSelectValue,
  77. WriteComputeMaskedIndex,
  78. WriteDpfKeyDistribution,
  79. WriteDpfEvaluations,
  80. }
  81. #[derive(Debug, Default, Clone, Copy)]
  82. pub struct Runtimes {
  83. durations: [Duration; 17],
  84. }
  85. impl Runtimes {
  86. #[inline(always)]
  87. pub fn record(&mut self, id: ProtocolStep, duration: Duration) {
  88. self.durations[id as usize] += duration;
  89. }
  90. pub fn get(&self, id: ProtocolStep) -> Duration {
  91. self.durations[id as usize]
  92. }
  93. }
  94. pub struct StashProtocol<F, SPDPF>
  95. where
  96. F: PrimeField + LegendreSymbol + Serializable,
  97. SPDPF: SinglePointDpf<Value = F>,
  98. {
  99. party_id: usize,
  100. stash_size: usize,
  101. access_counter: usize,
  102. state: State,
  103. stash_addresses_share: Vec<F>,
  104. stash_values_share: Vec<F>,
  105. stash_old_values_share: Vec<F>,
  106. address_tag_list: Vec<u64>,
  107. doprf_party_1: Option<DOPrfParty1<F>>,
  108. doprf_party_2: Option<DOPrfParty2<F>>,
  109. doprf_party_3: Option<DOPrfParty3<F>>,
  110. masked_doprf_party_1: Option<MaskedDOPrfParty1<F>>,
  111. masked_doprf_party_2: Option<MaskedDOPrfParty2<F>>,
  112. masked_doprf_party_3: Option<MaskedDOPrfParty3<F>>,
  113. _phantom: PhantomData<SPDPF>,
  114. }
  115. impl<F, SPDPF> StashProtocol<F, SPDPF>
  116. where
  117. F: PrimeField + LegendreSymbol + Serializable,
  118. SPDPF: SinglePointDpf<Value = F>,
  119. SPDPF::Key: Serializable,
  120. {
  121. pub fn new(party_id: usize, stash_size: usize) -> Self {
  122. assert!(party_id < 3);
  123. assert!(stash_size > 0);
  124. assert!(compute_stash_prf_output_bitsize(stash_size) <= 64);
  125. Self {
  126. party_id,
  127. stash_size,
  128. access_counter: 0,
  129. state: State::New,
  130. stash_addresses_share: Vec::with_capacity(stash_size),
  131. stash_values_share: Vec::with_capacity(stash_size),
  132. stash_old_values_share: Vec::with_capacity(stash_size),
  133. address_tag_list: if party_id == PARTY_1 {
  134. Default::default()
  135. } else {
  136. Vec::with_capacity(stash_size)
  137. },
  138. doprf_party_1: None,
  139. doprf_party_2: None,
  140. doprf_party_3: None,
  141. masked_doprf_party_1: None,
  142. masked_doprf_party_2: None,
  143. masked_doprf_party_3: None,
  144. _phantom: PhantomData,
  145. }
  146. }
  147. fn init_with_runtimes<C: AbstractCommunicator>(
  148. &mut self,
  149. comm: &mut C,
  150. runtimes: Option<Runtimes>,
  151. ) -> Result<Option<Runtimes>, Error> {
  152. assert_eq!(self.state, State::New);
  153. let t_start = Instant::now();
  154. let prf_output_bitsize = compute_stash_prf_output_bitsize(self.stash_size);
  155. let legendre_prf_key = LegendrePrf::<F>::key_gen(prf_output_bitsize);
  156. // run DOPRF initilization
  157. match self.party_id {
  158. PARTY_1 => {
  159. let mut doprf_p1 = DOPrfParty1::from_legendre_prf_key(legendre_prf_key.clone());
  160. let mut mdoprf_p1 = MaskedDOPrfParty1::from_legendre_prf_key(legendre_prf_key);
  161. doprf_p1.init(comm)?;
  162. mdoprf_p1.init(comm)?;
  163. doprf_p1.preprocess(comm, self.stash_size)?;
  164. mdoprf_p1.preprocess(comm, self.stash_size)?;
  165. self.doprf_party_1 = Some(doprf_p1);
  166. self.masked_doprf_party_1 = Some(mdoprf_p1);
  167. }
  168. PARTY_2 => {
  169. let mut doprf_p2 = DOPrfParty2::new(prf_output_bitsize);
  170. let mut mdoprf_p2 = MaskedDOPrfParty2::new(prf_output_bitsize);
  171. doprf_p2.init(comm)?;
  172. mdoprf_p2.init(comm)?;
  173. doprf_p2.preprocess(comm, self.stash_size)?;
  174. mdoprf_p2.preprocess(comm, self.stash_size)?;
  175. self.doprf_party_2 = Some(doprf_p2);
  176. self.masked_doprf_party_2 = Some(mdoprf_p2);
  177. }
  178. PARTY_3 => {
  179. let mut doprf_p3 = DOPrfParty3::new(prf_output_bitsize);
  180. let mut mdoprf_p3 = MaskedDOPrfParty3::new(prf_output_bitsize);
  181. doprf_p3.init(comm)?;
  182. mdoprf_p3.init(comm)?;
  183. doprf_p3.preprocess(comm, self.stash_size)?;
  184. mdoprf_p3.preprocess(comm, self.stash_size)?;
  185. self.doprf_party_3 = Some(doprf_p3);
  186. self.masked_doprf_party_3 = Some(mdoprf_p3);
  187. }
  188. _ => panic!("invalid party id"),
  189. }
  190. let t_end = Instant::now();
  191. let runtimes = runtimes.map(|mut r| {
  192. r.record(ProtocolStep::Init, t_end - t_start);
  193. r
  194. });
  195. // panic!("not implemented");
  196. self.state = State::AwaitingRead;
  197. Ok(runtimes)
  198. }
  199. pub fn read_with_runtimes<C: AbstractCommunicator>(
  200. &mut self,
  201. comm: &mut C,
  202. instruction: InstructionShare<F>,
  203. runtimes: Option<Runtimes>,
  204. ) -> Result<(StashStateShare<F>, Option<Runtimes>), Error> {
  205. assert_eq!(self.state, State::AwaitingRead);
  206. assert!(self.access_counter < self.stash_size);
  207. // 0. If the stash is empty, we are done
  208. if self.access_counter == 0 {
  209. self.state = State::AwaitingWrite;
  210. return Ok((
  211. StashStateShare {
  212. flag: F::ZERO,
  213. location: F::ZERO,
  214. value: F::ZERO,
  215. },
  216. runtimes,
  217. ));
  218. }
  219. let t_start = Instant::now();
  220. let (
  221. flag_share,
  222. location_share,
  223. t_after_masked_address_tag,
  224. t_after_dpf_keygen,
  225. t_after_compute_flag_loc,
  226. ) = match self.party_id {
  227. PARTY_1 => {
  228. // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
  229. let masked_address_tag: u64 = {
  230. let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
  231. mdoprf_p1.eval_to_uint(comm, 1, &[instruction.address])?[0]
  232. };
  233. let t_after_masked_address_tag = Instant::now();
  234. // 2. Create and send DPF keys for the function f(x) = if x = y { 1 } else { 0 }
  235. {
  236. let domain_size = 1 << compute_stash_prf_output_bitsize(self.stash_size);
  237. let (dpf_key_2, dpf_key_3) =
  238. SPDPF::generate_keys(domain_size, masked_address_tag, F::ONE);
  239. comm.send(PARTY_2, dpf_key_2)?;
  240. comm.send(PARTY_3, dpf_key_3)?;
  241. }
  242. let t_after_dpf_keygen = Instant::now();
  243. // 3. The other parties compute shares of <flag>, <loc>, i.e., if the address is present in
  244. // the stash and if so, where it is. We just take 0s as our shares.
  245. (
  246. F::ZERO,
  247. F::ZERO,
  248. t_after_masked_address_tag,
  249. t_after_dpf_keygen,
  250. t_after_dpf_keygen,
  251. )
  252. }
  253. PARTY_2 | PARTY_3 => {
  254. // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
  255. let address_tag_mask: u64 = match self.party_id {
  256. PARTY_2 => {
  257. let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
  258. mdoprf_p2.eval_to_uint(comm, 1, &[instruction.address])?[0]
  259. }
  260. PARTY_3 => {
  261. let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
  262. mdoprf_p3.eval_to_uint(comm, 1, &[instruction.address])?[0]
  263. }
  264. _ => panic!("invalid party id"),
  265. };
  266. let t_after_masked_address_tag = Instant::now();
  267. // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
  268. let dpf_key_i: SPDPF::Key = {
  269. let fut = comm.receive(PARTY_1)?;
  270. fut.get()?
  271. };
  272. let t_after_dpf_keygen = Instant::now();
  273. // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
  274. // so, where it is
  275. {
  276. let mut flag_share = F::ZERO;
  277. let mut location_share = F::ZERO;
  278. let mut j_as_field_element = F::ZERO;
  279. for j in 0..self.address_tag_list.len() {
  280. let dpf_value_j = SPDPF::evaluate_at(
  281. &dpf_key_i,
  282. self.address_tag_list[j] ^ address_tag_mask,
  283. );
  284. flag_share += dpf_value_j;
  285. location_share += j_as_field_element * dpf_value_j;
  286. j_as_field_element += F::ONE;
  287. }
  288. let t_after_compute_flag_loc = Instant::now();
  289. (
  290. flag_share,
  291. location_share,
  292. t_after_masked_address_tag,
  293. t_after_dpf_keygen,
  294. t_after_compute_flag_loc,
  295. )
  296. }
  297. }
  298. _ => panic!("invalid party id"),
  299. };
  300. // 4. Compute <loc> = if <flag> { <loc> } else { access_counter - 1 }
  301. let location_share = {
  302. let access_counter_share = if self.party_id == PARTY_1 {
  303. F::from_u128(self.access_counter as u128)
  304. } else {
  305. F::ZERO
  306. };
  307. SelectProtocol::select(comm, flag_share, location_share, access_counter_share)?
  308. };
  309. let t_after_location_share = Instant::now();
  310. // 5. Reshare <flag> among all three parties
  311. let flag_share = match self.party_id {
  312. PARTY_1 => {
  313. let flag_share = F::random(thread_rng());
  314. comm.send(PARTY_2, flag_share)?;
  315. flag_share
  316. }
  317. PARTY_2 => {
  318. let fut_1_2 = comm.receive::<F>(PARTY_1)?;
  319. flag_share - fut_1_2.get()?
  320. }
  321. _ => flag_share,
  322. };
  323. let t_after_flag_share = Instant::now();
  324. // 6. Read the value <val> from the stash (if <flag>) or read a zero value
  325. let (
  326. value_share,
  327. t_after_convert_to_replicated,
  328. t_after_masked_index,
  329. t_after_dpf_key_distr,
  330. ) = {
  331. // a) convert the stash into replicated secret sharing
  332. let fut_prev = comm.receive_previous::<Vec<F>>()?;
  333. comm.send_next(self.stash_values_share.to_vec())?;
  334. let stash_values_share_prev = fut_prev.get()?;
  335. let t_after_convert_to_replicated = Instant::now();
  336. // b) mask and reconstruct the stash index <loc>
  337. let index_bits = (self.access_counter as f64).log2().ceil() as u32;
  338. assert!(index_bits <= 16);
  339. let bit_mask = ((1 << index_bits) - 1) as u16;
  340. let (masked_loc, r_prev, r_next) =
  341. MaskIndexProtocol::mask_index(comm, index_bits, location_share)?;
  342. let t_after_masked_index = Instant::now();
  343. // c) use DPFs to read the stash value
  344. let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
  345. let fut_next = comm.receive_next::<SPDPF::Key>()?;
  346. {
  347. let (dpf_key_prev, dpf_key_next) =
  348. SPDPF::generate_keys(1 << index_bits, masked_loc as u64, F::ONE);
  349. comm.send_previous(dpf_key_prev)?;
  350. comm.send_next(dpf_key_next)?;
  351. }
  352. let dpf_key_prev = fut_prev.get()?;
  353. let dpf_key_next = fut_next.get()?;
  354. let t_after_dpf_key_distr = Instant::now();
  355. let mut value_share = F::ZERO;
  356. for j in 0..self.access_counter {
  357. let index_prev = ((j as u16 + r_prev) & bit_mask) as u64;
  358. let index_next = ((j as u16 + r_next) & bit_mask) as u64;
  359. value_share +=
  360. SPDPF::evaluate_at(&dpf_key_prev, index_prev) * self.stash_values_share[j];
  361. value_share +=
  362. SPDPF::evaluate_at(&dpf_key_next, index_next) * stash_values_share_prev[j];
  363. }
  364. (
  365. value_share,
  366. t_after_convert_to_replicated,
  367. t_after_masked_index,
  368. t_after_dpf_key_distr,
  369. )
  370. };
  371. let t_after_dpf_eval = Instant::now();
  372. let runtimes = runtimes.map(|mut r| {
  373. r.record(
  374. ProtocolStep::ReadMaskedAddressTag,
  375. t_after_masked_address_tag - t_start,
  376. );
  377. r.record(
  378. ProtocolStep::ReadDpfKeyGen,
  379. t_after_dpf_keygen - t_after_masked_address_tag,
  380. );
  381. r.record(
  382. ProtocolStep::ReadLookupFlagLocation,
  383. t_after_compute_flag_loc - t_after_dpf_keygen,
  384. );
  385. r.record(
  386. ProtocolStep::ReadComputeLocation,
  387. t_after_location_share - t_after_compute_flag_loc,
  388. );
  389. r.record(
  390. ProtocolStep::ReadReshareFlag,
  391. t_after_flag_share - t_after_location_share,
  392. );
  393. r.record(
  394. ProtocolStep::ReadConvertToReplicated,
  395. t_after_convert_to_replicated - t_after_flag_share,
  396. );
  397. r.record(
  398. ProtocolStep::ReadComputeMaskedIndex,
  399. t_after_masked_index - t_after_convert_to_replicated,
  400. );
  401. r.record(
  402. ProtocolStep::ReadDpfKeyDistribution,
  403. t_after_dpf_key_distr - t_after_masked_index,
  404. );
  405. r.record(
  406. ProtocolStep::ReadDpfEvaluations,
  407. t_after_dpf_eval - t_after_dpf_key_distr,
  408. );
  409. r
  410. });
  411. self.state = State::AwaitingWrite;
  412. Ok((
  413. StashStateShare {
  414. flag: flag_share,
  415. location: location_share,
  416. value: value_share,
  417. },
  418. runtimes,
  419. ))
  420. }
  421. pub fn write_with_runtimes<C: AbstractCommunicator>(
  422. &mut self,
  423. comm: &mut C,
  424. instruction: InstructionShare<F>,
  425. stash_state: StashStateShare<F>,
  426. db_address_share: F,
  427. db_value_share: F,
  428. runtimes: Option<Runtimes>,
  429. ) -> Result<Option<Runtimes>, Error> {
  430. assert_eq!(self.state, State::AwaitingWrite);
  431. assert!(self.access_counter < self.stash_size);
  432. let t_start = Instant::now();
  433. // 1. Compute tag y := PRF(k, <db_adr>) such that P2, P3 obtain y.
  434. match self.party_id {
  435. PARTY_1 => {
  436. let doprf_p1 = self.doprf_party_1.as_mut().unwrap();
  437. // for now do preprocessing on the fly
  438. doprf_p1.preprocess(comm, 1)?;
  439. doprf_p1.eval(comm, 1, &[db_address_share])?;
  440. }
  441. PARTY_2 => {
  442. let address_tag: u64 = {
  443. let doprf_p2 = self.doprf_party_2.as_mut().unwrap();
  444. // for now do preprocessing on the fly
  445. doprf_p2.preprocess(comm, 1)?;
  446. let fut_3_2 = comm.receive(PARTY_3)?;
  447. doprf_p2.eval(comm, 1, &[db_address_share])?;
  448. fut_3_2.get()?
  449. };
  450. self.address_tag_list.push(address_tag);
  451. }
  452. PARTY_3 => {
  453. let address_tag: u64 = {
  454. let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
  455. // for now do preprocessing on the fly
  456. doprf_p3.preprocess(comm, 1)?;
  457. let tag = doprf_p3.eval_to_uint(comm, 1, &[db_address_share])?[0];
  458. comm.send(PARTY_2, tag)?;
  459. tag
  460. };
  461. self.address_tag_list.push(address_tag);
  462. }
  463. _ => panic!("invalid party id"),
  464. }
  465. let t_after_address_tag = Instant::now();
  466. // 2. Insert new triple (<db_adr>, <db_val>, <db_val> into stash.
  467. self.stash_addresses_share.push(db_address_share);
  468. self.stash_values_share.push(db_value_share);
  469. self.stash_old_values_share.push(db_value_share);
  470. let t_after_store_triple = Instant::now();
  471. // 3. Update stash
  472. let previous_value_share =
  473. SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
  474. let t_after_select_previous_value = Instant::now();
  475. let value_share = SelectProtocol::select(
  476. comm,
  477. instruction.operation,
  478. instruction.value - previous_value_share,
  479. F::ZERO,
  480. )?;
  481. let t_after_select_value = Instant::now();
  482. let (t_after_masked_index, t_after_dpf_key_distr) = {
  483. // a) mask and reconstruct the stash index <loc>
  484. let index_bits = {
  485. let bits = usize::BITS - self.access_counter.leading_zeros();
  486. if bits > 0 {
  487. bits
  488. } else {
  489. 1
  490. }
  491. };
  492. assert!(index_bits <= 16);
  493. let bit_mask = ((1 << index_bits) - 1) as u16;
  494. let (masked_loc, r_prev, r_next) =
  495. MaskIndexProtocol::mask_index(comm, index_bits, stash_state.location)?;
  496. let t_after_masked_index = Instant::now();
  497. // b) use DPFs to read the stash value
  498. let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
  499. let fut_next = comm.receive_next::<SPDPF::Key>()?;
  500. {
  501. let (dpf_key_prev, dpf_key_next) =
  502. SPDPF::generate_keys(1 << index_bits, masked_loc as u64, value_share);
  503. comm.send_previous(dpf_key_prev)?;
  504. comm.send_next(dpf_key_next)?;
  505. }
  506. let dpf_key_prev = fut_prev.get()?;
  507. let dpf_key_next = fut_next.get()?;
  508. let t_after_dpf_key_distr = Instant::now();
  509. for j in 0..=self.access_counter {
  510. let index_prev = ((j as u16).wrapping_add(r_prev) & bit_mask) as u64;
  511. let index_next = ((j as u16).wrapping_add(r_next) & bit_mask) as u64;
  512. self.stash_values_share[j] += SPDPF::evaluate_at(&dpf_key_prev, index_prev);
  513. self.stash_values_share[j] += SPDPF::evaluate_at(&dpf_key_next, index_next);
  514. }
  515. (t_after_masked_index, t_after_dpf_key_distr)
  516. };
  517. let t_after_dpf_eval = Instant::now();
  518. self.access_counter += 1;
  519. self.state = if self.access_counter == self.stash_size {
  520. State::AccessesExhausted
  521. } else {
  522. State::AwaitingRead
  523. };
  524. let runtimes = runtimes.map(|mut r| {
  525. r.record(ProtocolStep::WriteAddressTag, t_after_address_tag - t_start);
  526. r.record(
  527. ProtocolStep::WriteStoreTriple,
  528. t_after_store_triple - t_after_address_tag,
  529. );
  530. r.record(
  531. ProtocolStep::WriteSelectPreviousValue,
  532. t_after_select_previous_value - t_after_store_triple,
  533. );
  534. r.record(
  535. ProtocolStep::WriteSelectValue,
  536. t_after_select_value - t_after_select_previous_value,
  537. );
  538. r.record(
  539. ProtocolStep::WriteComputeMaskedIndex,
  540. t_after_masked_index - t_after_select_value,
  541. );
  542. r.record(
  543. ProtocolStep::WriteDpfKeyDistribution,
  544. t_after_dpf_key_distr - t_after_masked_index,
  545. );
  546. r.record(
  547. ProtocolStep::WriteDpfEvaluations,
  548. t_after_dpf_eval - t_after_dpf_key_distr,
  549. );
  550. r
  551. });
  552. Ok(runtimes)
  553. }
  554. }
  555. impl<F, SPDPF> Stash<F> for StashProtocol<F, SPDPF>
  556. where
  557. F: PrimeField + LegendreSymbol + Serializable,
  558. SPDPF: SinglePointDpf<Value = F>,
  559. SPDPF::Key: Serializable,
  560. {
  561. fn get_party_id(&self) -> usize {
  562. self.party_id
  563. }
  564. fn get_stash_size(&self) -> usize {
  565. self.stash_size
  566. }
  567. fn get_access_counter(&self) -> usize {
  568. self.access_counter
  569. }
  570. fn reset(&mut self) {
  571. *self = Self::new(self.party_id, self.stash_size);
  572. }
  573. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  574. self.init_with_runtimes(comm, None).map(|_| ())
  575. }
  576. fn read<C: AbstractCommunicator>(
  577. &mut self,
  578. comm: &mut C,
  579. instruction: InstructionShare<F>,
  580. ) -> Result<StashStateShare<F>, Error> {
  581. self.read_with_runtimes(comm, instruction, None)
  582. .map(|x| x.0)
  583. }
  584. fn write<C: AbstractCommunicator>(
  585. &mut self,
  586. comm: &mut C,
  587. instruction: InstructionShare<F>,
  588. stash_state: StashStateShare<F>,
  589. db_address_share: F,
  590. db_value_share: F,
  591. ) -> Result<(), Error> {
  592. self.write_with_runtimes(
  593. comm,
  594. instruction,
  595. stash_state,
  596. db_address_share,
  597. db_value_share,
  598. None,
  599. )
  600. .map(|_| ())
  601. }
  602. fn get_stash_share(&self) -> (&[F], &[F], &[F]) {
  603. (
  604. &self.stash_addresses_share,
  605. &self.stash_values_share,
  606. &self.stash_old_values_share,
  607. )
  608. }
  609. }
  610. #[cfg(test)]
  611. mod tests {
  612. use super::*;
  613. use crate::common::Operation;
  614. use communicator::unix::make_unix_communicators;
  615. use dpf::spdpf::DummySpDpf;
  616. use ff::Field;
  617. use std::thread;
  618. use utils::field::Fp;
  619. fn run_init<F>(
  620. mut stash_party: impl Stash<F> + Send + 'static,
  621. mut comm: impl AbstractCommunicator + Send + 'static,
  622. ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator)>
  623. where
  624. F: PrimeField + LegendreSymbol,
  625. {
  626. thread::spawn(move || {
  627. stash_party.init(&mut comm).unwrap();
  628. (stash_party, comm)
  629. })
  630. }
  631. fn run_read<F>(
  632. mut stash_party: impl Stash<F> + Send + 'static,
  633. mut comm: impl AbstractCommunicator + Send + 'static,
  634. instruction: InstructionShare<F>,
  635. ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator, StashStateShare<F>)>
  636. where
  637. F: PrimeField + LegendreSymbol,
  638. {
  639. thread::spawn(move || {
  640. let result = stash_party.read(&mut comm, instruction);
  641. (stash_party, comm, result.unwrap())
  642. })
  643. }
  644. fn run_write<F>(
  645. mut stash_party: impl Stash<F> + Send + 'static,
  646. mut comm: impl AbstractCommunicator + Send + 'static,
  647. instruction: InstructionShare<F>,
  648. stash_state: StashStateShare<F>,
  649. db_address_share: F,
  650. db_value_share: F,
  651. ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator)>
  652. where
  653. F: PrimeField + LegendreSymbol,
  654. {
  655. thread::spawn(move || {
  656. stash_party
  657. .write(
  658. &mut comm,
  659. instruction,
  660. stash_state,
  661. db_address_share,
  662. db_value_share,
  663. )
  664. .unwrap();
  665. (stash_party, comm)
  666. })
  667. }
  668. #[test]
  669. fn test_stash() {
  670. type SPDPF = DummySpDpf<Fp>;
  671. let stash_size = 128;
  672. let mut num_accesses = 0;
  673. let party_1 = StashProtocol::<Fp, SPDPF>::new(PARTY_1, stash_size);
  674. let party_2 = StashProtocol::<Fp, SPDPF>::new(PARTY_2, stash_size);
  675. let party_3 = StashProtocol::<Fp, SPDPF>::new(PARTY_3, stash_size);
  676. assert_eq!(party_1.get_party_id(), PARTY_1);
  677. assert_eq!(party_2.get_party_id(), PARTY_2);
  678. assert_eq!(party_3.get_party_id(), PARTY_3);
  679. assert_eq!(party_1.get_stash_size(), stash_size);
  680. assert_eq!(party_2.get_stash_size(), stash_size);
  681. assert_eq!(party_3.get_stash_size(), stash_size);
  682. let (comm_3, comm_2, comm_1) = {
  683. let mut comms = make_unix_communicators(3);
  684. (
  685. comms.pop().unwrap(),
  686. comms.pop().unwrap(),
  687. comms.pop().unwrap(),
  688. )
  689. };
  690. let h1 = run_init(party_1, comm_1);
  691. let h2 = run_init(party_2, comm_2);
  692. let h3 = run_init(party_3, comm_3);
  693. let (party_1, comm_1) = h1.join().unwrap();
  694. let (party_2, comm_2) = h2.join().unwrap();
  695. let (party_3, comm_3) = h3.join().unwrap();
  696. assert_eq!(party_1.get_access_counter(), 0);
  697. assert_eq!(party_2.get_access_counter(), 0);
  698. assert_eq!(party_3.get_access_counter(), 0);
  699. // write a value 42 to address adr = 3
  700. let value = 42;
  701. let address = 3;
  702. let inst_w3_1 = InstructionShare {
  703. operation: Operation::Write.encode(),
  704. address: Fp::from_u128(address),
  705. value: Fp::from_u128(value),
  706. };
  707. let inst_w3_2 = InstructionShare {
  708. operation: Fp::ZERO,
  709. address: Fp::ZERO,
  710. value: Fp::ZERO,
  711. };
  712. let inst_w3_3 = inst_w3_2.clone();
  713. let h1 = run_read(party_1, comm_1, inst_w3_1);
  714. let h2 = run_read(party_2, comm_2, inst_w3_2);
  715. let h3 = run_read(party_3, comm_3, inst_w3_3);
  716. let (party_1, comm_1, state_1) = h1.join().unwrap();
  717. let (party_2, comm_2, state_2) = h2.join().unwrap();
  718. let (party_3, comm_3, state_3) = h3.join().unwrap();
  719. // since the stash is empty, st.flag must be zero
  720. assert_eq!(state_1.flag + state_2.flag + state_3.flag, Fp::ZERO);
  721. assert_eq!(
  722. state_1.location + state_2.location + state_3.location,
  723. Fp::ZERO
  724. );
  725. let h1 = run_write(
  726. party_1,
  727. comm_1,
  728. inst_w3_1,
  729. state_1,
  730. inst_w3_1.address,
  731. Fp::from_u128(0x71),
  732. );
  733. let h2 = run_write(
  734. party_2,
  735. comm_2,
  736. inst_w3_2,
  737. state_1,
  738. inst_w3_2.address,
  739. Fp::from_u128(0x72),
  740. );
  741. let h3 = run_write(
  742. party_3,
  743. comm_3,
  744. inst_w3_3,
  745. state_1,
  746. inst_w3_3.address,
  747. Fp::from_u128(0x73),
  748. );
  749. let (party_1, comm_1) = h1.join().unwrap();
  750. let (party_2, comm_2) = h2.join().unwrap();
  751. let (party_3, comm_3) = h3.join().unwrap();
  752. num_accesses += 1;
  753. assert_eq!(party_1.get_access_counter(), num_accesses);
  754. assert_eq!(party_2.get_access_counter(), num_accesses);
  755. assert_eq!(party_3.get_access_counter(), num_accesses);
  756. {
  757. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  758. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  759. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  760. assert_eq!(st_adrs_1.len(), num_accesses);
  761. assert_eq!(st_vals_1.len(), num_accesses);
  762. assert_eq!(st_old_vals_1.len(), num_accesses);
  763. assert_eq!(st_adrs_2.len(), num_accesses);
  764. assert_eq!(st_vals_2.len(), num_accesses);
  765. assert_eq!(st_old_vals_2.len(), num_accesses);
  766. assert_eq!(st_adrs_3.len(), num_accesses);
  767. assert_eq!(st_vals_3.len(), num_accesses);
  768. assert_eq!(st_old_vals_3.len(), num_accesses);
  769. assert_eq!(
  770. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  771. Fp::from_u128(address)
  772. );
  773. assert_eq!(
  774. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  775. Fp::from_u128(value)
  776. );
  777. }
  778. // read again from address adr = 3, we should get the value 42 back
  779. let inst_r3_1 = InstructionShare {
  780. operation: Operation::Read.encode(),
  781. address: Fp::from_u128(3),
  782. value: Fp::ZERO,
  783. };
  784. let inst_r3_2 = InstructionShare {
  785. operation: Fp::ZERO,
  786. address: Fp::ZERO,
  787. value: Fp::ZERO,
  788. };
  789. let inst_r3_3 = inst_r3_2.clone();
  790. let h1 = run_read(party_1, comm_1, inst_r3_1);
  791. let h2 = run_read(party_2, comm_2, inst_r3_2);
  792. let h3 = run_read(party_3, comm_3, inst_r3_3);
  793. let (party_1, comm_1, state_1) = h1.join().unwrap();
  794. let (party_2, comm_2, state_2) = h2.join().unwrap();
  795. let (party_3, comm_3, state_3) = h3.join().unwrap();
  796. let st_flag = state_1.flag + state_2.flag + state_3.flag;
  797. let st_location = state_1.location + state_2.location + state_3.location;
  798. let st_value = state_1.value + state_2.value + state_3.value;
  799. assert_eq!(st_flag, Fp::ONE);
  800. assert_eq!(st_location, Fp::from_u128(0));
  801. assert_eq!(st_value, Fp::from_u128(value));
  802. let h1 = run_write(
  803. party_1,
  804. comm_1,
  805. inst_r3_1,
  806. state_1,
  807. Fp::from_u128(0x83),
  808. Fp::from_u128(0x93),
  809. );
  810. let h2 = run_write(
  811. party_2,
  812. comm_2,
  813. inst_r3_2,
  814. state_1,
  815. Fp::from_u128(0x83),
  816. Fp::from_u128(0x93),
  817. );
  818. let h3 = run_write(
  819. party_3,
  820. comm_3,
  821. inst_r3_3,
  822. state_1,
  823. Fp::from_u128(0x83),
  824. Fp::from_u128(0x93),
  825. );
  826. let (party_1, comm_1) = h1.join().unwrap();
  827. let (party_2, comm_2) = h2.join().unwrap();
  828. let (party_3, comm_3) = h3.join().unwrap();
  829. num_accesses += 1;
  830. assert_eq!(party_1.get_access_counter(), num_accesses);
  831. assert_eq!(party_2.get_access_counter(), num_accesses);
  832. assert_eq!(party_3.get_access_counter(), num_accesses);
  833. {
  834. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  835. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  836. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  837. assert_eq!(st_adrs_1.len(), num_accesses);
  838. assert_eq!(st_vals_1.len(), num_accesses);
  839. assert_eq!(st_old_vals_1.len(), num_accesses);
  840. assert_eq!(st_adrs_2.len(), num_accesses);
  841. assert_eq!(st_vals_2.len(), num_accesses);
  842. assert_eq!(st_old_vals_2.len(), num_accesses);
  843. assert_eq!(st_adrs_3.len(), num_accesses);
  844. assert_eq!(st_vals_3.len(), num_accesses);
  845. assert_eq!(st_old_vals_3.len(), num_accesses);
  846. assert_eq!(
  847. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  848. Fp::from_u128(address)
  849. );
  850. assert_eq!(
  851. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  852. Fp::from_u128(value)
  853. );
  854. }
  855. // now write a value 0x1337 to address adr = 3
  856. let old_value = value;
  857. let value = 0x1337;
  858. let address = 3;
  859. let inst_w3_1 = InstructionShare {
  860. operation: Operation::Write.encode(),
  861. address: Fp::from_u128(address),
  862. value: Fp::from_u128(value),
  863. };
  864. let inst_w3_2 = InstructionShare {
  865. operation: Fp::ZERO,
  866. address: Fp::ZERO,
  867. value: Fp::ZERO,
  868. };
  869. let inst_w3_3 = inst_w3_2.clone();
  870. let h1 = run_read(party_1, comm_1, inst_w3_1);
  871. let h2 = run_read(party_2, comm_2, inst_w3_2);
  872. let h3 = run_read(party_3, comm_3, inst_w3_3);
  873. let (party_1, comm_1, state_1) = h1.join().unwrap();
  874. let (party_2, comm_2, state_2) = h2.join().unwrap();
  875. let (party_3, comm_3, state_3) = h3.join().unwrap();
  876. // since we already wrote to the address, it should be present in the stash
  877. assert_eq!(state_1.flag + state_2.flag + state_3.flag, Fp::ONE);
  878. assert_eq!(
  879. state_1.location + state_2.location + state_3.location,
  880. Fp::ZERO
  881. );
  882. assert_eq!(
  883. state_1.value + state_2.value + state_3.value,
  884. Fp::from_u128(old_value)
  885. );
  886. let h1 = run_write(
  887. party_1,
  888. comm_1,
  889. inst_w3_1,
  890. state_1,
  891. // inst_w3_1.address,
  892. Fp::from_u128(0x61),
  893. Fp::from_u128(0x71),
  894. );
  895. let h2 = run_write(
  896. party_2,
  897. comm_2,
  898. inst_w3_2,
  899. state_2,
  900. // inst_w3_2.address,
  901. Fp::from_u128(0x62),
  902. Fp::from_u128(0x72),
  903. );
  904. let h3 = run_write(
  905. party_3,
  906. comm_3,
  907. inst_w3_3,
  908. state_3,
  909. // inst_w3_3.address,
  910. Fp::from_u128(0x63),
  911. Fp::from_u128(0x73),
  912. );
  913. let (party_1, comm_1) = h1.join().unwrap();
  914. let (party_2, comm_2) = h2.join().unwrap();
  915. let (party_3, comm_3) = h3.join().unwrap();
  916. num_accesses += 1;
  917. assert_eq!(party_1.get_access_counter(), num_accesses);
  918. assert_eq!(party_2.get_access_counter(), num_accesses);
  919. assert_eq!(party_3.get_access_counter(), num_accesses);
  920. {
  921. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  922. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  923. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  924. assert_eq!(st_adrs_1.len(), num_accesses);
  925. assert_eq!(st_vals_1.len(), num_accesses);
  926. assert_eq!(st_old_vals_1.len(), num_accesses);
  927. assert_eq!(st_adrs_2.len(), num_accesses);
  928. assert_eq!(st_vals_2.len(), num_accesses);
  929. assert_eq!(st_old_vals_2.len(), num_accesses);
  930. assert_eq!(st_adrs_3.len(), num_accesses);
  931. assert_eq!(st_vals_3.len(), num_accesses);
  932. assert_eq!(st_old_vals_3.len(), num_accesses);
  933. assert_eq!(
  934. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  935. Fp::from_u128(address)
  936. );
  937. assert_eq!(
  938. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  939. Fp::from_u128(value)
  940. );
  941. }
  942. // read again from address adr = 3, we should get the value 0x1337 back
  943. let inst_r3_1 = InstructionShare {
  944. operation: Operation::Read.encode(),
  945. address: Fp::from_u128(3),
  946. value: Fp::ZERO,
  947. };
  948. let inst_r3_2 = InstructionShare {
  949. operation: Fp::ZERO,
  950. address: Fp::ZERO,
  951. value: Fp::ZERO,
  952. };
  953. let inst_r3_3 = inst_r3_2.clone();
  954. let h1 = run_read(party_1, comm_1, inst_r3_1);
  955. let h2 = run_read(party_2, comm_2, inst_r3_2);
  956. let h3 = run_read(party_3, comm_3, inst_r3_3);
  957. let (party_1, comm_1, state_1) = h1.join().unwrap();
  958. let (party_2, comm_2, state_2) = h2.join().unwrap();
  959. let (party_3, comm_3, state_3) = h3.join().unwrap();
  960. let st_flag = state_1.flag + state_2.flag + state_3.flag;
  961. let st_location = state_1.location + state_2.location + state_3.location;
  962. let st_value = state_1.value + state_2.value + state_3.value;
  963. assert_eq!(st_flag, Fp::ONE);
  964. assert_eq!(st_location, Fp::from_u128(0));
  965. assert_eq!(st_value, Fp::from_u128(value));
  966. let h1 = run_write(
  967. party_1,
  968. comm_1,
  969. inst_r3_1,
  970. state_1,
  971. Fp::from_u128(0x83),
  972. Fp::from_u128(0x93),
  973. );
  974. let h2 = run_write(
  975. party_2,
  976. comm_2,
  977. inst_r3_2,
  978. state_2,
  979. Fp::from_u128(0x83),
  980. Fp::from_u128(0x93),
  981. );
  982. let h3 = run_write(
  983. party_3,
  984. comm_3,
  985. inst_r3_3,
  986. state_3,
  987. Fp::from_u128(0x83),
  988. Fp::from_u128(0x93),
  989. );
  990. let (party_1, _comm_1) = h1.join().unwrap();
  991. let (party_2, _comm_2) = h2.join().unwrap();
  992. let (party_3, _comm_3) = h3.join().unwrap();
  993. num_accesses += 1;
  994. assert_eq!(party_1.get_access_counter(), num_accesses);
  995. assert_eq!(party_2.get_access_counter(), num_accesses);
  996. assert_eq!(party_3.get_access_counter(), num_accesses);
  997. {
  998. let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
  999. let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
  1000. let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
  1001. assert_eq!(st_adrs_1.len(), num_accesses);
  1002. assert_eq!(st_vals_1.len(), num_accesses);
  1003. assert_eq!(st_old_vals_1.len(), num_accesses);
  1004. assert_eq!(st_adrs_2.len(), num_accesses);
  1005. assert_eq!(st_vals_2.len(), num_accesses);
  1006. assert_eq!(st_old_vals_2.len(), num_accesses);
  1007. assert_eq!(st_adrs_3.len(), num_accesses);
  1008. assert_eq!(st_vals_3.len(), num_accesses);
  1009. assert_eq!(st_old_vals_3.len(), num_accesses);
  1010. assert_eq!(
  1011. st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
  1012. Fp::from_u128(address)
  1013. );
  1014. assert_eq!(
  1015. st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
  1016. Fp::from_u128(value)
  1017. );
  1018. }
  1019. }
  1020. }