|
@@ -7,6 +7,7 @@ use communicator::{AbstractCommunicator, Fut, Serializable};
|
|
use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
|
|
use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
|
|
use ff::PrimeField;
|
|
use ff::PrimeField;
|
|
use itertools::{izip, Itertools};
|
|
use itertools::{izip, Itertools};
|
|
|
|
+use rand::thread_rng;
|
|
use std::iter::repeat;
|
|
use std::iter::repeat;
|
|
use std::marker::PhantomData;
|
|
use std::marker::PhantomData;
|
|
use utils::field::{FromPrf, LegendreSymbol};
|
|
use utils::field::{FromPrf, LegendreSymbol};
|
|
@@ -28,7 +29,11 @@ where
|
|
instruction: InstructionShare<F>,
|
|
instruction: InstructionShare<F>,
|
|
) -> Result<F, Error>;
|
|
) -> Result<F, Error>;
|
|
|
|
|
|
- fn get_db<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<Vec<F>, Error>;
|
|
|
|
|
|
+ fn get_db<C: AbstractCommunicator>(
|
|
|
|
+ &mut self,
|
|
|
|
+ comm: &mut C,
|
|
|
|
+ rerandomize_shares: bool,
|
|
|
|
+ ) -> Result<Vec<F>, Error>;
|
|
}
|
|
}
|
|
|
|
|
|
const PARTY_1: usize = 0;
|
|
const PARTY_1: usize = 0;
|
|
@@ -459,13 +464,38 @@ where
|
|
Ok(read_value)
|
|
Ok(read_value)
|
|
}
|
|
}
|
|
|
|
|
|
- fn get_db<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<Vec<F>, Error> {
|
|
|
|
|
|
+ fn get_db<C: AbstractCommunicator>(
|
|
|
|
+ &mut self,
|
|
|
|
+ comm: &mut C,
|
|
|
|
+ rerandomize_shares: bool,
|
|
|
|
+ ) -> Result<Vec<F>, Error> {
|
|
assert!(self.is_initialized);
|
|
assert!(self.is_initialized);
|
|
|
|
|
|
if self.get_access_counter() > 0 {
|
|
if self.get_access_counter() > 0 {
|
|
self.refresh(comm)?;
|
|
self.refresh(comm)?;
|
|
}
|
|
}
|
|
- return Ok(self.memory_share[0..1 << self.log_db_size].to_vec());
|
|
|
|
|
|
+
|
|
|
|
+ if rerandomize_shares {
|
|
|
|
+ let fut = comm.receive_previous()?;
|
|
|
|
+ let mut rng = thread_rng();
|
|
|
|
+ let mask: Vec<_> = (0..1 << self.log_db_size)
|
|
|
|
+ .map(|_| F::random(&mut rng))
|
|
|
|
+ .collect();
|
|
|
|
+ let mut masked_share: Vec<_> = self.memory_share[0..1 << self.log_db_size]
|
|
|
|
+ .iter()
|
|
|
|
+ .zip(mask.iter())
|
|
|
|
+ .map(|(&x, &m)| x + m)
|
|
|
|
+ .collect();
|
|
|
|
+ comm.send_next(mask)?;
|
|
|
|
+ let mask_prev: Vec<F> = fut.get()?;
|
|
|
|
+ masked_share
|
|
|
|
+ .iter_mut()
|
|
|
|
+ .zip(mask_prev.iter())
|
|
|
|
+ .for_each(|(x, &mp)| *x -= mp);
|
|
|
|
+ Ok(masked_share)
|
|
|
|
+ } else {
|
|
|
|
+ Ok(self.memory_share[0..1 << self.log_db_size].to_vec())
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -528,7 +558,7 @@ mod tests {
|
|
thread::Builder::new()
|
|
thread::Builder::new()
|
|
.name(format!("Party {}", doram_party.get_party_id()))
|
|
.name(format!("Party {}", doram_party.get_party_id()))
|
|
.spawn(move || {
|
|
.spawn(move || {
|
|
- let output = doram_party.get_db(&mut comm).unwrap();
|
|
|
|
|
|
+ let output = doram_party.get_db(&mut comm, false).unwrap();
|
|
(doram_party, comm, output)
|
|
(doram_party, comm, output)
|
|
})
|
|
})
|
|
.unwrap()
|
|
.unwrap()
|