// We really want points to be capital letters and scalars to be // lowercase letters #![allow(non_snake_case)] pub mod params; use aes::cipher::{BlockEncrypt, KeyInit}; use aes::Aes128Enc; use aes::Block; use std::env; use std::io::Cursor; use std::mem; use std::time::Instant; use subtle::Choice; use subtle::ConditionallySelectable; use rand::RngCore; use sha2::Digest; use sha2::Sha256; use sha2::Sha512; use curve25519_dalek::constants as dalek_constants; use curve25519_dalek::ristretto::CompressedRistretto; use curve25519_dalek::ristretto::RistrettoBasepointTable; use curve25519_dalek::ristretto::RistrettoPoint; use curve25519_dalek::scalar::Scalar; use spiral_rs::client::*; use spiral_rs::params::*; use spiral_rs::server::*; use lazy_static::lazy_static; type DbEntry = u64; // Generators of the Ristretto group (the standard B and another one C, // for which the DL relationship is unknown), and their precomputed // multiplication tables. Used for the Oblivious Transfer protocol lazy_static! { pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT; pub static ref OT_C: RistrettoPoint = RistrettoPoint::hash_from_bytes::(b"OT Generator C"); pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE; pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C); } // XOR a 16-byte slice into a Block (which will be used as an AES key) fn xor16(outar: &mut Block, inar: &[u8; 16]) { for i in 0..16 { outar[i] ^= inar[i]; } } // Encrypt a database of 2^r elements, where each element is a DbEntry, // using the 2*r provided keys (r pairs of keys). Also add the provided // blinding factor to each element before encryption (the same blinding // factor for all elements). Each element is encrypted in AES counter // mode, with the counter being the element number and the key computed // as the XOR of r of the provided keys, one from each pair, according // to the bits of the element number. Outputs a byte vector containing // the encrypted database. fn encdb_xor_keys(db: &[DbEntry], keys: &[[u8; 16]], r: usize, blind: DbEntry) -> Vec { let num_records: usize = 1 << r; let mut ret = Vec::::with_capacity(num_records * mem::size_of::()); for j in 0..num_records { let mut key = Block::from([0u8; 16]); for i in 0..r { let bit = if (j & (1 << i)) == 0 { 0 } else { 1 }; xor16(&mut key, &keys[2 * i + bit]); } let aes = Aes128Enc::new(&key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&j.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream; ret.extend(encelem.to_le_bytes()); } ret } // Generate the keys for encrypting the database fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> { let mut keys: Vec<[u8; 16]> = Vec::new(); let mut rng = rand::thread_rng(); for _ in 0..2 * r { let mut k: [u8; 16] = [0; 16]; rng.fill_bytes(&mut k); keys.push(k); } keys } // 1-out-of-2 Oblivious Transfer (OT) fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) { let Btable: &RistrettoBasepointTable = &OT_B_TABLE; let C: &RistrettoPoint = &OT_C; let mut rng = rand07::thread_rng(); let x = Scalar::random(&mut rng); let xB = &x * Btable; let CmxB = C - xB; let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel); ((sel, x), P.compress().to_bytes()) } fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] { let Btable: &RistrettoBasepointTable = &OT_B_TABLE; let Ctable: &RistrettoBasepointTable = &OT_C_TABLE; let mut rng = rand07::thread_rng(); let y = Scalar::random(&mut rng); let yB = &y * Btable; let yC = &y * Ctable; let P = CompressedRistretto::from_slice(query).decompress().unwrap(); let yP0 = y * P; let yP1 = yC - yP0; let mut HyP0 = Sha256::digest(yP0.compress().as_bytes()); for i in 0..16 { HyP0[i] ^= m0[i]; } let mut HyP1 = Sha256::digest(yP1.compress().as_bytes()); for i in 0..16 { HyP1[i] ^= m1[i]; } let mut ret = [0u8; 64]; ret[0..32].copy_from_slice(yB.compress().as_bytes()); ret[32..48].copy_from_slice(&HyP0[0..16]); ret[48..64].copy_from_slice(&HyP1[0..16]); ret } fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] { let yB = CompressedRistretto::from_slice(&response[0..32]) .decompress() .unwrap(); let yP = state.1 * yB; let mut HyP = Sha256::digest(yP.compress().as_bytes()); for i in 0..16 { HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0); } HyP[0..16].try_into().unwrap() } // Obliviously fetch the key for element q of the database (which has // 2^r elements total). Each bit of q is used in a 1-out-of-2 OT to get // one of the keys in each of the r pairs of keys on the server side. // The resulting r keys are XORed together. fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) { let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r); let mut query: Vec<[u8; 32]> = Vec::with_capacity(r); for i in 0..r { let bit = ((q >> i) & 1) as u8; let (si, qi) = ot12_request(bit.into()); state.push(si); query.push(qi); } (state, query) } fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> { let r = query.len(); assert!(keys.len() == 2 * r); let mut response: Vec<[u8; 64]> = Vec::with_capacity(r); for i in 0..r { response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1])); } response } fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block { let r = state.len(); assert!(response.len() == r); let mut key = Block::from([0u8; 16]); for i in 0..r { xor16(&mut key, &ot12_receive(state[i], &response[i])); } key } // Having received the key for element q with r parallel 1-out-of-2 OTs, // and having received the encrypted element with (non-symmetric) PIR, // use the key to decrypt the element. fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry { let aes = Aes128Enc::new(key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&q.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); encelement ^ aeskeystream } // Things that are only done once total, not once for each SPIR fn one_time_setup() { // Resolve the lazy statics let _B: &RistrettoPoint = &OT_B; let _Btable: &RistrettoBasepointTable = &OT_B_TABLE; let _C: &RistrettoPoint = &OT_C; let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE; } fn print_params_summary(params: &Params) { let db_elem_size = params.item_size(); let total_size = params.num_items() * db_elem_size; println!( "Using a {} x {} byte database ({} bytes total)", params.num_items(), db_elem_size, total_size ); } fn main() { let args: Vec = env::args().collect(); if args.len() != 2 { println!("Usage: {} r\nr = log_2(num_records)", args[0]); return; } let r: usize = args[1].parse().unwrap(); let num_records = 1 << r; println!("===== ONE-TIME SETUP =====\n"); let otsetup_start = Instant::now(); let spiral_params = params::get_spiral_params(r); let mut rng = rand::thread_rng(); one_time_setup(); let otsetup_us = otsetup_start.elapsed().as_micros(); print_params_summary(&spiral_params); println!("OT one-time setup: {} µs", otsetup_us); // One-time setup for the Spiral client let spc_otsetup_start = Instant::now(); let mut clientrng = rand::thread_rng(); let mut client = Client::init(&spiral_params, &mut clientrng); let pub_params = client.generate_keys(); let pub_params_buf = pub_params.serialize(); let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros(); let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::(); println!( "Spiral client one-time setup: {} µs, {} bytes", spc_otsetup_us, pub_params_buf.len() ); println!("\n===== PREPROCESSING =====\n"); // Spiral preprocessing: create a PIR lookup for an element at a // random location let spc_query_start = Instant::now(); let rand_idx = (rng.next_u64() as usize) % num_records; let rand_pir_idx = rand_idx / spiral_blocking_factor; println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx); let spc_query = client.generate_query(rand_pir_idx); let spc_query_buf = spc_query.serialize(); let spc_query_us = spc_query_start.elapsed().as_micros(); println!( "Spiral query: {} µs, {} bytes", spc_query_us, spc_query_buf.len() ); // Create the database encryption keys and do the OT to fetch the // right one, but don't actually encrypt the database yet let dbkeys = gen_db_enc_keys(r); let otkeyreq_start = Instant::now(); let (keystate, keyquery) = otkey_request(rand_idx, r); let keyquerysize = keyquery.len() * keyquery[0].len(); let otkeyreq_us = otkeyreq_start.elapsed().as_micros(); let otkeysrv_start = Instant::now(); let keyresponse = otkey_serve(keyquery, &dbkeys); let keyrespsize = keyresponse.len() * keyresponse[0].len(); let otkeysrv_us = otkeysrv_start.elapsed().as_micros(); let otkeyrcv_start = Instant::now(); let otkey = otkey_receive(keystate, &keyresponse); let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros(); println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize); println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize); println!("key OT receive in {} µs", otkeyrcv_us); // Create a database with recognizable contents let mut db: Vec = ((0 as DbEntry)..(num_records as DbEntry)) .map(|x| 10000001 * x) .collect(); println!("\n===== RUNTIME =====\n"); // Pick the record we actually want to query let q = (rng.next_u64() as usize) % num_records; // Compute the offset from the record index we're actually looking // for to the random one we picked earlier. Tell it to the server, // who will rotate right the database by that amount before // encrypting it. let idx_offset = (num_records + rand_idx - q) % num_records; println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */); // The server rotates, blinds, and encrypts the database let blind: DbEntry = 20; let encdb_start = Instant::now(); db.rotate_right(idx_offset); let encdb = encdb_xor_keys(&db, &dbkeys, r, blind); let encdb_us = encdb_start.elapsed().as_micros(); println!("Server encrypt database {} µs", encdb_us); // Load the encrypted database into Spiral let sps_loaddb_start = Instant::now(); let sps_db = load_db_from_seek(&spiral_params, &mut Cursor::new(encdb)); let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros(); println!("Server load database {} µs", sps_loaddb_us); // Do the PIR query let sps_query_start = Instant::now(); let sps_query = Query::deserialize(&spiral_params, &spc_query_buf); let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice()); let sps_query_us = sps_query_start.elapsed().as_micros(); println!( "Server compute response {} µs, {} bytes (*including* the above expansion time)", sps_query_us, sps_response.len() ); // Decode the response to yield the whole Spiral block let spc_recv_start = Instant::now(); let encdbblock = client.decode_response(sps_response.as_slice()); // Extract the one encrypted DbEntry we were looking for (and the // only one we are able to decrypt) let entry_in_block = rand_idx % spiral_blocking_factor; let loc_in_block = entry_in_block * mem::size_of::(); let loc_in_block_end = (entry_in_block + 1) * mem::size_of::(); let encdbentry = DbEntry::from_le_bytes( encdbblock[loc_in_block..loc_in_block_end] .try_into() .unwrap(), ); let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry); let spc_recv_us = spc_recv_start.elapsed().as_micros(); println!("Client decode response {} µs", spc_recv_us); println!("index = {}, Response = {}", q, decdbentry); }