123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- use std::collections::VecDeque;
- use std::os::raw::c_uchar;
- use std::sync::mpsc::*;
- use std::thread::*;
- use rayon::prelude::*;
- use spiral_rs::client::PublicParameters;
- use spiral_rs::client::Query;
- use spiral_rs::server::process_query;
- use crate::db_encrypt;
- use crate::load_db_from_slice_mt;
- use crate::ot::*;
- use crate::params;
- use crate::to_vecdata;
- use crate::DbEntry;
- use crate::PreProcSingleMsg;
- use crate::PreProcSingleRespMsg;
- use crate::VecData;
- enum Command {
- PreProcMsg(Vec<PreProcSingleMsg>),
- QueryMsg(usize, usize, usize, DbEntry),
- }
- enum Response {
- PreProcResp(Vec<u8>),
- QueryResp(Vec<u8>),
- }
- // The internal client state for a single preprocess query
- struct PreProcSingleState<'a> {
- db_keys: Vec<[u8; 16]>,
- query: Query<'a>,
- }
- pub struct Server {
- incoming_cmd: SyncSender<Command>,
- outgoing_resp: Receiver<Response>,
- }
- impl Server {
- pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
- let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
- let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
- spawn(move || {
- let spiral_params = params::get_spiral_params(r);
- let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
- let num_records = 1 << r;
- let num_records_mask = num_records - 1;
- // State for preprocessing queries
- let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
- // Wait for commands
- loop {
- match incoming_cmd_recv.recv() {
- Err(_) => break,
- Ok(Command::PreProcMsg(cliquery)) => {
- let num_preproc = cliquery.len();
- let mut resp_state: Vec<PreProcSingleState> =
- Vec::with_capacity(num_preproc);
- let mut resp_msg: Vec<PreProcSingleRespMsg> =
- Vec::with_capacity(num_preproc);
- cliquery
- .into_par_iter()
- .map(|q| {
- let db_keys = gen_db_enc_keys(r);
- let query = Query::deserialize(&spiral_params, &q.spc_query);
- let ot_resp = otkey_serve(q.ot_query, &db_keys);
- (
- PreProcSingleState { db_keys, query },
- PreProcSingleRespMsg { ot_resp },
- )
- })
- .unzip_into_vecs(&mut resp_state, &mut resp_msg);
- preproc_state.append(&mut VecDeque::from(resp_state));
- let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
- outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
- }
- Ok(Command::QueryMsg(offset, db, rot, blind)) => {
- // Panic if there's no preprocess state
- // available
- let nextstate = preproc_state.pop_front().unwrap();
- // Encrypt the database with the keys, rotating
- // and blinding in the process. It is safe to
- // construct a slice out of the const pointer we
- // were handed because that pointer will stay
- // valid until we return something back to the
- // caller.
- let totoffset = (offset + rot) & num_records_mask;
- let num_threads = rayon::current_num_threads();
- let encdb = unsafe {
- let dbslice =
- std::slice::from_raw_parts(db as *const DbEntry, num_records);
- db_encrypt(
- dbslice,
- &nextstate.db_keys,
- r,
- totoffset,
- blind,
- num_threads,
- )
- };
- // Load the encrypted db into Spiral
- let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
- // Process the query
- let resp = process_query(
- &spiral_params,
- &pub_params,
- &nextstate.query,
- sps_db.as_slice(),
- );
- outgoing_resp_send.send(Response::QueryResp(resp)).unwrap();
- }
- // When adding new messages, the following line is
- // useful during development
- // _ => panic!("Received something unexpected in server loop"),
- }
- }
- });
- Server {
- incoming_cmd,
- outgoing_resp,
- }
- }
- pub fn preproc_process(&self, msg: &[u8]) -> Vec<u8> {
- self.incoming_cmd
- .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
- .unwrap();
- let ret = match self.outgoing_resp.recv() {
- Ok(Response::PreProcResp(x)) => x,
- _ => panic!("Received something unexpected in preproc_process"),
- };
- ret
- }
- pub fn query_process(
- &self,
- msg: &[u8],
- db: *const DbEntry,
- rot: usize,
- blind: DbEntry,
- ) -> Vec<u8> {
- let offset = usize::from_le_bytes(msg.try_into().unwrap());
- self.incoming_cmd
- .send(Command::QueryMsg(offset, db as usize, rot, blind))
- .unwrap();
- let ret = match self.outgoing_resp.recv() {
- Ok(Response::QueryResp(x)) => x,
- _ => panic!("Received something unexpected in query_process"),
- };
- ret
- }
- }
- #[no_mangle]
- pub extern "C" fn spir_server_new(
- r: u8,
- pub_params: *const c_uchar,
- pub_params_len: usize,
- ) -> *mut Server {
- let pub_params_slice = unsafe {
- assert!(!pub_params.is_null());
- std::slice::from_raw_parts(pub_params, pub_params_len)
- };
- let mut pub_params_vec: Vec<u8> = Vec::new();
- pub_params_vec.extend_from_slice(pub_params_slice);
- let server = Server::new(r as usize, pub_params_vec);
- Box::into_raw(Box::new(server))
- }
- #[no_mangle]
- pub extern "C" fn spir_server_free(server: *mut Server) {
- if server.is_null() {
- return;
- }
- unsafe {
- Box::from_raw(server);
- }
- }
- #[no_mangle]
- pub extern "C" fn spir_server_preproc_process(
- serverptr: *mut Server,
- msgdata: *const c_uchar,
- msglen: usize,
- ) -> VecData {
- let server = unsafe {
- assert!(!serverptr.is_null());
- &mut *serverptr
- };
- let msg_slice = unsafe {
- assert!(!msgdata.is_null());
- std::slice::from_raw_parts(msgdata, msglen)
- };
- let retvec = server.preproc_process(msg_slice);
- to_vecdata(retvec)
- }
- #[no_mangle]
- pub extern "C" fn spir_server_query_process(
- serverptr: *mut Server,
- msgdata: *const c_uchar,
- msglen: usize,
- db: *const DbEntry,
- rot: usize,
- blind: DbEntry,
- ) -> VecData {
- let server = unsafe {
- assert!(!serverptr.is_null());
- &mut *serverptr
- };
- let msg_slice = unsafe {
- assert!(!msgdata.is_null());
- std::slice::from_raw_parts(msgdata, msglen)
- };
- let retvec = server.query_process(msg_slice, db, rot, blind);
- to_vecdata(retvec)
- }
|