|
@@ -1,17 +1,25 @@
|
|
|
use clap::{CommandFactory, Parser};
|
|
|
use communicator::tcp::{make_tcp_communicator, NetworkOptions, NetworkPartyInfo};
|
|
|
-use communicator::AbstractCommunicator;
|
|
|
+use communicator::{AbstractCommunicator, CommunicationStats};
|
|
|
use cuckoo::hash::AesHashFunction;
|
|
|
use dpf::mpdpf::SmartMpDpf;
|
|
|
use dpf::spdpf::HalfTreeSpDpf;
|
|
|
use ff::{Field, PrimeField};
|
|
|
use oram::common::{InstructionShare, Operation};
|
|
|
-use oram::oram::{DistributedOram, DistributedOramProtocol, Runtimes};
|
|
|
+use oram::oram::{
|
|
|
+ DistributedOram, DistributedOramProtocol, ProtocolStep as OramProtocolStep,
|
|
|
+ Runtimes as OramRuntimes,
|
|
|
+};
|
|
|
+use oram::tools::BenchmarkMetaData;
|
|
|
use rand::{Rng, SeedableRng};
|
|
|
use rand_chacha::ChaChaRng;
|
|
|
use rayon;
|
|
|
+use serde;
|
|
|
+use serde_json;
|
|
|
+use std::collections::HashMap;
|
|
|
use std::process;
|
|
|
-use std::time::Instant;
|
|
|
+use std::time::{Duration, Instant};
|
|
|
+use strum::IntoEnumIterator;
|
|
|
use utils::field::Fp;
|
|
|
|
|
|
type MPDPF = SmartMpDpf<Fp, HalfTreeSpDpf<Fp>, AesHashFunction<u16>>;
|
|
@@ -31,6 +39,9 @@ struct Cli {
|
|
|
/// How many threads to use for the computation
|
|
|
#[arg(long, short = 't', default_value_t = 1)]
|
|
|
pub threads: usize,
|
|
|
+ /// Output statistics in JSON
|
|
|
+ #[arg(long, short = 'j')]
|
|
|
+ pub json: bool,
|
|
|
/// Which address to listen on for incoming connections
|
|
|
#[arg(long, short = 'l')]
|
|
|
pub listen_host: String,
|
|
@@ -45,6 +56,43 @@ struct Cli {
|
|
|
pub connect_timeout_seconds: usize,
|
|
|
}
|
|
|
|
|
|
+#[derive(Debug, Clone, serde::Serialize)]
|
|
|
+struct BenchmarkResults {
|
|
|
+ party_id: usize,
|
|
|
+ log_db_size: u32,
|
|
|
+ preprocess: bool,
|
|
|
+ threads: usize,
|
|
|
+ comm_stats_preprocess: HashMap<usize, CommunicationStats>,
|
|
|
+ comm_stats_access: HashMap<usize, CommunicationStats>,
|
|
|
+ runtimes: HashMap<String, Duration>,
|
|
|
+ meta: BenchmarkMetaData,
|
|
|
+}
|
|
|
+
|
|
|
+impl BenchmarkResults {
|
|
|
+ pub fn new(
|
|
|
+ cli: &Cli,
|
|
|
+ comm_stats_preprocess: &HashMap<usize, CommunicationStats>,
|
|
|
+ comm_stats_access: &HashMap<usize, CommunicationStats>,
|
|
|
+ runtimes: &OramRuntimes,
|
|
|
+ ) -> Self {
|
|
|
+ let mut runtime_map = HashMap::new();
|
|
|
+ for step in OramProtocolStep::iter() {
|
|
|
+ runtime_map.insert(step.to_string(), runtimes.get(step));
|
|
|
+ }
|
|
|
+
|
|
|
+ Self {
|
|
|
+ party_id: cli.party_id as usize,
|
|
|
+ log_db_size: cli.log_db_size,
|
|
|
+ preprocess: cli.preprocess,
|
|
|
+ threads: cli.threads,
|
|
|
+ comm_stats_preprocess: comm_stats_preprocess.clone(),
|
|
|
+ comm_stats_access: comm_stats_access.clone(),
|
|
|
+ runtimes: runtime_map,
|
|
|
+ meta: BenchmarkMetaData::collect(),
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
fn parse_log_db_size(s: &str) -> Result<u32, Box<dyn std::error::Error + Send + Sync + 'static>> {
|
|
|
let log_db_size: u32 = s.parse()?;
|
|
|
if log_db_size & 1 == 1 {
|
|
@@ -85,7 +133,7 @@ fn main() {
|
|
|
let cli = Cli::parse();
|
|
|
|
|
|
let mut netopts = NetworkOptions {
|
|
|
- listen_host: cli.listen_host,
|
|
|
+ listen_host: cli.listen_host.clone(),
|
|
|
listen_port: cli.listen_port,
|
|
|
connect_info: vec![NetworkPartyInfo::Listen; 3],
|
|
|
connect_timeout_seconds: cli.connect_timeout_seconds,
|
|
@@ -96,7 +144,7 @@ fn main() {
|
|
|
.build_global()
|
|
|
.unwrap();
|
|
|
|
|
|
- for c in cli.connect {
|
|
|
+ for c in cli.connect.iter() {
|
|
|
if netopts.connect_info[c.0] != NetworkPartyInfo::Listen {
|
|
|
println!(
|
|
|
"{}",
|
|
@@ -108,7 +156,7 @@ fn main() {
|
|
|
);
|
|
|
process::exit(1);
|
|
|
}
|
|
|
- netopts.connect_info[c.0] = NetworkPartyInfo::Connect(c.1, c.2);
|
|
|
+ netopts.connect_info[c.0] = NetworkPartyInfo::Connect(c.1.clone(), c.2);
|
|
|
}
|
|
|
|
|
|
let mut comm = match make_tcp_communicator(3, cli.party_id as usize, &netopts) {
|
|
@@ -145,13 +193,10 @@ fn main() {
|
|
|
]
|
|
|
};
|
|
|
|
|
|
- let t_start = Instant::now();
|
|
|
-
|
|
|
doram.init(&mut comm, &db_share).expect("init failed");
|
|
|
|
|
|
- let d_init = Instant::now() - t_start;
|
|
|
-
|
|
|
- let mut runtimes = Runtimes::default();
|
|
|
+ comm.reset_stats();
|
|
|
+ let mut runtimes = OramRuntimes::default();
|
|
|
|
|
|
let d_preprocess = if cli.preprocess {
|
|
|
let t_start = Instant::now();
|
|
@@ -166,6 +211,9 @@ fn main() {
|
|
|
Default::default()
|
|
|
};
|
|
|
|
|
|
+ let comm_stats_preprocess = comm.get_stats();
|
|
|
+ comm.reset_stats();
|
|
|
+
|
|
|
let t_start = Instant::now();
|
|
|
for (_i, inst) in instructions.iter().enumerate() {
|
|
|
// println!("executing instruction #{i}: {inst:?}");
|
|
@@ -177,36 +225,39 @@ fn main() {
|
|
|
}
|
|
|
let d_accesses = Instant::now() - t_start;
|
|
|
|
|
|
- println!(
|
|
|
- "time init: {:10.3} ms",
|
|
|
- d_init.as_secs_f64() * 1000.0
|
|
|
- );
|
|
|
- println!(
|
|
|
- "time preprocess: {:10.3} ms",
|
|
|
- d_preprocess.as_secs_f64() * 1000.0
|
|
|
- );
|
|
|
- println!(
|
|
|
- " per accesses: {:10.3} ms",
|
|
|
- d_preprocess.as_secs_f64() * 1000.0 / stash_size as f64
|
|
|
- );
|
|
|
- println!(
|
|
|
- "time accesses: {:10.3} ms{}",
|
|
|
- d_accesses.as_secs_f64() * 1000.0,
|
|
|
- if cli.preprocess {
|
|
|
- " (online only)"
|
|
|
- } else {
|
|
|
- ""
|
|
|
- }
|
|
|
- );
|
|
|
- println!(
|
|
|
- " per accesses: {:10.3} ms",
|
|
|
- d_accesses.as_secs_f64() * 1000.0 / stash_size as f64
|
|
|
- );
|
|
|
+ let comm_stats_access = comm.get_stats();
|
|
|
|
|
|
- let comm_stats = comm.get_stats();
|
|
|
comm.shutdown();
|
|
|
|
|
|
- runtimes.print(cli.party_id as usize + 1, stash_size);
|
|
|
+ let results =
|
|
|
+ BenchmarkResults::new(&cli, &comm_stats_preprocess, &comm_stats_access, &runtimes);
|
|
|
|
|
|
- println!("{comm_stats:#?}");
|
|
|
+ if cli.json {
|
|
|
+ println!("{}", serde_json::to_string(&results).unwrap());
|
|
|
+ } else {
|
|
|
+ println!(
|
|
|
+ "time preprocess: {:10.3} ms",
|
|
|
+ d_preprocess.as_secs_f64() * 1000.0
|
|
|
+ );
|
|
|
+ println!(
|
|
|
+ " per accesses: {:10.3} ms",
|
|
|
+ d_preprocess.as_secs_f64() * 1000.0 / stash_size as f64
|
|
|
+ );
|
|
|
+ println!(
|
|
|
+ "time accesses: {:10.3} ms{}",
|
|
|
+ d_accesses.as_secs_f64() * 1000.0,
|
|
|
+ if cli.preprocess {
|
|
|
+ " (online only)"
|
|
|
+ } else {
|
|
|
+ ""
|
|
|
+ }
|
|
|
+ );
|
|
|
+ println!(
|
|
|
+ " per accesses: {:10.3} ms",
|
|
|
+ d_accesses.as_secs_f64() * 1000.0 / stash_size as f64
|
|
|
+ );
|
|
|
+ runtimes.print(cli.party_id as usize + 1, stash_size);
|
|
|
+ println!("communication preprocessing: {comm_stats_preprocess:#?}");
|
|
|
+ println!("communication accesses: {comm_stats_access:#?}");
|
|
|
+ }
|
|
|
}
|