bench_doprf.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. use clap::{CommandFactory, Parser};
  2. use communicator::tcp::{make_tcp_communicator, NetworkOptions, NetworkPartyInfo};
  3. use communicator::AbstractCommunicator;
  4. use ff::Field;
  5. use oram::doprf::{
  6. DOPrfParty1, DOPrfParty2, DOPrfParty3, JointDOPrf, MaskedDOPrfParty1, MaskedDOPrfParty2,
  7. MaskedDOPrfParty3,
  8. };
  9. use rand::SeedableRng;
  10. use rand_chacha::ChaChaRng;
  11. use std::process;
  12. use std::time::{Duration, Instant};
  13. use utils::field::Fp;
  14. const PARTY_1: usize = 0;
  15. const PARTY_2: usize = 1;
  16. const PARTY_3: usize = 2;
  17. #[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum, strum_macros::Display)]
  18. enum Mode {
  19. Alternating,
  20. Joint,
  21. Masked,
  22. Plain,
  23. }
  24. #[derive(Debug, clap::Parser)]
  25. struct Cli {
  26. /// ID of this party
  27. #[arg(long, short = 'i', value_parser = clap::value_parser!(u32).range(0..3))]
  28. pub party_id: u32,
  29. /// Output bitsize of the DOPrf
  30. #[arg(long, short = 's', value_parser = clap::value_parser!(u32).range(1..))]
  31. pub bitsize: u32,
  32. /// Number of evaluations to compute
  33. #[arg(long, short = 'n', value_parser = clap::value_parser!(u32).range(1..))]
  34. pub num_evaluations: u32,
  35. /// Which protocol variant to benchmark
  36. #[arg(long, short = 'm', value_enum)]
  37. pub mode: Mode,
  38. /// Which address to listen on for incoming connections
  39. #[arg(long, short = 'l')]
  40. pub listen_host: String,
  41. /// Which port to listen on for incoming connections
  42. #[arg(long, short = 'p', value_parser = clap::value_parser!(u16).range(1..))]
  43. pub listen_port: u16,
  44. /// Connection info for each party
  45. #[arg(long, short = 'c', value_name = "PARTY_ID>:<HOST>:<PORT", value_parser = parse_connect)]
  46. pub connect: Vec<(usize, String, u16)>,
  47. /// How long to try connecting before aborting
  48. #[arg(long, short = 't', default_value_t = 10)]
  49. pub connect_timeout_seconds: usize,
  50. }
  51. fn parse_connect(
  52. s: &str,
  53. ) -> Result<(usize, String, u16), Box<dyn std::error::Error + Send + Sync + 'static>> {
  54. let parts: Vec<_> = s.split(":").collect();
  55. if parts.len() != 3 {
  56. return Err(clap::Error::raw(
  57. clap::error::ErrorKind::ValueValidation,
  58. format!("'{}' has not the format '<party-id>:<host>:<post>'", s),
  59. )
  60. .into());
  61. }
  62. let party_id: usize = parts[0].parse()?;
  63. let host = parts[1];
  64. let port: u16 = parts[2].parse()?;
  65. if port == 0 {
  66. return Err(clap::Error::raw(
  67. clap::error::ErrorKind::ValueValidation,
  68. "the port needs to be positive",
  69. )
  70. .into());
  71. }
  72. Ok((party_id, host.to_owned(), port))
  73. }
  74. fn make_random_shares(n: usize) -> Vec<Fp> {
  75. let mut rng = ChaChaRng::from_seed([0u8; 32]);
  76. (0..n).map(|_| Fp::random(&mut rng)).collect()
  77. }
  78. fn bench_plain<C: AbstractCommunicator>(
  79. comm: &mut C,
  80. bitsize: usize,
  81. num_evaluations: usize,
  82. ) -> (Duration, Duration, Duration) {
  83. let shares = make_random_shares(num_evaluations);
  84. match comm.get_my_id() {
  85. PARTY_1 => {
  86. let mut p1 = DOPrfParty1::<Fp>::new(bitsize);
  87. let t_start = Instant::now();
  88. p1.init(comm).expect("init failed");
  89. let t_after_init = Instant::now();
  90. p1.preprocess(comm, num_evaluations)
  91. .expect("preprocess failed");
  92. let t_after_preprocess = Instant::now();
  93. for i in 0..num_evaluations {
  94. p1.eval(comm, 1, &[shares[i]]).expect("eval failed");
  95. }
  96. let t_after_eval = Instant::now();
  97. (
  98. t_after_init - t_start,
  99. t_after_preprocess - t_after_init,
  100. t_after_eval - t_after_preprocess,
  101. )
  102. }
  103. PARTY_2 => {
  104. let mut p2 = DOPrfParty2::<Fp>::new(bitsize);
  105. let t_start = Instant::now();
  106. p2.init(comm).expect("init failed");
  107. let t_after_init = Instant::now();
  108. p2.preprocess(comm, num_evaluations)
  109. .expect("preprocess failed");
  110. let t_after_preprocess = Instant::now();
  111. for i in 0..num_evaluations {
  112. p2.eval(comm, 1, &[shares[i]]).expect("eval failed");
  113. }
  114. let t_after_eval = Instant::now();
  115. (
  116. t_after_init - t_start,
  117. t_after_preprocess - t_after_init,
  118. t_after_eval - t_after_preprocess,
  119. )
  120. }
  121. PARTY_3 => {
  122. let mut p3 = DOPrfParty3::<Fp>::new(bitsize);
  123. let t_start = Instant::now();
  124. p3.init(comm).expect("init failed");
  125. let t_after_init = Instant::now();
  126. p3.preprocess(comm, num_evaluations)
  127. .expect("preprocess failed");
  128. let t_after_preprocess = Instant::now();
  129. for i in 0..num_evaluations {
  130. p3.eval(comm, 1, &[shares[i]]).expect("eval failed");
  131. }
  132. let t_after_eval = Instant::now();
  133. (
  134. t_after_init - t_start,
  135. t_after_preprocess - t_after_init,
  136. t_after_eval - t_after_preprocess,
  137. )
  138. }
  139. _ => panic!("invalid party id"),
  140. }
  141. }
  142. fn bench_masked<C: AbstractCommunicator>(
  143. comm: &mut C,
  144. bitsize: usize,
  145. num_evaluations: usize,
  146. ) -> (Duration, Duration, Duration) {
  147. let shares = make_random_shares(num_evaluations);
  148. match comm.get_my_id() {
  149. PARTY_1 => {
  150. let mut p1 = MaskedDOPrfParty1::<Fp>::new(bitsize);
  151. let t_start = Instant::now();
  152. p1.init(comm).expect("init failed");
  153. let t_after_init = Instant::now();
  154. p1.preprocess(comm, num_evaluations)
  155. .expect("preprocess failed");
  156. let t_after_preprocess = Instant::now();
  157. for i in 0..num_evaluations {
  158. p1.eval(comm, 1, &[shares[i]]).expect("eval failed");
  159. }
  160. let t_after_eval = Instant::now();
  161. (
  162. t_after_init - t_start,
  163. t_after_preprocess - t_after_init,
  164. t_after_eval - t_after_preprocess,
  165. )
  166. }
  167. PARTY_2 => {
  168. let mut p2 = MaskedDOPrfParty2::<Fp>::new(bitsize);
  169. let t_start = Instant::now();
  170. p2.init(comm).expect("init failed");
  171. let t_after_init = Instant::now();
  172. p2.preprocess(comm, num_evaluations)
  173. .expect("preprocess failed");
  174. let t_after_preprocess = Instant::now();
  175. for i in 0..num_evaluations {
  176. p2.eval(comm, 1, &[shares[i]]).expect("eval failed");
  177. }
  178. let t_after_eval = Instant::now();
  179. (
  180. t_after_init - t_start,
  181. t_after_preprocess - t_after_init,
  182. t_after_eval - t_after_preprocess,
  183. )
  184. }
  185. PARTY_3 => {
  186. let mut p3 = MaskedDOPrfParty3::<Fp>::new(bitsize);
  187. let t_start = Instant::now();
  188. p3.init(comm).expect("init failed");
  189. let t_after_init = Instant::now();
  190. p3.preprocess(comm, num_evaluations)
  191. .expect("preprocess failed");
  192. let t_after_preprocess = Instant::now();
  193. for i in 0..num_evaluations {
  194. p3.eval(comm, 1, &[shares[i]]).expect("eval failed");
  195. }
  196. let t_after_eval = Instant::now();
  197. (
  198. t_after_init - t_start,
  199. t_after_preprocess - t_after_init,
  200. t_after_eval - t_after_preprocess,
  201. )
  202. }
  203. _ => panic!("invalid party id"),
  204. }
  205. }
  206. fn bench_joint<C: AbstractCommunicator>(
  207. comm: &mut C,
  208. bitsize: usize,
  209. num_evaluations: usize,
  210. ) -> (Duration, Duration, Duration) {
  211. let shares = make_random_shares(num_evaluations);
  212. let mut p = JointDOPrf::<Fp>::new(bitsize);
  213. let t_start = Instant::now();
  214. p.init(comm).expect("init failed");
  215. let t_after_init = Instant::now();
  216. p.preprocess(comm, num_evaluations)
  217. .expect("preprocess failed");
  218. let t_after_preprocess = Instant::now();
  219. for i in 0..num_evaluations {
  220. p.eval_to_uint::<_, u128>(comm, &[shares[i]])
  221. .expect("eval failed");
  222. }
  223. let t_after_eval = Instant::now();
  224. (
  225. t_after_init - t_start,
  226. t_after_preprocess - t_after_init,
  227. t_after_eval - t_after_preprocess,
  228. )
  229. }
  230. fn bench_alternating<C: AbstractCommunicator>(
  231. comm: &mut C,
  232. bitsize: usize,
  233. num_evaluations: usize,
  234. ) -> (Duration, Duration, Duration) {
  235. let shares = make_random_shares(num_evaluations);
  236. let mut p1 = DOPrfParty1::<Fp>::new(bitsize);
  237. let mut p2 = DOPrfParty2::<Fp>::new(bitsize);
  238. let mut p3 = DOPrfParty3::<Fp>::new(bitsize);
  239. match comm.get_my_id() {
  240. PARTY_1 => {
  241. let t_start = Instant::now();
  242. p1.init(comm).expect("init failed");
  243. p2.init(comm).expect("init failed");
  244. p3.init(comm).expect("init failed");
  245. let t_after_init = Instant::now();
  246. p1.preprocess(comm, num_evaluations)
  247. .expect("preprocess failed");
  248. p2.preprocess(comm, num_evaluations)
  249. .expect("preprocess failed");
  250. p3.preprocess(comm, num_evaluations)
  251. .expect("preprocess failed");
  252. let t_after_preprocess = Instant::now();
  253. for i in 0..num_evaluations {
  254. p1.eval(comm, 1, &[shares[i]]).expect("eval failed");
  255. p2.eval(comm, 1, &[shares[i]]).expect("eval failed");
  256. p3.eval(comm, 1, &[shares[i]]).expect("eval failed");
  257. }
  258. let t_after_eval = Instant::now();
  259. (
  260. t_after_init - t_start,
  261. t_after_preprocess - t_after_init,
  262. t_after_eval - t_after_preprocess,
  263. )
  264. }
  265. PARTY_2 => {
  266. let t_start = Instant::now();
  267. p2.init(comm).expect("init failed");
  268. p3.init(comm).expect("init failed");
  269. p1.init(comm).expect("init failed");
  270. let t_after_init = Instant::now();
  271. p2.preprocess(comm, num_evaluations)
  272. .expect("preprocess failed");
  273. p3.preprocess(comm, num_evaluations)
  274. .expect("preprocess failed");
  275. p1.preprocess(comm, num_evaluations)
  276. .expect("preprocess failed");
  277. let t_after_preprocess = Instant::now();
  278. for i in 0..num_evaluations {
  279. p2.eval(comm, 1, &[shares[i]]).expect("eval failed");
  280. p3.eval(comm, 1, &[shares[i]]).expect("eval failed");
  281. p1.eval(comm, 1, &[shares[i]]).expect("eval failed");
  282. }
  283. let t_after_eval = Instant::now();
  284. (
  285. t_after_init - t_start,
  286. t_after_preprocess - t_after_init,
  287. t_after_eval - t_after_preprocess,
  288. )
  289. }
  290. PARTY_3 => {
  291. let t_start = Instant::now();
  292. p3.init(comm).expect("init failed");
  293. p1.init(comm).expect("init failed");
  294. p2.init(comm).expect("init failed");
  295. let t_after_init = Instant::now();
  296. p3.preprocess(comm, num_evaluations)
  297. .expect("preprocess failed");
  298. p1.preprocess(comm, num_evaluations)
  299. .expect("preprocess failed");
  300. p2.preprocess(comm, num_evaluations)
  301. .expect("preprocess failed");
  302. let t_after_preprocess = Instant::now();
  303. for i in 0..num_evaluations {
  304. p3.eval(comm, 1, &[shares[i]]).expect("eval failed");
  305. p1.eval(comm, 1, &[shares[i]]).expect("eval failed");
  306. p2.eval(comm, 1, &[shares[i]]).expect("eval failed");
  307. }
  308. let t_after_eval = Instant::now();
  309. (
  310. t_after_init - t_start,
  311. t_after_preprocess - t_after_init,
  312. t_after_eval - t_after_preprocess,
  313. )
  314. }
  315. _ => panic!("invalid party id"),
  316. }
  317. }
  318. fn main() {
  319. let cli = Cli::parse();
  320. let mut netopts = NetworkOptions {
  321. listen_host: cli.listen_host,
  322. listen_port: cli.listen_port,
  323. connect_info: vec![NetworkPartyInfo::Listen; 3],
  324. connect_timeout_seconds: cli.connect_timeout_seconds,
  325. };
  326. for c in cli.connect {
  327. if netopts.connect_info[c.0] != NetworkPartyInfo::Listen {
  328. println!(
  329. "{}",
  330. clap::Error::raw(
  331. clap::error::ErrorKind::ValueValidation,
  332. format!("multiple connect arguments for party {}", c.0),
  333. )
  334. .format(&mut Cli::command())
  335. );
  336. process::exit(1);
  337. }
  338. netopts.connect_info[c.0] = NetworkPartyInfo::Connect(c.1, c.2);
  339. }
  340. let mut comm = match make_tcp_communicator(3, cli.party_id as usize, &netopts) {
  341. Ok(comm) => comm,
  342. Err(e) => {
  343. eprintln!("network setup failed: {:?}", e);
  344. process::exit(1);
  345. }
  346. };
  347. let (d_init, d_preprocess, d_eval) = match cli.mode {
  348. Mode::Plain => bench_plain(
  349. &mut comm,
  350. cli.bitsize as usize,
  351. cli.num_evaluations as usize,
  352. ),
  353. Mode::Masked => bench_masked(
  354. &mut comm,
  355. cli.bitsize as usize,
  356. cli.num_evaluations as usize,
  357. ),
  358. Mode::Joint => bench_joint(
  359. &mut comm,
  360. cli.bitsize as usize,
  361. cli.num_evaluations as usize,
  362. ),
  363. Mode::Alternating => bench_alternating(
  364. &mut comm,
  365. cli.bitsize as usize,
  366. cli.num_evaluations as usize,
  367. ),
  368. };
  369. comm.shutdown();
  370. println!("=========== DOPrf ============");
  371. println!("mode: {}", cli.mode);
  372. println!("- {} bit output", cli.bitsize);
  373. println!("- {} evaluations", cli.num_evaluations);
  374. println!("time init: {:3.3} s", d_init.as_secs_f64());
  375. println!("time preprocess: {:3.3} s", d_preprocess.as_secs_f64());
  376. println!(
  377. " per evaluation: {:3.3} s",
  378. d_preprocess.as_secs_f64() / cli.num_evaluations as f64
  379. );
  380. println!("time eval: {:3.3} s", d_eval.as_secs_f64());
  381. println!(
  382. " per evaluation: {:3.3} s",
  383. d_eval.as_secs_f64() / cli.num_evaluations as f64
  384. );
  385. println!("==============================");
  386. }