Browse Source

bench: fix #accesses per epoch computation

Lennart Braun 1 year ago
parent
commit
1e6f63d5ba
2 changed files with 20 additions and 15 deletions
  1. 14 13
      oram/examples/bench_doram.rs
  2. 6 2
      oram/src/oram.rs

+ 14 - 13
oram/examples/bench_doram.rs

@@ -189,11 +189,11 @@ fn main() {
 
     let db_size = 1 << cli.log_db_size;
     let db_share: Vec<_> = vec![Fp::ZERO; db_size];
-    let stash_size = 1 << (cli.log_db_size >> 1);
+    let num_accesses_per_epoch = doram.get_stash_size();
 
     let instructions = if cli.party_id == 0 {
         let mut rng = ChaChaRng::from_seed([0u8; 32]);
-        (0..stash_size)
+        (0..num_accesses_per_epoch)
             .map(|_| InstructionShare {
                 operation: Operation::Write.encode(),
                 address: Fp::from_u128(rng.gen_range(0..db_size) as u128),
@@ -207,7 +207,7 @@ fn main() {
                 address: Fp::ZERO,
                 value: Fp::ZERO
             };
-            stash_size
+            num_accesses_per_epoch
         ]
     };
 
@@ -249,16 +249,17 @@ fn main() {
         .unwrap();
 
     let t_start = Instant::now();
-    for (_i, inst) in instructions.iter().enumerate() {
-        // println!("executing instruction #{i}: {inst:?}");
-        runtimes = thread_pool_online.install(|| {
-            doram
+    runtimes = thread_pool_online.install(|| {
+        for (_i, inst) in instructions.iter().enumerate() {
+            // println!("executing instruction #{i}: {inst:?}");
+            runtimes = doram
                 .access_with_runtimes(&mut comm, *inst, Some(runtimes))
                 .expect("access failed")
                 .1
-                .unwrap()
-        });
-    }
+                .unwrap();
+        }
+        runtimes
+    });
     let d_accesses = Instant::now() - t_start;
 
     let comm_stats_access = comm.get_stats();
@@ -279,7 +280,7 @@ fn main() {
         );
         println!(
             "   per accesses:  {:10.3} ms",
-            d_preprocess.as_secs_f64() * 1000.0 / stash_size as f64
+            d_preprocess.as_secs_f64() * 1000.0 / num_accesses_per_epoch as f64
         );
         println!(
             "time accesses:    {:10.3} ms{}",
@@ -292,9 +293,9 @@ fn main() {
         );
         println!(
             "   per accesses:  {:10.3} ms",
-            d_accesses.as_secs_f64() * 1000.0 / stash_size as f64
+            d_accesses.as_secs_f64() * 1000.0 / num_accesses_per_epoch as f64
         );
-        runtimes.print(cli.party_id as usize + 1, stash_size);
+        runtimes.print(cli.party_id as usize + 1, num_accesses_per_epoch);
         println!("communication preprocessing: {comm_stats_preprocess:#?}");
         println!("communication accesses: {comm_stats_access:#?}");
     }

+ 6 - 2
oram/src/oram.rs

@@ -298,6 +298,10 @@ where
         self.stash.as_ref().unwrap()
     }
 
+    pub fn get_stash_size(&self) -> usize {
+        self.stash_size
+    }
+
     fn pos_prev(&self, tag: u128) -> usize {
         debug_assert_eq!(self.memory_index_tags_prev_sorted.len(), self.memory_size);
         self.memory_index_tags_prev_sorted
@@ -1193,7 +1197,7 @@ mod tests {
     #[test]
     fn test_oram_even_exp() {
         let db_size = 1 << 4;
-        let stash_size = (db_size as f64).sqrt().ceil() as usize;
+        let stash_size = (db_size as f64).sqrt().round() as usize;
 
         let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
             setup(db_size);
@@ -1320,7 +1324,7 @@ mod tests {
     #[test]
     fn test_oram_odd_exp() {
         let db_size = 1 << 5;
-        let stash_size = (db_size as f64).sqrt().ceil() as usize;
+        let stash_size = (db_size as f64).sqrt().round() as usize;
 
         let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
             setup(db_size);