Browse Source

preproc_finish API

Ian Goldberg 1 year ago
parent
commit
473d2cb478
7 changed files with 99 additions and 10 deletions
  1. 6 0
      cxx/spir.cpp
  2. 1 1
      cxx/spir.hpp
  3. 3 0
      cxx/spir_ffi.h
  4. 9 0
      cxx/spir_test.cpp
  5. 69 7
      src/client.rs
  6. 8 0
      src/main.rs
  7. 3 2
      src/server.rs

+ 6 - 0
cxx/spir.cpp

@@ -28,6 +28,12 @@ string SPIR_Client::preproc_PIRs(uint32_t num_preproc)
     return ret;
 }
 
+void SPIR_Client::preproc_finish(const string &server_preproc)
+{
+    spir_client_preproc_finish(this->client, server_preproc.data(),
+        server_preproc.length());
+}
+
 SPIR_Server::SPIR_Server(uint8_t r, const string &pub_params)
 {
     this->server = spir_server_new(r, pub_params.data(),

+ 1 - 1
cxx/spir.hpp

@@ -22,7 +22,7 @@ public:
     // preprocessing
     string preproc_PIRs(uint32_t num_pirs); // returns the string to send to the server
 
-    void preproc_handle(const string &server_preproc);
+    void preproc_finish(const string &server_preproc);
 
     // SPIR query for index idx
     string query(size_t idx); // returns the string to send to the server

+ 3 - 0
cxx/spir_ffi.h

@@ -26,6 +26,9 @@ extern void spir_client_free(void *client);
 
 extern VecData spir_client_preproc_PIRs(void *client, uint32_t num_preproc);
 
+extern void spir_client_preproc_finish(void *client,
+    const char *msgdata, size_t msglen);
+
 extern void* spir_server_new(uint8_t r, const char *pub_params,
     size_t pub_params_len);
 

+ 9 - 0
cxx/spir_test.cpp

@@ -51,6 +51,8 @@ int main(int argc, char **argv)
 
     cout << "\n===== PREPROCESSING =====\n\n";
 
+    cout << "num_preproc = " << num_preproc << "\n";
+
     struct timeval preproc_client_start;
     gettimeofday(&preproc_client_start, NULL);
 
@@ -67,5 +69,12 @@ int main(int argc, char **argv)
     cout << "Preprocessing server: " << preproc_server_us << " µs\n";
     cout << "preproc_response len = " << preproc_resp.length() << "\n";
 
+    struct timeval preproc_finish_start;
+    gettimeofday(&preproc_finish_start, NULL);
+
+    client.preproc_finish(preproc_resp);
+    size_t preproc_finish_us = elapsed_us(&preproc_finish_start);
+    cout << "Preprocessing client finish: " << preproc_finish_us << " µs\n";
+
     return 0;
 }

+ 69 - 7
src/client.rs

@@ -1,10 +1,16 @@
 use rand::rngs::ThreadRng;
 use rand::RngCore;
 
+use std::collections::VecDeque;
 use std::mem;
+use std::os::raw::c_uchar;
 use std::sync::mpsc::*;
 use std::thread::*;
 
+use aes::Block;
+
+use rayon::prelude::*;
+
 use subtle::Choice;
 
 use curve25519_dalek::ristretto::RistrettoPoint;
@@ -15,24 +21,32 @@ use crate::params;
 use crate::to_vecdata;
 use crate::DbEntry;
 use crate::PreProcSingleMsg;
+use crate::PreProcSingleRespMsg;
 use crate::VecData;
 
 enum Command {
     PreProc(usize),
-    PreProcResp(Vec<u8>),
+    PreProcFinish(Vec<PreProcSingleRespMsg>),
 }
 
 enum Response {
     PubParams(Vec<u8>),
     PreProcMsg(Vec<u8>),
+    PreProcDone,
 }
 
-// The internal client state for a single preprocess query
-struct PreProcSingleState {
+// The internal client state for a single outstanding preprocess query
+struct PreProcOutSingleState {
     rand_idx: usize,
     ot_state: Vec<(Choice, Scalar)>,
 }
 
+// The internal client state for a single preprocess ready to be used
+struct PreProcSingleState {
+    rand_idx: usize,
+    ot_key: Block,
+}
+
 pub struct Client {
     r: usize,
     thread_handle: JoinHandle<()>,
@@ -59,8 +73,11 @@ impl Client {
                 .send(Response::PubParams(pub_params))
                 .unwrap();
 
-            // State for preprocessing queries
-            let mut preproc_state: Vec<PreProcSingleState> = Vec::new();
+            // State for outstanding preprocessing queries
+            let mut preproc_out_state: Vec<PreProcOutSingleState> = Vec::new();
+
+            // State for preprocessing queries ready to be used
+            let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
 
             // Wait for commands
             loop {
@@ -69,14 +86,14 @@ impl Client {
                     Ok(Command::PreProc(num_preproc)) => {
                         // Ensure we don't already have outstanding
                         // preprocessing state
-                        assert!(preproc_state.len() == 0);
+                        assert!(preproc_out_state.len() == 0);
                         let mut preproc_msg: Vec<PreProcSingleMsg> = Vec::new();
                         for _ in 0..num_preproc {
                             let rand_idx = (rng.next_u64() as usize) & num_records_mask;
                             let rand_pir_idx = rand_idx / spiral_blocking_factor;
                             let spc_query = spiral_client.generate_query(rand_pir_idx).serialize();
                             let (ot_state, ot_query) = otkey_request(rand_idx, r);
-                            preproc_state.push(PreProcSingleState { rand_idx, ot_state });
+                            preproc_out_state.push(PreProcOutSingleState { rand_idx, ot_state });
                             preproc_msg.push(PreProcSingleMsg {
                                 ot_query,
                                 spc_query,
@@ -85,6 +102,24 @@ impl Client {
                         let ret: Vec<u8> = bincode::serialize(&preproc_msg).unwrap();
                         outgoing_resp_send.send(Response::PreProcMsg(ret)).unwrap();
                     }
+                    Ok(Command::PreProcFinish(srvresp)) => {
+                        let num_preproc = srvresp.len();
+                        assert!(preproc_out_state.len() == num_preproc);
+                        let mut newstate: VecDeque<PreProcSingleState> = preproc_out_state
+                            .into_par_iter()
+                            .zip(srvresp)
+                            .map(|(c, s)| {
+                                let ot_key = otkey_receive(c.ot_state, &s.ot_resp);
+                                PreProcSingleState {
+                                    rand_idx: c.rand_idx,
+                                    ot_key,
+                                }
+                            })
+                            .collect();
+                        preproc_state.append(&mut newstate);
+                        preproc_out_state = Vec::new();
+                        outgoing_resp_send.send(Response::PreProcDone).unwrap();
+                    }
                     _ => panic!("Received something unexpected in client loop"),
                 }
             }
@@ -115,6 +150,16 @@ impl Client {
         };
         ret
     }
+
+    pub fn preproc_finish(&self, msg: &[u8]) {
+        self.incoming_cmd
+            .send(Command::PreProcFinish(bincode::deserialize(msg).unwrap()))
+            .unwrap();
+        match self.outgoing_resp.recv() {
+            Ok(Response::PreProcDone) => (),
+            _ => panic!("Received something unexpected in preproc_finish"),
+        }
+    }
 }
 
 #[repr(C)]
@@ -151,3 +196,20 @@ pub extern "C" fn spir_client_preproc_PIRs(clientptr: *mut Client, num_preproc:
     let retvec = client.preproc_PIRs(num_preproc as usize);
     to_vecdata(retvec)
 }
+
+#[no_mangle]
+pub extern "C" fn spir_client_preproc_finish(
+    clientptr: *mut Client,
+    msgdata: *const c_uchar,
+    msglen: usize,
+) {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    client.preproc_finish(&msg_slice);
+}

+ 8 - 0
src/main.rs

@@ -50,6 +50,8 @@ fn main() {
 
     println!("\n===== PREPROCESSING =====\n");
 
+    println!("num_preproc = {}", num_preproc);
+
     let preproc_client_start = Instant::now();
     let preproc_msg = client.preproc_PIRs(num_preproc);
     let preproc_client_us = preproc_client_start.elapsed().as_micros();
@@ -64,6 +66,12 @@ fn main() {
     println!("Preprocessing server: {} µs", preproc_server_us);
     println!("preproc_response len = {}", preproc_resp.len());
 
+    let preproc_finish_start = Instant::now();
+    client.preproc_finish(&preproc_resp);
+    let preproc_finish_us = preproc_finish_start.elapsed().as_micros();
+
+    println!("Preprocessing client finish: {} µs", preproc_finish_us);
+
     /*
         let spiral_params = params::get_spiral_params(r);
         let mut rng = rand::thread_rng();

+ 3 - 2
src/server.rs

@@ -1,3 +1,4 @@
+use std::collections::VecDeque;
 use std::os::raw::c_uchar;
 use std::sync::mpsc::*;
 use std::thread::*;
@@ -46,7 +47,7 @@ impl Server {
             let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
 
             // State for preprocessing queries
-            let mut preproc_state: Vec<PreProcSingleState> = Vec::new();
+            let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
 
             // Wait for commands
             loop {
@@ -70,7 +71,7 @@ impl Server {
                                 )
                             })
                             .unzip_into_vecs(&mut resp_state, &mut resp_msg);
-                        preproc_state.append(&mut resp_state);
+                        preproc_state.append(&mut VecDeque::from(resp_state));
                         let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
                         outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
                     }