Browse Source

server query_process API

Ian Goldberg 1 year ago
parent
commit
b49605e91c
6 changed files with 129 additions and 1 deletions
  1. 10 0
      cxx/spir.cpp
  2. 6 0
      cxx/spir_ffi.h
  3. 16 0
      cxx/spir_test.cpp
  4. 1 1
      src/client.rs
  5. 14 0
      src/main.rs
  6. 82 0
      src/server.rs

+ 10 - 0
cxx/spir.cpp

@@ -61,3 +61,13 @@ string SPIR_Server::preproc_process(const string &msg)
     spir_vecdata_free(retmsg);
     return ret;
 }
+
+string SPIR_Server::query_process(const string &client_query,
+        const SPIR::DBEntry *db, size_t rot, SPIR::DBEntry blind)
+{
+    VecData retmsg = spir_server_query_process(this->server,
+        client_query.data(), client_query.length(), db, rot, blind);
+    string ret(retmsg.data, retmsg.len);
+    spir_vecdata_free(retmsg);
+    return ret;
+}

+ 6 - 0
cxx/spir_ffi.h

@@ -7,6 +7,8 @@
 extern "C" {
 #endif
 
+typedef size_t DBEntry;
+
 typedef struct {
     const char *data;
     size_t len;
@@ -39,6 +41,10 @@ extern void spir_server_free(void *server);
 extern VecData spir_server_preproc_process(void *server,
     const char *msgdata, size_t msglen);
 
+extern VecData spir_server_query_process(void *server,
+    const char *msgdata, size_t msglen, const DBEntry *db,
+    size_t rot, DBEntry blind);
+
 extern void spir_vecdata_free(VecData vecdata);
 
 #ifdef __cplusplus

+ 16 - 0
cxx/spir_test.cpp

@@ -79,6 +79,12 @@ int main(int argc, char **argv)
     size_t preproc_finish_us = elapsed_us(&preproc_finish_start);
     cout << "Preprocessing client finish: " << preproc_finish_us << " µs\n";
 
+    // Create the database
+    SPIR::DBEntry *db = new SPIR::DBEntry[num_records];
+    for (size_t i=0; i<num_records; ++i) {
+        db[i] = i * 10000001;
+    }
+
     for (size_t i=0; i<num_pirs; ++i) {
         cout << "\n===== SPIR QUERY " << i+1 << " =====\n\n";
 
@@ -97,7 +103,17 @@ int main(int argc, char **argv)
         cout << "Preprocessing client: " << query_client_us << " µs\n";
         cout << "query_msg len = " << query_msg.length() << "\n";
 
+        struct timeval query_server_start;
+        gettimeofday(&query_server_start, NULL);
+
+        string query_resp = server.query_process(query_msg, db, 100, 20);
+        size_t query_server_us = elapsed_us(&query_server_start);
+        cout << "Query server: " << query_server_us << " µs\n";
+        cout << "query_resp len = " << query_resp.length() << "\n";
+
     }
 
+    delete[] db;
+
     return 0;
 }

+ 1 - 1
src/client.rs

@@ -131,7 +131,7 @@ impl Client {
                         let nextstate = preproc_state.pop_front().unwrap();
                         let offset = (num_records + nextstate.rand_idx - idx) & num_records_mask;
                         let mut querymsg: Vec<u8> = Vec::new();
-                        querymsg.extend(idx.to_le_bytes());
+                        querymsg.extend(offset.to_le_bytes());
                         query_state.push_back(nextstate);
                         outgoing_resp_send
                             .send(Response::QueryMsg(querymsg))

+ 14 - 0
src/main.rs

@@ -6,6 +6,7 @@ use rand::RngCore;
 use spiral_spir::client::Client;
 use spiral_spir::init;
 use spiral_spir::server::Server;
+use spiral_spir::DbEntry;
 
 fn main() {
     let args: Vec<String> = env::args().collect();
@@ -71,6 +72,12 @@ fn main() {
 
     println!("Preprocessing client finish: {} µs", preproc_finish_us);
 
+    // Create a database with recognizable contents
+    let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
+        .map(|x| 10000001 * x)
+        .collect();
+    let dbptr = db.as_ptr();
+
     let mut rng = rand::thread_rng();
     for i in 1..num_pirs + 1 {
         println!("\n===== SPIR QUERY {} =====\n", i);
@@ -82,6 +89,13 @@ fn main() {
 
         println!("Query client: {} µs", query_client_us);
         println!("query_msg len = {}", query_msg.len());
+
+        let query_server_start = Instant::now();
+        let query_resp = server.query_process(&query_msg, dbptr, 100, 20);
+        let query_server_us = query_server_start.elapsed().as_micros();
+
+        println!("Query server: {} µs", query_server_us);
+        println!("query_resp len = {}", query_resp.len());
     }
 
     /*

+ 82 - 0
src/server.rs

@@ -8,10 +8,14 @@ use rayon::prelude::*;
 use spiral_rs::client::PublicParameters;
 use spiral_rs::client::Query;
 use spiral_rs::params::Params;
+use spiral_rs::server::process_query;
 
+use crate::encdb_xor_keys;
+use crate::load_db_from_slice_mt;
 use crate::ot::*;
 use crate::params;
 use crate::to_vecdata;
+use crate::DbEntry;
 use crate::PreProcSingleMsg;
 use crate::PreProcSingleRespMsg;
 use crate::VecData;
@@ -19,10 +23,12 @@ use crate::VecData;
 enum Command {
     PubParams(Vec<u8>),
     PreProcMsg(Vec<PreProcSingleMsg>),
+    QueryMsg(usize, usize, usize, DbEntry),
 }
 
 enum Response {
     PreProcResp(Vec<u8>),
+    QueryResp(Vec<u8>),
 }
 
 // The internal client state for a single preprocess query
@@ -45,6 +51,8 @@ impl Server {
         let thread_handle = spawn(move || {
             let spiral_params = params::get_spiral_params(r);
             let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
+            let num_records = 1 << r;
+            let num_records_mask = num_records - 1;
 
             // State for preprocessing queries
             let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
@@ -75,6 +83,41 @@ impl Server {
                         let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
                         outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
                     }
+                    Ok(Command::QueryMsg(offset, db, rot, blind)) => {
+                        // Panic if there's no preprocess state
+                        // available
+                        let nextstate = preproc_state.pop_front().unwrap();
+                        // Encrypt the database with the keys, rotating
+                        // and blinding in the process.  It is safe to
+                        // construct a slice out of the const pointer we
+                        // were handed because that pointer will stay
+                        // valid until we return something back to the
+                        // caller.
+                        let totoffset = (offset + rot) & num_records_mask;
+                        let num_threads = rayon::current_num_threads();
+                        let encdb = unsafe {
+                            let dbslice =
+                                std::slice::from_raw_parts(db as *const DbEntry, num_records);
+                            encdb_xor_keys(
+                                &dbslice,
+                                &nextstate.db_keys,
+                                r,
+                                totoffset,
+                                blind,
+                                num_threads,
+                            )
+                        };
+                        // Load the encrypted db into Spiral
+                        let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
+                        // Process the query
+                        let resp = process_query(
+                            &spiral_params,
+                            &pub_params,
+                            &nextstate.query,
+                            sps_db.as_slice(),
+                        );
+                        outgoing_resp_send.send(Response::QueryResp(resp)).unwrap();
+                    }
                     _ => panic!("Received something unexpected"),
                 }
             }
@@ -97,6 +140,24 @@ impl Server {
         };
         ret
     }
+
+    pub fn query_process(
+        &self,
+        msg: &[u8],
+        db: *const DbEntry,
+        rot: usize,
+        blind: DbEntry,
+    ) -> Vec<u8> {
+        let offset = usize::from_le_bytes(msg.try_into().unwrap());
+        self.incoming_cmd
+            .send(Command::QueryMsg(offset, db as usize, rot, blind))
+            .unwrap();
+        let ret = match self.outgoing_resp.recv() {
+            Ok(Response::QueryResp(x)) => x,
+            _ => panic!("Received something unexpected in query_process"),
+        };
+        ret
+    }
 }
 
 #[no_mangle]
@@ -142,3 +203,24 @@ pub extern "C" fn spir_server_preproc_process(
     let retvec = server.preproc_process(&msg_slice);
     to_vecdata(retvec)
 }
+
+#[no_mangle]
+pub extern "C" fn spir_server_query_process(
+    serverptr: *mut Server,
+    msgdata: *const c_uchar,
+    msglen: usize,
+    db: *const DbEntry,
+    rot: usize,
+    blind: DbEntry,
+) -> VecData {
+    let server = unsafe {
+        assert!(!serverptr.is_null());
+        &mut *serverptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    let retvec = server.query_process(&msg_slice, db, rot, blind);
+    to_vecdata(retvec)
+}