소스 검색

oram: allow arbitrary db sizes

Lennart Braun 1 년 전
부모
커밋
e31e134071
2개의 변경된 파일245개의 추가작업 그리고 96개의 파일을 삭제
  1. 2 14
      oram/examples/bench_doram.rs
  2. 243 82
      oram/src/oram.rs

+ 2 - 14
oram/examples/bench_doram.rs

@@ -31,7 +31,7 @@ struct Cli {
     #[arg(long, short = 'i', value_parser = clap::value_parser!(u32).range(0..3))]
     pub party_id: u32,
     /// Log2 of the database size, must be even
-    #[arg(long, short = 's', value_parser = parse_log_db_size)]
+    #[arg(long, short = 's', value_parser = clap::value_parser!(u32).range(4..))]
     pub log_db_size: u32,
     /// Use preprocessing
     #[arg(long)]
@@ -93,18 +93,6 @@ impl BenchmarkResults {
     }
 }
 
-fn parse_log_db_size(s: &str) -> Result<u32, Box<dyn std::error::Error + Send + Sync + 'static>> {
-    let log_db_size: u32 = s.parse()?;
-    if log_db_size & 1 == 1 {
-        return Err(clap::Error::raw(
-            clap::error::ErrorKind::InvalidValue,
-            format!("log_db_size must be even"),
-        )
-        .into());
-    }
-    Ok(log_db_size)
-}
-
 fn parse_connect(
     s: &str,
 ) -> Result<(usize, String, u16), Box<dyn std::error::Error + Send + Sync + 'static>> {
@@ -167,7 +155,7 @@ fn main() {
         }
     };
 
-    let mut doram = DOram::new(cli.party_id as usize, cli.log_db_size);
+    let mut doram = DOram::new(cli.party_id as usize, 1 << cli.log_db_size);
 
     let db_size = 1 << cli.log_db_size;
     let db_share: Vec<_> = vec![Fp::ZERO; db_size];

+ 243 - 82
oram/src/oram.rs

@@ -25,7 +25,7 @@ where
 {
     fn get_party_id(&self) -> usize;
 
-    fn get_log_db_size(&self) -> u32;
+    fn get_db_size(&self) -> usize;
 
     fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error>;
 
@@ -200,7 +200,7 @@ where
     SPDPF::Key: Serializable,
 {
     party_id: usize,
-    log_db_size: u32,
+    db_size: usize,
     stash_size: usize,
     memory_size: usize,
     memory_share: Vec<F>,
@@ -245,16 +245,15 @@ where
     SPDPF: SinglePointDpf<Value = F> + Sync,
     SPDPF::Key: Serializable + Sync,
 {
-    pub fn new(party_id: usize, log_db_size: u32) -> Self {
+    pub fn new(party_id: usize, db_size: usize) -> Self {
         assert!(party_id < 3);
-        assert_eq!(log_db_size & 1, 0);
-        let stash_size = 1 << (log_db_size / 2);
-        let memory_size = (1 << log_db_size) + stash_size;
+        let stash_size = (db_size as f64).sqrt().round() as usize;
+        let memory_size = db_size + stash_size;
         let prf_output_bitsize = compute_oram_prf_output_bitsize(memory_size);
 
         Self {
             party_id,
-            log_db_size,
+            db_size,
             stash_size,
             memory_size,
             memory_share: Default::default(),
@@ -889,7 +888,7 @@ where
 
         // 2. If the value was found in a stash, we read from the dummy address
         let dummy_address_share = match self.party_id {
-            PARTY_1 => F::from_u128((1 << self.log_db_size) + self.get_access_counter() as u128),
+            PARTY_1 => F::from_u128((self.db_size + self.get_access_counter()) as u128),
             _ => F::ZERO,
         };
         let db_address_share = self.select_party.as_mut().unwrap().select(
@@ -978,13 +977,12 @@ where
         self.party_id
     }
 
-    fn get_log_db_size(&self) -> u32 {
-        self.log_db_size
+    fn get_db_size(&self) -> usize {
+        self.db_size
     }
 
     fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error> {
-        let db_size = 1 << self.log_db_size;
-        assert_eq!(db_share.len(), db_size);
+        assert_eq!(db_share.len(), self.db_size);
 
         // 1. Initialize memory share with given db share and pad with dummy values
         self.memory_share = Vec::with_capacity(self.memory_size);
@@ -1031,10 +1029,8 @@ where
         if rerandomize_shares {
             let fut = comm.receive_previous()?;
             let mut rng = thread_rng();
-            let mask: Vec<_> = (0..1 << self.log_db_size)
-                .map(|_| F::random(&mut rng))
-                .collect();
-            let mut masked_share: Vec<_> = self.memory_share[0..1 << self.log_db_size]
+            let mask: Vec<_> = (0..self.db_size).map(|_| F::random(&mut rng)).collect();
+            let mut masked_share: Vec<_> = self.memory_share[0..self.db_size]
                 .iter()
                 .zip(mask.iter())
                 .map(|(&x, &m)| x + m)
@@ -1047,7 +1043,7 @@ where
                 .for_each(|(x, &mp)| *x -= mp);
             Ok(masked_share)
         } else {
-            Ok(self.memory_share[0..1 << self.log_db_size].to_vec())
+            Ok(self.memory_share[0..self.db_size].to_vec())
         }
     }
 }
@@ -1122,24 +1118,54 @@ mod tests {
             .unwrap()
     }
 
-    #[test]
-    fn test_oram() {
-        type SPDPF = DummySpDpf<Fp>;
-        type MPDPF = DummyMpDpf<Fp>;
+    fn mk_read(address: u128, value: u128) -> InstructionShare<Fp> {
+        InstructionShare {
+            operation: Operation::Read.encode(),
+            address: Fp::from_u128(address),
+            value: Fp::from_u128(value),
+        }
+    }
 
-        let log_db_size = 4;
-        let db_size = 1 << log_db_size;
-        let stash_size = 1 << (4 >> 1);
+    fn mk_write(address: u128, value: u128) -> InstructionShare<Fp> {
+        InstructionShare {
+            operation: Operation::Write.encode(),
+            address: Fp::from_u128(address),
+            value: Fp::from_u128(value),
+        }
+    }
 
-        let party_1 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_1, log_db_size);
-        let party_2 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_2, log_db_size);
-        let party_3 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_3, log_db_size);
+    const INST_ZERO_SHARE: InstructionShare<Fp> = InstructionShare {
+        operation: Fp::ZERO,
+        address: Fp::ZERO,
+        value: Fp::ZERO,
+    };
+
+    type SPDPF = DummySpDpf<Fp>;
+    type MPDPF = DummyMpDpf<Fp>;
+
+    fn setup(
+        db_size: usize,
+    ) -> (
+        (
+            impl DistributedOram<Fp>,
+            impl DistributedOram<Fp>,
+            impl DistributedOram<Fp>,
+        ),
+        (
+            impl AbstractCommunicator,
+            impl AbstractCommunicator,
+            impl AbstractCommunicator,
+        ),
+    ) {
+        let party_1 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_1, db_size);
+        let party_2 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_2, db_size);
+        let party_3 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_3, db_size);
         assert_eq!(party_1.get_party_id(), PARTY_1);
         assert_eq!(party_2.get_party_id(), PARTY_2);
         assert_eq!(party_3.get_party_id(), PARTY_3);
-        assert_eq!(party_1.get_log_db_size(), log_db_size);
-        assert_eq!(party_2.get_log_db_size(), log_db_size);
-        assert_eq!(party_3.get_log_db_size(), log_db_size);
+        assert_eq!(party_1.get_db_size(), db_size);
+        assert_eq!(party_2.get_db_size(), db_size);
+        assert_eq!(party_3.get_db_size(), db_size);
 
         let (comm_3, comm_2, comm_1) = {
             let mut comms = make_unix_communicators(3);
@@ -1158,31 +1184,19 @@ mod tests {
         let h1 = run_init(party_1, comm_1, &db_share_1);
         let h2 = run_init(party_2, comm_2, &db_share_2);
         let h3 = run_init(party_3, comm_3, &db_share_3);
-        let (mut party_1, mut comm_1) = h1.join().unwrap();
-        let (mut party_2, mut comm_2) = h2.join().unwrap();
-        let (mut party_3, mut comm_3) = h3.join().unwrap();
-
-        fn mk_read(address: u128, value: u128) -> InstructionShare<Fp> {
-            InstructionShare {
-                operation: Operation::Read.encode(),
-                address: Fp::from_u128(address),
-                value: Fp::from_u128(value),
-            }
-        }
+        let (party_1, comm_1) = h1.join().unwrap();
+        let (party_2, comm_2) = h2.join().unwrap();
+        let (party_3, comm_3) = h3.join().unwrap();
+        ((party_1, party_2, party_3), (comm_1, comm_2, comm_3))
+    }
 
-        fn mk_write(address: u128, value: u128) -> InstructionShare<Fp> {
-            InstructionShare {
-                operation: Operation::Write.encode(),
-                address: Fp::from_u128(address),
-                value: Fp::from_u128(value),
-            }
-        }
+    #[test]
+    fn test_oram_even_exp() {
+        let db_size = 1 << 4;
+        let stash_size = (db_size as f64).sqrt().ceil() as usize;
 
-        let inst_zero_share = InstructionShare {
-            operation: Fp::ZERO,
-            address: Fp::ZERO,
-            value: Fp::ZERO,
-        };
+        let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
+            setup(db_size);
 
         let number_cycles = 8;
         let instructions = [
@@ -1272,45 +1286,194 @@ mod tests {
             ],
         ];
 
-        // fn print_stash(
-        //     p1: &DistributedOramProtocol<Fp, MPDPF, SPDPF>,
-        //     p2: &DistributedOramProtocol<Fp, MPDPF, SPDPF>,
-        //     p3: &DistributedOramProtocol<Fp, MPDPF, SPDPF>,
-        // ) {
-        //     let st1 = p1.get_stash().get_stash_share();
-        //     let st2 = p2.get_stash().get_stash_share();
-        //     let st3 = p3.get_stash().get_stash_share();
-        //     let adrs: Vec<_> = izip!(st1.0.iter(), st2.0.iter(), st3.0.iter(),)
-        //         .map(|(&x, &y, &z)| x + y + z)
-        //         .collect();
-        //     let vals: Vec<_> = izip!(st1.1.iter(), st2.1.iter(), st3.1.iter(),)
-        //         .map(|(&x, &y, &z)| x + y + z)
-        //         .collect();
-        //     let olds: Vec<_> = izip!(st1.2.iter(), st2.2.iter(), st3.2.iter(),)
-        //         .map(|(&x, &y, &z)| x + y + z)
-        //         .collect();
-        //     eprintln!("STASH: =======================");
-        //     eprintln!("adrs = {adrs:?}");
-        //     eprintln!("vals = {vals:?}");
-        //     eprintln!("olds = {olds:?}");
-        //     eprintln!("==============================");
-        // }
+        for i in 0..number_cycles {
+            for j in 0..stash_size {
+                let inst = instructions[i * stash_size + j];
+                let expected_value = expected_values[i * stash_size + j];
+                let h1 = run_access(party_1, comm_1, inst);
+                let h2 = run_access(party_2, comm_2, INST_ZERO_SHARE);
+                let h3 = run_access(party_3, comm_3, INST_ZERO_SHARE);
+                let (p1, c1, value_1) = h1.join().unwrap();
+                let (p2, c2, value_2) = h2.join().unwrap();
+                let (p3, c3, value_3) = h3.join().unwrap();
+                (party_1, party_2, party_3) = (p1, p2, p3);
+                (comm_1, comm_2, comm_3) = (c1, c2, c3);
+                assert_eq!(value_1 + value_2 + value_3, expected_value);
+            }
+            let h1 = run_get_db(party_1, comm_1);
+            let h2 = run_get_db(party_2, comm_2);
+            let h3 = run_get_db(party_3, comm_3);
+            let (p1, c1, db_share_1) = h1.join().unwrap();
+            let (p2, c2, db_share_2) = h2.join().unwrap();
+            let (p3, c3, db_share_3) = h3.join().unwrap();
+            (party_1, party_2, party_3) = (p1, p2, p3);
+            (comm_1, comm_2, comm_3) = (c1, c2, c3);
+            let db: Vec<_> = izip!(db_share_1.iter(), db_share_2.iter(), db_share_3.iter())
+                .map(|(&x, &y, &z)| x + y + z)
+                .collect();
+            for k in 0..db_size {
+                assert_eq!(db[k], Fp::from_u128(expected_db_contents[i][k]));
+            }
+        }
+    }
+
+    #[test]
+    fn test_oram_odd_exp() {
+        let db_size = 1 << 5;
+        let stash_size = (db_size as f64).sqrt().ceil() as usize;
+
+        let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
+            setup(db_size);
+
+        let number_cycles = 8;
+        let instructions = [
+            mk_write(26, 64),
+            mk_read(4, 141),
+            mk_write(25, 701),
+            mk_write(29, 927),
+            mk_read(28, 132),
+            mk_write(30, 990),
+            mk_write(23, 167),
+            mk_write(31, 347),
+            mk_write(26, 1020),
+            mk_write(20, 893),
+            mk_read(26, 805),
+            mk_write(3, 949),
+            mk_read(10, 195),
+            mk_write(29, 767),
+            mk_read(28, 107),
+            mk_write(30, 426),
+            mk_write(22, 605),
+            mk_write(0, 171),
+            mk_write(4, 210),
+            mk_read(12, 737),
+            mk_write(19, 977),
+            mk_read(16, 143),
+            mk_write(29, 775),
+            mk_read(28, 34),
+            mk_write(27, 95),
+            mk_write(30, 130),
+            mk_read(8, 89),
+            mk_read(23, 132),
+            mk_read(21, 12),
+            mk_read(4, 675),
+            mk_write(28, 225),
+            mk_write(5, 978),
+            mk_write(2, 833),
+            mk_write(1, 456),
+            mk_write(17, 921),
+            mk_read(26, 293),
+            mk_write(5, 474),
+            mk_write(7, 981),
+            mk_read(19, 189),
+            mk_write(1, 248),
+            mk_read(27, 573),
+            mk_read(17, 142),
+            mk_read(29, 945),
+            mk_read(16, 902),
+            mk_write(16, 799),
+            mk_read(28, 864),
+            mk_write(6, 986),
+            mk_read(2, 201),
+        ];
+        let expected_values = [
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(64),
+            Fp::from_u128(0),
+            Fp::from_u128(1020),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(927),
+            Fp::from_u128(0),
+            Fp::from_u128(990),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(767),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(426),
+            Fp::from_u128(0),
+            Fp::from_u128(167),
+            Fp::from_u128(0),
+            Fp::from_u128(210),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(1020),
+            Fp::from_u128(978),
+            Fp::from_u128(0),
+            Fp::from_u128(977),
+            Fp::from_u128(456),
+            Fp::from_u128(95),
+            Fp::from_u128(921),
+            Fp::from_u128(775),
+            Fp::from_u128(0),
+            Fp::from_u128(0),
+            Fp::from_u128(225),
+            Fp::from_u128(0),
+            Fp::from_u128(833),
+        ];
+        let expected_db_contents = [
+            [
+                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 701, 64,
+                0, 0, 927, 990, 0,
+            ],
+            [
+                0, 0, 0, 949, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, 0, 0, 167, 0,
+                701, 1020, 0, 0, 927, 990, 347,
+            ],
+            [
+                171, 0, 0, 949, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, 0, 605, 167,
+                0, 701, 1020, 0, 0, 767, 426, 347,
+            ],
+            [
+                171, 0, 0, 949, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 977, 893, 0, 605,
+                167, 0, 701, 1020, 0, 0, 775, 426, 347,
+            ],
+            [
+                171, 0, 0, 949, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 977, 893, 0, 605,
+                167, 0, 701, 1020, 95, 0, 775, 130, 347,
+            ],
+            [
+                171, 456, 833, 949, 210, 978, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 921, 0, 977, 893, 0,
+                605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
+            ],
+            [
+                171, 248, 833, 949, 210, 474, 0, 981, 0, 0, 0, 0, 0, 0, 0, 0, 0, 921, 0, 977, 893,
+                0, 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
+            ],
+            [
+                171, 248, 833, 949, 210, 474, 986, 981, 0, 0, 0, 0, 0, 0, 0, 0, 799, 921, 0, 977,
+                893, 0, 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
+            ],
+        ];
 
         for i in 0..number_cycles {
             for j in 0..stash_size {
                 let inst = instructions[i * stash_size + j];
-                // eprintln!("Running {inst:?}");
                 let expected_value = expected_values[i * stash_size + j];
                 let h1 = run_access(party_1, comm_1, inst);
-                let h2 = run_access(party_2, comm_2, inst_zero_share);
-                let h3 = run_access(party_3, comm_3, inst_zero_share);
+                let h2 = run_access(party_2, comm_2, INST_ZERO_SHARE);
+                let h3 = run_access(party_3, comm_3, INST_ZERO_SHARE);
                 let (p1, c1, value_1) = h1.join().unwrap();
                 let (p2, c2, value_2) = h2.join().unwrap();
                 let (p3, c3, value_3) = h3.join().unwrap();
                 (party_1, party_2, party_3) = (p1, p2, p3);
                 (comm_1, comm_2, comm_3) = (c1, c2, c3);
                 assert_eq!(value_1 + value_2 + value_3, expected_value);
-                // print_stash(&party_1, &party_2, &party_3);
             }
             let h1 = run_get_db(party_1, comm_1);
             let h2 = run_get_db(party_2, comm_2);
@@ -1323,8 +1486,6 @@ mod tests {
             let db: Vec<_> = izip!(db_share_1.iter(), db_share_2.iter(), db_share_3.iter())
                 .map(|(&x, &y, &z)| x + y + z)
                 .collect();
-            // eprintln!("expected = {:#x?}", expected_db_contents[i]);
-            // eprintln!("db = {:#?}", db);
             for k in 0..db_size {
                 assert_eq!(db[k], Fp::from_u128(expected_db_contents[i][k]));
             }