server.rs 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. use std::collections::VecDeque;
  2. use std::os::raw::c_uchar;
  3. use std::sync::mpsc::*;
  4. use std::thread::*;
  5. use rayon::prelude::*;
  6. use spiral_rs::client::PublicParameters;
  7. use spiral_rs::client::Query;
  8. use spiral_rs::server::process_query;
  9. use crate::db_encrypt;
  10. use crate::load_db_from_slice_mt;
  11. use crate::ot::*;
  12. use crate::params;
  13. use crate::to_vecdata;
  14. use crate::DbEntry;
  15. use crate::PreProcSingleMsg;
  16. use crate::PreProcSingleRespMsg;
  17. use crate::VecData;
  18. enum Command {
  19. PreProcMsg(Vec<PreProcSingleMsg>),
  20. QueryMsg(usize, usize, usize, DbEntry),
  21. }
  22. enum Response {
  23. PreProcResp(Vec<u8>),
  24. QueryResp(Vec<u8>),
  25. }
  26. // The internal client state for a single preprocess query
  27. struct PreProcSingleState<'a> {
  28. db_keys: Vec<[u8; 16]>,
  29. query: Query<'a>,
  30. }
  31. pub struct Server {
  32. incoming_cmd: SyncSender<Command>,
  33. outgoing_resp: Receiver<Response>,
  34. }
  35. impl Server {
  36. pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
  37. let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
  38. let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
  39. spawn(move || {
  40. let spiral_params = params::get_spiral_params(r);
  41. let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
  42. let num_records = 1 << r;
  43. let num_records_mask = num_records - 1;
  44. // State for preprocessing queries
  45. let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
  46. // Wait for commands
  47. loop {
  48. match incoming_cmd_recv.recv() {
  49. Err(_) => break,
  50. Ok(Command::PreProcMsg(cliquery)) => {
  51. let num_preproc = cliquery.len();
  52. let mut resp_state: Vec<PreProcSingleState> =
  53. Vec::with_capacity(num_preproc);
  54. let mut resp_msg: Vec<PreProcSingleRespMsg> =
  55. Vec::with_capacity(num_preproc);
  56. cliquery
  57. .into_par_iter()
  58. .map(|q| {
  59. let db_keys = gen_db_enc_keys(r);
  60. let query = Query::deserialize(&spiral_params, &q.spc_query);
  61. let ot_resp = otkey_serve(q.ot_query, &db_keys);
  62. (
  63. PreProcSingleState { db_keys, query },
  64. PreProcSingleRespMsg { ot_resp },
  65. )
  66. })
  67. .unzip_into_vecs(&mut resp_state, &mut resp_msg);
  68. preproc_state.append(&mut VecDeque::from(resp_state));
  69. let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
  70. outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
  71. }
  72. Ok(Command::QueryMsg(offset, db, rot, blind)) => {
  73. // Panic if there's no preprocess state
  74. // available
  75. let nextstate = preproc_state.pop_front().unwrap();
  76. // Encrypt the database with the keys, rotating
  77. // and blinding in the process. It is safe to
  78. // construct a slice out of the const pointer we
  79. // were handed because that pointer will stay
  80. // valid until we return something back to the
  81. // caller.
  82. let totoffset = (offset + rot) & num_records_mask;
  83. let num_threads = rayon::current_num_threads();
  84. let encdb = unsafe {
  85. let dbslice =
  86. std::slice::from_raw_parts(db as *const DbEntry, num_records);
  87. db_encrypt(
  88. dbslice,
  89. &nextstate.db_keys,
  90. r,
  91. totoffset,
  92. blind,
  93. num_threads,
  94. )
  95. };
  96. // Load the encrypted db into Spiral
  97. let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
  98. // Process the query
  99. let resp = process_query(
  100. &spiral_params,
  101. &pub_params,
  102. &nextstate.query,
  103. sps_db.as_slice(),
  104. );
  105. outgoing_resp_send.send(Response::QueryResp(resp)).unwrap();
  106. }
  107. // When adding new messages, the following line is
  108. // useful during development
  109. // _ => panic!("Received something unexpected in server loop"),
  110. }
  111. }
  112. });
  113. Server {
  114. incoming_cmd,
  115. outgoing_resp,
  116. }
  117. }
  118. pub fn preproc_process(&self, msg: &[u8]) -> Vec<u8> {
  119. self.incoming_cmd
  120. .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
  121. .unwrap();
  122. let ret = match self.outgoing_resp.recv() {
  123. Ok(Response::PreProcResp(x)) => x,
  124. _ => panic!("Received something unexpected in preproc_process"),
  125. };
  126. ret
  127. }
  128. pub fn query_process(
  129. &self,
  130. msg: &[u8],
  131. db: *const DbEntry,
  132. rot: usize,
  133. blind: DbEntry,
  134. ) -> Vec<u8> {
  135. let offset = usize::from_le_bytes(msg.try_into().unwrap());
  136. self.incoming_cmd
  137. .send(Command::QueryMsg(offset, db as usize, rot, blind))
  138. .unwrap();
  139. let ret = match self.outgoing_resp.recv() {
  140. Ok(Response::QueryResp(x)) => x,
  141. _ => panic!("Received something unexpected in query_process"),
  142. };
  143. ret
  144. }
  145. }
  146. #[no_mangle]
  147. pub extern "C" fn spir_server_new(
  148. r: u8,
  149. pub_params: *const c_uchar,
  150. pub_params_len: usize,
  151. ) -> *mut Server {
  152. let pub_params_slice = unsafe {
  153. assert!(!pub_params.is_null());
  154. std::slice::from_raw_parts(pub_params, pub_params_len)
  155. };
  156. let mut pub_params_vec: Vec<u8> = Vec::new();
  157. pub_params_vec.extend_from_slice(pub_params_slice);
  158. let server = Server::new(r as usize, pub_params_vec);
  159. Box::into_raw(Box::new(server))
  160. }
  161. #[no_mangle]
  162. pub extern "C" fn spir_server_free(server: *mut Server) {
  163. if server.is_null() {
  164. return;
  165. }
  166. unsafe {
  167. Box::from_raw(server);
  168. }
  169. }
  170. #[no_mangle]
  171. pub extern "C" fn spir_server_preproc_process(
  172. serverptr: *mut Server,
  173. msgdata: *const c_uchar,
  174. msglen: usize,
  175. ) -> VecData {
  176. let server = unsafe {
  177. assert!(!serverptr.is_null());
  178. &mut *serverptr
  179. };
  180. let msg_slice = unsafe {
  181. assert!(!msgdata.is_null());
  182. std::slice::from_raw_parts(msgdata, msglen)
  183. };
  184. let retvec = server.preproc_process(msg_slice);
  185. to_vecdata(retvec)
  186. }
  187. #[no_mangle]
  188. pub extern "C" fn spir_server_query_process(
  189. serverptr: *mut Server,
  190. msgdata: *const c_uchar,
  191. msglen: usize,
  192. db: *const DbEntry,
  193. rot: usize,
  194. blind: DbEntry,
  195. ) -> VecData {
  196. let server = unsafe {
  197. assert!(!serverptr.is_null());
  198. &mut *serverptr
  199. };
  200. let msg_slice = unsafe {
  201. assert!(!msgdata.is_null());
  202. std::slice::from_raw_parts(msgdata, msglen)
  203. };
  204. let retvec = server.query_process(msg_slice, db, rot, blind);
  205. to_vecdata(retvec)
  206. }