main.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. // We really want points to be capital letters and scalars to be
  2. // lowercase letters
  3. #![allow(non_snake_case)]
  4. pub mod aligned_memory_mt;
  5. pub mod params;
  6. pub mod spiral_mt;
  7. use aes::cipher::{BlockEncrypt, KeyInit};
  8. use aes::Aes128Enc;
  9. use aes::Block;
  10. use std::env;
  11. use std::mem;
  12. use std::time::Instant;
  13. use subtle::Choice;
  14. use subtle::ConditionallySelectable;
  15. use rand::RngCore;
  16. use sha2::Digest;
  17. use sha2::Sha256;
  18. use sha2::Sha512;
  19. use curve25519_dalek::constants as dalek_constants;
  20. use curve25519_dalek::ristretto::CompressedRistretto;
  21. use curve25519_dalek::ristretto::RistrettoBasepointTable;
  22. use curve25519_dalek::ristretto::RistrettoPoint;
  23. use curve25519_dalek::scalar::Scalar;
  24. use crossbeam::thread;
  25. use spiral_rs::client::*;
  26. use spiral_rs::params::*;
  27. use spiral_rs::server::*;
  28. use crate::spiral_mt::*;
  29. use lazy_static::lazy_static;
  30. type DbEntry = u64;
  31. // Generators of the Ristretto group (the standard B and another one C,
  32. // for which the DL relationship is unknown), and their precomputed
  33. // multiplication tables. Used for the Oblivious Transfer protocol
  34. lazy_static! {
  35. pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
  36. pub static ref OT_C: RistrettoPoint =
  37. RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
  38. pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
  39. pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
  40. }
  41. // XOR a 16-byte slice into a Block (which will be used as an AES key)
  42. fn xor16(outar: &mut Block, inar: &[u8; 16]) {
  43. for i in 0..16 {
  44. outar[i] ^= inar[i];
  45. }
  46. }
  47. // Encrypt a database of 2^r elements, where each element is a DbEntry,
  48. // using the 2*r provided keys (r pairs of keys). Also add the provided
  49. // blinding factor to each element before encryption (the same blinding
  50. // factor for all elements). Each element is encrypted in AES counter
  51. // mode, with the counter being the element number and the key computed
  52. // as the XOR of r of the provided keys, one from each pair, according
  53. // to the bits of the element number. Outputs a byte vector containing
  54. // the encrypted database.
  55. fn encdb_xor_keys(
  56. db: &[DbEntry],
  57. keys: &[[u8; 16]],
  58. r: usize,
  59. blind: DbEntry,
  60. num_threads: usize,
  61. ) -> Vec<u8> {
  62. let num_records: usize = 1 << r;
  63. let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
  64. ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
  65. thread::scope(|s| {
  66. let mut record_thread_start = 0usize;
  67. let records_per_thread_base = num_records / num_threads;
  68. let records_per_thread_extra = num_records % num_threads;
  69. let mut retslice = ret.as_mut_slice();
  70. for thr in 0..num_threads {
  71. let records_this_thread =
  72. records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
  73. let record_thread_end = record_thread_start + records_this_thread;
  74. let (thread_ret, retslice_) =
  75. retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
  76. retslice = retslice_;
  77. s.spawn(move |_| {
  78. let mut offset = 0usize;
  79. for j in record_thread_start..record_thread_end {
  80. let mut key = Block::from([0u8; 16]);
  81. for i in 0..r {
  82. let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
  83. xor16(&mut key, &keys[2 * i + bit]);
  84. }
  85. let aes = Aes128Enc::new(&key);
  86. let mut block = Block::from([0u8; 16]);
  87. block[0..8].copy_from_slice(&j.to_le_bytes());
  88. aes.encrypt_block(&mut block);
  89. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  90. let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
  91. thread_ret[offset..offset + mem::size_of::<DbEntry>()]
  92. .copy_from_slice(&encelem.to_le_bytes());
  93. offset += mem::size_of::<DbEntry>();
  94. }
  95. });
  96. record_thread_start = record_thread_end;
  97. }
  98. })
  99. .unwrap();
  100. ret
  101. }
  102. // Generate the keys for encrypting the database
  103. fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
  104. let mut keys: Vec<[u8; 16]> = Vec::new();
  105. let mut rng = rand::thread_rng();
  106. for _ in 0..2 * r {
  107. let mut k: [u8; 16] = [0; 16];
  108. rng.fill_bytes(&mut k);
  109. keys.push(k);
  110. }
  111. keys
  112. }
  113. // 1-out-of-2 Oblivious Transfer (OT)
  114. fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
  115. let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  116. let C: &RistrettoPoint = &OT_C;
  117. let mut rng = rand07::thread_rng();
  118. let x = Scalar::random(&mut rng);
  119. let xB = &x * Btable;
  120. let CmxB = C - xB;
  121. let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
  122. ((sel, x), P.compress().to_bytes())
  123. }
  124. fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
  125. let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  126. let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
  127. let mut rng = rand07::thread_rng();
  128. let y = Scalar::random(&mut rng);
  129. let yB = &y * Btable;
  130. let yC = &y * Ctable;
  131. let P = CompressedRistretto::from_slice(query).decompress().unwrap();
  132. let yP0 = y * P;
  133. let yP1 = yC - yP0;
  134. let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
  135. for i in 0..16 {
  136. HyP0[i] ^= m0[i];
  137. }
  138. let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
  139. for i in 0..16 {
  140. HyP1[i] ^= m1[i];
  141. }
  142. let mut ret = [0u8; 64];
  143. ret[0..32].copy_from_slice(yB.compress().as_bytes());
  144. ret[32..48].copy_from_slice(&HyP0[0..16]);
  145. ret[48..64].copy_from_slice(&HyP1[0..16]);
  146. ret
  147. }
  148. fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
  149. let yB = CompressedRistretto::from_slice(&response[0..32])
  150. .decompress()
  151. .unwrap();
  152. let yP = state.1 * yB;
  153. let mut HyP = Sha256::digest(yP.compress().as_bytes());
  154. for i in 0..16 {
  155. HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
  156. }
  157. HyP[0..16].try_into().unwrap()
  158. }
  159. // Obliviously fetch the key for element q of the database (which has
  160. // 2^r elements total). Each bit of q is used in a 1-out-of-2 OT to get
  161. // one of the keys in each of the r pairs of keys on the server side.
  162. // The resulting r keys are XORed together.
  163. fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
  164. let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
  165. let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
  166. for i in 0..r {
  167. let bit = ((q >> i) & 1) as u8;
  168. let (si, qi) = ot12_request(bit.into());
  169. state.push(si);
  170. query.push(qi);
  171. }
  172. (state, query)
  173. }
  174. fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
  175. let r = query.len();
  176. assert!(keys.len() == 2 * r);
  177. let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
  178. for i in 0..r {
  179. response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
  180. }
  181. response
  182. }
  183. fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
  184. let r = state.len();
  185. assert!(response.len() == r);
  186. let mut key = Block::from([0u8; 16]);
  187. for i in 0..r {
  188. xor16(&mut key, &ot12_receive(state[i], &response[i]));
  189. }
  190. key
  191. }
  192. // Having received the key for element q with r parallel 1-out-of-2 OTs,
  193. // and having received the encrypted element with (non-symmetric) PIR,
  194. // use the key to decrypt the element.
  195. fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
  196. let aes = Aes128Enc::new(key);
  197. let mut block = Block::from([0u8; 16]);
  198. block[0..8].copy_from_slice(&q.to_le_bytes());
  199. aes.encrypt_block(&mut block);
  200. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  201. encelement ^ aeskeystream
  202. }
  203. // Things that are only done once total, not once for each SPIR
  204. fn one_time_setup() {
  205. // Resolve the lazy statics
  206. let _B: &RistrettoPoint = &OT_B;
  207. let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  208. let _C: &RistrettoPoint = &OT_C;
  209. let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
  210. }
  211. fn print_params_summary(params: &Params) {
  212. let db_elem_size = params.item_size();
  213. let total_size = params.num_items() * db_elem_size;
  214. println!(
  215. "Using a {} x {} byte database ({} bytes total)",
  216. params.num_items(),
  217. db_elem_size,
  218. total_size
  219. );
  220. }
  221. fn main() {
  222. let args: Vec<String> = env::args().collect();
  223. if args.len() != 2 && args.len() != 3 {
  224. println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
  225. return;
  226. }
  227. let r: usize = args[1].parse().unwrap();
  228. let mut num_threads = 1usize;
  229. if args.len() == 3 {
  230. num_threads = args[2].parse().unwrap();
  231. }
  232. let num_records = 1 << r;
  233. println!("===== ONE-TIME SETUP =====\n");
  234. let otsetup_start = Instant::now();
  235. let spiral_params = params::get_spiral_params(r);
  236. let mut rng = rand::thread_rng();
  237. one_time_setup();
  238. let otsetup_us = otsetup_start.elapsed().as_micros();
  239. print_params_summary(&spiral_params);
  240. println!("OT one-time setup: {} µs", otsetup_us);
  241. // One-time setup for the Spiral client
  242. let spc_otsetup_start = Instant::now();
  243. let mut clientrng = rand::thread_rng();
  244. let mut client = Client::init(&spiral_params, &mut clientrng);
  245. let pub_params = client.generate_keys();
  246. let pub_params_buf = pub_params.serialize();
  247. let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
  248. let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
  249. println!(
  250. "Spiral client one-time setup: {} µs, {} bytes",
  251. spc_otsetup_us,
  252. pub_params_buf.len()
  253. );
  254. println!("\n===== PREPROCESSING =====\n");
  255. // Spiral preprocessing: create a PIR lookup for an element at a
  256. // random location
  257. let spc_query_start = Instant::now();
  258. let rand_idx = (rng.next_u64() as usize) % num_records;
  259. let rand_pir_idx = rand_idx / spiral_blocking_factor;
  260. println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
  261. let spc_query = client.generate_query(rand_pir_idx);
  262. let spc_query_buf = spc_query.serialize();
  263. let spc_query_us = spc_query_start.elapsed().as_micros();
  264. println!(
  265. "Spiral query: {} µs, {} bytes",
  266. spc_query_us,
  267. spc_query_buf.len()
  268. );
  269. // Create the database encryption keys and do the OT to fetch the
  270. // right one, but don't actually encrypt the database yet
  271. let dbkeys = gen_db_enc_keys(r);
  272. let otkeyreq_start = Instant::now();
  273. let (keystate, keyquery) = otkey_request(rand_idx, r);
  274. let keyquerysize = keyquery.len() * keyquery[0].len();
  275. let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
  276. let otkeysrv_start = Instant::now();
  277. let keyresponse = otkey_serve(keyquery, &dbkeys);
  278. let keyrespsize = keyresponse.len() * keyresponse[0].len();
  279. let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
  280. let otkeyrcv_start = Instant::now();
  281. let otkey = otkey_receive(keystate, &keyresponse);
  282. let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
  283. println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
  284. println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
  285. println!("key OT receive in {} µs", otkeyrcv_us);
  286. // Create a database with recognizable contents
  287. let mut db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
  288. .map(|x| 10000001 * x)
  289. .collect();
  290. println!("\n===== RUNTIME =====\n");
  291. // Pick the record we actually want to query
  292. let q = (rng.next_u64() as usize) % num_records;
  293. // Compute the offset from the record index we're actually looking
  294. // for to the random one we picked earlier. Tell it to the server,
  295. // who will rotate right the database by that amount before
  296. // encrypting it.
  297. let idx_offset = (num_records + rand_idx - q) % num_records;
  298. println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
  299. // The server rotates, blinds, and encrypts the database
  300. let blind: DbEntry = 20;
  301. let encdb_start = Instant::now();
  302. db.rotate_right(idx_offset);
  303. let encdb = encdb_xor_keys(&db, &dbkeys, r, blind, num_threads);
  304. let encdb_us = encdb_start.elapsed().as_micros();
  305. println!("Server encrypt database {} µs", encdb_us);
  306. // Load the encrypted database into Spiral
  307. let sps_loaddb_start = Instant::now();
  308. let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
  309. let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
  310. println!("Server load database {} µs", sps_loaddb_us);
  311. // Do the PIR query
  312. let sps_query_start = Instant::now();
  313. let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
  314. let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
  315. let sps_query_us = sps_query_start.elapsed().as_micros();
  316. println!(
  317. "Server compute response {} µs, {} bytes (*including* the above expansion time)",
  318. sps_query_us,
  319. sps_response.len()
  320. );
  321. // Decode the response to yield the whole Spiral block
  322. let spc_recv_start = Instant::now();
  323. let encdbblock = client.decode_response(sps_response.as_slice());
  324. // Extract the one encrypted DbEntry we were looking for (and the
  325. // only one we are able to decrypt)
  326. let entry_in_block = rand_idx % spiral_blocking_factor;
  327. let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
  328. let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
  329. let encdbentry = DbEntry::from_le_bytes(
  330. encdbblock[loc_in_block..loc_in_block_end]
  331. .try_into()
  332. .unwrap(),
  333. );
  334. let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
  335. let spc_recv_us = spc_recv_start.elapsed().as_micros();
  336. println!("Client decode response {} µs", spc_recv_us);
  337. println!("index = {}, Response = {}", q, decdbentry);
  338. }