// We really want points to be capital letters and scalars to be // lowercase letters #![allow(non_snake_case)] mod aligned_memory_mt; pub mod client; mod ot; mod params; pub mod server; mod spiral_mt; use aes::cipher::{BlockEncrypt, KeyInit}; use aes::Aes128Enc; use aes::Block; use std::env; use std::mem; use std::time::Instant; use rand::RngCore; use std::os::raw::c_uchar; use rayon::scope; use rayon::ThreadPoolBuilder; use spiral_rs::client::*; use spiral_rs::params::*; use spiral_rs::server::*; use crate::spiral_mt::*; use crate::ot::{otkey_init, xor16}; pub type DbEntry = u64; // Encrypt a database of 2^r elements, where each element is a DbEntry, // using the 2*r provided keys (r pairs of keys). Also rotate the // database by rot positions, and add the provided blinding factor to // each element before encryption (the same blinding factor for all // elements). Each element is encrypted in AES counter mode, with the // counter being the element number and the key computed as the XOR of r // of the provided keys, one from each pair, according to the bits of // the element number. Outputs a byte vector containing the encrypted // database. pub fn encdb_xor_keys( db: &[DbEntry], keys: &[[u8; 16]], r: usize, rot: usize, blind: DbEntry, num_threads: usize, ) -> Vec { let num_records: usize = 1 << r; let num_record_mask: usize = num_records - 1; let negrot = num_records - rot; let mut ret = Vec::::with_capacity(num_records * mem::size_of::()); ret.resize(num_records * mem::size_of::(), 0); scope(|s| { let mut record_thread_start = 0usize; let records_per_thread_base = num_records / num_threads; let records_per_thread_extra = num_records % num_threads; let mut retslice = ret.as_mut_slice(); for thr in 0..num_threads { let records_this_thread = records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 }; let record_thread_end = record_thread_start + records_this_thread; let (thread_ret, retslice_) = retslice.split_at_mut(records_this_thread * mem::size_of::()); retslice = retslice_; s.spawn(move |_| { let mut offset = 0usize; for j in record_thread_start..record_thread_end { let rec = (j + negrot) & num_record_mask; let mut key = Block::from([0u8; 16]); for i in 0..r { let bit = if (j & (1 << i)) == 0 { 0 } else { 1 }; xor16(&mut key, &keys[2 * i + bit]); } let aes = Aes128Enc::new(&key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&j.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream; thread_ret[offset..offset + mem::size_of::()] .copy_from_slice(&encelem.to_le_bytes()); offset += mem::size_of::(); } }); record_thread_start = record_thread_end; } }); ret } // Generate the keys for encrypting the database pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> { let mut keys: Vec<[u8; 16]> = Vec::new(); let mut rng = rand::thread_rng(); for _ in 0..2 * r { let mut k: [u8; 16] = [0; 16]; rng.fill_bytes(&mut k); keys.push(k); } keys } // Having received the key for element q with r parallel 1-out-of-2 OTs, // and having received the encrypted element with (non-symmetric) PIR, // use the key to decrypt the element. pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry { let aes = Aes128Enc::new(key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&q.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); encelement ^ aeskeystream } // Things that are only done once total, not once for each SPIR pub fn init(num_threads: usize) { otkey_init(); // Initialize the thread pool ThreadPoolBuilder::new() .num_threads(num_threads) .build_global() .unwrap(); } pub fn print_params_summary(params: &Params) { let db_elem_size = params.item_size(); let total_size = params.num_items() * db_elem_size; println!( "Using a {} x {} byte database ({} bytes total)", params.num_items(), db_elem_size, total_size ); } #[no_mangle] pub extern "C" fn spir_init(num_threads: u32) { init(num_threads as usize); } #[repr(C)] pub struct VecData { data: *const c_uchar, len: usize, cap: usize, } #[repr(C)] pub struct VecMutData { data: *mut c_uchar, len: usize, cap: usize, } #[no_mangle] pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) { unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) }; }