server.rs 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. use std::collections::VecDeque;
  2. use std::os::raw::c_uchar;
  3. use std::sync::mpsc::*;
  4. use std::thread::*;
  5. use rayon::prelude::*;
  6. use spiral_rs::client::PublicParameters;
  7. use spiral_rs::client::Query;
  8. use spiral_rs::params::Params;
  9. use crate::ot::*;
  10. use crate::params;
  11. use crate::to_vecdata;
  12. use crate::PreProcSingleMsg;
  13. use crate::PreProcSingleRespMsg;
  14. use crate::VecData;
  15. enum Command {
  16. PubParams(Vec<u8>),
  17. PreProcMsg(Vec<PreProcSingleMsg>),
  18. }
  19. enum Response {
  20. PreProcResp(Vec<u8>),
  21. }
  22. // The internal client state for a single preprocess query
  23. struct PreProcSingleState<'a> {
  24. db_keys: Vec<[u8; 16]>,
  25. query: Query<'a>,
  26. }
  27. pub struct Server {
  28. r: usize,
  29. thread_handle: JoinHandle<()>,
  30. incoming_cmd: SyncSender<Command>,
  31. outgoing_resp: Receiver<Response>,
  32. }
  33. impl Server {
  34. pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
  35. let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
  36. let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
  37. let thread_handle = spawn(move || {
  38. let spiral_params = params::get_spiral_params(r);
  39. let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
  40. // State for preprocessing queries
  41. let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
  42. // Wait for commands
  43. loop {
  44. match incoming_cmd_recv.recv() {
  45. Err(_) => break,
  46. Ok(Command::PreProcMsg(cliquery)) => {
  47. let num_preproc = cliquery.len();
  48. let mut resp_state: Vec<PreProcSingleState> =
  49. Vec::with_capacity(num_preproc);
  50. let mut resp_msg: Vec<PreProcSingleRespMsg> =
  51. Vec::with_capacity(num_preproc);
  52. cliquery
  53. .into_par_iter()
  54. .map(|q| {
  55. let db_keys = gen_db_enc_keys(r);
  56. let query = Query::deserialize(&spiral_params, &q.spc_query);
  57. let ot_resp = otkey_serve(q.ot_query, &db_keys);
  58. (
  59. PreProcSingleState { db_keys, query },
  60. PreProcSingleRespMsg { ot_resp },
  61. )
  62. })
  63. .unzip_into_vecs(&mut resp_state, &mut resp_msg);
  64. preproc_state.append(&mut VecDeque::from(resp_state));
  65. let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
  66. outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
  67. }
  68. _ => panic!("Received something unexpected"),
  69. }
  70. }
  71. });
  72. Server {
  73. r,
  74. thread_handle,
  75. incoming_cmd,
  76. outgoing_resp,
  77. }
  78. }
  79. pub fn preproc_process(&self, msg: &[u8]) -> Vec<u8> {
  80. self.incoming_cmd
  81. .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
  82. .unwrap();
  83. let ret = match self.outgoing_resp.recv() {
  84. Ok(Response::PreProcResp(x)) => x,
  85. _ => panic!("Received something unexpected in preproc_process"),
  86. };
  87. ret
  88. }
  89. }
  90. #[no_mangle]
  91. pub extern "C" fn spir_server_new(
  92. r: u8,
  93. pub_params: *const c_uchar,
  94. pub_params_len: usize,
  95. ) -> *mut Server {
  96. let pub_params_slice = unsafe {
  97. assert!(!pub_params.is_null());
  98. std::slice::from_raw_parts(pub_params, pub_params_len)
  99. };
  100. let mut pub_params_vec: Vec<u8> = Vec::new();
  101. pub_params_vec.extend_from_slice(pub_params_slice);
  102. let server = Server::new(r as usize, pub_params_vec);
  103. Box::into_raw(Box::new(server))
  104. }
  105. #[no_mangle]
  106. pub extern "C" fn spir_server_free(server: *mut Server) {
  107. if server.is_null() {
  108. return;
  109. }
  110. unsafe {
  111. Box::from_raw(server);
  112. }
  113. }
  114. #[no_mangle]
  115. pub extern "C" fn spir_server_preproc_process(
  116. serverptr: *mut Server,
  117. msgdata: *const c_uchar,
  118. msglen: usize,
  119. ) -> VecData {
  120. let server = unsafe {
  121. assert!(!serverptr.is_null());
  122. &mut *serverptr
  123. };
  124. let msg_slice = unsafe {
  125. assert!(!msgdata.is_null());
  126. std::slice::from_raw_parts(msgdata, msglen)
  127. };
  128. let retvec = server.preproc_process(&msg_slice);
  129. to_vecdata(retvec)
  130. }