Browse Source

stash: combine parties into single implementation

Lennart Braun 2 years ago
parent
commit
65b4d824e2
1 changed files with 194 additions and 567 deletions
  1. 194 567
      oram/src/stash.rs

+ 194 - 567
oram/src/stash.rs

@@ -15,7 +15,6 @@ use std::marker::PhantomData;
 use utils::field::LegendreSymbol;
 
 type BitVec = bitvec::vec::BitVec<u8>;
-// type BitSlice = bitvec::slice::BitSlice<u8>;
 
 #[derive(Clone, Copy, Debug, Default)]
 pub struct StashEntryShare<F: PrimeField> {
@@ -184,53 +183,70 @@ where
     Ok(())
 }
 
-pub struct StashParty1<F, SPDPF>
+pub struct StashProtocol<F, SPDPF>
 where
     F: PrimeField + LegendreSymbol + Serializable,
     SPDPF: SinglePointDpf<Value = F>,
 {
+    party_id: usize,
     stash_size: usize,
     access_counter: usize,
     state: State,
     stash_addresses_share: Vec<F>,
     stash_values_share: Vec<F>,
     stash_old_values_share: Vec<F>,
+    address_tag_list: Vec<u64>,
     doprf_party_1: Option<DOPrfParty1<F>>,
+    doprf_party_2: Option<DOPrfParty2<F>>,
+    doprf_party_3: Option<DOPrfParty3<F>>,
     masked_doprf_party_1: Option<MaskedDOPrfParty1<F>>,
+    masked_doprf_party_2: Option<MaskedDOPrfParty2<F>>,
+    masked_doprf_party_3: Option<MaskedDOPrfParty3<F>>,
     _phantom: PhantomData<SPDPF>,
 }
 
-impl<F, SPDPF> StashParty1<F, SPDPF>
+impl<F, SPDPF> StashProtocol<F, SPDPF>
 where
     F: PrimeField + LegendreSymbol + Serializable,
     SPDPF: SinglePointDpf<Value = F>,
 {
-    pub fn new(stash_size: usize) -> Self {
+    pub fn new(party_id: usize, stash_size: usize) -> Self {
+        assert!(party_id < 3);
         assert!(stash_size > 0);
         assert!(compute_stash_prf_output_bitsize(stash_size) <= 64);
 
         Self {
+            party_id,
             stash_size,
             access_counter: 0,
             state: State::New,
             stash_addresses_share: Vec::with_capacity(stash_size),
             stash_values_share: Vec::with_capacity(stash_size),
             stash_old_values_share: Vec::with_capacity(stash_size),
+            address_tag_list: if party_id == PARTY_1 {
+                Default::default()
+            } else {
+                Vec::with_capacity(stash_size)
+            },
             doprf_party_1: None,
+            doprf_party_2: None,
+            doprf_party_3: None,
             masked_doprf_party_1: None,
+            masked_doprf_party_2: None,
+            masked_doprf_party_3: None,
             _phantom: PhantomData,
         }
     }
 }
 
-impl<F, SPDPF> Stash<F> for StashParty1<F, SPDPF>
+impl<F, SPDPF> Stash<F> for StashProtocol<F, SPDPF>
 where
     F: PrimeField + LegendreSymbol + Serializable,
     SPDPF: SinglePointDpf<Value = F>,
     SPDPF::Key: Serializable,
 {
     fn get_party_id(&self) -> usize {
-        1
+        self.party_id
     }
 
     fn get_stash_size(&self) -> usize {
@@ -246,16 +262,34 @@ where
 
         let prf_output_bitsize = compute_stash_prf_output_bitsize(self.stash_size);
         let legendre_prf_key = LegendrePrf::<F>::key_gen(prf_output_bitsize);
-        self.doprf_party_1 = Some(DOPrfParty1::from_legendre_prf_key(legendre_prf_key.clone()));
-        self.masked_doprf_party_1 =
-            Some(MaskedDOPrfParty1::from_legendre_prf_key(legendre_prf_key));
 
         // run DOPRF initilization
-        {
-            let doprf_p1 = self.doprf_party_1.as_mut().unwrap();
-            doprf_p1.init(comm)?;
-            let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
-            mdoprf_p1.init(comm)?;
+        match self.party_id {
+            PARTY_1 => {
+                let mut doprf_p1 = DOPrfParty1::from_legendre_prf_key(legendre_prf_key.clone());
+                let mut mdoprf_p1 = MaskedDOPrfParty1::from_legendre_prf_key(legendre_prf_key);
+                doprf_p1.init(comm)?;
+                mdoprf_p1.init(comm)?;
+                self.doprf_party_1 = Some(doprf_p1);
+                self.masked_doprf_party_1 = Some(mdoprf_p1);
+            }
+            PARTY_2 => {
+                let mut doprf_p2 = DOPrfParty2::new(prf_output_bitsize);
+                let mut mdoprf_p2 = MaskedDOPrfParty2::new(prf_output_bitsize);
+                doprf_p2.init(comm)?;
+                mdoprf_p2.init(comm)?;
+                self.doprf_party_2 = Some(doprf_p2);
+                self.masked_doprf_party_2 = Some(mdoprf_p2);
+            }
+            PARTY_3 => {
+                let mut doprf_p3 = DOPrfParty3::new(prf_output_bitsize);
+                let mut mdoprf_p3 = MaskedDOPrfParty3::new(prf_output_bitsize);
+                doprf_p3.init(comm)?;
+                mdoprf_p3.init(comm)?;
+                self.doprf_party_3 = Some(doprf_p3);
+                self.masked_doprf_party_3 = Some(mdoprf_p3);
+            }
+            _ => panic!("invalid party id"),
         }
 
         // panic!("not implemented");
@@ -281,40 +315,100 @@ where
             });
         }
 
-        // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
-        let masked_address_tag = {
-            let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
-            // for now do preprocessing on the fly
-            mdoprf_p1.preprocess(comm, 1)?;
-            let mut masked_tag = mdoprf_p1.eval(comm, 1, &[instruction.address])?;
-            bits_to_u64(masked_tag.pop().unwrap())
+        let (flag_share, location_share) = match self.party_id {
+            PARTY_1 => {
+                // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
+                let masked_address_tag = {
+                    let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
+                    // for now do preprocessing on the fly
+                    mdoprf_p1.preprocess(comm, 1)?;
+                    let mut masked_tag = mdoprf_p1.eval(comm, 1, &[instruction.address])?;
+                    bits_to_u64(masked_tag.pop().unwrap())
+                };
+
+                // 2. Create and send DPF keys for the function f(x) = if x = y { 1 } else { 0 }
+                {
+                    let domain_size = 1 << compute_stash_prf_output_bitsize(self.stash_size);
+                    let (dpf_key_2, dpf_key_3) =
+                        SPDPF::generate_keys(domain_size, masked_address_tag, F::ONE);
+                    comm.send(PARTY_2, dpf_key_2)?;
+                    comm.send(PARTY_3, dpf_key_3)?;
+                }
+
+                // 3. The other parties compute shares of <flag>, <loc>, i.e., if the address is present in
+                //    the stash and if so, where it is. We just take 0s as our shares.
+                (F::ZERO, F::ZERO)
+            }
+            PARTY_2 | PARTY_3 => {
+                // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
+                let address_tag_mask = {
+                    let mut mask = match self.party_id {
+                        PARTY_2 => {
+                            let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
+                            // for now do preprocessing on the fly
+                            mdoprf_p2.preprocess(comm, 1)?;
+                            mdoprf_p2.eval(comm, 1, &[instruction.address])?
+                        }
+                        PARTY_3 => {
+                            let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
+                            // for now do preprocessing on the fly
+                            mdoprf_p3.preprocess(comm, 1)?;
+                            mdoprf_p3.eval(comm, 1, &[instruction.address])?
+                        }
+                        _ => panic!("invalid party id"),
+                    };
+                    bits_to_u64(mask.pop().unwrap())
+                };
+
+                // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
+                let dpf_key_i: SPDPF::Key = {
+                    let fut = comm.receive(PARTY_1)?;
+                    fut.get()?
+                };
+
+                // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
+                //    so, where it is
+                {
+                    let mut flag_share = F::ZERO;
+                    let mut location_share = F::ZERO;
+                    let mut j_as_field_element = F::ZERO;
+                    for j in 0..self.address_tag_list.len() {
+                        let dpf_value_j = SPDPF::evaluate_at(
+                            &dpf_key_i,
+                            self.address_tag_list[j] ^ address_tag_mask,
+                        );
+                        flag_share += dpf_value_j;
+                        location_share += j_as_field_element * dpf_value_j;
+                        j_as_field_element += F::ONE;
+                    }
+                    (flag_share, location_share)
+                }
+            }
+            _ => panic!("invalid party id"),
         };
 
-        // 2. Create and send DPF keys for the function f(x) = if x = y { 1 } else { 0 }
-        {
-            let domain_size = 1 << compute_stash_prf_output_bitsize(self.stash_size);
-            let (dpf_key_2, dpf_key_3) =
-                SPDPF::generate_keys(domain_size, masked_address_tag, F::ONE);
-            comm.send(PARTY_2, dpf_key_2)?;
-            comm.send(PARTY_3, dpf_key_3)?;
-        }
-
-        // 3. The other parties compute shares of <flag>, <loc>, i.e., if the address is present in
-        //    the stash and if so, where it is
-
         // 4. Compute <loc> = if <flag> { <loc> } else { access_counter - 1 }
-        let location_share = SelectProtocol::select(
-            comm,
-            F::ZERO,
-            F::ZERO,
-            F::from_u128(self.access_counter as u128),
-        )?;
+        let location_share = {
+            let access_counter_share = if self.party_id == PARTY_1 {
+                F::from_u128(self.access_counter as u128)
+            } else {
+                F::ZERO
+            };
+            SelectProtocol::select(comm, flag_share, location_share, access_counter_share)?
+        };
 
         // 5. Reshare <flag> among all three parties
-        let flag_share = {
-            let flag_share = F::random(thread_rng());
-            comm.send(PARTY_2, flag_share)?;
-            flag_share
+        let flag_share = match self.party_id {
+            PARTY_1 => {
+                let flag_share = F::random(thread_rng());
+                comm.send(PARTY_2, flag_share)?;
+                flag_share
+            }
+            PARTY_2 => {
+                let fut_1_2 = comm.receive::<F>(PARTY_1)?;
+                flag_share - fut_1_2.get()?
+            }
+            _ => flag_share,
         };
 
         // 6. Read the value <val> from the stash (if <flag>) or read a zero value
@@ -325,7 +419,6 @@ where
             &self.stash_values_share,
         )?;
 
-        // TODO: handle an empty stash differently
         self.state = State::AwaitingWrite;
         Ok(StashStateShare {
             flag: flag_share,
@@ -346,12 +439,38 @@ where
         assert!(self.access_counter < self.stash_size);
 
         // 1. Compute tag y := PRF(k, <db_adr>) such that P2, P3 obtain y.
-        {
-            let doprf_p1 = self.doprf_party_1.as_mut().unwrap();
-            // for now do preprocessing on the fly
-            doprf_p1.preprocess(comm, 1)?;
-            doprf_p1.eval(comm, 1, &[db_address_share])?;
-        };
+        match self.party_id {
+            PARTY_1 => {
+                let doprf_p1 = self.doprf_party_1.as_mut().unwrap();
+                // for now do preprocessing on the fly
+                doprf_p1.preprocess(comm, 1)?;
+                doprf_p1.eval(comm, 1, &[db_address_share])?;
+            }
+            PARTY_2 => {
+                let address_tag: u64 = {
+                    let doprf_p2 = self.doprf_party_2.as_mut().unwrap();
+                    // for now do preprocessing on the fly
+                    doprf_p2.preprocess(comm, 1)?;
+                    let fut_3_2 = comm.receive(PARTY_3)?;
+                    doprf_p2.eval(comm, 1, &[db_address_share])?;
+                    fut_3_2.get()?
+                };
+                self.address_tag_list.push(address_tag);
+            }
+            PARTY_3 => {
+                let address_tag: u64 = {
+                    let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
+                    // for now do preprocessing on the fly
+                    doprf_p3.preprocess(comm, 1)?;
+                    let mut tag = doprf_p3.eval(comm, 1, &[db_address_share])?;
+                    let tag = bits_to_u64(tag.pop().unwrap());
+                    comm.send(PARTY_2, tag)?;
+                    tag
+                };
+                self.address_tag_list.push(address_tag);
+            }
+            _ => panic!("invalid party id"),
+        }
 
         // 2. Insert new triple (<db_adr>, <db_val>, <db_val> into stash.
         self.stash_addresses_share.push(db_address_share);
@@ -359,533 +478,41 @@ where
         self.stash_old_values_share.push(db_value_share);
 
         // 3. Update stash
-        // - if I.op = write, we want to write I.val to index loc
-        // - if I.op = read, we need to write db_val to index c
-        //   (since that has been done already in step 2, it is essentially a no-op)
-        {
-            let previous_value_share =
-                SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
-            let location_share =
-                SelectProtocol::select(comm, instruction.operation, stash_state.location, -F::ONE)?;
-            let value_share = SelectProtocol::select(
+        let previous_value_share =
+            SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
+        let location_share = {
+            let invalid_location_share = if self.party_id == PARTY_1 {
+                -F::ONE
+            } else {
+                F::ZERO
+            };
+            SelectProtocol::select(
                 comm,
                 instruction.operation,
-                instruction.value - previous_value_share,
-                F::ZERO,
-            )?;
-            // eprintln!(
-            //     "P{}: prev_val = {:?}, loc = {:?}, upd = {:?}, vals = {:?}",
-            //     comm.get_my_id() + 1,
-            //     previous_value_share,
-            //     location_share,
-            //     value_share,
-            //     self.stash_values_share
-            // );
-            stash_write_value::<C, F, SPDPF>(
-                comm,
-                self.access_counter,
-                location_share,
-                value_share,
-                &mut self.stash_values_share,
-            )?;
-            // eprintln!(
-            //     "P{}: prev_val = {:?}, loc = {:?}, upd = {:?}, vals = {:?}",
-            //     comm.get_my_id() + 1,
-            //     previous_value_share,
-            //     location_share,
-            //     value_share,
-            //     self.stash_values_share
-            // );
-        }
-
-        // todo!("not implemented");
-        self.access_counter += 1;
-        self.state = if self.access_counter == self.stash_size {
-            State::AccessesExhausted
-        } else {
-            State::AwaitingRead
+                stash_state.location,
+                invalid_location_share,
+            )?
         };
-        Ok(())
-    }
-
-    fn get_stash_share(&self) -> (&[F], &[F], &[F]) {
-        (
-            &self.stash_addresses_share,
-            &self.stash_values_share,
-            &self.stash_old_values_share,
-        )
-    }
-}
-
-pub struct StashParty2<F, SPDPF>
-where
-    F: PrimeField + LegendreSymbol + Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
-{
-    stash_size: usize,
-    access_counter: usize,
-    state: State,
-    stash_addresses_share: Vec<F>,
-    stash_values_share: Vec<F>,
-    stash_old_values_share: Vec<F>,
-    address_tag_list: Vec<u64>,
-    doprf_party_2: Option<DOPrfParty2<F>>,
-    masked_doprf_party_2: Option<MaskedDOPrfParty2<F>>,
-    _phantom: PhantomData<SPDPF>,
-}
-
-impl<F, SPDPF> StashParty2<F, SPDPF>
-where
-    F: PrimeField + LegendreSymbol + Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
-{
-    pub fn new(stash_size: usize) -> Self {
-        assert!(stash_size > 0);
-        assert!(compute_stash_prf_output_bitsize(stash_size) <= 64);
-
-        Self {
-            stash_size,
-            access_counter: 0,
-            state: State::New,
-            stash_addresses_share: Vec::with_capacity(stash_size),
-            stash_values_share: Vec::with_capacity(stash_size),
-            stash_old_values_share: Vec::with_capacity(stash_size),
-            address_tag_list: Vec::with_capacity(stash_size),
-            doprf_party_2: None,
-            masked_doprf_party_2: None,
-            _phantom: PhantomData,
-        }
-    }
-}
-
-impl<F, SPDPF> Stash<F> for StashParty2<F, SPDPF>
-where
-    F: PrimeField + LegendreSymbol + Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
-    SPDPF::Key: Serializable,
-{
-    fn get_party_id(&self) -> usize {
-        2
-    }
-
-    fn get_stash_size(&self) -> usize {
-        self.stash_size
-    }
-
-    fn get_access_counter(&self) -> usize {
-        self.access_counter
-    }
-
-    fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
-        assert_eq!(self.state, State::New);
-
-        let prf_output_bitsize = compute_stash_prf_output_bitsize(self.stash_size);
-        self.doprf_party_2 = Some(DOPrfParty2::new(prf_output_bitsize));
-        self.masked_doprf_party_2 = Some(MaskedDOPrfParty2::new(prf_output_bitsize));
-
-        // run DOPRF initilization
-        {
-            let doprf_p2 = self.doprf_party_2.as_mut().unwrap();
-            doprf_p2.init(comm)?;
-            let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
-            mdoprf_p2.init(comm)?;
-        }
-
-        // panic!("not implemented");
-        self.state = State::AwaitingRead;
-        Ok(())
-    }
-
-    fn read<C: AbstractCommunicator>(
-        &mut self,
-        comm: &mut C,
-        instruction: InstructionShare<F>,
-    ) -> Result<StashStateShare<F>, Error> {
-        assert_eq!(self.state, State::AwaitingRead);
-        assert!(self.access_counter < self.stash_size);
-
-        // 0. If the stash is empty, we are done
-        if self.access_counter == 0 {
-            self.state = State::AwaitingWrite;
-            return Ok(StashStateShare {
-                flag: F::ZERO,
-                location: F::ZERO,
-                value: F::ZERO,
-            });
-        }
-
-        // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
-        let address_tag_mask = {
-            let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
-            // for now do preprocessing on the fly
-            mdoprf_p2.preprocess(comm, 1)?;
-            let mut mask = mdoprf_p2.eval(comm, 1, &[instruction.address])?;
-            bits_to_u64(mask.pop().unwrap())
-        };
-
-        // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
-        let dpf_key_2: SPDPF::Key = {
-            let fut = comm.receive(PARTY_1)?;
-            fut.get()?
-        };
-
-        // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
-        //    so, where it is
-        let (flag_share, location_share) = {
-            let mut flag_share = F::ZERO;
-            let mut location_share = F::ZERO;
-            let mut j_as_field_element = F::ZERO;
-            for j in 0..self.address_tag_list.len() {
-                let dpf_value_j =
-                    SPDPF::evaluate_at(&dpf_key_2, self.address_tag_list[j] ^ address_tag_mask);
-                flag_share += dpf_value_j;
-                location_share += j_as_field_element * dpf_value_j;
-                j_as_field_element += F::ONE;
-            }
-            (flag_share, location_share)
-        };
-
-        // 4. Compute <loc> = if <flag> { <loc> } else { access_counter - 1 }
-        let location_share = SelectProtocol::select(comm, flag_share, location_share, F::ZERO)?;
-
-        // 5. Reshare <flag> among all three parties
-        let flag_share = {
-            let fut_1_2 = comm.receive::<F>(PARTY_1)?;
-            flag_share - fut_1_2.get()?
-        };
-
-        // 6. Read the value <val> from the stash (if <flag>) or read a zero value
-        let value_share = stash_read_value::<C, F, SPDPF>(
+        let value_share = SelectProtocol::select(
             comm,
-            self.access_counter,
-            location_share,
-            &self.stash_values_share,
+            instruction.operation,
+            instruction.value - previous_value_share,
+            F::ZERO,
         )?;
-
-        self.state = State::AwaitingWrite;
-        Ok(StashStateShare {
-            flag: flag_share,
-            location: location_share,
-            value: value_share,
-        })
-    }
-
-    fn write<C: AbstractCommunicator>(
-        &mut self,
-        comm: &mut C,
-        instruction: InstructionShare<F>,
-        stash_state: StashStateShare<F>,
-        db_address_share: F,
-        db_value_share: F,
-    ) -> Result<(), Error> {
-        assert_eq!(self.state, State::AwaitingWrite);
-        assert!(self.access_counter < self.stash_size);
-
-        // 1. Compute tag y := PRF(k, <db_adr>) such that P2, P3 obtain y and append y to the tag
-        //    list.
-        let address_tag: u64 = {
-            let doprf_p2 = self.doprf_party_2.as_mut().unwrap();
-            // for now do preprocessing on the fly
-            doprf_p2.preprocess(comm, 1)?;
-            let fut_3_2 = comm.receive(PARTY_3)?;
-            doprf_p2.eval(comm, 1, &[db_address_share])?;
-            fut_3_2.get()?
-        };
-        self.address_tag_list.push(address_tag);
-
-        // 2. Insert new triple (<db_adr>, <db_val>, <db_val> into stash.
-        self.stash_addresses_share.push(db_address_share);
-        self.stash_values_share.push(db_value_share);
-        self.stash_old_values_share.push(db_value_share);
-
-        // 3. Update stash
-        // - if I.op = write, we want to write I.val to index loc
-        // - if I.op = read, we need to write db_val to index c
-        //   (since that has been done already in step 2, it is essentially a no-op)
-        {
-            let previous_value_share =
-                SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
-            let location_share =
-                SelectProtocol::select(comm, instruction.operation, stash_state.location, -F::ONE)?;
-            let value_share = SelectProtocol::select(
-                comm,
-                instruction.operation,
-                instruction.value - previous_value_share,
-                F::ZERO,
-            )?;
-            // eprintln!(
-            //     "P{}: prev_val = {:?}, loc = {:?}, upd = {:?}, vals = {:?}",
-            //     comm.get_my_id() + 1,
-            //     previous_value_share,
-            //     location_share,
-            //     value_share,
-            //     self.stash_values_share
-            // );
-            stash_write_value::<C, F, SPDPF>(
-                comm,
-                self.access_counter,
-                location_share,
-                value_share,
-                &mut self.stash_values_share,
-            )?;
-            // eprintln!(
-            //     "P{}: prev_val = {:?}, loc = {:?}, upd = {:?}, vals = {:?}",
-            //     comm.get_my_id() + 1,
-            //     previous_value_share,
-            //     location_share,
-            //     value_share,
-            //     self.stash_values_share
-            // );
-        }
-
-        self.access_counter += 1;
-        self.state = if self.access_counter == self.stash_size {
-            State::AccessesExhausted
-        } else {
-            State::AwaitingRead
-        };
-        Ok(())
-    }
-
-    fn get_stash_share(&self) -> (&[F], &[F], &[F]) {
-        (
-            &self.stash_addresses_share,
-            &self.stash_values_share,
-            &self.stash_old_values_share,
-        )
-    }
-}
-
-pub struct StashParty3<F, SPDPF>
-where
-    F: PrimeField + LegendreSymbol + Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
-{
-    stash_size: usize,
-    access_counter: usize,
-    state: State,
-    stash_addresses_share: Vec<F>,
-    stash_values_share: Vec<F>,
-    stash_old_values_share: Vec<F>,
-    address_tag_list: Vec<u64>,
-    doprf_party_3: Option<DOPrfParty3<F>>,
-    masked_doprf_party_3: Option<MaskedDOPrfParty3<F>>,
-    _phantom: PhantomData<SPDPF>,
-}
-
-impl<F, SPDPF> StashParty3<F, SPDPF>
-where
-    F: PrimeField + LegendreSymbol + Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
-{
-    pub fn new(stash_size: usize) -> Self {
-        assert!(stash_size > 0);
-        assert!(compute_stash_prf_output_bitsize(stash_size) <= 64);
-
-        Self {
-            stash_size,
-            access_counter: 0,
-            state: State::New,
-            stash_addresses_share: Vec::with_capacity(stash_size),
-            stash_values_share: Vec::with_capacity(stash_size),
-            stash_old_values_share: Vec::with_capacity(stash_size),
-            address_tag_list: Vec::with_capacity(stash_size),
-            doprf_party_3: None,
-            masked_doprf_party_3: None,
-            _phantom: PhantomData,
-        }
-    }
-}
-
-impl<F, SPDPF> Stash<F> for StashParty3<F, SPDPF>
-where
-    F: PrimeField + LegendreSymbol + Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
-    SPDPF::Key: Serializable,
-{
-    fn get_party_id(&self) -> usize {
-        3
-    }
-
-    fn get_stash_size(&self) -> usize {
-        self.stash_size
-    }
-
-    fn get_access_counter(&self) -> usize {
-        self.access_counter
-    }
-
-    fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
-        assert_eq!(self.state, State::New);
-
-        let prf_output_bitsize = compute_stash_prf_output_bitsize(self.stash_size);
-        self.doprf_party_3 = Some(DOPrfParty3::new(prf_output_bitsize));
-        self.masked_doprf_party_3 = Some(MaskedDOPrfParty3::new(prf_output_bitsize));
-
-        // run DOPRF initilization
-        {
-            let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
-            doprf_p3.init(comm)?;
-            let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
-            mdoprf_p3.init(comm)?;
-        }
-
-        // panic!("not implemented");
-        self.state = State::AwaitingRead;
-        Ok(())
-    }
-
-    fn read<C: AbstractCommunicator>(
-        &mut self,
-        comm: &mut C,
-        instruction: InstructionShare<F>,
-    ) -> Result<StashStateShare<F>, Error> {
-        assert_eq!(self.state, State::AwaitingRead);
-        assert!(self.access_counter < self.stash_size);
-
-        // 0. If the stash is empty, we are done
-        if self.access_counter == 0 {
-            self.state = State::AwaitingWrite;
-            return Ok(StashStateShare {
-                flag: F::ZERO,
-                location: F::ZERO,
-                value: F::ZERO,
-            });
-        }
-
-        // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
-        let address_tag_mask = {
-            let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
-
-            // for now do preprocessing on the fly
-            mdoprf_p3.preprocess(comm, 1)?;
-            let mut mask = mdoprf_p3.eval(comm, 1, &[instruction.address])?;
-            bits_to_u64(mask.pop().unwrap())
-        };
-
-        // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
-        let dpf_key_3: SPDPF::Key = {
-            let fut = comm.receive(PARTY_1)?;
-            fut.get()?
-        };
-
-        // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
-        //    so, where it is
-        let (flag_share, location_share) = {
-            let mut flag_share = F::ZERO;
-            let mut location_share = F::ZERO;
-            let mut j_as_field_element = F::ZERO;
-            for j in 0..self.address_tag_list.len() {
-                let dpf_value_j =
-                    SPDPF::evaluate_at(&dpf_key_3, self.address_tag_list[j] ^ address_tag_mask);
-                flag_share += dpf_value_j;
-                location_share += j_as_field_element * dpf_value_j;
-                j_as_field_element += F::ONE;
-            }
-            (flag_share, location_share)
-        };
-
-        // 4. Compute <loc> = if <flag> { <loc> } else { access_counter - 1 }
-        let location_share = SelectProtocol::select(comm, flag_share, location_share, F::ZERO)?;
-
-        // 5. Reshare <flag> among all three parties (nothing to do for P3)
-
-        // 6. Read the value <val> from the stash (if <flag>) or read a zero value
-        let value_share = stash_read_value::<C, F, SPDPF>(
+        stash_write_value::<C, F, SPDPF>(
             comm,
             self.access_counter,
             location_share,
-            &self.stash_values_share,
+            value_share,
+            &mut self.stash_values_share,
         )?;
 
-        self.state = State::AwaitingWrite;
-        Ok(StashStateShare {
-            flag: flag_share,
-            location: location_share,
-            value: value_share,
-        })
-    }
-
-    fn write<C: AbstractCommunicator>(
-        &mut self,
-        comm: &mut C,
-        instruction: InstructionShare<F>,
-        stash_state: StashStateShare<F>,
-        db_address_share: F,
-        db_value_share: F,
-    ) -> Result<(), Error> {
-        assert_eq!(self.state, State::AwaitingWrite);
-        assert!(self.access_counter < self.stash_size);
-
-        // 1. Compute tag y := PRF(k, <db_adr>) such that P2, P3 obtain y and append y to the tag
-        //    list.
-        let address_tag: u64 = {
-            let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
-            // for now do preprocessing on the fly
-            doprf_p3.preprocess(comm, 1)?;
-            let mut tag = doprf_p3.eval(comm, 1, &[db_address_share])?;
-            let tag = bits_to_u64(tag.pop().unwrap());
-            comm.send(PARTY_2, tag)?;
-            tag
-        };
-        self.address_tag_list.push(address_tag);
-
-        // 2. Insert new triple (<db_adr>, <db_val>, <db_val> into stash.
-        self.stash_addresses_share.push(db_address_share);
-        self.stash_values_share.push(db_value_share);
-        self.stash_old_values_share.push(db_value_share);
-
-        // 3. Update stash
-        // - if I.op = write, we want to write I.val to index loc
-        //      if flag = true then loc < c
-        //          -> previous value in st.val
-        //      if flag = false then loc = c
-        //          -> previous value is db_val
-        // - if I.op = read, we need don't need to write anything
-        //   (but still touch every entry)
-        {
-            let previous_value_share =
-                SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
-            let location_share =
-                SelectProtocol::select(comm, instruction.operation, stash_state.location, -F::ONE)?;
-            let value_share = SelectProtocol::select(
-                comm,
-                instruction.operation,
-                instruction.value - previous_value_share,
-                F::ZERO,
-            )?;
-            // eprintln!(
-            //     "P{}: prev_val = {:?}, loc = {:?}, upd = {:?}, vals = {:?}",
-            //     comm.get_my_id() + 1,
-            //     previous_value_share,
-            //     location_share,
-            //     value_share,
-            //     self.stash_values_share
-            // );
-            stash_write_value::<C, F, SPDPF>(
-                comm,
-                self.access_counter,
-                location_share,
-                value_share,
-                &mut self.stash_values_share,
-            )?;
-            // eprintln!(
-            //     "P{}: prev_val = {:?}, loc = {:?}, upd = {:?}, vals = {:?}",
-            //     comm.get_my_id() + 1,
-            //     previous_value_share,
-            //     location_share,
-            //     value_share,
-            //     self.stash_values_share
-            // );
-        }
-
         self.access_counter += 1;
         self.state = if self.access_counter == self.stash_size {
             State::AccessesExhausted
         } else {
             State::AwaitingRead
         };
-
         Ok(())
     }
 
@@ -967,12 +594,12 @@ mod tests {
         let stash_size = 128;
         let mut num_accesses = 0;
 
-        let party_1 = StashParty1::<Fp, SPDPF>::new(stash_size);
-        let party_2 = StashParty2::<Fp, SPDPF>::new(stash_size);
-        let party_3 = StashParty3::<Fp, SPDPF>::new(stash_size);
-        assert_eq!(party_1.get_party_id(), 1);
-        assert_eq!(party_2.get_party_id(), 2);
-        assert_eq!(party_3.get_party_id(), 3);
+        let party_1 = StashProtocol::<Fp, SPDPF>::new(PARTY_1, stash_size);
+        let party_2 = StashProtocol::<Fp, SPDPF>::new(PARTY_2, stash_size);
+        let party_3 = StashProtocol::<Fp, SPDPF>::new(PARTY_3, stash_size);
+        assert_eq!(party_1.get_party_id(), PARTY_1);
+        assert_eq!(party_2.get_party_id(), PARTY_2);
+        assert_eq!(party_3.get_party_id(), PARTY_3);
         assert_eq!(party_1.get_stash_size(), stash_size);
         assert_eq!(party_2.get_stash_size(), stash_size);
         assert_eq!(party_3.get_stash_size(), stash_size);