  1. // We really want points to be capital letters and scalars to be
  2. // lowercase letters
  3. #![allow(non_snake_case)]
  4. pub mod aligned_memory_mt;
  5. pub mod params;
  6. pub mod spiral_mt;
  7. use aes::cipher::{BlockEncrypt, KeyInit};
  8. use aes::Aes128Enc;
  9. use aes::Block;
  10. use std::env;
  11. use std::mem;
  12. use std::time::Instant;
  13. use subtle::Choice;
  14. use subtle::ConditionallySelectable;
  15. use rand::RngCore;
  16. use sha2::Digest;
  17. use sha2::Sha256;
  18. use sha2::Sha512;
  19. use curve25519_dalek::constants as dalek_constants;
  20. use curve25519_dalek::ristretto::CompressedRistretto;
  21. use curve25519_dalek::ristretto::RistrettoBasepointTable;
  22. use curve25519_dalek::ristretto::RistrettoPoint;
  23. use curve25519_dalek::scalar::Scalar;
  24. use spiral_rs::client::*;
  25. use spiral_rs::params::*;
  26. use spiral_rs::server::*;
  27. use crate::spiral_mt::*;
  28. use lazy_static::lazy_static;
  29. type DbEntry = u64;
  30. // Generators of the Ristretto group (the standard B and another one C,
  31. // for which the DL relationship is unknown), and their precomputed
  32. // multiplication tables. Used for the Oblivious Transfer protocol
  33. lazy_static! {
  34. pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
  35. pub static ref OT_C: RistrettoPoint =
  36. RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
  37. pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
  38. pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
  39. }
  40. // XOR a 16-byte slice into a Block (which will be used as an AES key)
  41. fn xor16(outar: &mut Block, inar: &[u8; 16]) {
  42. for i in 0..16 {
  43. outar[i] ^= inar[i];
  44. }
  45. }
  46. // Encrypt a database of 2^r elements, where each element is a DbEntry,
  47. // using the 2*r provided keys (r pairs of keys). Also add the provided
  48. // blinding factor to each element before encryption (the same blinding
  49. // factor for all elements). Each element is encrypted in AES counter
  50. // mode, with the counter being the element number and the key computed
  51. // as the XOR of r of the provided keys, one from each pair, according
  52. // to the bits of the element number. Outputs a byte vector containing
  53. // the encrypted database.
  54. fn encdb_xor_keys(db: &[DbEntry], keys: &[[u8; 16]], r: usize, blind: DbEntry) -> Vec<u8> {
  55. let num_records: usize = 1 << r;
  56. let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
  57. for j in 0..num_records {
  58. let mut key = Block::from([0u8; 16]);
  59. for i in 0..r {
  60. let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
  61. xor16(&mut key, &keys[2 * i + bit]);
  62. }
  63. let aes = Aes128Enc::new(&key);
  64. let mut block = Block::from([0u8; 16]);
  65. block[0..8].copy_from_slice(&j.to_le_bytes());
  66. aes.encrypt_block(&mut block);
  67. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  68. let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
  69. ret.extend(encelem.to_le_bytes());
  70. }
  71. ret
  72. }
  73. // Generate the keys for encrypting the database
  74. fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
  75. let mut keys: Vec<[u8; 16]> = Vec::new();
  76. let mut rng = rand::thread_rng();
  77. for _ in 0..2 * r {
  78. let mut k: [u8; 16] = [0; 16];
  79. rng.fill_bytes(&mut k);
  80. keys.push(k);
  81. }
  82. keys
  83. }
  84. // 1-out-of-2 Oblivious Transfer (OT)
  85. fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
  86. let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  87. let C: &RistrettoPoint = &OT_C;
  88. let mut rng = rand07::thread_rng();
  89. let x = Scalar::random(&mut rng);
  90. let xB = &x * Btable;
  91. let CmxB = C - xB;
  92. let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
  93. ((sel, x), P.compress().to_bytes())
  94. }
  95. fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
  96. let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  97. let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
  98. let mut rng = rand07::thread_rng();
  99. let y = Scalar::random(&mut rng);
  100. let yB = &y * Btable;
  101. let yC = &y * Ctable;
  102. let P = CompressedRistretto::from_slice(query).decompress().unwrap();
  103. let yP0 = y * P;
  104. let yP1 = yC - yP0;
  105. let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
  106. for i in 0..16 {
  107. HyP0[i] ^= m0[i];
  108. }
  109. let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
  110. for i in 0..16 {
  111. HyP1[i] ^= m1[i];
  112. }
  113. let mut ret = [0u8; 64];
  114. ret[0..32].copy_from_slice(yB.compress().as_bytes());
  115. ret[32..48].copy_from_slice(&HyP0[0..16]);
  116. ret[48..64].copy_from_slice(&HyP1[0..16]);
  117. ret
  118. }
  119. fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
  120. let yB = CompressedRistretto::from_slice(&response[0..32])
  121. .decompress()
  122. .unwrap();
  123. let yP = state.1 * yB;
  124. let mut HyP = Sha256::digest(yP.compress().as_bytes());
  125. for i in 0..16 {
  126. HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
  127. }
  128. HyP[0..16].try_into().unwrap()
  129. }
  130. // Obliviously fetch the key for element q of the database (which has
  131. // 2^r elements total). Each bit of q is used in a 1-out-of-2 OT to get
  132. // one of the keys in each of the r pairs of keys on the server side.
  133. // The resulting r keys are XORed together.
  134. fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
  135. let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
  136. let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
  137. for i in 0..r {
  138. let bit = ((q >> i) & 1) as u8;
  139. let (si, qi) = ot12_request(bit.into());
  140. state.push(si);
  141. query.push(qi);
  142. }
  143. (state, query)
  144. }
  145. fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
  146. let r = query.len();
  147. assert!(keys.len() == 2 * r);
  148. let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
  149. for i in 0..r {
  150. response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
  151. }
  152. response
  153. }
  154. fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
  155. let r = state.len();
  156. assert!(response.len() == r);
  157. let mut key = Block::from([0u8; 16]);
  158. for i in 0..r {
  159. xor16(&mut key, &ot12_receive(state[i], &response[i]));
  160. }
  161. key
  162. }
  163. // Having received the key for element q with r parallel 1-out-of-2 OTs,
  164. // and having received the encrypted element with (non-symmetric) PIR,
  165. // use the key to decrypt the element.
  166. fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
  167. let aes = Aes128Enc::new(key);
  168. let mut block = Block::from([0u8; 16]);
  169. block[0..8].copy_from_slice(&q.to_le_bytes());
  170. aes.encrypt_block(&mut block);
  171. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  172. encelement ^ aeskeystream
  173. }
  174. // Things that are only done once total, not once for each SPIR
  175. fn one_time_setup() {
  176. // Resolve the lazy statics
  177. let _B: &RistrettoPoint = &OT_B;
  178. let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  179. let _C: &RistrettoPoint = &OT_C;
  180. let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
  181. }
  182. fn print_params_summary(params: &Params) {
  183. let db_elem_size = params.item_size();
  184. let total_size = params.num_items() * db_elem_size;
  185. println!(
  186. "Using a {} x {} byte database ({} bytes total)",
  187. params.num_items(),
  188. db_elem_size,
  189. total_size
  190. );
  191. }
  192. fn main() {
  193. let args: Vec<String> = env::args().collect();
  194. if args.len() != 2 && args.len() != 3 {
  195. println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
  196. return;
  197. }
  198. let r: usize = args[1].parse().unwrap();
  199. let mut num_threads = 1usize;
  200. if args.len() == 3 {
  201. num_threads = args[2].parse().unwrap();
  202. }
  203. let num_records = 1 << r;
  204. println!("===== ONE-TIME SETUP =====\n");
  205. let otsetup_start = Instant::now();
  206. let spiral_params = params::get_spiral_params(r);
  207. let mut rng = rand::thread_rng();
  208. one_time_setup();
  209. let otsetup_us = otsetup_start.elapsed().as_micros();
  210. print_params_summary(&spiral_params);
  211. println!("OT one-time setup: {} µs", otsetup_us);
  212. // One-time setup for the Spiral client
  213. let spc_otsetup_start = Instant::now();
  214. let mut clientrng = rand::thread_rng();
  215. let mut client = Client::init(&spiral_params, &mut clientrng);
  216. let pub_params = client.generate_keys();
  217. let pub_params_buf = pub_params.serialize();
  218. let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
  219. let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
  220. println!(
  221. "Spiral client one-time setup: {} µs, {} bytes",
  222. spc_otsetup_us,
  223. pub_params_buf.len()
  224. );
  225. println!("\n===== PREPROCESSING =====\n");
  226. // Spiral preprocessing: create a PIR lookup for an element at a
  227. // random location
  228. let spc_query_start = Instant::now();
  229. let rand_idx = (rng.next_u64() as usize) % num_records;
  230. let rand_pir_idx = rand_idx / spiral_blocking_factor;
  231. println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
  232. let spc_query = client.generate_query(rand_pir_idx);
  233. let spc_query_buf = spc_query.serialize();
  234. let spc_query_us = spc_query_start.elapsed().as_micros();
  235. println!(
  236. "Spiral query: {} µs, {} bytes",
  237. spc_query_us,
  238. spc_query_buf.len()
  239. );
  240. // Create the database encryption keys and do the OT to fetch the
  241. // right one, but don't actually encrypt the database yet
  242. let dbkeys = gen_db_enc_keys(r);
  243. let otkeyreq_start = Instant::now();
  244. let (keystate, keyquery) = otkey_request(rand_idx, r);
  245. let keyquerysize = keyquery.len() * keyquery[0].len();
  246. let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
  247. let otkeysrv_start = Instant::now();
  248. let keyresponse = otkey_serve(keyquery, &dbkeys);
  249. let keyrespsize = keyresponse.len() * keyresponse[0].len();
  250. let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
  251. let otkeyrcv_start = Instant::now();
  252. let otkey = otkey_receive(keystate, &keyresponse);
  253. let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
  254. println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
  255. println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
  256. println!("key OT receive in {} µs", otkeyrcv_us);
  257. // Create a database with recognizable contents
  258. let mut db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
  259. .map(|x| 10000001 * x)
  260. .collect();
  261. println!("\n===== RUNTIME =====\n");
  262. // Pick the record we actually want to query
  263. let q = (rng.next_u64() as usize) % num_records;
  264. // Compute the offset from the record index we're actually looking
  265. // for to the random one we picked earlier. Tell it to the server,
  266. // who will rotate right the database by that amount before
  267. // encrypting it.
  268. let idx_offset = (num_records + rand_idx - q) % num_records;
  269. println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
  270. // The server rotates, blinds, and encrypts the database
  271. let blind: DbEntry = 20;
  272. let encdb_start = Instant::now();
  273. db.rotate_right(idx_offset);
  274. let encdb = encdb_xor_keys(&db, &dbkeys, r, blind);
  275. let encdb_us = encdb_start.elapsed().as_micros();
  276. println!("Server encrypt database {} µs", encdb_us);
  277. // Load the encrypted database into Spiral
  278. let sps_loaddb_start = Instant::now();
  279. let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
  280. let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
  281. println!("Server load database {} µs", sps_loaddb_us);
  282. // Do the PIR query
  283. let sps_query_start = Instant::now();
  284. let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
  285. let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
  286. let sps_query_us = sps_query_start.elapsed().as_micros();
  287. println!(
  288. "Server compute response {} µs, {} bytes (*including* the above expansion time)",
  289. sps_query_us,
  290. sps_response.len()
  291. );
  292. // Decode the response to yield the whole Spiral block
  293. let spc_recv_start = Instant::now();
  294. let encdbblock = client.decode_response(sps_response.as_slice());
  295. // Extract the one encrypted DbEntry we were looking for (and the
  296. // only one we are able to decrypt)
  297. let entry_in_block = rand_idx % spiral_blocking_factor;
  298. let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
  299. let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
  300. let encdbentry = DbEntry::from_le_bytes(
  301. encdbblock[loc_in_block..loc_in_block_end]
  302. .try_into()
  303. .unwrap(),
  304. );
  305. let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
  306. let spc_recv_us = spc_recv_start.elapsed().as_micros();
  307. println!("Client decode response {} µs", spc_recv_us);
  308. println!("index = {}, Response = {}", q, decdbentry);
  309. }