Переглянути джерело

bench: different thread pools for prep/online

Lennart Braun 1 рік тому
батько
коміт
3864b9ed23
1 змінених файлів з 54 додано та 15 видалено
  1. 54 15
      oram/examples/bench_doram.rs

+ 54 - 15
oram/examples/bench_doram.rs

@@ -39,6 +39,12 @@ struct Cli {
     /// How many threads to use for the computation
     #[arg(long, short = 't', default_value_t = 1)]
     pub threads: usize,
+    /// How many threads to use for the preprocessing phase (default: same as -t)
+    #[arg(long, default_value_t = -1)]
+    pub threads_prep: isize,
+    /// How many threads to use for the online phase (default: same as -t)
+    #[arg(long, default_value_t = -1)]
+    pub threads_online: isize,
     /// Output statistics in JSON
     #[arg(long, short = 'j')]
     pub json: bool,
@@ -61,7 +67,8 @@ struct BenchmarkResults {
     party_id: usize,
     log_db_size: u32,
     preprocess: bool,
-    threads: usize,
+    threads_prep: usize,
+    threads_online: usize,
     comm_stats_preprocess: HashMap<usize, CommunicationStats>,
     comm_stats_access: HashMap<usize, CommunicationStats>,
     runtimes: HashMap<String, Duration>,
@@ -80,11 +87,23 @@ impl BenchmarkResults {
             runtime_map.insert(step.to_string(), runtimes.get(step));
         }
 
+        let threads_prep = if cli.threads_prep < 0 {
+            cli.threads
+        } else {
+            cli.threads_prep as usize
+        };
+        let threads_online = if cli.threads_online < 0 {
+            cli.threads
+        } else {
+            cli.threads_online as usize
+        };
+
         Self {
             party_id: cli.party_id as usize,
             log_db_size: cli.log_db_size,
             preprocess: cli.preprocess,
-            threads: cli.threads,
+            threads_prep,
+            threads_online,
             comm_stats_preprocess: comm_stats_preprocess.clone(),
             comm_stats_access: comm_stats_access.clone(),
             runtimes: runtime_map,
@@ -127,10 +146,16 @@ fn main() {
         connect_timeout_seconds: cli.connect_timeout_seconds,
     };
 
-    rayon::ThreadPoolBuilder::new()
-        .num_threads(cli.threads)
-        .build_global()
-        .unwrap();
+    let threads_prep = if cli.threads_prep < 0 {
+        cli.threads
+    } else {
+        cli.threads_prep as usize
+    };
+    let threads_online = if cli.threads_online < 0 {
+        cli.threads
+    } else {
+        cli.threads_online as usize
+    };
 
     for c in cli.connect.iter() {
         if netopts.connect_info[c.0] != NetworkPartyInfo::Listen {
@@ -183,16 +208,23 @@ fn main() {
 
     doram.init(&mut comm, &db_share).expect("init failed");
 
+    let thread_pool_prep = rayon::ThreadPoolBuilder::new()
+        .num_threads(threads_prep)
+        .build()
+        .unwrap();
+
     comm.reset_stats();
     let mut runtimes = OramRuntimes::default();
 
     let d_preprocess = if cli.preprocess {
         let t_start = Instant::now();
 
-        runtimes = doram
-            .preprocess_with_runtimes(&mut comm, 1, Some(runtimes))
-            .expect("preprocess failed")
-            .unwrap();
+        runtimes = thread_pool_prep.install(|| {
+            doram
+                .preprocess_with_runtimes(&mut comm, 1, Some(runtimes))
+                .expect("preprocess failed")
+                .unwrap()
+        });
 
         t_start.elapsed()
     } else {
@@ -202,14 +234,21 @@ fn main() {
     let comm_stats_preprocess = comm.get_stats();
     comm.reset_stats();
 
+    let thread_pool_online = rayon::ThreadPoolBuilder::new()
+        .num_threads(threads_online)
+        .build()
+        .unwrap();
+
     let t_start = Instant::now();
     for (_i, inst) in instructions.iter().enumerate() {
         // println!("executing instruction #{i}: {inst:?}");
-        runtimes = doram
-            .access_with_runtimes(&mut comm, *inst, Some(runtimes))
-            .expect("access failed")
-            .1
-            .unwrap();
+        runtimes = thread_pool_online.install(|| {
+            doram
+                .access_with_runtimes(&mut comm, *inst, Some(runtimes))
+                .expect("access failed")
+                .1
+                .unwrap()
+        });
     }
     let d_accesses = Instant::now() - t_start;