client.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. use rand::rngs::ThreadRng;
  2. use rand::RngCore;
  3. use std::collections::VecDeque;
  4. use std::mem;
  5. use std::os::raw::c_uchar;
  6. use std::sync::mpsc::*;
  7. use std::thread::*;
  8. use aes::Block;
  9. use rayon::prelude::*;
  10. use subtle::Choice;
  11. use curve25519_dalek::ristretto::RistrettoPoint;
  12. use curve25519_dalek::scalar::Scalar;
  13. use crate::dbentry_decrypt;
  14. use crate::ot::*;
  15. use crate::params;
  16. use crate::to_vecdata;
  17. use crate::DbEntry;
  18. use crate::PreProcSingleMsg;
  19. use crate::PreProcSingleRespMsg;
  20. use crate::VecData;
  21. enum Command {
  22. PreProc(usize),
  23. PreProcFinish(Vec<PreProcSingleRespMsg>),
  24. Query(usize),
  25. QueryFinish(Vec<u8>),
  26. }
  27. enum Response {
  28. PubParams(Vec<u8>),
  29. PreProcMsg(Vec<u8>),
  30. PreProcDone,
  31. QueryMsg(Vec<u8>),
  32. QueryDone(DbEntry),
  33. }
  34. // The internal client state for a single outstanding preprocess query
  35. struct PreProcOutSingleState {
  36. rand_idx: usize,
  37. ot_state: Vec<(Choice, Scalar)>,
  38. }
  39. // The internal client state for a single preprocess ready to be used
  40. struct PreProcSingleState {
  41. rand_idx: usize,
  42. ot_key: Block,
  43. }
  44. pub struct Client {
  45. r: usize,
  46. thread_handle: JoinHandle<()>,
  47. incoming_cmd: SyncSender<Command>,
  48. outgoing_resp: Receiver<Response>,
  49. }
  50. impl Client {
  51. pub fn new(r: usize) -> (Self, Vec<u8>) {
  52. let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
  53. let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
  54. let thread_handle = spawn(move || {
  55. let spiral_params = params::get_spiral_params(r);
  56. let mut clientrng = rand::thread_rng();
  57. let mut rng = rand::thread_rng();
  58. let mut spiral_client = spiral_rs::client::Client::init(&spiral_params, &mut clientrng);
  59. let num_records = 1 << r;
  60. let num_records_mask = num_records - 1;
  61. let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
  62. // The first communication is the pub_params
  63. let pub_params = spiral_client.generate_keys().serialize();
  64. outgoing_resp_send
  65. .send(Response::PubParams(pub_params))
  66. .unwrap();
  67. // State for outstanding preprocessing queries
  68. let mut preproc_out_state: Vec<PreProcOutSingleState> = Vec::new();
  69. // State for preprocessing queries ready to be used
  70. let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
  71. // State for outstanding active queries
  72. let mut query_state: VecDeque<PreProcSingleState> = VecDeque::new();
  73. // Wait for commands
  74. loop {
  75. match incoming_cmd_recv.recv() {
  76. Err(_) => break,
  77. Ok(Command::PreProc(num_preproc)) => {
  78. // Ensure we don't already have outstanding
  79. // preprocessing state
  80. assert!(preproc_out_state.len() == 0);
  81. let mut preproc_msg: Vec<PreProcSingleMsg> = Vec::new();
  82. for _ in 0..num_preproc {
  83. let rand_idx = (rng.next_u64() as usize) & num_records_mask;
  84. let rand_pir_idx = rand_idx / spiral_blocking_factor;
  85. let spc_query = spiral_client.generate_query(rand_pir_idx).serialize();
  86. let (ot_state, ot_query) = otkey_request(rand_idx, r);
  87. preproc_out_state.push(PreProcOutSingleState { rand_idx, ot_state });
  88. preproc_msg.push(PreProcSingleMsg {
  89. ot_query,
  90. spc_query,
  91. });
  92. }
  93. let ret: Vec<u8> = bincode::serialize(&preproc_msg).unwrap();
  94. outgoing_resp_send.send(Response::PreProcMsg(ret)).unwrap();
  95. }
  96. Ok(Command::PreProcFinish(srvresp)) => {
  97. let num_preproc = srvresp.len();
  98. assert!(preproc_out_state.len() == num_preproc);
  99. let mut newstate: VecDeque<PreProcSingleState> = preproc_out_state
  100. .into_par_iter()
  101. .zip(srvresp)
  102. .map(|(c, s)| {
  103. let ot_key = otkey_receive(c.ot_state, &s.ot_resp);
  104. PreProcSingleState {
  105. rand_idx: c.rand_idx,
  106. ot_key,
  107. }
  108. })
  109. .collect();
  110. preproc_state.append(&mut newstate);
  111. preproc_out_state = Vec::new();
  112. outgoing_resp_send.send(Response::PreProcDone).unwrap();
  113. }
  114. Ok(Command::Query(idx)) => {
  115. // panic if there are no preproc states
  116. // available
  117. let nextstate = preproc_state.pop_front().unwrap();
  118. let offset = (num_records + idx - nextstate.rand_idx) & num_records_mask;
  119. let mut querymsg: Vec<u8> = Vec::new();
  120. querymsg.extend(offset.to_le_bytes());
  121. query_state.push_back(nextstate);
  122. outgoing_resp_send
  123. .send(Response::QueryMsg(querymsg))
  124. .unwrap();
  125. }
  126. Ok(Command::QueryFinish(msg)) => {
  127. // panic if there is no outstanding state
  128. let nextstate = query_state.pop_front().unwrap();
  129. let encdbblock = spiral_client.decode_response(msg.as_slice());
  130. // Extract the one encrypted DbEntry we were
  131. // looking for (and the only one we are able to
  132. // decrypt)
  133. let entry_in_block = nextstate.rand_idx % spiral_blocking_factor;
  134. let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
  135. let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
  136. let encdbentry = DbEntry::from_le_bytes(
  137. encdbblock[loc_in_block..loc_in_block_end]
  138. .try_into()
  139. .unwrap(),
  140. );
  141. let decdbentry =
  142. dbentry_decrypt(&nextstate.ot_key, nextstate.rand_idx, encdbentry);
  143. outgoing_resp_send
  144. .send(Response::QueryDone(decdbentry))
  145. .unwrap();
  146. }
  147. _ => panic!("Received something unexpected in client loop"),
  148. }
  149. }
  150. });
  151. let pub_params = match outgoing_resp.recv() {
  152. Ok(Response::PubParams(x)) => x,
  153. _ => panic!("Received something unexpected in client new"),
  154. };
  155. (
  156. Client {
  157. r,
  158. thread_handle,
  159. incoming_cmd,
  160. outgoing_resp,
  161. },
  162. pub_params,
  163. )
  164. }
  165. pub fn preproc(&self, num_preproc: usize) -> Vec<u8> {
  166. self.incoming_cmd
  167. .send(Command::PreProc(num_preproc))
  168. .unwrap();
  169. match self.outgoing_resp.recv() {
  170. Ok(Response::PreProcMsg(x)) => x,
  171. _ => panic!("Received something unexpected in preproc"),
  172. }
  173. }
  174. pub fn preproc_finish(&self, msg: &[u8]) {
  175. self.incoming_cmd
  176. .send(Command::PreProcFinish(bincode::deserialize(msg).unwrap()))
  177. .unwrap();
  178. match self.outgoing_resp.recv() {
  179. Ok(Response::PreProcDone) => (),
  180. _ => panic!("Received something unexpected in preproc_finish"),
  181. }
  182. }
  183. pub fn query(&self, idx: usize) -> Vec<u8> {
  184. self.incoming_cmd.send(Command::Query(idx)).unwrap();
  185. match self.outgoing_resp.recv() {
  186. Ok(Response::QueryMsg(x)) => x,
  187. _ => panic!("Received something unexpected in preproc"),
  188. }
  189. }
  190. pub fn query_finish(&self, msg: &[u8]) -> DbEntry {
  191. self.incoming_cmd
  192. .send(Command::QueryFinish(msg.to_vec()))
  193. .unwrap();
  194. match self.outgoing_resp.recv() {
  195. Ok(Response::QueryDone(entry)) => entry,
  196. _ => panic!("Received something unexpected in preproc_finish"),
  197. }
  198. }
  199. }
  200. #[repr(C)]
  201. pub struct ClientNewRet {
  202. client: *mut Client,
  203. pub_params: VecData,
  204. }
  205. #[no_mangle]
  206. pub extern "C" fn spir_client_new(r: u8) -> ClientNewRet {
  207. let (client, pub_params) = Client::new(r as usize);
  208. ClientNewRet {
  209. client: Box::into_raw(Box::new(client)),
  210. pub_params: to_vecdata(pub_params),
  211. }
  212. }
  213. #[no_mangle]
  214. pub extern "C" fn spir_client_free(client: *mut Client) {
  215. if client.is_null() {
  216. return;
  217. }
  218. unsafe {
  219. Box::from_raw(client);
  220. }
  221. }
  222. #[no_mangle]
  223. pub extern "C" fn spir_client_preproc(clientptr: *mut Client, num_preproc: u32) -> VecData {
  224. let client = unsafe {
  225. assert!(!clientptr.is_null());
  226. &mut *clientptr
  227. };
  228. let retvec = client.preproc(num_preproc as usize);
  229. to_vecdata(retvec)
  230. }
  231. #[no_mangle]
  232. pub extern "C" fn spir_client_preproc_finish(
  233. clientptr: *mut Client,
  234. msgdata: *const c_uchar,
  235. msglen: usize,
  236. ) {
  237. let client = unsafe {
  238. assert!(!clientptr.is_null());
  239. &mut *clientptr
  240. };
  241. let msg_slice = unsafe {
  242. assert!(!msgdata.is_null());
  243. std::slice::from_raw_parts(msgdata, msglen)
  244. };
  245. client.preproc_finish(&msg_slice);
  246. }
  247. #[no_mangle]
  248. pub extern "C" fn spir_client_query(clientptr: *mut Client, idx: usize) -> VecData {
  249. let client = unsafe {
  250. assert!(!clientptr.is_null());
  251. &mut *clientptr
  252. };
  253. let retvec = client.query(idx);
  254. to_vecdata(retvec)
  255. }
  256. #[no_mangle]
  257. pub extern "C" fn spir_client_query_finish(
  258. clientptr: *mut Client,
  259. msgdata: *const c_uchar,
  260. msglen: usize,
  261. ) -> DbEntry {
  262. let client = unsafe {
  263. assert!(!clientptr.is_null());
  264. &mut *clientptr
  265. };
  266. let msg_slice = unsafe {
  267. assert!(!msgdata.is_null());
  268. std::slice::from_raw_parts(msgdata, msglen)
  269. };
  270. client.query_finish(&msg_slice)
  271. }