| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128 |
- use crate::common::{Error, InstructionShare};
- use crate::doprf::{
- DOPrfParty1, DOPrfParty2, DOPrfParty3, LegendrePrf, MaskedDOPrfParty1, MaskedDOPrfParty2,
- MaskedDOPrfParty3,
- };
- use crate::mask_index::{MaskIndex, MaskIndexProtocol};
- use crate::select::{Select, SelectProtocol};
- use communicator::{AbstractCommunicator, Fut, Serializable};
- use dpf::spdpf::SinglePointDpf;
- use ff::PrimeField;
- use rand::thread_rng;
- use std::marker::PhantomData;
- use std::time::{Duration, Instant};
- use utils::field::LegendreSymbol;
- #[derive(Clone, Copy, Debug, Default)]
- pub struct StashEntryShare<F: PrimeField> {
- pub address: F,
- pub value: F,
- pub old_value: F,
- }
- #[derive(Clone, Copy, Debug, Default)]
- pub struct StashStateShare<F: PrimeField> {
- pub flag: F,
- pub location: F,
- pub value: F,
- }
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
- enum State {
- New,
- AwaitingRead,
- AwaitingWrite,
- AccessesExhausted,
- }
- const PARTY_1: usize = 0;
- const PARTY_2: usize = 1;
- const PARTY_3: usize = 2;
- pub trait Stash<F: PrimeField> {
- fn get_party_id(&self) -> usize;
- fn get_stash_size(&self) -> usize;
- fn get_access_counter(&self) -> usize;
- fn reset(&mut self);
- fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error>;
- fn read<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- instruction: InstructionShare<F>,
- ) -> Result<StashStateShare<F>, Error>;
- fn write<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- instruction: InstructionShare<F>,
- stash_state: StashStateShare<F>,
- db_address_share: F,
- db_value_share: F,
- ) -> Result<(), Error>;
- fn get_stash_share(&self) -> (&[F], &[F], &[F]);
- }
- fn compute_stash_prf_output_bitsize(stash_size: usize) -> usize {
- (usize::BITS - stash_size.leading_zeros()) as usize + 40
- }
- #[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter, strum_macros::Display)]
- pub enum ProtocolStep {
- Init = 0,
- ReadMaskedAddressTag,
- ReadDpfKeyGen,
- ReadLookupFlagLocation,
- ReadComputeLocation,
- ReadReshareFlag,
- ReadConvertToReplicated,
- ReadComputeMaskedIndex,
- ReadDpfKeyDistribution,
- ReadDpfEvaluations,
- WriteAddressTag,
- WriteStoreTriple,
- WriteSelectPreviousValue,
- WriteSelectValue,
- WriteComputeMaskedIndex,
- WriteDpfKeyDistribution,
- WriteDpfEvaluations,
- }
- #[derive(Debug, Default, Clone, Copy)]
- pub struct Runtimes {
- durations: [Duration; 17],
- }
- impl Runtimes {
- #[inline(always)]
- pub fn record(&mut self, id: ProtocolStep, duration: Duration) {
- self.durations[id as usize] += duration;
- }
- pub fn get(&self, id: ProtocolStep) -> Duration {
- self.durations[id as usize]
- }
- }
- pub struct StashProtocol<F, SPDPF>
- where
- F: PrimeField + LegendreSymbol + Serializable,
- SPDPF: SinglePointDpf<Value = F>,
- {
- party_id: usize,
- stash_size: usize,
- access_counter: usize,
- state: State,
- stash_addresses_share: Vec<F>,
- stash_values_share: Vec<F>,
- stash_old_values_share: Vec<F>,
- address_tag_list: Vec<u64>,
- doprf_party_1: Option<DOPrfParty1<F>>,
- doprf_party_2: Option<DOPrfParty2<F>>,
- doprf_party_3: Option<DOPrfParty3<F>>,
- masked_doprf_party_1: Option<MaskedDOPrfParty1<F>>,
- masked_doprf_party_2: Option<MaskedDOPrfParty2<F>>,
- masked_doprf_party_3: Option<MaskedDOPrfParty3<F>>,
- _phantom: PhantomData<SPDPF>,
- }
- impl<F, SPDPF> StashProtocol<F, SPDPF>
- where
- F: PrimeField + LegendreSymbol + Serializable,
- SPDPF: SinglePointDpf<Value = F>,
- SPDPF::Key: Serializable,
- {
- pub fn new(party_id: usize, stash_size: usize) -> Self {
- assert!(party_id < 3);
- assert!(stash_size > 0);
- assert!(compute_stash_prf_output_bitsize(stash_size) <= 64);
- Self {
- party_id,
- stash_size,
- access_counter: 0,
- state: State::New,
- stash_addresses_share: Vec::with_capacity(stash_size),
- stash_values_share: Vec::with_capacity(stash_size),
- stash_old_values_share: Vec::with_capacity(stash_size),
- address_tag_list: if party_id == PARTY_1 {
- Default::default()
- } else {
- Vec::with_capacity(stash_size)
- },
- doprf_party_1: None,
- doprf_party_2: None,
- doprf_party_3: None,
- masked_doprf_party_1: None,
- masked_doprf_party_2: None,
- masked_doprf_party_3: None,
- _phantom: PhantomData,
- }
- }
- fn init_with_runtimes<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- runtimes: Option<Runtimes>,
- ) -> Result<Option<Runtimes>, Error> {
- assert_eq!(self.state, State::New);
- let t_start = Instant::now();
- let prf_output_bitsize = compute_stash_prf_output_bitsize(self.stash_size);
- let legendre_prf_key = LegendrePrf::<F>::key_gen(prf_output_bitsize);
- // run DOPRF initilization
- match self.party_id {
- PARTY_1 => {
- let mut doprf_p1 = DOPrfParty1::from_legendre_prf_key(legendre_prf_key.clone());
- let mut mdoprf_p1 = MaskedDOPrfParty1::from_legendre_prf_key(legendre_prf_key);
- doprf_p1.init(comm)?;
- mdoprf_p1.init(comm)?;
- doprf_p1.preprocess(comm, self.stash_size)?;
- mdoprf_p1.preprocess(comm, self.stash_size)?;
- self.doprf_party_1 = Some(doprf_p1);
- self.masked_doprf_party_1 = Some(mdoprf_p1);
- }
- PARTY_2 => {
- let mut doprf_p2 = DOPrfParty2::new(prf_output_bitsize);
- let mut mdoprf_p2 = MaskedDOPrfParty2::new(prf_output_bitsize);
- doprf_p2.init(comm)?;
- mdoprf_p2.init(comm)?;
- doprf_p2.preprocess(comm, self.stash_size)?;
- mdoprf_p2.preprocess(comm, self.stash_size)?;
- self.doprf_party_2 = Some(doprf_p2);
- self.masked_doprf_party_2 = Some(mdoprf_p2);
- }
- PARTY_3 => {
- let mut doprf_p3 = DOPrfParty3::new(prf_output_bitsize);
- let mut mdoprf_p3 = MaskedDOPrfParty3::new(prf_output_bitsize);
- doprf_p3.init(comm)?;
- mdoprf_p3.init(comm)?;
- doprf_p3.preprocess(comm, self.stash_size)?;
- mdoprf_p3.preprocess(comm, self.stash_size)?;
- self.doprf_party_3 = Some(doprf_p3);
- self.masked_doprf_party_3 = Some(mdoprf_p3);
- }
- _ => panic!("invalid party id"),
- }
- let t_end = Instant::now();
- let runtimes = runtimes.map(|mut r| {
- r.record(ProtocolStep::Init, t_end - t_start);
- r
- });
- // panic!("not implemented");
- self.state = State::AwaitingRead;
- Ok(runtimes)
- }
- pub fn read_with_runtimes<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- instruction: InstructionShare<F>,
- runtimes: Option<Runtimes>,
- ) -> Result<(StashStateShare<F>, Option<Runtimes>), Error> {
- assert_eq!(self.state, State::AwaitingRead);
- assert!(self.access_counter < self.stash_size);
- // 0. If the stash is empty, we are done
- if self.access_counter == 0 {
- self.state = State::AwaitingWrite;
- return Ok((
- StashStateShare {
- flag: F::ZERO,
- location: F::ZERO,
- value: F::ZERO,
- },
- runtimes,
- ));
- }
- let t_start = Instant::now();
- let (
- flag_share,
- location_share,
- t_after_masked_address_tag,
- t_after_dpf_keygen,
- t_after_compute_flag_loc,
- ) = match self.party_id {
- PARTY_1 => {
- // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
- let masked_address_tag: u64 = {
- let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
- mdoprf_p1.eval_to_uint(comm, 1, &[instruction.address])?[0]
- };
- let t_after_masked_address_tag = Instant::now();
- // 2. Create and send DPF keys for the function f(x) = if x = y { 1 } else { 0 }
- {
- let domain_size = 1 << compute_stash_prf_output_bitsize(self.stash_size);
- let (dpf_key_2, dpf_key_3) =
- SPDPF::generate_keys(domain_size, masked_address_tag, F::ONE);
- comm.send(PARTY_2, dpf_key_2)?;
- comm.send(PARTY_3, dpf_key_3)?;
- }
- let t_after_dpf_keygen = Instant::now();
- // 3. The other parties compute shares of <flag>, <loc>, i.e., if the address is present in
- // the stash and if so, where it is. We just take 0s as our shares.
- (
- F::ZERO,
- F::ZERO,
- t_after_masked_address_tag,
- t_after_dpf_keygen,
- t_after_dpf_keygen,
- )
- }
- PARTY_2 | PARTY_3 => {
- // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
- let address_tag_mask: u64 = match self.party_id {
- PARTY_2 => {
- let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
- mdoprf_p2.eval_to_uint(comm, 1, &[instruction.address])?[0]
- }
- PARTY_3 => {
- let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
- mdoprf_p3.eval_to_uint(comm, 1, &[instruction.address])?[0]
- }
- _ => panic!("invalid party id"),
- };
- let t_after_masked_address_tag = Instant::now();
- // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
- let dpf_key_i: SPDPF::Key = {
- let fut = comm.receive(PARTY_1)?;
- fut.get()?
- };
- let t_after_dpf_keygen = Instant::now();
- // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
- // so, where it is
- {
- let mut flag_share = F::ZERO;
- let mut location_share = F::ZERO;
- let mut j_as_field_element = F::ZERO;
- for j in 0..self.address_tag_list.len() {
- let dpf_value_j = SPDPF::evaluate_at(
- &dpf_key_i,
- self.address_tag_list[j] ^ address_tag_mask,
- );
- flag_share += dpf_value_j;
- location_share += j_as_field_element * dpf_value_j;
- j_as_field_element += F::ONE;
- }
- let t_after_compute_flag_loc = Instant::now();
- (
- flag_share,
- location_share,
- t_after_masked_address_tag,
- t_after_dpf_keygen,
- t_after_compute_flag_loc,
- )
- }
- }
- _ => panic!("invalid party id"),
- };
- // 4. Compute <loc> = if <flag> { <loc> } else { access_counter - 1 }
- let location_share = {
- let access_counter_share = if self.party_id == PARTY_1 {
- F::from_u128(self.access_counter as u128)
- } else {
- F::ZERO
- };
- SelectProtocol::select(comm, flag_share, location_share, access_counter_share)?
- };
- let t_after_location_share = Instant::now();
- // 5. Reshare <flag> among all three parties
- let flag_share = match self.party_id {
- PARTY_1 => {
- let flag_share = F::random(thread_rng());
- comm.send(PARTY_2, flag_share)?;
- flag_share
- }
- PARTY_2 => {
- let fut_1_2 = comm.receive::<F>(PARTY_1)?;
- flag_share - fut_1_2.get()?
- }
- _ => flag_share,
- };
- let t_after_flag_share = Instant::now();
- // 6. Read the value <val> from the stash (if <flag>) or read a zero value
- let (
- value_share,
- t_after_convert_to_replicated,
- t_after_masked_index,
- t_after_dpf_key_distr,
- ) = {
- // a) convert the stash into replicated secret sharing
- let fut_prev = comm.receive_previous::<Vec<F>>()?;
- comm.send_next(self.stash_values_share.to_vec())?;
- let stash_values_share_prev = fut_prev.get()?;
- let t_after_convert_to_replicated = Instant::now();
- // b) mask and reconstruct the stash index <loc>
- let index_bits = (self.access_counter as f64).log2().ceil() as u32;
- assert!(index_bits <= 16);
- let bit_mask = ((1 << index_bits) - 1) as u16;
- let (masked_loc, r_prev, r_next) =
- MaskIndexProtocol::mask_index(comm, index_bits, location_share)?;
- let t_after_masked_index = Instant::now();
- // c) use DPFs to read the stash value
- let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
- let fut_next = comm.receive_next::<SPDPF::Key>()?;
- {
- let (dpf_key_prev, dpf_key_next) =
- SPDPF::generate_keys(1 << index_bits, masked_loc as u64, F::ONE);
- comm.send_previous(dpf_key_prev)?;
- comm.send_next(dpf_key_next)?;
- }
- let dpf_key_prev = fut_prev.get()?;
- let dpf_key_next = fut_next.get()?;
- let t_after_dpf_key_distr = Instant::now();
- let mut value_share = F::ZERO;
- for j in 0..self.access_counter {
- let index_prev = ((j as u16 + r_prev) & bit_mask) as u64;
- let index_next = ((j as u16 + r_next) & bit_mask) as u64;
- value_share +=
- SPDPF::evaluate_at(&dpf_key_prev, index_prev) * self.stash_values_share[j];
- value_share +=
- SPDPF::evaluate_at(&dpf_key_next, index_next) * stash_values_share_prev[j];
- }
- (
- value_share,
- t_after_convert_to_replicated,
- t_after_masked_index,
- t_after_dpf_key_distr,
- )
- };
- let t_after_dpf_eval = Instant::now();
- let runtimes = runtimes.map(|mut r| {
- r.record(
- ProtocolStep::ReadMaskedAddressTag,
- t_after_masked_address_tag - t_start,
- );
- r.record(
- ProtocolStep::ReadDpfKeyGen,
- t_after_dpf_keygen - t_after_masked_address_tag,
- );
- r.record(
- ProtocolStep::ReadLookupFlagLocation,
- t_after_compute_flag_loc - t_after_dpf_keygen,
- );
- r.record(
- ProtocolStep::ReadComputeLocation,
- t_after_location_share - t_after_compute_flag_loc,
- );
- r.record(
- ProtocolStep::ReadReshareFlag,
- t_after_flag_share - t_after_location_share,
- );
- r.record(
- ProtocolStep::ReadConvertToReplicated,
- t_after_convert_to_replicated - t_after_flag_share,
- );
- r.record(
- ProtocolStep::ReadComputeMaskedIndex,
- t_after_masked_index - t_after_convert_to_replicated,
- );
- r.record(
- ProtocolStep::ReadDpfKeyDistribution,
- t_after_dpf_key_distr - t_after_masked_index,
- );
- r.record(
- ProtocolStep::ReadDpfEvaluations,
- t_after_dpf_eval - t_after_dpf_key_distr,
- );
- r
- });
- self.state = State::AwaitingWrite;
- Ok((
- StashStateShare {
- flag: flag_share,
- location: location_share,
- value: value_share,
- },
- runtimes,
- ))
- }
- pub fn write_with_runtimes<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- instruction: InstructionShare<F>,
- stash_state: StashStateShare<F>,
- db_address_share: F,
- db_value_share: F,
- runtimes: Option<Runtimes>,
- ) -> Result<Option<Runtimes>, Error> {
- assert_eq!(self.state, State::AwaitingWrite);
- assert!(self.access_counter < self.stash_size);
- let t_start = Instant::now();
- // 1. Compute tag y := PRF(k, <db_adr>) such that P2, P3 obtain y.
- match self.party_id {
- PARTY_1 => {
- let doprf_p1 = self.doprf_party_1.as_mut().unwrap();
- // for now do preprocessing on the fly
- doprf_p1.preprocess(comm, 1)?;
- doprf_p1.eval(comm, 1, &[db_address_share])?;
- }
- PARTY_2 => {
- let address_tag: u64 = {
- let doprf_p2 = self.doprf_party_2.as_mut().unwrap();
- // for now do preprocessing on the fly
- doprf_p2.preprocess(comm, 1)?;
- let fut_3_2 = comm.receive(PARTY_3)?;
- doprf_p2.eval(comm, 1, &[db_address_share])?;
- fut_3_2.get()?
- };
- self.address_tag_list.push(address_tag);
- }
- PARTY_3 => {
- let address_tag: u64 = {
- let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
- // for now do preprocessing on the fly
- doprf_p3.preprocess(comm, 1)?;
- let tag = doprf_p3.eval_to_uint(comm, 1, &[db_address_share])?[0];
- comm.send(PARTY_2, tag)?;
- tag
- };
- self.address_tag_list.push(address_tag);
- }
- _ => panic!("invalid party id"),
- }
- let t_after_address_tag = Instant::now();
- // 2. Insert new triple (<db_adr>, <db_val>, <db_val> into stash.
- self.stash_addresses_share.push(db_address_share);
- self.stash_values_share.push(db_value_share);
- self.stash_old_values_share.push(db_value_share);
- let t_after_store_triple = Instant::now();
- // 3. Update stash
- let previous_value_share =
- SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
- let t_after_select_previous_value = Instant::now();
- let value_share = SelectProtocol::select(
- comm,
- instruction.operation,
- instruction.value - previous_value_share,
- F::ZERO,
- )?;
- let t_after_select_value = Instant::now();
- let (t_after_masked_index, t_after_dpf_key_distr) = {
- // a) mask and reconstruct the stash index <loc>
- let index_bits = {
- let bits = usize::BITS - self.access_counter.leading_zeros();
- if bits > 0 {
- bits
- } else {
- 1
- }
- };
- assert!(index_bits <= 16);
- let bit_mask = ((1 << index_bits) - 1) as u16;
- let (masked_loc, r_prev, r_next) =
- MaskIndexProtocol::mask_index(comm, index_bits, stash_state.location)?;
- let t_after_masked_index = Instant::now();
- // b) use DPFs to read the stash value
- let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
- let fut_next = comm.receive_next::<SPDPF::Key>()?;
- {
- let (dpf_key_prev, dpf_key_next) =
- SPDPF::generate_keys(1 << index_bits, masked_loc as u64, value_share);
- comm.send_previous(dpf_key_prev)?;
- comm.send_next(dpf_key_next)?;
- }
- let dpf_key_prev = fut_prev.get()?;
- let dpf_key_next = fut_next.get()?;
- let t_after_dpf_key_distr = Instant::now();
- for j in 0..=self.access_counter {
- let index_prev = ((j as u16).wrapping_add(r_prev) & bit_mask) as u64;
- let index_next = ((j as u16).wrapping_add(r_next) & bit_mask) as u64;
- self.stash_values_share[j] += SPDPF::evaluate_at(&dpf_key_prev, index_prev);
- self.stash_values_share[j] += SPDPF::evaluate_at(&dpf_key_next, index_next);
- }
- (t_after_masked_index, t_after_dpf_key_distr)
- };
- let t_after_dpf_eval = Instant::now();
- self.access_counter += 1;
- self.state = if self.access_counter == self.stash_size {
- State::AccessesExhausted
- } else {
- State::AwaitingRead
- };
- let runtimes = runtimes.map(|mut r| {
- r.record(ProtocolStep::WriteAddressTag, t_after_address_tag - t_start);
- r.record(
- ProtocolStep::WriteStoreTriple,
- t_after_store_triple - t_after_address_tag,
- );
- r.record(
- ProtocolStep::WriteSelectPreviousValue,
- t_after_select_previous_value - t_after_store_triple,
- );
- r.record(
- ProtocolStep::WriteSelectValue,
- t_after_select_value - t_after_select_previous_value,
- );
- r.record(
- ProtocolStep::WriteComputeMaskedIndex,
- t_after_masked_index - t_after_select_value,
- );
- r.record(
- ProtocolStep::WriteDpfKeyDistribution,
- t_after_dpf_key_distr - t_after_masked_index,
- );
- r.record(
- ProtocolStep::WriteDpfEvaluations,
- t_after_dpf_eval - t_after_dpf_key_distr,
- );
- r
- });
- Ok(runtimes)
- }
- }
- impl<F, SPDPF> Stash<F> for StashProtocol<F, SPDPF>
- where
- F: PrimeField + LegendreSymbol + Serializable,
- SPDPF: SinglePointDpf<Value = F>,
- SPDPF::Key: Serializable,
- {
- fn get_party_id(&self) -> usize {
- self.party_id
- }
- fn get_stash_size(&self) -> usize {
- self.stash_size
- }
- fn get_access_counter(&self) -> usize {
- self.access_counter
- }
- fn reset(&mut self) {
- *self = Self::new(self.party_id, self.stash_size);
- }
- fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
- self.init_with_runtimes(comm, None).map(|_| ())
- }
- fn read<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- instruction: InstructionShare<F>,
- ) -> Result<StashStateShare<F>, Error> {
- self.read_with_runtimes(comm, instruction, None)
- .map(|x| x.0)
- }
- fn write<C: AbstractCommunicator>(
- &mut self,
- comm: &mut C,
- instruction: InstructionShare<F>,
- stash_state: StashStateShare<F>,
- db_address_share: F,
- db_value_share: F,
- ) -> Result<(), Error> {
- self.write_with_runtimes(
- comm,
- instruction,
- stash_state,
- db_address_share,
- db_value_share,
- None,
- )
- .map(|_| ())
- }
- fn get_stash_share(&self) -> (&[F], &[F], &[F]) {
- (
- &self.stash_addresses_share,
- &self.stash_values_share,
- &self.stash_old_values_share,
- )
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use crate::common::Operation;
- use communicator::unix::make_unix_communicators;
- use dpf::spdpf::DummySpDpf;
- use ff::Field;
- use std::thread;
- use utils::field::Fp;
- fn run_init<F>(
- mut stash_party: impl Stash<F> + Send + 'static,
- mut comm: impl AbstractCommunicator + Send + 'static,
- ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator)>
- where
- F: PrimeField + LegendreSymbol,
- {
- thread::spawn(move || {
- stash_party.init(&mut comm).unwrap();
- (stash_party, comm)
- })
- }
- fn run_read<F>(
- mut stash_party: impl Stash<F> + Send + 'static,
- mut comm: impl AbstractCommunicator + Send + 'static,
- instruction: InstructionShare<F>,
- ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator, StashStateShare<F>)>
- where
- F: PrimeField + LegendreSymbol,
- {
- thread::spawn(move || {
- let result = stash_party.read(&mut comm, instruction);
- (stash_party, comm, result.unwrap())
- })
- }
- fn run_write<F>(
- mut stash_party: impl Stash<F> + Send + 'static,
- mut comm: impl AbstractCommunicator + Send + 'static,
- instruction: InstructionShare<F>,
- stash_state: StashStateShare<F>,
- db_address_share: F,
- db_value_share: F,
- ) -> thread::JoinHandle<(impl Stash<F>, impl AbstractCommunicator)>
- where
- F: PrimeField + LegendreSymbol,
- {
- thread::spawn(move || {
- stash_party
- .write(
- &mut comm,
- instruction,
- stash_state,
- db_address_share,
- db_value_share,
- )
- .unwrap();
- (stash_party, comm)
- })
- }
- #[test]
- fn test_stash() {
- type SPDPF = DummySpDpf<Fp>;
- let stash_size = 128;
- let mut num_accesses = 0;
- let party_1 = StashProtocol::<Fp, SPDPF>::new(PARTY_1, stash_size);
- let party_2 = StashProtocol::<Fp, SPDPF>::new(PARTY_2, stash_size);
- let party_3 = StashProtocol::<Fp, SPDPF>::new(PARTY_3, stash_size);
- assert_eq!(party_1.get_party_id(), PARTY_1);
- assert_eq!(party_2.get_party_id(), PARTY_2);
- assert_eq!(party_3.get_party_id(), PARTY_3);
- assert_eq!(party_1.get_stash_size(), stash_size);
- assert_eq!(party_2.get_stash_size(), stash_size);
- assert_eq!(party_3.get_stash_size(), stash_size);
- let (comm_3, comm_2, comm_1) = {
- let mut comms = make_unix_communicators(3);
- (
- comms.pop().unwrap(),
- comms.pop().unwrap(),
- comms.pop().unwrap(),
- )
- };
- let h1 = run_init(party_1, comm_1);
- let h2 = run_init(party_2, comm_2);
- let h3 = run_init(party_3, comm_3);
- let (party_1, comm_1) = h1.join().unwrap();
- let (party_2, comm_2) = h2.join().unwrap();
- let (party_3, comm_3) = h3.join().unwrap();
- assert_eq!(party_1.get_access_counter(), 0);
- assert_eq!(party_2.get_access_counter(), 0);
- assert_eq!(party_3.get_access_counter(), 0);
- // write a value 42 to address adr = 3
- let value = 42;
- let address = 3;
- let inst_w3_1 = InstructionShare {
- operation: Operation::Write.encode(),
- address: Fp::from_u128(address),
- value: Fp::from_u128(value),
- };
- let inst_w3_2 = InstructionShare {
- operation: Fp::ZERO,
- address: Fp::ZERO,
- value: Fp::ZERO,
- };
- let inst_w3_3 = inst_w3_2.clone();
- let h1 = run_read(party_1, comm_1, inst_w3_1);
- let h2 = run_read(party_2, comm_2, inst_w3_2);
- let h3 = run_read(party_3, comm_3, inst_w3_3);
- let (party_1, comm_1, state_1) = h1.join().unwrap();
- let (party_2, comm_2, state_2) = h2.join().unwrap();
- let (party_3, comm_3, state_3) = h3.join().unwrap();
- // since the stash is empty, st.flag must be zero
- assert_eq!(state_1.flag + state_2.flag + state_3.flag, Fp::ZERO);
- assert_eq!(
- state_1.location + state_2.location + state_3.location,
- Fp::ZERO
- );
- let h1 = run_write(
- party_1,
- comm_1,
- inst_w3_1,
- state_1,
- inst_w3_1.address,
- Fp::from_u128(0x71),
- );
- let h2 = run_write(
- party_2,
- comm_2,
- inst_w3_2,
- state_1,
- inst_w3_2.address,
- Fp::from_u128(0x72),
- );
- let h3 = run_write(
- party_3,
- comm_3,
- inst_w3_3,
- state_1,
- inst_w3_3.address,
- Fp::from_u128(0x73),
- );
- let (party_1, comm_1) = h1.join().unwrap();
- let (party_2, comm_2) = h2.join().unwrap();
- let (party_3, comm_3) = h3.join().unwrap();
- num_accesses += 1;
- assert_eq!(party_1.get_access_counter(), num_accesses);
- assert_eq!(party_2.get_access_counter(), num_accesses);
- assert_eq!(party_3.get_access_counter(), num_accesses);
- {
- let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
- let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
- let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
- assert_eq!(st_adrs_1.len(), num_accesses);
- assert_eq!(st_vals_1.len(), num_accesses);
- assert_eq!(st_old_vals_1.len(), num_accesses);
- assert_eq!(st_adrs_2.len(), num_accesses);
- assert_eq!(st_vals_2.len(), num_accesses);
- assert_eq!(st_old_vals_2.len(), num_accesses);
- assert_eq!(st_adrs_3.len(), num_accesses);
- assert_eq!(st_vals_3.len(), num_accesses);
- assert_eq!(st_old_vals_3.len(), num_accesses);
- assert_eq!(
- st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
- Fp::from_u128(address)
- );
- assert_eq!(
- st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
- Fp::from_u128(value)
- );
- }
- // read again from address adr = 3, we should get the value 42 back
- let inst_r3_1 = InstructionShare {
- operation: Operation::Read.encode(),
- address: Fp::from_u128(3),
- value: Fp::ZERO,
- };
- let inst_r3_2 = InstructionShare {
- operation: Fp::ZERO,
- address: Fp::ZERO,
- value: Fp::ZERO,
- };
- let inst_r3_3 = inst_r3_2.clone();
- let h1 = run_read(party_1, comm_1, inst_r3_1);
- let h2 = run_read(party_2, comm_2, inst_r3_2);
- let h3 = run_read(party_3, comm_3, inst_r3_3);
- let (party_1, comm_1, state_1) = h1.join().unwrap();
- let (party_2, comm_2, state_2) = h2.join().unwrap();
- let (party_3, comm_3, state_3) = h3.join().unwrap();
- let st_flag = state_1.flag + state_2.flag + state_3.flag;
- let st_location = state_1.location + state_2.location + state_3.location;
- let st_value = state_1.value + state_2.value + state_3.value;
- assert_eq!(st_flag, Fp::ONE);
- assert_eq!(st_location, Fp::from_u128(0));
- assert_eq!(st_value, Fp::from_u128(value));
- let h1 = run_write(
- party_1,
- comm_1,
- inst_r3_1,
- state_1,
- Fp::from_u128(0x83),
- Fp::from_u128(0x93),
- );
- let h2 = run_write(
- party_2,
- comm_2,
- inst_r3_2,
- state_1,
- Fp::from_u128(0x83),
- Fp::from_u128(0x93),
- );
- let h3 = run_write(
- party_3,
- comm_3,
- inst_r3_3,
- state_1,
- Fp::from_u128(0x83),
- Fp::from_u128(0x93),
- );
- let (party_1, comm_1) = h1.join().unwrap();
- let (party_2, comm_2) = h2.join().unwrap();
- let (party_3, comm_3) = h3.join().unwrap();
- num_accesses += 1;
- assert_eq!(party_1.get_access_counter(), num_accesses);
- assert_eq!(party_2.get_access_counter(), num_accesses);
- assert_eq!(party_3.get_access_counter(), num_accesses);
- {
- let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
- let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
- let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
- assert_eq!(st_adrs_1.len(), num_accesses);
- assert_eq!(st_vals_1.len(), num_accesses);
- assert_eq!(st_old_vals_1.len(), num_accesses);
- assert_eq!(st_adrs_2.len(), num_accesses);
- assert_eq!(st_vals_2.len(), num_accesses);
- assert_eq!(st_old_vals_2.len(), num_accesses);
- assert_eq!(st_adrs_3.len(), num_accesses);
- assert_eq!(st_vals_3.len(), num_accesses);
- assert_eq!(st_old_vals_3.len(), num_accesses);
- assert_eq!(
- st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
- Fp::from_u128(address)
- );
- assert_eq!(
- st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
- Fp::from_u128(value)
- );
- }
- // now write a value 0x1337 to address adr = 3
- let old_value = value;
- let value = 0x1337;
- let address = 3;
- let inst_w3_1 = InstructionShare {
- operation: Operation::Write.encode(),
- address: Fp::from_u128(address),
- value: Fp::from_u128(value),
- };
- let inst_w3_2 = InstructionShare {
- operation: Fp::ZERO,
- address: Fp::ZERO,
- value: Fp::ZERO,
- };
- let inst_w3_3 = inst_w3_2.clone();
- let h1 = run_read(party_1, comm_1, inst_w3_1);
- let h2 = run_read(party_2, comm_2, inst_w3_2);
- let h3 = run_read(party_3, comm_3, inst_w3_3);
- let (party_1, comm_1, state_1) = h1.join().unwrap();
- let (party_2, comm_2, state_2) = h2.join().unwrap();
- let (party_3, comm_3, state_3) = h3.join().unwrap();
- // since we already wrote to the address, it should be present in the stash
- assert_eq!(state_1.flag + state_2.flag + state_3.flag, Fp::ONE);
- assert_eq!(
- state_1.location + state_2.location + state_3.location,
- Fp::ZERO
- );
- assert_eq!(
- state_1.value + state_2.value + state_3.value,
- Fp::from_u128(old_value)
- );
- let h1 = run_write(
- party_1,
- comm_1,
- inst_w3_1,
- state_1,
- // inst_w3_1.address,
- Fp::from_u128(0x61),
- Fp::from_u128(0x71),
- );
- let h2 = run_write(
- party_2,
- comm_2,
- inst_w3_2,
- state_2,
- // inst_w3_2.address,
- Fp::from_u128(0x62),
- Fp::from_u128(0x72),
- );
- let h3 = run_write(
- party_3,
- comm_3,
- inst_w3_3,
- state_3,
- // inst_w3_3.address,
- Fp::from_u128(0x63),
- Fp::from_u128(0x73),
- );
- let (party_1, comm_1) = h1.join().unwrap();
- let (party_2, comm_2) = h2.join().unwrap();
- let (party_3, comm_3) = h3.join().unwrap();
- num_accesses += 1;
- assert_eq!(party_1.get_access_counter(), num_accesses);
- assert_eq!(party_2.get_access_counter(), num_accesses);
- assert_eq!(party_3.get_access_counter(), num_accesses);
- {
- let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
- let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
- let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
- assert_eq!(st_adrs_1.len(), num_accesses);
- assert_eq!(st_vals_1.len(), num_accesses);
- assert_eq!(st_old_vals_1.len(), num_accesses);
- assert_eq!(st_adrs_2.len(), num_accesses);
- assert_eq!(st_vals_2.len(), num_accesses);
- assert_eq!(st_old_vals_2.len(), num_accesses);
- assert_eq!(st_adrs_3.len(), num_accesses);
- assert_eq!(st_vals_3.len(), num_accesses);
- assert_eq!(st_old_vals_3.len(), num_accesses);
- assert_eq!(
- st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
- Fp::from_u128(address)
- );
- assert_eq!(
- st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
- Fp::from_u128(value)
- );
- }
- // read again from address adr = 3, we should get the value 0x1337 back
- let inst_r3_1 = InstructionShare {
- operation: Operation::Read.encode(),
- address: Fp::from_u128(3),
- value: Fp::ZERO,
- };
- let inst_r3_2 = InstructionShare {
- operation: Fp::ZERO,
- address: Fp::ZERO,
- value: Fp::ZERO,
- };
- let inst_r3_3 = inst_r3_2.clone();
- let h1 = run_read(party_1, comm_1, inst_r3_1);
- let h2 = run_read(party_2, comm_2, inst_r3_2);
- let h3 = run_read(party_3, comm_3, inst_r3_3);
- let (party_1, comm_1, state_1) = h1.join().unwrap();
- let (party_2, comm_2, state_2) = h2.join().unwrap();
- let (party_3, comm_3, state_3) = h3.join().unwrap();
- let st_flag = state_1.flag + state_2.flag + state_3.flag;
- let st_location = state_1.location + state_2.location + state_3.location;
- let st_value = state_1.value + state_2.value + state_3.value;
- assert_eq!(st_flag, Fp::ONE);
- assert_eq!(st_location, Fp::from_u128(0));
- assert_eq!(st_value, Fp::from_u128(value));
- let h1 = run_write(
- party_1,
- comm_1,
- inst_r3_1,
- state_1,
- Fp::from_u128(0x83),
- Fp::from_u128(0x93),
- );
- let h2 = run_write(
- party_2,
- comm_2,
- inst_r3_2,
- state_2,
- Fp::from_u128(0x83),
- Fp::from_u128(0x93),
- );
- let h3 = run_write(
- party_3,
- comm_3,
- inst_r3_3,
- state_3,
- Fp::from_u128(0x83),
- Fp::from_u128(0x93),
- );
- let (party_1, _comm_1) = h1.join().unwrap();
- let (party_2, _comm_2) = h2.join().unwrap();
- let (party_3, _comm_3) = h3.join().unwrap();
- num_accesses += 1;
- assert_eq!(party_1.get_access_counter(), num_accesses);
- assert_eq!(party_2.get_access_counter(), num_accesses);
- assert_eq!(party_3.get_access_counter(), num_accesses);
- {
- let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
- let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
- let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
- assert_eq!(st_adrs_1.len(), num_accesses);
- assert_eq!(st_vals_1.len(), num_accesses);
- assert_eq!(st_old_vals_1.len(), num_accesses);
- assert_eq!(st_adrs_2.len(), num_accesses);
- assert_eq!(st_vals_2.len(), num_accesses);
- assert_eq!(st_old_vals_2.len(), num_accesses);
- assert_eq!(st_adrs_3.len(), num_accesses);
- assert_eq!(st_vals_3.len(), num_accesses);
- assert_eq!(st_old_vals_3.len(), num_accesses);
- assert_eq!(
- st_adrs_1[0] + st_adrs_2[0] + st_adrs_3[0],
- Fp::from_u128(address)
- );
- assert_eq!(
- st_vals_1[0] + st_vals_2[0] + st_vals_3[0],
- Fp::from_u128(value)
- );
- }
- }
- }
|