|
@@ -25,7 +25,7 @@ where
|
|
|
{
|
|
|
fn get_party_id(&self) -> usize;
|
|
|
|
|
|
- fn get_log_db_size(&self) -> u32;
|
|
|
+ fn get_db_size(&self) -> usize;
|
|
|
|
|
|
fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error>;
|
|
|
|
|
@@ -200,7 +200,7 @@ where
|
|
|
SPDPF::Key: Serializable,
|
|
|
{
|
|
|
party_id: usize,
|
|
|
- log_db_size: u32,
|
|
|
+ db_size: usize,
|
|
|
stash_size: usize,
|
|
|
memory_size: usize,
|
|
|
memory_share: Vec<F>,
|
|
@@ -245,16 +245,15 @@ where
|
|
|
SPDPF: SinglePointDpf<Value = F> + Sync,
|
|
|
SPDPF::Key: Serializable + Sync,
|
|
|
{
|
|
|
- pub fn new(party_id: usize, log_db_size: u32) -> Self {
|
|
|
+ pub fn new(party_id: usize, db_size: usize) -> Self {
|
|
|
assert!(party_id < 3);
|
|
|
- assert_eq!(log_db_size & 1, 0);
|
|
|
- let stash_size = 1 << (log_db_size / 2);
|
|
|
- let memory_size = (1 << log_db_size) + stash_size;
|
|
|
+ let stash_size = (db_size as f64).sqrt().round() as usize;
|
|
|
+ let memory_size = db_size + stash_size;
|
|
|
let prf_output_bitsize = compute_oram_prf_output_bitsize(memory_size);
|
|
|
|
|
|
Self {
|
|
|
party_id,
|
|
|
- log_db_size,
|
|
|
+ db_size,
|
|
|
stash_size,
|
|
|
memory_size,
|
|
|
memory_share: Default::default(),
|
|
@@ -889,7 +888,7 @@ where
|
|
|
|
|
|
// 2. If the value was found in a stash, we read from the dummy address
|
|
|
let dummy_address_share = match self.party_id {
|
|
|
- PARTY_1 => F::from_u128((1 << self.log_db_size) + self.get_access_counter() as u128),
|
|
|
+ PARTY_1 => F::from_u128((self.db_size + self.get_access_counter()) as u128),
|
|
|
_ => F::ZERO,
|
|
|
};
|
|
|
let db_address_share = self.select_party.as_mut().unwrap().select(
|
|
@@ -978,13 +977,12 @@ where
|
|
|
self.party_id
|
|
|
}
|
|
|
|
|
|
- fn get_log_db_size(&self) -> u32 {
|
|
|
- self.log_db_size
|
|
|
+ fn get_db_size(&self) -> usize {
|
|
|
+ self.db_size
|
|
|
}
|
|
|
|
|
|
fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error> {
|
|
|
- let db_size = 1 << self.log_db_size;
|
|
|
- assert_eq!(db_share.len(), db_size);
|
|
|
+ assert_eq!(db_share.len(), self.db_size);
|
|
|
|
|
|
// 1. Initialize memory share with given db share and pad with dummy values
|
|
|
self.memory_share = Vec::with_capacity(self.memory_size);
|
|
@@ -1031,10 +1029,8 @@ where
|
|
|
if rerandomize_shares {
|
|
|
let fut = comm.receive_previous()?;
|
|
|
let mut rng = thread_rng();
|
|
|
- let mask: Vec<_> = (0..1 << self.log_db_size)
|
|
|
- .map(|_| F::random(&mut rng))
|
|
|
- .collect();
|
|
|
- let mut masked_share: Vec<_> = self.memory_share[0..1 << self.log_db_size]
|
|
|
+ let mask: Vec<_> = (0..self.db_size).map(|_| F::random(&mut rng)).collect();
|
|
|
+ let mut masked_share: Vec<_> = self.memory_share[0..self.db_size]
|
|
|
.iter()
|
|
|
.zip(mask.iter())
|
|
|
.map(|(&x, &m)| x + m)
|
|
@@ -1047,7 +1043,7 @@ where
|
|
|
.for_each(|(x, &mp)| *x -= mp);
|
|
|
Ok(masked_share)
|
|
|
} else {
|
|
|
- Ok(self.memory_share[0..1 << self.log_db_size].to_vec())
|
|
|
+ Ok(self.memory_share[0..self.db_size].to_vec())
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1122,24 +1118,54 @@ mod tests {
|
|
|
.unwrap()
|
|
|
}
|
|
|
|
|
|
- #[test]
|
|
|
- fn test_oram() {
|
|
|
- type SPDPF = DummySpDpf<Fp>;
|
|
|
- type MPDPF = DummyMpDpf<Fp>;
|
|
|
+ fn mk_read(address: u128, value: u128) -> InstructionShare<Fp> {
|
|
|
+ InstructionShare {
|
|
|
+ operation: Operation::Read.encode(),
|
|
|
+ address: Fp::from_u128(address),
|
|
|
+ value: Fp::from_u128(value),
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- let log_db_size = 4;
|
|
|
- let db_size = 1 << log_db_size;
|
|
|
- let stash_size = 1 << (4 >> 1);
|
|
|
+ fn mk_write(address: u128, value: u128) -> InstructionShare<Fp> {
|
|
|
+ InstructionShare {
|
|
|
+ operation: Operation::Write.encode(),
|
|
|
+ address: Fp::from_u128(address),
|
|
|
+ value: Fp::from_u128(value),
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- let party_1 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_1, log_db_size);
|
|
|
- let party_2 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_2, log_db_size);
|
|
|
- let party_3 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_3, log_db_size);
|
|
|
+ const INST_ZERO_SHARE: InstructionShare<Fp> = InstructionShare {
|
|
|
+ operation: Fp::ZERO,
|
|
|
+ address: Fp::ZERO,
|
|
|
+ value: Fp::ZERO,
|
|
|
+ };
|
|
|
+
|
|
|
+ type SPDPF = DummySpDpf<Fp>;
|
|
|
+ type MPDPF = DummyMpDpf<Fp>;
|
|
|
+
|
|
|
+ fn setup(
|
|
|
+ db_size: usize,
|
|
|
+ ) -> (
|
|
|
+ (
|
|
|
+ impl DistributedOram<Fp>,
|
|
|
+ impl DistributedOram<Fp>,
|
|
|
+ impl DistributedOram<Fp>,
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ impl AbstractCommunicator,
|
|
|
+ impl AbstractCommunicator,
|
|
|
+ impl AbstractCommunicator,
|
|
|
+ ),
|
|
|
+ ) {
|
|
|
+ let party_1 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_1, db_size);
|
|
|
+ let party_2 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_2, db_size);
|
|
|
+ let party_3 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_3, db_size);
|
|
|
assert_eq!(party_1.get_party_id(), PARTY_1);
|
|
|
assert_eq!(party_2.get_party_id(), PARTY_2);
|
|
|
assert_eq!(party_3.get_party_id(), PARTY_3);
|
|
|
- assert_eq!(party_1.get_log_db_size(), log_db_size);
|
|
|
- assert_eq!(party_2.get_log_db_size(), log_db_size);
|
|
|
- assert_eq!(party_3.get_log_db_size(), log_db_size);
|
|
|
+ assert_eq!(party_1.get_db_size(), db_size);
|
|
|
+ assert_eq!(party_2.get_db_size(), db_size);
|
|
|
+ assert_eq!(party_3.get_db_size(), db_size);
|
|
|
|
|
|
let (comm_3, comm_2, comm_1) = {
|
|
|
let mut comms = make_unix_communicators(3);
|
|
@@ -1158,31 +1184,19 @@ mod tests {
|
|
|
let h1 = run_init(party_1, comm_1, &db_share_1);
|
|
|
let h2 = run_init(party_2, comm_2, &db_share_2);
|
|
|
let h3 = run_init(party_3, comm_3, &db_share_3);
|
|
|
- let (mut party_1, mut comm_1) = h1.join().unwrap();
|
|
|
- let (mut party_2, mut comm_2) = h2.join().unwrap();
|
|
|
- let (mut party_3, mut comm_3) = h3.join().unwrap();
|
|
|
-
|
|
|
- fn mk_read(address: u128, value: u128) -> InstructionShare<Fp> {
|
|
|
- InstructionShare {
|
|
|
- operation: Operation::Read.encode(),
|
|
|
- address: Fp::from_u128(address),
|
|
|
- value: Fp::from_u128(value),
|
|
|
- }
|
|
|
- }
|
|
|
+ let (party_1, comm_1) = h1.join().unwrap();
|
|
|
+ let (party_2, comm_2) = h2.join().unwrap();
|
|
|
+ let (party_3, comm_3) = h3.join().unwrap();
|
|
|
+ ((party_1, party_2, party_3), (comm_1, comm_2, comm_3))
|
|
|
+ }
|
|
|
|
|
|
- fn mk_write(address: u128, value: u128) -> InstructionShare<Fp> {
|
|
|
- InstructionShare {
|
|
|
- operation: Operation::Write.encode(),
|
|
|
- address: Fp::from_u128(address),
|
|
|
- value: Fp::from_u128(value),
|
|
|
- }
|
|
|
- }
|
|
|
+ #[test]
|
|
|
+ fn test_oram_even_exp() {
|
|
|
+ let db_size = 1 << 4;
|
|
|
+ let stash_size = (db_size as f64).sqrt().ceil() as usize;
|
|
|
|
|
|
- let inst_zero_share = InstructionShare {
|
|
|
- operation: Fp::ZERO,
|
|
|
- address: Fp::ZERO,
|
|
|
- value: Fp::ZERO,
|
|
|
- };
|
|
|
+ let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
|
|
|
+ setup(db_size);
|
|
|
|
|
|
let number_cycles = 8;
|
|
|
let instructions = [
|
|
@@ -1272,45 +1286,194 @@ mod tests {
|
|
|
],
|
|
|
];
|
|
|
|
|
|
- // fn print_stash(
|
|
|
- // p1: &DistributedOramProtocol<Fp, MPDPF, SPDPF>,
|
|
|
- // p2: &DistributedOramProtocol<Fp, MPDPF, SPDPF>,
|
|
|
- // p3: &DistributedOramProtocol<Fp, MPDPF, SPDPF>,
|
|
|
- // ) {
|
|
|
- // let st1 = p1.get_stash().get_stash_share();
|
|
|
- // let st2 = p2.get_stash().get_stash_share();
|
|
|
- // let st3 = p3.get_stash().get_stash_share();
|
|
|
- // let adrs: Vec<_> = izip!(st1.0.iter(), st2.0.iter(), st3.0.iter(),)
|
|
|
- // .map(|(&x, &y, &z)| x + y + z)
|
|
|
- // .collect();
|
|
|
- // let vals: Vec<_> = izip!(st1.1.iter(), st2.1.iter(), st3.1.iter(),)
|
|
|
- // .map(|(&x, &y, &z)| x + y + z)
|
|
|
- // .collect();
|
|
|
- // let olds: Vec<_> = izip!(st1.2.iter(), st2.2.iter(), st3.2.iter(),)
|
|
|
- // .map(|(&x, &y, &z)| x + y + z)
|
|
|
- // .collect();
|
|
|
- // eprintln!("STASH: =======================");
|
|
|
- // eprintln!("adrs = {adrs:?}");
|
|
|
- // eprintln!("vals = {vals:?}");
|
|
|
- // eprintln!("olds = {olds:?}");
|
|
|
- // eprintln!("==============================");
|
|
|
- // }
|
|
|
+ for i in 0..number_cycles {
|
|
|
+ for j in 0..stash_size {
|
|
|
+ let inst = instructions[i * stash_size + j];
|
|
|
+ let expected_value = expected_values[i * stash_size + j];
|
|
|
+ let h1 = run_access(party_1, comm_1, inst);
|
|
|
+ let h2 = run_access(party_2, comm_2, INST_ZERO_SHARE);
|
|
|
+ let h3 = run_access(party_3, comm_3, INST_ZERO_SHARE);
|
|
|
+ let (p1, c1, value_1) = h1.join().unwrap();
|
|
|
+ let (p2, c2, value_2) = h2.join().unwrap();
|
|
|
+ let (p3, c3, value_3) = h3.join().unwrap();
|
|
|
+ (party_1, party_2, party_3) = (p1, p2, p3);
|
|
|
+ (comm_1, comm_2, comm_3) = (c1, c2, c3);
|
|
|
+ assert_eq!(value_1 + value_2 + value_3, expected_value);
|
|
|
+ }
|
|
|
+ let h1 = run_get_db(party_1, comm_1);
|
|
|
+ let h2 = run_get_db(party_2, comm_2);
|
|
|
+ let h3 = run_get_db(party_3, comm_3);
|
|
|
+ let (p1, c1, db_share_1) = h1.join().unwrap();
|
|
|
+ let (p2, c2, db_share_2) = h2.join().unwrap();
|
|
|
+ let (p3, c3, db_share_3) = h3.join().unwrap();
|
|
|
+ (party_1, party_2, party_3) = (p1, p2, p3);
|
|
|
+ (comm_1, comm_2, comm_3) = (c1, c2, c3);
|
|
|
+ let db: Vec<_> = izip!(db_share_1.iter(), db_share_2.iter(), db_share_3.iter())
|
|
|
+ .map(|(&x, &y, &z)| x + y + z)
|
|
|
+ .collect();
|
|
|
+ for k in 0..db_size {
|
|
|
+ assert_eq!(db[k], Fp::from_u128(expected_db_contents[i][k]));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_oram_odd_exp() {
|
|
|
+ let db_size = 1 << 5;
|
|
|
+ let stash_size = (db_size as f64).sqrt().ceil() as usize;
|
|
|
+
|
|
|
+ let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
|
|
|
+ setup(db_size);
|
|
|
+
|
|
|
+ let number_cycles = 8;
|
|
|
+ let instructions = [
|
|
|
+ mk_write(26, 64),
|
|
|
+ mk_read(4, 141),
|
|
|
+ mk_write(25, 701),
|
|
|
+ mk_write(29, 927),
|
|
|
+ mk_read(28, 132),
|
|
|
+ mk_write(30, 990),
|
|
|
+ mk_write(23, 167),
|
|
|
+ mk_write(31, 347),
|
|
|
+ mk_write(26, 1020),
|
|
|
+ mk_write(20, 893),
|
|
|
+ mk_read(26, 805),
|
|
|
+ mk_write(3, 949),
|
|
|
+ mk_read(10, 195),
|
|
|
+ mk_write(29, 767),
|
|
|
+ mk_read(28, 107),
|
|
|
+ mk_write(30, 426),
|
|
|
+ mk_write(22, 605),
|
|
|
+ mk_write(0, 171),
|
|
|
+ mk_write(4, 210),
|
|
|
+ mk_read(12, 737),
|
|
|
+ mk_write(19, 977),
|
|
|
+ mk_read(16, 143),
|
|
|
+ mk_write(29, 775),
|
|
|
+ mk_read(28, 34),
|
|
|
+ mk_write(27, 95),
|
|
|
+ mk_write(30, 130),
|
|
|
+ mk_read(8, 89),
|
|
|
+ mk_read(23, 132),
|
|
|
+ mk_read(21, 12),
|
|
|
+ mk_read(4, 675),
|
|
|
+ mk_write(28, 225),
|
|
|
+ mk_write(5, 978),
|
|
|
+ mk_write(2, 833),
|
|
|
+ mk_write(1, 456),
|
|
|
+ mk_write(17, 921),
|
|
|
+ mk_read(26, 293),
|
|
|
+ mk_write(5, 474),
|
|
|
+ mk_write(7, 981),
|
|
|
+ mk_read(19, 189),
|
|
|
+ mk_write(1, 248),
|
|
|
+ mk_read(27, 573),
|
|
|
+ mk_read(17, 142),
|
|
|
+ mk_read(29, 945),
|
|
|
+ mk_read(16, 902),
|
|
|
+ mk_write(16, 799),
|
|
|
+ mk_read(28, 864),
|
|
|
+ mk_write(6, 986),
|
|
|
+ mk_read(2, 201),
|
|
|
+ ];
|
|
|
+ let expected_values = [
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(64),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(1020),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(927),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(990),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(767),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(426),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(167),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(210),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(1020),
|
|
|
+ Fp::from_u128(978),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(977),
|
|
|
+ Fp::from_u128(456),
|
|
|
+ Fp::from_u128(95),
|
|
|
+ Fp::from_u128(921),
|
|
|
+ Fp::from_u128(775),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(225),
|
|
|
+ Fp::from_u128(0),
|
|
|
+ Fp::from_u128(833),
|
|
|
+ ];
|
|
|
+ let expected_db_contents = [
|
|
|
+ [
|
|
|
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 701, 64,
|
|
|
+ 0, 0, 927, 990, 0,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 0, 0, 0, 949, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, 0, 0, 167, 0,
|
|
|
+ 701, 1020, 0, 0, 927, 990, 347,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 171, 0, 0, 949, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, 0, 605, 167,
|
|
|
+ 0, 701, 1020, 0, 0, 767, 426, 347,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 171, 0, 0, 949, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 977, 893, 0, 605,
|
|
|
+ 167, 0, 701, 1020, 0, 0, 775, 426, 347,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 171, 0, 0, 949, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 977, 893, 0, 605,
|
|
|
+ 167, 0, 701, 1020, 95, 0, 775, 130, 347,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 171, 456, 833, 949, 210, 978, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 921, 0, 977, 893, 0,
|
|
|
+ 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 171, 248, 833, 949, 210, 474, 0, 981, 0, 0, 0, 0, 0, 0, 0, 0, 0, 921, 0, 977, 893,
|
|
|
+ 0, 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 171, 248, 833, 949, 210, 474, 986, 981, 0, 0, 0, 0, 0, 0, 0, 0, 799, 921, 0, 977,
|
|
|
+ 893, 0, 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
|
|
|
+ ],
|
|
|
+ ];
|
|
|
|
|
|
for i in 0..number_cycles {
|
|
|
for j in 0..stash_size {
|
|
|
let inst = instructions[i * stash_size + j];
|
|
|
- // eprintln!("Running {inst:?}");
|
|
|
let expected_value = expected_values[i * stash_size + j];
|
|
|
let h1 = run_access(party_1, comm_1, inst);
|
|
|
- let h2 = run_access(party_2, comm_2, inst_zero_share);
|
|
|
- let h3 = run_access(party_3, comm_3, inst_zero_share);
|
|
|
+ let h2 = run_access(party_2, comm_2, INST_ZERO_SHARE);
|
|
|
+ let h3 = run_access(party_3, comm_3, INST_ZERO_SHARE);
|
|
|
let (p1, c1, value_1) = h1.join().unwrap();
|
|
|
let (p2, c2, value_2) = h2.join().unwrap();
|
|
|
let (p3, c3, value_3) = h3.join().unwrap();
|
|
|
(party_1, party_2, party_3) = (p1, p2, p3);
|
|
|
(comm_1, comm_2, comm_3) = (c1, c2, c3);
|
|
|
assert_eq!(value_1 + value_2 + value_3, expected_value);
|
|
|
- // print_stash(&party_1, &party_2, &party_3);
|
|
|
}
|
|
|
let h1 = run_get_db(party_1, comm_1);
|
|
|
let h2 = run_get_db(party_2, comm_2);
|
|
@@ -1323,8 +1486,6 @@ mod tests {
|
|
|
let db: Vec<_> = izip!(db_share_1.iter(), db_share_2.iter(), db_share_3.iter())
|
|
|
.map(|(&x, &y, &z)| x + y + z)
|
|
|
.collect();
|
|
|
- // eprintln!("expected = {:#x?}", expected_db_contents[i]);
|
|
|
- // eprintln!("db = {:#?}", db);
|
|
|
for k in 0..db_size {
|
|
|
assert_eq!(db[k], Fp::from_u128(expected_db_contents[i][k]));
|
|
|
}
|