main.rs 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. // We really want points to be capital letters and scalars to be
  2. // lowercase letters
  3. #![allow(non_snake_case)]
  4. use std::env;
  5. use std::time::Instant;
  6. use spiral_rs::client::*;
  7. use spiral_rs::server::*;
  8. use spiral_spir::*;
  9. fn main() {
  10. let args: Vec<String> = env::args().collect();
  11. if args.len() < 2 || args.len() > 4 {
  12. println!("Usage: {} r [num_threads [num_pirs]]\nr = log_2(num_records)", args[0]);
  13. return;
  14. }
  15. let r: usize = args[1].parse().unwrap();
  16. let mut num_threads = 1usize;
  17. let mut num_pirs = 1usize;
  18. if args.len() > 2 {
  19. num_threads = args[2].parse().unwrap();
  20. }
  21. if args.len() > 3 {
  22. num_pirs = args[3].parse().unwrap();
  23. }
  24. let num_records = 1 << r;
  25. println!("===== ONE-TIME SETUP =====\n");
  26. let otsetup_start = Instant::now();
  27. init(num_threads);
  28. let otsetup_us = otsetup_start.elapsed().as_micros();
  29. println!("OT one-time setup: {} µs", otsetup_us);
  30. /*
  31. let otsetup_start = Instant::now();
  32. let spiral_params = params::get_spiral_params(r);
  33. let mut rng = rand::thread_rng();
  34. init(num_threads);
  35. let otsetup_us = otsetup_start.elapsed().as_micros();
  36. print_params_summary(&spiral_params);
  37. println!("OT one-time setup: {} µs", otsetup_us);
  38. // One-time setup for the Spiral client
  39. let spc_otsetup_start = Instant::now();
  40. let mut clientrng = rand::thread_rng();
  41. let mut client = Client::init(&spiral_params, &mut clientrng);
  42. let pub_params = client.generate_keys();
  43. let pub_params_buf = pub_params.serialize();
  44. let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
  45. let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
  46. println!(
  47. "Spiral client one-time setup: {} µs, {} bytes",
  48. spc_otsetup_us,
  49. pub_params_buf.len()
  50. );
  51. println!("\n===== PREPROCESSING =====\n");
  52. // Spiral preprocessing: create a PIR lookup for an element at a
  53. // random location
  54. let spc_query_start = Instant::now();
  55. let rand_idx = (rng.next_u64() as usize) % num_records;
  56. let rand_pir_idx = rand_idx / spiral_blocking_factor;
  57. println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
  58. let spc_query = client.generate_query(rand_pir_idx);
  59. let spc_query_buf = spc_query.serialize();
  60. let spc_query_us = spc_query_start.elapsed().as_micros();
  61. println!(
  62. "Spiral query: {} µs, {} bytes",
  63. spc_query_us,
  64. spc_query_buf.len()
  65. );
  66. // Create the database encryption keys and do the OT to fetch the
  67. // right one, but don't actually encrypt the database yet
  68. let dbkeys = gen_db_enc_keys(r);
  69. let otkeyreq_start = Instant::now();
  70. let (keystate, keyquery) = otkey_request(rand_idx, r);
  71. let keyquerysize = keyquery.len() * keyquery[0].len();
  72. let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
  73. let otkeysrv_start = Instant::now();
  74. let keyresponse = otkey_serve(keyquery, &dbkeys);
  75. let keyrespsize = keyresponse.len() * keyresponse[0].len();
  76. let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
  77. let otkeyrcv_start = Instant::now();
  78. let otkey = otkey_receive(keystate, &keyresponse);
  79. let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
  80. println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
  81. println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
  82. println!("key OT receive in {} µs", otkeyrcv_us);
  83. // Create a database with recognizable contents
  84. let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
  85. .map(|x| 10000001 * x)
  86. .collect();
  87. println!("\n===== RUNTIME =====\n");
  88. // Pick the record we actually want to query
  89. let q = (rng.next_u64() as usize) % num_records;
  90. // Compute the offset from the record index we're actually looking
  91. // for to the random one we picked earlier. Tell it to the server,
  92. // who will rotate right the database by that amount before
  93. // encrypting it.
  94. let idx_offset = (num_records + rand_idx - q) % num_records;
  95. println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
  96. // The server rotates, blinds, and encrypts the database
  97. let blind: DbEntry = 20;
  98. let encdb_start = Instant::now();
  99. let encdb = encdb_xor_keys(&db, &dbkeys, r, idx_offset, blind, num_threads);
  100. let encdb_us = encdb_start.elapsed().as_micros();
  101. println!("Server encrypt database {} µs", encdb_us);
  102. // Load the encrypted database into Spiral
  103. let sps_loaddb_start = Instant::now();
  104. let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
  105. let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
  106. println!("Server load database {} µs", sps_loaddb_us);
  107. // Do the PIR query
  108. let sps_query_start = Instant::now();
  109. let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
  110. let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
  111. let sps_query_us = sps_query_start.elapsed().as_micros();
  112. println!(
  113. "Server compute response {} µs, {} bytes (*including* the above expansion time)",
  114. sps_query_us,
  115. sps_response.len()
  116. );
  117. // Decode the response to yield the whole Spiral block
  118. let spc_recv_start = Instant::now();
  119. let encdbblock = client.decode_response(sps_response.as_slice());
  120. // Extract the one encrypted DbEntry we were looking for (and the
  121. // only one we are able to decrypt)
  122. let entry_in_block = rand_idx % spiral_blocking_factor;
  123. let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
  124. let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
  125. let encdbentry = DbEntry::from_le_bytes(
  126. encdbblock[loc_in_block..loc_in_block_end]
  127. .try_into()
  128. .unwrap(),
  129. );
  130. let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
  131. let spc_recv_us = spc_recv_start.elapsed().as_micros();
  132. println!("Client decode response {} µs", spc_recv_us);
  133. println!("index = {}, Response = {}", q, decdbentry);
  134. */
  135. }