use std::collections::VecDeque; use std::os::raw::c_uchar; use std::sync::mpsc::*; use std::thread::*; use rayon::prelude::*; use spiral_rs::client::PublicParameters; use spiral_rs::client::Query; use spiral_rs::server::process_query; use crate::db_encrypt; use crate::load_db_from_slice_mt; use crate::ot::*; use crate::params; use crate::to_vecdata; use crate::DbEntry; use crate::PreProcSingleMsg; use crate::PreProcSingleRespMsg; use crate::VecData; enum Command { PreProcMsg(Vec), QueryMsg(usize, usize, usize, DbEntry), } enum Response { PreProcResp(Vec), QueryResp(Vec), } // The internal client state for a single preprocess query struct PreProcSingleState<'a> { db_keys: Vec<[u8; 16]>, query: Query<'a>, } pub struct Server { incoming_cmd: SyncSender, outgoing_resp: Receiver, } impl Server { pub fn new(r: usize, pub_params: Vec) -> Self { let (incoming_cmd, incoming_cmd_recv) = sync_channel(0); let (outgoing_resp_send, outgoing_resp) = sync_channel(0); spawn(move || { let spiral_params = params::get_spiral_params(r); let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params); let num_records = 1 << r; let num_records_mask = num_records - 1; // State for preprocessing queries let mut preproc_state: VecDeque = VecDeque::new(); // Wait for commands loop { match incoming_cmd_recv.recv() { Err(_) => break, Ok(Command::PreProcMsg(cliquery)) => { let num_preproc = cliquery.len(); let mut resp_state: Vec = Vec::with_capacity(num_preproc); let mut resp_msg: Vec = Vec::with_capacity(num_preproc); cliquery .into_par_iter() .map(|q| { let db_keys = gen_db_enc_keys(r); let query = Query::deserialize(&spiral_params, &q.spc_query); let ot_resp = otkey_serve(q.ot_query, &db_keys); ( PreProcSingleState { db_keys, query }, PreProcSingleRespMsg { ot_resp }, ) }) .unzip_into_vecs(&mut resp_state, &mut resp_msg); preproc_state.append(&mut VecDeque::from(resp_state)); let ret: Vec = bincode::serialize(&resp_msg).unwrap(); outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap(); } Ok(Command::QueryMsg(offset, db, rot, blind)) => { // Panic if there's no preprocess state // available let nextstate = preproc_state.pop_front().unwrap(); // Encrypt the database with the keys, rotating // and blinding in the process. It is safe to // construct a slice out of the const pointer we // were handed because that pointer will stay // valid until we return something back to the // caller. let totoffset = (offset + rot) & num_records_mask; let num_threads = rayon::current_num_threads(); let encdb = unsafe { let dbslice = std::slice::from_raw_parts(db as *const DbEntry, num_records); db_encrypt( dbslice, &nextstate.db_keys, r, totoffset, blind, num_threads, ) }; // Load the encrypted db into Spiral let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads); // Process the query let resp = process_query( &spiral_params, &pub_params, &nextstate.query, sps_db.as_slice(), ); outgoing_resp_send.send(Response::QueryResp(resp)).unwrap(); } // When adding new messages, the following line is // useful during development // _ => panic!("Received something unexpected in server loop"), } } }); Server { incoming_cmd, outgoing_resp, } } pub fn preproc_process(&self, msg: &[u8]) -> Vec { self.incoming_cmd .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap())) .unwrap(); let ret = match self.outgoing_resp.recv() { Ok(Response::PreProcResp(x)) => x, _ => panic!("Received something unexpected in preproc_process"), }; ret } pub fn query_process( &self, msg: &[u8], db: *const DbEntry, rot: usize, blind: DbEntry, ) -> Vec { let offset = usize::from_le_bytes(msg.try_into().unwrap()); self.incoming_cmd .send(Command::QueryMsg(offset, db as usize, rot, blind)) .unwrap(); let ret = match self.outgoing_resp.recv() { Ok(Response::QueryResp(x)) => x, _ => panic!("Received something unexpected in query_process"), }; ret } } #[no_mangle] pub extern "C" fn spir_server_new( r: u8, pub_params: *const c_uchar, pub_params_len: usize, ) -> *mut Server { let pub_params_slice = unsafe { assert!(!pub_params.is_null()); std::slice::from_raw_parts(pub_params, pub_params_len) }; let mut pub_params_vec: Vec = Vec::new(); pub_params_vec.extend_from_slice(pub_params_slice); let server = Server::new(r as usize, pub_params_vec); Box::into_raw(Box::new(server)) } #[no_mangle] pub extern "C" fn spir_server_free(server: *mut Server) { if server.is_null() { return; } unsafe { Box::from_raw(server); } } #[no_mangle] pub extern "C" fn spir_server_preproc_process( serverptr: *mut Server, msgdata: *const c_uchar, msglen: usize, ) -> VecData { let server = unsafe { assert!(!serverptr.is_null()); &mut *serverptr }; let msg_slice = unsafe { assert!(!msgdata.is_null()); std::slice::from_raw_parts(msgdata, msglen) }; let retvec = server.preproc_process(msg_slice); to_vecdata(retvec) } #[no_mangle] pub extern "C" fn spir_server_query_process( serverptr: *mut Server, msgdata: *const c_uchar, msglen: usize, db: *const DbEntry, rot: usize, blind: DbEntry, ) -> VecData { let server = unsafe { assert!(!serverptr.is_null()); &mut *serverptr }; let msg_slice = unsafe { assert!(!msgdata.is_null()); std::slice::from_raw_parts(msgdata, msglen) }; let retvec = server.query_process(msg_slice, db, rot, blind); to_vecdata(retvec) }