server.rs 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. use std::os::raw::c_uchar;
  2. use std::sync::mpsc::*;
  3. use std::thread::*;
  4. use rayon::prelude::*;
  5. use spiral_rs::client::PublicParameters;
  6. use spiral_rs::client::Query;
  7. use spiral_rs::params::Params;
  8. use crate::ot::*;
  9. use crate::params;
  10. use crate::to_vecdata;
  11. use crate::PreProcSingleMsg;
  12. use crate::PreProcSingleRespMsg;
  13. use crate::VecData;
  14. enum Command {
  15. PubParams(Vec<u8>),
  16. PreProcMsg(Vec<PreProcSingleMsg>),
  17. }
  18. enum Response {
  19. PreProcResp(Vec<u8>),
  20. }
  21. // The internal client state for a single preprocess query
  22. struct PreProcSingleState<'a> {
  23. db_keys: Vec<[u8; 16]>,
  24. query: Query<'a>,
  25. }
  26. pub struct Server {
  27. r: usize,
  28. thread_handle: JoinHandle<()>,
  29. incoming_cmd: SyncSender<Command>,
  30. outgoing_resp: Receiver<Response>,
  31. }
  32. impl Server {
  33. pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
  34. let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
  35. let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
  36. let thread_handle = spawn(move || {
  37. let spiral_params = params::get_spiral_params(r);
  38. let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
  39. // State for preprocessing queries
  40. let mut preproc_state: Vec<PreProcSingleState> = Vec::new();
  41. // Wait for commands
  42. loop {
  43. match incoming_cmd_recv.recv() {
  44. Err(_) => break,
  45. Ok(Command::PreProcMsg(cliquery)) => {
  46. let num_preproc = cliquery.len();
  47. let mut resp_state: Vec<PreProcSingleState> =
  48. Vec::with_capacity(num_preproc);
  49. let mut resp_msg: Vec<PreProcSingleRespMsg> =
  50. Vec::with_capacity(num_preproc);
  51. cliquery
  52. .into_par_iter()
  53. .map(|q| {
  54. let db_keys = gen_db_enc_keys(r);
  55. let query = Query::deserialize(&spiral_params, &q.spc_query);
  56. let ot_resp = otkey_serve(q.ot_query, &db_keys);
  57. (
  58. PreProcSingleState { db_keys, query },
  59. PreProcSingleRespMsg { ot_resp },
  60. )
  61. })
  62. .unzip_into_vecs(&mut resp_state, &mut resp_msg);
  63. preproc_state.append(&mut resp_state);
  64. let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
  65. outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
  66. }
  67. _ => panic!("Received something unexpected"),
  68. }
  69. }
  70. });
  71. Server {
  72. r,
  73. thread_handle,
  74. incoming_cmd,
  75. outgoing_resp,
  76. }
  77. }
  78. pub fn preproc_PIRs(&self, msg: &[u8]) -> Vec<u8> {
  79. self.incoming_cmd
  80. .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
  81. .unwrap();
  82. let ret = match self.outgoing_resp.recv() {
  83. Ok(Response::PreProcResp(x)) => x,
  84. _ => panic!("Received something unexpected in preproc_PIRs"),
  85. };
  86. ret
  87. }
  88. }
  89. #[no_mangle]
  90. pub extern "C" fn spir_server_new(
  91. r: u8,
  92. pub_params: *const c_uchar,
  93. pub_params_len: usize,
  94. ) -> *mut Server {
  95. let pub_params_slice = unsafe {
  96. assert!(!pub_params.is_null());
  97. std::slice::from_raw_parts(pub_params, pub_params_len)
  98. };
  99. let mut pub_params_vec: Vec<u8> = Vec::new();
  100. pub_params_vec.extend_from_slice(pub_params_slice);
  101. let server = Server::new(r as usize, pub_params_vec);
  102. Box::into_raw(Box::new(server))
  103. }
  104. #[no_mangle]
  105. pub extern "C" fn spir_server_free(server: *mut Server) {
  106. if server.is_null() {
  107. return;
  108. }
  109. unsafe {
  110. Box::from_raw(server);
  111. }
  112. }
  113. #[no_mangle]
  114. pub extern "C" fn spir_server_preproc_PIRs(
  115. serverptr: *mut Server,
  116. msgdata: *const c_uchar,
  117. msglen: usize,
  118. ) -> VecData {
  119. let server = unsafe {
  120. assert!(!serverptr.is_null());
  121. &mut *serverptr
  122. };
  123. let msg_slice = unsafe {
  124. assert!(!msgdata.is_null());
  125. std::slice::from_raw_parts(msgdata, msglen)
  126. };
  127. let retvec = server.preproc_PIRs(&msg_slice);
  128. to_vecdata(retvec)
  129. }