|
@@ -39,6 +39,12 @@ struct Cli {
|
|
|
/// How many threads to use for the computation
|
|
|
#[arg(long, short = 't', default_value_t = 1)]
|
|
|
pub threads: usize,
|
|
|
+ /// How many threads to use for the preprocessing phase (default: same as -t)
|
|
|
+ #[arg(long, default_value_t = -1)]
|
|
|
+ pub threads_prep: isize,
|
|
|
+ /// How many threads to use for the online phase (default: same as -t)
|
|
|
+ #[arg(long, default_value_t = -1)]
|
|
|
+ pub threads_online: isize,
|
|
|
/// Output statistics in JSON
|
|
|
#[arg(long, short = 'j')]
|
|
|
pub json: bool,
|
|
@@ -61,7 +67,8 @@ struct BenchmarkResults {
|
|
|
party_id: usize,
|
|
|
log_db_size: u32,
|
|
|
preprocess: bool,
|
|
|
- threads: usize,
|
|
|
+ threads_prep: usize,
|
|
|
+ threads_online: usize,
|
|
|
comm_stats_preprocess: HashMap<usize, CommunicationStats>,
|
|
|
comm_stats_access: HashMap<usize, CommunicationStats>,
|
|
|
runtimes: HashMap<String, Duration>,
|
|
@@ -80,11 +87,23 @@ impl BenchmarkResults {
|
|
|
runtime_map.insert(step.to_string(), runtimes.get(step));
|
|
|
}
|
|
|
|
|
|
+ let threads_prep = if cli.threads_prep < 0 {
|
|
|
+ cli.threads
|
|
|
+ } else {
|
|
|
+ cli.threads_prep as usize
|
|
|
+ };
|
|
|
+ let threads_online = if cli.threads_online < 0 {
|
|
|
+ cli.threads
|
|
|
+ } else {
|
|
|
+ cli.threads_online as usize
|
|
|
+ };
|
|
|
+
|
|
|
Self {
|
|
|
party_id: cli.party_id as usize,
|
|
|
log_db_size: cli.log_db_size,
|
|
|
preprocess: cli.preprocess,
|
|
|
- threads: cli.threads,
|
|
|
+ threads_prep,
|
|
|
+ threads_online,
|
|
|
comm_stats_preprocess: comm_stats_preprocess.clone(),
|
|
|
comm_stats_access: comm_stats_access.clone(),
|
|
|
runtimes: runtime_map,
|
|
@@ -127,10 +146,16 @@ fn main() {
|
|
|
connect_timeout_seconds: cli.connect_timeout_seconds,
|
|
|
};
|
|
|
|
|
|
- rayon::ThreadPoolBuilder::new()
|
|
|
- .num_threads(cli.threads)
|
|
|
- .build_global()
|
|
|
- .unwrap();
|
|
|
+ let threads_prep = if cli.threads_prep < 0 {
|
|
|
+ cli.threads
|
|
|
+ } else {
|
|
|
+ cli.threads_prep as usize
|
|
|
+ };
|
|
|
+ let threads_online = if cli.threads_online < 0 {
|
|
|
+ cli.threads
|
|
|
+ } else {
|
|
|
+ cli.threads_online as usize
|
|
|
+ };
|
|
|
|
|
|
for c in cli.connect.iter() {
|
|
|
if netopts.connect_info[c.0] != NetworkPartyInfo::Listen {
|
|
@@ -183,16 +208,23 @@ fn main() {
|
|
|
|
|
|
doram.init(&mut comm, &db_share).expect("init failed");
|
|
|
|
|
|
+ let thread_pool_prep = rayon::ThreadPoolBuilder::new()
|
|
|
+ .num_threads(threads_prep)
|
|
|
+ .build()
|
|
|
+ .unwrap();
|
|
|
+
|
|
|
comm.reset_stats();
|
|
|
let mut runtimes = OramRuntimes::default();
|
|
|
|
|
|
let d_preprocess = if cli.preprocess {
|
|
|
let t_start = Instant::now();
|
|
|
|
|
|
- runtimes = doram
|
|
|
- .preprocess_with_runtimes(&mut comm, 1, Some(runtimes))
|
|
|
- .expect("preprocess failed")
|
|
|
- .unwrap();
|
|
|
+ runtimes = thread_pool_prep.install(|| {
|
|
|
+ doram
|
|
|
+ .preprocess_with_runtimes(&mut comm, 1, Some(runtimes))
|
|
|
+ .expect("preprocess failed")
|
|
|
+ .unwrap()
|
|
|
+ });
|
|
|
|
|
|
t_start.elapsed()
|
|
|
} else {
|
|
@@ -202,14 +234,21 @@ fn main() {
|
|
|
let comm_stats_preprocess = comm.get_stats();
|
|
|
comm.reset_stats();
|
|
|
|
|
|
+ let thread_pool_online = rayon::ThreadPoolBuilder::new()
|
|
|
+ .num_threads(threads_online)
|
|
|
+ .build()
|
|
|
+ .unwrap();
|
|
|
+
|
|
|
let t_start = Instant::now();
|
|
|
for (_i, inst) in instructions.iter().enumerate() {
|
|
|
// println!("executing instruction #{i}: {inst:?}");
|
|
|
- runtimes = doram
|
|
|
- .access_with_runtimes(&mut comm, *inst, Some(runtimes))
|
|
|
- .expect("access failed")
|
|
|
- .1
|
|
|
- .unwrap();
|
|
|
+ runtimes = thread_pool_online.install(|| {
|
|
|
+ doram
|
|
|
+ .access_with_runtimes(&mut comm, *inst, Some(runtimes))
|
|
|
+ .expect("access failed")
|
|
|
+ .1
|
|
|
+ .unwrap()
|
|
|
+ });
|
|
|
}
|
|
|
let d_accesses = Instant::now() - t_start;
|
|
|
|