Browse Source

oram: improve benchmark output

Lennart Braun 2 years ago
parent
commit
75c79e1052
5 changed files with 190 additions and 44 deletions
  1. 3 0
      oram/Cargo.toml
  2. 90 39
      oram/examples/bench_doram.rs
  3. 1 0
      oram/src/lib.rs
  4. 7 5
      oram/src/oram.rs
  5. 89 0
      oram/src/tools.rs

+ 3 - 0
oram/Cargo.toml

@@ -13,12 +13,14 @@ bincode = "2.0.0-rc.2"
 bitvec = "1.0.1"
 ff = "0.13.0"
 funty = "2.0.0"
+git-version = "0.3.5"
 itertools = "0.10.5"
 num-bigint = "0.4.3"
 num-traits = "0.2.15"
 rand = "0.8.5"
 rand_chacha = "0.3.1"
 rayon = "1.6.1"
+serde = { version = "1.0", features = ["derive"] }
 strum = { version = "0.24.1", features = ["derive"] }
 strum_macros = "0.24"
 
@@ -26,6 +28,7 @@ strum_macros = "0.24"
 cuckoo = { path = "../cuckoo" }
 clap = { version = "4.1.4", features = ["derive"] }
 criterion = "0.4.0"
+serde_json = "1.0"
 
 [[bench]]
 name = "doprf"

+ 90 - 39
oram/examples/bench_doram.rs

@@ -1,17 +1,25 @@
 use clap::{CommandFactory, Parser};
 use communicator::tcp::{make_tcp_communicator, NetworkOptions, NetworkPartyInfo};
-use communicator::AbstractCommunicator;
+use communicator::{AbstractCommunicator, CommunicationStats};
 use cuckoo::hash::AesHashFunction;
 use dpf::mpdpf::SmartMpDpf;
 use dpf::spdpf::HalfTreeSpDpf;
 use ff::{Field, PrimeField};
 use oram::common::{InstructionShare, Operation};
-use oram::oram::{DistributedOram, DistributedOramProtocol, Runtimes};
+use oram::oram::{
+    DistributedOram, DistributedOramProtocol, ProtocolStep as OramProtocolStep,
+    Runtimes as OramRuntimes,
+};
+use oram::tools::BenchmarkMetaData;
 use rand::{Rng, SeedableRng};
 use rand_chacha::ChaChaRng;
 use rayon;
+use serde;
+use serde_json;
+use std::collections::HashMap;
 use std::process;
-use std::time::Instant;
+use std::time::{Duration, Instant};
+use strum::IntoEnumIterator;
 use utils::field::Fp;
 
 type MPDPF = SmartMpDpf<Fp, HalfTreeSpDpf<Fp>, AesHashFunction<u16>>;
@@ -31,6 +39,9 @@ struct Cli {
     /// How many threads to use for the computation
     #[arg(long, short = 't', default_value_t = 1)]
     pub threads: usize,
+    /// Output statistics in JSON
+    #[arg(long, short = 'j')]
+    pub json: bool,
     /// Which address to listen on for incoming connections
     #[arg(long, short = 'l')]
     pub listen_host: String,
@@ -45,6 +56,43 @@ struct Cli {
     pub connect_timeout_seconds: usize,
 }
 
+#[derive(Debug, Clone, serde::Serialize)]
+struct BenchmarkResults {
+    party_id: usize,
+    log_db_size: u32,
+    preprocess: bool,
+    threads: usize,
+    comm_stats_preprocess: HashMap<usize, CommunicationStats>,
+    comm_stats_access: HashMap<usize, CommunicationStats>,
+    runtimes: HashMap<String, Duration>,
+    meta: BenchmarkMetaData,
+}
+
+impl BenchmarkResults {
+    pub fn new(
+        cli: &Cli,
+        comm_stats_preprocess: &HashMap<usize, CommunicationStats>,
+        comm_stats_access: &HashMap<usize, CommunicationStats>,
+        runtimes: &OramRuntimes,
+    ) -> Self {
+        let mut runtime_map = HashMap::new();
+        for step in OramProtocolStep::iter() {
+            runtime_map.insert(step.to_string(), runtimes.get(step));
+        }
+
+        Self {
+            party_id: cli.party_id as usize,
+            log_db_size: cli.log_db_size,
+            preprocess: cli.preprocess,
+            threads: cli.threads,
+            comm_stats_preprocess: comm_stats_preprocess.clone(),
+            comm_stats_access: comm_stats_access.clone(),
+            runtimes: runtime_map,
+            meta: BenchmarkMetaData::collect(),
+        }
+    }
+}
+
 fn parse_log_db_size(s: &str) -> Result<u32, Box<dyn std::error::Error + Send + Sync + 'static>> {
     let log_db_size: u32 = s.parse()?;
     if log_db_size & 1 == 1 {
@@ -85,7 +133,7 @@ fn main() {
     let cli = Cli::parse();
 
     let mut netopts = NetworkOptions {
-        listen_host: cli.listen_host,
+        listen_host: cli.listen_host.clone(),
         listen_port: cli.listen_port,
         connect_info: vec![NetworkPartyInfo::Listen; 3],
         connect_timeout_seconds: cli.connect_timeout_seconds,
@@ -96,7 +144,7 @@ fn main() {
         .build_global()
         .unwrap();
 
-    for c in cli.connect {
+    for c in cli.connect.iter() {
         if netopts.connect_info[c.0] != NetworkPartyInfo::Listen {
             println!(
                 "{}",
@@ -108,7 +156,7 @@ fn main() {
             );
             process::exit(1);
         }
-        netopts.connect_info[c.0] = NetworkPartyInfo::Connect(c.1, c.2);
+        netopts.connect_info[c.0] = NetworkPartyInfo::Connect(c.1.clone(), c.2);
     }
 
     let mut comm = match make_tcp_communicator(3, cli.party_id as usize, &netopts) {
@@ -145,13 +193,10 @@ fn main() {
         ]
     };
 
-    let t_start = Instant::now();
-
     doram.init(&mut comm, &db_share).expect("init failed");
 
-    let d_init = Instant::now() - t_start;
-
-    let mut runtimes = Runtimes::default();
+    comm.reset_stats();
+    let mut runtimes = OramRuntimes::default();
 
     let d_preprocess = if cli.preprocess {
         let t_start = Instant::now();
@@ -166,6 +211,9 @@ fn main() {
         Default::default()
     };
 
+    let comm_stats_preprocess = comm.get_stats();
+    comm.reset_stats();
+
     let t_start = Instant::now();
     for (_i, inst) in instructions.iter().enumerate() {
         // println!("executing instruction #{i}: {inst:?}");
@@ -177,36 +225,39 @@ fn main() {
     }
     let d_accesses = Instant::now() - t_start;
 
-    println!(
-        "time init:        {:10.3} ms",
-        d_init.as_secs_f64() * 1000.0
-    );
-    println!(
-        "time preprocess:  {:10.3} ms",
-        d_preprocess.as_secs_f64() * 1000.0
-    );
-    println!(
-        "   per accesses:  {:10.3} ms",
-        d_preprocess.as_secs_f64() * 1000.0 / stash_size as f64
-    );
-    println!(
-        "time accesses:    {:10.3} ms{}",
-        d_accesses.as_secs_f64() * 1000.0,
-        if cli.preprocess {
-            "  (online only)"
-        } else {
-            ""
-        }
-    );
-    println!(
-        "   per accesses:  {:10.3} ms",
-        d_accesses.as_secs_f64() * 1000.0 / stash_size as f64
-    );
+    let comm_stats_access = comm.get_stats();
 
-    let comm_stats = comm.get_stats();
     comm.shutdown();
 
-    runtimes.print(cli.party_id as usize + 1, stash_size);
+    let results =
+        BenchmarkResults::new(&cli, &comm_stats_preprocess, &comm_stats_access, &runtimes);
 
-    println!("{comm_stats:#?}");
+    if cli.json {
+        println!("{}", serde_json::to_string(&results).unwrap());
+    } else {
+        println!(
+            "time preprocess:  {:10.3} ms",
+            d_preprocess.as_secs_f64() * 1000.0
+        );
+        println!(
+            "   per accesses:  {:10.3} ms",
+            d_preprocess.as_secs_f64() * 1000.0 / stash_size as f64
+        );
+        println!(
+            "time accesses:    {:10.3} ms{}",
+            d_accesses.as_secs_f64() * 1000.0,
+            if cli.preprocess {
+                "  (online only)"
+            } else {
+                ""
+            }
+        );
+        println!(
+            "   per accesses:  {:10.3} ms",
+            d_accesses.as_secs_f64() * 1000.0 / stash_size as f64
+        );
+        runtimes.print(cli.party_id as usize + 1, stash_size);
+        println!("communication preprocessing: {comm_stats_preprocess:#?}");
+        println!("communication accesses: {comm_stats_access:#?}");
+    }
 }

+ 1 - 0
oram/src/lib.rs

@@ -5,3 +5,4 @@ pub mod oram;
 pub mod p_ot;
 pub mod select;
 pub mod stash;
+pub mod tools;

+ 7 - 5
oram/src/oram.rs

@@ -69,6 +69,7 @@ pub enum ProtocolStep {
     PreprocessDOPrf,
     PreprocessPOt,
     PreprocessSelect,
+    Access,
     AccessStashRead,
     AccessAddressSelection,
     AccessDatabaseRead,
@@ -91,7 +92,7 @@ pub enum ProtocolStep {
 
 #[derive(Debug, Default, Clone, Copy)]
 pub struct Runtimes {
-    durations: [Duration; 29],
+    durations: [Duration; 30],
     stash_runtimes: StashRuntimes,
 }
 
@@ -675,10 +676,6 @@ where
         );
 
         let runtimes = runtimes.map(|mut r| {
-            r.record(
-                ProtocolStep::Preprocess,
-                t_after_receiving_index_tags_mine - t_start,
-            );
             r.record(
                 ProtocolStep::PreprocessLPRFKeyGenPrev,
                 t_after_gen_lpks_prev - t_start,
@@ -719,6 +716,10 @@ where
                 ProtocolStep::PreprocessSelect,
                 t_after_preprocess_select - t_after_preprocess_pot,
             );
+            r.record(
+                ProtocolStep::Preprocess,
+                t_after_preprocess_select - t_start,
+            );
             r
         });
 
@@ -959,6 +960,7 @@ where
                 ProtocolStep::AccessRefresh,
                 t_after_refresh - t_after_value_selection,
             );
+            r.record(ProtocolStep::Access, t_after_refresh - t_start);
             r
         });
 

+ 89 - 0
oram/src/tools.rs

@@ -0,0 +1,89 @@
+use git_version::git_version;
+use serde::Serialize;
+use std::fs::{read_to_string, File};
+use std::io::{BufRead, BufReader};
+use std::process;
+
+#[derive(Clone, Debug, Serialize)]
+pub struct BenchmarkMetaData {
+    pub hostname: String,
+    pub username: String,
+    pub timestamp: String,
+    pub cmdline: Vec<String>,
+    pub pid: u32,
+    pub git_version: String,
+}
+
+impl BenchmarkMetaData {
+    pub fn collect() -> Self {
+        BenchmarkMetaData {
+            hostname: get_hostname(),
+            username: get_username(),
+            timestamp: get_timestamp(),
+            cmdline: get_cmdline(),
+            pid: get_pid(),
+            git_version: git_version!(args = ["--abbrev=40", "--always", "--dirty"]).to_string(),
+        }
+    }
+}
+
+pub fn run_command_with_args(cmd: &str, args: &[&str]) -> String {
+    String::from_utf8(
+        process::Command::new(cmd)
+            .args(args)
+            .output()
+            .expect("process failed")
+            .stdout,
+    )
+    .expect("utf-8 decoding failed")
+    .trim()
+    .to_string()
+}
+
+pub fn run_command(cmd: &str) -> String {
+    String::from_utf8(
+        process::Command::new(cmd)
+            .output()
+            .expect("process failed")
+            .stdout,
+    )
+    .expect("utf-8 decoding failed")
+    .trim()
+    .to_string()
+}
+
+pub fn read_file(path: &str) -> String {
+    read_to_string(path).expect("read_to_string failed")
+}
+
+pub fn get_username() -> String {
+    run_command("whoami")
+}
+
+pub fn get_hostname() -> String {
+    read_file("/proc/sys/kernel/hostname").trim().to_string()
+}
+
+pub fn get_timestamp() -> String {
+    run_command_with_args("date", &["--iso-8601=s"])
+}
+
+pub fn get_cmdline() -> Vec<String> {
+    let f = File::open("/proc/self/cmdline").expect("cannot open file");
+    let mut reader = BufReader::new(f);
+    let mut cmdline: Vec<String> = Vec::new();
+    loop {
+        let mut bytes = Vec::<u8>::new();
+        let num_bytes = reader.read_until(0, &mut bytes).expect("read failed");
+        if num_bytes == 0 {
+            break;
+        }
+        bytes.pop();
+        cmdline.push(String::from_utf8(bytes).expect("utf-8 decoding failed"))
+    }
+    cmdline
+}
+
+pub fn get_pid() -> u32 {
+    process::id()
+}