pub mod client; mod ot; mod params; pub mod server; mod spiral_mt; use aes::cipher::{BlockEncrypt, KeyInit}; use aes::Aes128Enc; use aes::Block; use std::mem; use serde::{Deserialize, Serialize}; use std::os::raw::c_uchar; use rayon::scope; use rayon::ThreadPoolBuilder; use serde_with::serde_as; use spiral_rs::params::*; use crate::ot::{otkey_init, xor16}; use crate::spiral_mt::*; pub type DbEntry = u64; // Encrypt a database of 2^r elements, where each element is a DbEntry, // using the 2*r provided keys (r pairs of keys). Also rotate the // database by rot positions, and add the provided blinding factor to // each element before encryption (the same blinding factor for all // elements). Each element is encrypted in AES counter mode, with the // counter being the element number and the key computed as the XOR of r // of the provided keys, one from each pair, according to the bits of // the element number. Outputs a byte vector containing the encrypted // database. fn db_encrypt( db: &[DbEntry], keys: &[[u8; 16]], r: usize, rot: usize, blind: DbEntry, num_threads: usize, ) -> Vec { let num_records: usize = 1 << r; let num_record_mask: usize = num_records - 1; let mut ret = vec![0; num_records * mem::size_of::()]; scope(|s| { let mut record_thread_start = 0usize; let records_per_thread_base = num_records / num_threads; let records_per_thread_extra = num_records % num_threads; let mut retslice = ret.as_mut_slice(); for thr in 0..num_threads { let records_this_thread = records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 }; let record_thread_end = record_thread_start + records_this_thread; let (thread_ret, retslice_) = retslice.split_at_mut(records_this_thread * mem::size_of::()); retslice = retslice_; s.spawn(move |_| { let mut offset = 0usize; for j in record_thread_start..record_thread_end { let rec = (j + rot) & num_record_mask; let mut key = Block::from([0u8; 16]); for i in 0..r { let bit = if (j & (1 << i)) == 0 { 0 } else { 1 }; xor16(&mut key, &keys[2 * i + bit]); } let aes = Aes128Enc::new(&key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&j.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream; thread_ret[offset..offset + mem::size_of::()] .copy_from_slice(&encelem.to_le_bytes()); offset += mem::size_of::(); } }); record_thread_start = record_thread_end; } }); ret } // Having received the key for element q with r parallel 1-out-of-2 OTs, // and having received the encrypted element with (non-symmetric) PIR, // use the key to decrypt the element. fn dbentry_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry { let aes = Aes128Enc::new(key); let mut block = Block::from([0u8; 16]); block[0..8].copy_from_slice(&q.to_le_bytes()); aes.encrypt_block(&mut block); let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap()); encelement ^ aeskeystream } // Things that are only done once total, not once for each SPIR pub fn init(num_threads: usize) { otkey_init(); // Initialize the thread pool ThreadPoolBuilder::new() .num_threads(num_threads) .build_global() .unwrap(); } pub fn print_params_summary(params: &Params) { let db_elem_size = params.item_size(); let total_size = params.num_items() * db_elem_size; println!( "Using a {} x {} byte database ({} bytes total)", params.num_items(), db_elem_size, total_size ); } // The message format for a single preprocess query #[derive(Serialize, Deserialize)] struct PreProcSingleMsg { ot_query: Vec<[u8; 32]>, spc_query: Vec, } // The message format for a single preprocess response #[serde_as] #[derive(Serialize, Deserialize)] struct PreProcSingleRespMsg { #[serde_as(as = "Vec<[_; 64]>")] ot_resp: Vec<[u8; 64]>, } #[no_mangle] pub extern "C" fn spir_init(num_threads: u32) { init(num_threads as usize); } #[repr(C)] pub struct VecData { data: *const c_uchar, len: usize, cap: usize, } #[repr(C)] pub struct VecMutData { data: *mut c_uchar, len: usize, cap: usize, } pub fn to_vecdata(v: Vec) -> VecData { let vecdata = VecData { data: v.as_ptr(), len: v.len(), cap: v.capacity(), }; std::mem::forget(v); vecdata } #[no_mangle] pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) { unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) }; }