123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- mod aligned_memory_mt;
- pub mod client;
- mod ot;
- mod params;
- pub mod server;
- mod spiral_mt;
- use aes::cipher::{BlockEncrypt, KeyInit};
- use aes::Aes128Enc;
- use aes::Block;
- use std::mem;
- use serde::{Deserialize, Serialize};
- use std::os::raw::c_uchar;
- use rayon::scope;
- use rayon::ThreadPoolBuilder;
- use serde_with::serde_as;
- use spiral_rs::params::*;
- use crate::ot::{otkey_init, xor16};
- use crate::spiral_mt::*;
- pub type DbEntry = u64;
- // Encrypt a database of 2^r elements, where each element is a DbEntry,
- // using the 2*r provided keys (r pairs of keys). Also rotate the
- // database by rot positions, and add the provided blinding factor to
- // each element before encryption (the same blinding factor for all
- // elements). Each element is encrypted in AES counter mode, with the
- // counter being the element number and the key computed as the XOR of r
- // of the provided keys, one from each pair, according to the bits of
- // the element number. Outputs a byte vector containing the encrypted
- // database.
- fn db_encrypt(
- db: &[DbEntry],
- keys: &[[u8; 16]],
- r: usize,
- rot: usize,
- blind: DbEntry,
- num_threads: usize,
- ) -> Vec<u8> {
- let num_records: usize = 1 << r;
- let num_record_mask: usize = num_records - 1;
- let mut ret = vec![0; num_records * mem::size_of::<DbEntry>()];
- scope(|s| {
- let mut record_thread_start = 0usize;
- let records_per_thread_base = num_records / num_threads;
- let records_per_thread_extra = num_records % num_threads;
- let mut retslice = ret.as_mut_slice();
- for thr in 0..num_threads {
- let records_this_thread =
- records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
- let record_thread_end = record_thread_start + records_this_thread;
- let (thread_ret, retslice_) =
- retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
- retslice = retslice_;
- s.spawn(move |_| {
- let mut offset = 0usize;
- for j in record_thread_start..record_thread_end {
- let rec = (j + rot) & num_record_mask;
- let mut key = Block::from([0u8; 16]);
- for i in 0..r {
- let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
- xor16(&mut key, &keys[2 * i + bit]);
- }
- let aes = Aes128Enc::new(&key);
- let mut block = Block::from([0u8; 16]);
- block[0..8].copy_from_slice(&j.to_le_bytes());
- aes.encrypt_block(&mut block);
- let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
- let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream;
- thread_ret[offset..offset + mem::size_of::<DbEntry>()]
- .copy_from_slice(&encelem.to_le_bytes());
- offset += mem::size_of::<DbEntry>();
- }
- });
- record_thread_start = record_thread_end;
- }
- });
- ret
- }
- // Having received the key for element q with r parallel 1-out-of-2 OTs,
- // and having received the encrypted element with (non-symmetric) PIR,
- // use the key to decrypt the element.
- fn dbentry_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
- let aes = Aes128Enc::new(key);
- let mut block = Block::from([0u8; 16]);
- block[0..8].copy_from_slice(&q.to_le_bytes());
- aes.encrypt_block(&mut block);
- let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
- encelement ^ aeskeystream
- }
- // Things that are only done once total, not once for each SPIR
- pub fn init(num_threads: usize) {
- otkey_init();
- // Initialize the thread pool
- ThreadPoolBuilder::new()
- .num_threads(num_threads)
- .build_global()
- .unwrap();
- }
- pub fn print_params_summary(params: &Params) {
- let db_elem_size = params.item_size();
- let total_size = params.num_items() * db_elem_size;
- println!(
- "Using a {} x {} byte database ({} bytes total)",
- params.num_items(),
- db_elem_size,
- total_size
- );
- }
- // The message format for a single preprocess query
- #[derive(Serialize, Deserialize)]
- struct PreProcSingleMsg {
- ot_query: Vec<[u8; 32]>,
- spc_query: Vec<u8>,
- }
- // The message format for a single preprocess response
- #[serde_as]
- #[derive(Serialize, Deserialize)]
- struct PreProcSingleRespMsg {
- #[serde_as(as = "Vec<[_; 64]>")]
- ot_resp: Vec<[u8; 64]>,
- }
- #[no_mangle]
- pub extern "C" fn spir_init(num_threads: u32) {
- init(num_threads as usize);
- }
- #[repr(C)]
- pub struct VecData {
- data: *const c_uchar,
- len: usize,
- cap: usize,
- }
- #[repr(C)]
- pub struct VecMutData {
- data: *mut c_uchar,
- len: usize,
- cap: usize,
- }
- pub fn to_vecdata(v: Vec<u8>) -> VecData {
- let vecdata = VecData {
- data: v.as_ptr(),
- len: v.len(),
- cap: v.capacity(),
- };
- std::mem::forget(v);
- vecdata
- }
- #[no_mangle]
- pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
- unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };
- }
|