server.rs 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. use std::collections::VecDeque;
  2. use std::os::raw::c_uchar;
  3. use std::sync::mpsc::*;
  4. use std::thread::*;
  5. use rayon::prelude::*;
  6. use spiral_rs::client::PublicParameters;
  7. use spiral_rs::client::Query;
  8. use spiral_rs::params::Params;
  9. use spiral_rs::server::process_query;
  10. use crate::db_encrypt;
  11. use crate::load_db_from_slice_mt;
  12. use crate::ot::*;
  13. use crate::params;
  14. use crate::to_vecdata;
  15. use crate::DbEntry;
  16. use crate::PreProcSingleMsg;
  17. use crate::PreProcSingleRespMsg;
  18. use crate::VecData;
  19. enum Command {
  20. PubParams(Vec<u8>),
  21. PreProcMsg(Vec<PreProcSingleMsg>),
  22. QueryMsg(usize, usize, usize, DbEntry),
  23. }
  24. enum Response {
  25. PreProcResp(Vec<u8>),
  26. QueryResp(Vec<u8>),
  27. }
  28. // The internal client state for a single preprocess query
  29. struct PreProcSingleState<'a> {
  30. db_keys: Vec<[u8; 16]>,
  31. query: Query<'a>,
  32. }
  33. pub struct Server {
  34. r: usize,
  35. thread_handle: JoinHandle<()>,
  36. incoming_cmd: SyncSender<Command>,
  37. outgoing_resp: Receiver<Response>,
  38. }
  39. impl Server {
  40. pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
  41. let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
  42. let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
  43. let thread_handle = spawn(move || {
  44. let spiral_params = params::get_spiral_params(r);
  45. let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
  46. let num_records = 1 << r;
  47. let num_records_mask = num_records - 1;
  48. // State for preprocessing queries
  49. let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
  50. // Wait for commands
  51. loop {
  52. match incoming_cmd_recv.recv() {
  53. Err(_) => break,
  54. Ok(Command::PreProcMsg(cliquery)) => {
  55. let num_preproc = cliquery.len();
  56. let mut resp_state: Vec<PreProcSingleState> =
  57. Vec::with_capacity(num_preproc);
  58. let mut resp_msg: Vec<PreProcSingleRespMsg> =
  59. Vec::with_capacity(num_preproc);
  60. cliquery
  61. .into_par_iter()
  62. .map(|q| {
  63. let db_keys = gen_db_enc_keys(r);
  64. let query = Query::deserialize(&spiral_params, &q.spc_query);
  65. let ot_resp = otkey_serve(q.ot_query, &db_keys);
  66. (
  67. PreProcSingleState { db_keys, query },
  68. PreProcSingleRespMsg { ot_resp },
  69. )
  70. })
  71. .unzip_into_vecs(&mut resp_state, &mut resp_msg);
  72. preproc_state.append(&mut VecDeque::from(resp_state));
  73. let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
  74. outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
  75. }
  76. Ok(Command::QueryMsg(offset, db, rot, blind)) => {
  77. // Panic if there's no preprocess state
  78. // available
  79. let nextstate = preproc_state.pop_front().unwrap();
  80. // Encrypt the database with the keys, rotating
  81. // and blinding in the process. It is safe to
  82. // construct a slice out of the const pointer we
  83. // were handed because that pointer will stay
  84. // valid until we return something back to the
  85. // caller.
  86. let totoffset = (offset + rot) & num_records_mask;
  87. let num_threads = rayon::current_num_threads();
  88. let encdb = unsafe {
  89. let dbslice =
  90. std::slice::from_raw_parts(db as *const DbEntry, num_records);
  91. db_encrypt(
  92. &dbslice,
  93. &nextstate.db_keys,
  94. r,
  95. totoffset,
  96. blind,
  97. num_threads,
  98. )
  99. };
  100. // Load the encrypted db into Spiral
  101. let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
  102. // Process the query
  103. let resp = process_query(
  104. &spiral_params,
  105. &pub_params,
  106. &nextstate.query,
  107. sps_db.as_slice(),
  108. );
  109. outgoing_resp_send.send(Response::QueryResp(resp)).unwrap();
  110. }
  111. _ => panic!("Received something unexpected"),
  112. }
  113. }
  114. });
  115. Server {
  116. r,
  117. thread_handle,
  118. incoming_cmd,
  119. outgoing_resp,
  120. }
  121. }
  122. pub fn preproc_process(&self, msg: &[u8]) -> Vec<u8> {
  123. self.incoming_cmd
  124. .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
  125. .unwrap();
  126. let ret = match self.outgoing_resp.recv() {
  127. Ok(Response::PreProcResp(x)) => x,
  128. _ => panic!("Received something unexpected in preproc_process"),
  129. };
  130. ret
  131. }
  132. pub fn query_process(
  133. &self,
  134. msg: &[u8],
  135. db: *const DbEntry,
  136. rot: usize,
  137. blind: DbEntry,
  138. ) -> Vec<u8> {
  139. let offset = usize::from_le_bytes(msg.try_into().unwrap());
  140. self.incoming_cmd
  141. .send(Command::QueryMsg(offset, db as usize, rot, blind))
  142. .unwrap();
  143. let ret = match self.outgoing_resp.recv() {
  144. Ok(Response::QueryResp(x)) => x,
  145. _ => panic!("Received something unexpected in query_process"),
  146. };
  147. ret
  148. }
  149. }
  150. #[no_mangle]
  151. pub extern "C" fn spir_server_new(
  152. r: u8,
  153. pub_params: *const c_uchar,
  154. pub_params_len: usize,
  155. ) -> *mut Server {
  156. let pub_params_slice = unsafe {
  157. assert!(!pub_params.is_null());
  158. std::slice::from_raw_parts(pub_params, pub_params_len)
  159. };
  160. let mut pub_params_vec: Vec<u8> = Vec::new();
  161. pub_params_vec.extend_from_slice(pub_params_slice);
  162. let server = Server::new(r as usize, pub_params_vec);
  163. Box::into_raw(Box::new(server))
  164. }
  165. #[no_mangle]
  166. pub extern "C" fn spir_server_free(server: *mut Server) {
  167. if server.is_null() {
  168. return;
  169. }
  170. unsafe {
  171. Box::from_raw(server);
  172. }
  173. }
  174. #[no_mangle]
  175. pub extern "C" fn spir_server_preproc_process(
  176. serverptr: *mut Server,
  177. msgdata: *const c_uchar,
  178. msglen: usize,
  179. ) -> VecData {
  180. let server = unsafe {
  181. assert!(!serverptr.is_null());
  182. &mut *serverptr
  183. };
  184. let msg_slice = unsafe {
  185. assert!(!msgdata.is_null());
  186. std::slice::from_raw_parts(msgdata, msglen)
  187. };
  188. let retvec = server.preproc_process(&msg_slice);
  189. to_vecdata(retvec)
  190. }
  191. #[no_mangle]
  192. pub extern "C" fn spir_server_query_process(
  193. serverptr: *mut Server,
  194. msgdata: *const c_uchar,
  195. msglen: usize,
  196. db: *const DbEntry,
  197. rot: usize,
  198. blind: DbEntry,
  199. ) -> VecData {
  200. let server = unsafe {
  201. assert!(!serverptr.is_null());
  202. &mut *serverptr
  203. };
  204. let msg_slice = unsafe {
  205. assert!(!msgdata.is_null());
  206. std::slice::from_raw_parts(msgdata, msglen)
  207. };
  208. let retvec = server.query_process(&msg_slice, db, rot, blind);
  209. to_vecdata(retvec)
  210. }