// We really want points to be capital letters and scalars to be // lowercase letters #![allow(non_snake_case)] mod aligned_memory_mt; mod params; mod spiral_mt; pub mod client; pub mod server; use aes::cipher::{BlockEncrypt, KeyInit}; use aes::Aes128Enc; use aes::Block; use std::env; use std::mem; use std::time::Instant; use subtle::Choice; use subtle::ConditionallySelectable; use rand::RngCore; use sha2::Digest; use sha2::Sha256; use sha2::Sha512; use curve25519_dalek::constants as dalek_constants; use curve25519_dalek::ristretto::CompressedRistretto; use curve25519_dalek::ristretto::RistrettoBasepointTable; use curve25519_dalek::ristretto::RistrettoPoint; use curve25519_dalek::scalar::Scalar; use rayon::scope; use rayon::ThreadPoolBuilder; use spiral_rs::client::*; use spiral_rs::params::*; use spiral_rs::server::*; use crate::spiral_mt::*; use lazy_static::lazy_static; pub type DbEntry = u64; // Generators of the Ristretto group (the standard B and another one C, // for which the DL relationship is unknown), and their precomputed // multiplication tables. Used for the Oblivious Transfer protocol lazy_static! { pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT; pub static ref OT_C: RistrettoPoint = RistrettoPoint::hash_from_bytes::(b"OT Generator C"); pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE; pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C); } // XOR a 16-byte slice into a Block (which will be used as an AES key) fn xor16(outar: &mut Block, inar: &[u8; 16]) { for i in 0..16 { outar[i] ^= inar[i]; } } // Encrypt a database of 2^r elements, where each element is a DbEntry, // using the 2*r provided keys (r pairs of keys). Also rotate the // database by rot positions, and add the provided blinding factor to // each element before encryption (the same blinding factor for all // elements). Each element is encrypted in AES counter mode, with the // counter being the element number and the key computed as the XOR of r // of the provided keys, one from each pair, according to the bits of // the element number. Outputs a byte vector containing the encrypted // database. pub fn encdb_xor_keys( db: &[DbEntry], keys: &[[u8; 16]], r: usize, rot: usize, blind: DbEntry, num_threads: usize, ) -> Vec { let num_records: usize = 1 << r; let num_record_mask: usize = num_records - 1; let negrot = num_records - rot; let mut ret = Vec::::with_capacity(num_records * mem::size_of::()); ret.resize(num_records * mem::size_of::(), 0); scope(|s| { let mut record_thread_start = 0usize; let records_per_thread_base = num_records / num_threads; let records_per_thread_extra = num_records % num_threads; let mut retslice = ret.as_mut_slice(); for thr in 0..num_threads { let records_this_thread = records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 }; let record_thread_end = record_thread_start + records_this_thread; let (thread_ret, retslice_) = retslice.split_at_mut(records_this_thread * mem::size_of::()); retslice = retslice_; s.spawn(move |_| { let mut offset = 0usize; for j in record_thread_start..record_thread_end { let rec = (j + negrot) & num_record_mask; let mut key = Block::from([0u8; 16]); for i in 0..r { let bit = if (j & (1 << i)) == 0 { 0 } else { 1 }; xor16(&mut key, &keys[2 * i + bit]); } let aes = Aes128Enc::new(&key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&j.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream; thread_ret[offset..offset + mem::size_of::()] .copy_from_slice(&encelem.to_le_bytes()); offset += mem::size_of::(); } }); record_thread_start = record_thread_end; } }); ret } // Generate the keys for encrypting the database pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> { let mut keys: Vec<[u8; 16]> = Vec::new(); let mut rng = rand::thread_rng(); for _ in 0..2 * r { let mut k: [u8; 16] = [0; 16]; rng.fill_bytes(&mut k); keys.push(k); } keys } // 1-out-of-2 Oblivious Transfer (OT) fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) { let Btable: &RistrettoBasepointTable = &OT_B_TABLE; let C: &RistrettoPoint = &OT_C; let mut rng = rand07::thread_rng(); let x = Scalar::random(&mut rng); let xB = &x * Btable; let CmxB = C - xB; let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel); ((sel, x), P.compress().to_bytes()) } fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] { let Btable: &RistrettoBasepointTable = &OT_B_TABLE; let Ctable: &RistrettoBasepointTable = &OT_C_TABLE; let mut rng = rand07::thread_rng(); let y = Scalar::random(&mut rng); let yB = &y * Btable; let yC = &y * Ctable; let P = CompressedRistretto::from_slice(query).decompress().unwrap(); let yP0 = y * P; let yP1 = yC - yP0; let mut HyP0 = Sha256::digest(yP0.compress().as_bytes()); for i in 0..16 { HyP0[i] ^= m0[i]; } let mut HyP1 = Sha256::digest(yP1.compress().as_bytes()); for i in 0..16 { HyP1[i] ^= m1[i]; } let mut ret = [0u8; 64]; ret[0..32].copy_from_slice(yB.compress().as_bytes()); ret[32..48].copy_from_slice(&HyP0[0..16]); ret[48..64].copy_from_slice(&HyP1[0..16]); ret } fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] { let yB = CompressedRistretto::from_slice(&response[0..32]) .decompress() .unwrap(); let yP = state.1 * yB; let mut HyP = Sha256::digest(yP.compress().as_bytes()); for i in 0..16 { HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0); } HyP[0..16].try_into().unwrap() } // Obliviously fetch the key for element q of the database (which has // 2^r elements total). Each bit of q is used in a 1-out-of-2 OT to get // one of the keys in each of the r pairs of keys on the server side. // The resulting r keys are XORed together. pub fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) { let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r); let mut query: Vec<[u8; 32]> = Vec::with_capacity(r); for i in 0..r { let bit = ((q >> i) & 1) as u8; let (si, qi) = ot12_request(bit.into()); state.push(si); query.push(qi); } (state, query) } pub fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> { let r = query.len(); assert!(keys.len() == 2 * r); let mut response: Vec<[u8; 64]> = Vec::with_capacity(r); for i in 0..r { response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1])); } response } pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block { let r = state.len(); assert!(response.len() == r); let mut key = Block::from([0u8; 16]); for i in 0..r { xor16(&mut key, &ot12_receive(state[i], &response[i])); } key } // Having received the key for element q with r parallel 1-out-of-2 OTs, // and having received the encrypted element with (non-symmetric) PIR, // use the key to decrypt the element. pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry { let aes = Aes128Enc::new(key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&q.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); encelement ^ aeskeystream } // Things that are only done once total, not once for each SPIR pub fn init(num_threads: usize) { // Resolve the lazy statics let _B: &RistrettoPoint = &OT_B; let _Btable: &RistrettoBasepointTable = &OT_B_TABLE; let _C: &RistrettoPoint = &OT_C; let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE; // Initialize the thread pool ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap(); } pub fn print_params_summary(params: &Params) { let db_elem_size = params.item_size(); let total_size = params.num_items() * db_elem_size; println!( "Using a {} x {} byte database ({} bytes total)", params.num_items(), db_elem_size, total_size ); } #[no_mangle] pub extern "C" fn spir_init(num_threads: u32) { init(num_threads as usize); }