Browse Source

client query_finish API

Ian Goldberg 1 year ago
parent
commit
23df7c068e
7 changed files with 84 additions and 13 deletions
  1. 6 0
      cxx/spir.cpp
  2. 3 0
      cxx/spir_ffi.h
  3. 7 0
      cxx/spir_test.cpp
  4. 56 7
      src/client.rs
  5. 3 4
      src/lib.rs
  6. 7 0
      src/main.rs
  7. 2 2
      src/server.rs

+ 6 - 0
cxx/spir.cpp

@@ -42,6 +42,12 @@ string SPIR_Client::query(size_t idx)
     return ret;
 }
 
+SPIR::DBEntry SPIR_Client::query_finish(const string &server_resp)
+{
+    return spir_client_query_finish(this->client,
+        server_resp.data(), server_resp.length());
+}
+
 SPIR_Server::SPIR_Server(uint8_t r, const string &pub_params)
 {
     this->server = spir_server_new(r, pub_params.data(),

+ 3 - 0
cxx/spir_ffi.h

@@ -33,6 +33,9 @@ extern void spir_client_preproc_finish(void *client,
 
 extern VecData spir_client_query(void *client, size_t idx);
 
+extern DBEntry spir_client_query_finish(void *client,
+    const char *msgdata, size_t msglen);
+
 extern void* spir_server_new(uint8_t r, const char *pub_params,
     size_t pub_params_len);
 

+ 7 - 0
cxx/spir_test.cpp

@@ -111,6 +111,13 @@ int main(int argc, char **argv)
         cout << "Query server: " << query_server_us << " µs\n";
         cout << "query_resp len = " << query_resp.length() << "\n";
 
+        struct timeval query_finish_start;
+        gettimeofday(&query_finish_start, NULL);
+
+        SPIR::DBEntry entry = client.query_finish(query_resp);
+        size_t query_finish_us = elapsed_us(&query_finish_start);
+        cout << "Query client finish: " << query_finish_us << " µs\n";
+        cout << "idx = " << idx << "; entry = " << entry << "\n";
     }
 
     delete[] db;

+ 56 - 7
src/client.rs

@@ -16,6 +16,7 @@ use subtle::Choice;
 use curve25519_dalek::ristretto::RistrettoPoint;
 use curve25519_dalek::scalar::Scalar;
 
+use crate::dbentry_decrypt;
 use crate::ot::*;
 use crate::params;
 use crate::to_vecdata;
@@ -28,6 +29,7 @@ enum Command {
     PreProc(usize),
     PreProcFinish(Vec<PreProcSingleRespMsg>),
     Query(usize),
+    QueryFinish(Vec<u8>),
 }
 
 enum Response {
@@ -35,6 +37,7 @@ enum Response {
     PreProcMsg(Vec<u8>),
     PreProcDone,
     QueryMsg(Vec<u8>),
+    QueryDone(DbEntry),
 }
 
 // The internal client state for a single outstanding preprocess query
@@ -129,7 +132,7 @@ impl Client {
                         // panic if there are no preproc states
                         // available
                         let nextstate = preproc_state.pop_front().unwrap();
-                        let offset = (num_records + nextstate.rand_idx - idx) & num_records_mask;
+                        let offset = (num_records + idx - nextstate.rand_idx) & num_records_mask;
                         let mut querymsg: Vec<u8> = Vec::new();
                         querymsg.extend(offset.to_le_bytes());
                         query_state.push_back(nextstate);
@@ -137,6 +140,27 @@ impl Client {
                             .send(Response::QueryMsg(querymsg))
                             .unwrap();
                     }
+                    Ok(Command::QueryFinish(msg)) => {
+                        // panic if there is no outstanding state
+                        let nextstate = query_state.pop_front().unwrap();
+                        let encdbblock = spiral_client.decode_response(msg.as_slice());
+                        // Extract the one encrypted DbEntry we were
+                        // looking for (and the only one we are able to
+                        // decrypt)
+                        let entry_in_block = nextstate.rand_idx % spiral_blocking_factor;
+                        let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
+                        let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
+                        let encdbentry = DbEntry::from_le_bytes(
+                            encdbblock[loc_in_block..loc_in_block_end]
+                                .try_into()
+                                .unwrap(),
+                        );
+                        let decdbentry =
+                            dbentry_decrypt(&nextstate.ot_key, nextstate.rand_idx, encdbentry);
+                        outgoing_resp_send
+                            .send(Response::QueryDone(decdbentry))
+                            .unwrap();
+                    }
                     _ => panic!("Received something unexpected in client loop"),
                 }
             }
@@ -161,11 +185,10 @@ impl Client {
         self.incoming_cmd
             .send(Command::PreProc(num_preproc))
             .unwrap();
-        let ret = match self.outgoing_resp.recv() {
+        match self.outgoing_resp.recv() {
             Ok(Response::PreProcMsg(x)) => x,
             _ => panic!("Received something unexpected in preproc"),
-        };
-        ret
+        }
     }
 
     pub fn preproc_finish(&self, msg: &[u8]) {
@@ -180,11 +203,20 @@ impl Client {
 
     pub fn query(&self, idx: usize) -> Vec<u8> {
         self.incoming_cmd.send(Command::Query(idx)).unwrap();
-        let ret = match self.outgoing_resp.recv() {
+        match self.outgoing_resp.recv() {
             Ok(Response::QueryMsg(x)) => x,
             _ => panic!("Received something unexpected in preproc"),
-        };
-        ret
+        }
+    }
+
+    pub fn query_finish(&self, msg: &[u8]) -> DbEntry {
+        self.incoming_cmd
+            .send(Command::QueryFinish(msg.to_vec()))
+            .unwrap();
+        match self.outgoing_resp.recv() {
+            Ok(Response::QueryDone(entry)) => entry,
+            _ => panic!("Received something unexpected in preproc_finish"),
+        }
     }
 }
 
@@ -249,3 +281,20 @@ pub extern "C" fn spir_client_query(clientptr: *mut Client, idx: usize) -> VecDa
     let retvec = client.query(idx);
     to_vecdata(retvec)
 }
+
+#[no_mangle]
+pub extern "C" fn spir_client_query_finish(
+    clientptr: *mut Client,
+    msgdata: *const c_uchar,
+    msglen: usize,
+) -> DbEntry {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    client.query_finish(&msg_slice)
+}

+ 3 - 4
src/lib.rs

@@ -43,7 +43,7 @@ pub type DbEntry = u64;
 // of the provided keys, one from each pair, according to the bits of
 // the element number.  Outputs a byte vector containing the encrypted
 // database.
-pub fn encdb_xor_keys(
+fn db_encrypt(
     db: &[DbEntry],
     keys: &[[u8; 16]],
     r: usize,
@@ -53,7 +53,6 @@ pub fn encdb_xor_keys(
 ) -> Vec<u8> {
     let num_records: usize = 1 << r;
     let num_record_mask: usize = num_records - 1;
-    let negrot = num_records - rot;
     let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
     ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
     scope(|s| {
@@ -71,7 +70,7 @@ pub fn encdb_xor_keys(
             s.spawn(move |_| {
                 let mut offset = 0usize;
                 for j in record_thread_start..record_thread_end {
-                    let rec = (j + negrot) & num_record_mask;
+                    let rec = (j + rot) & num_record_mask;
                     let mut key = Block::from([0u8; 16]);
                     for i in 0..r {
                         let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
@@ -97,7 +96,7 @@ pub fn encdb_xor_keys(
 // Having received the key for element q with r parallel 1-out-of-2 OTs,
 // and having received the encrypted element with (non-symmetric) PIR,
 // use the key to decrypt the element.
-pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
+fn dbentry_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
     let aes = Aes128Enc::new(key);
     let mut block = Block::from([0u8; 16]);
     block[0..8].copy_from_slice(&q.to_le_bytes());

+ 7 - 0
src/main.rs

@@ -96,6 +96,13 @@ fn main() {
 
         println!("Query server: {} µs", query_server_us);
         println!("query_resp len = {}", query_resp.len());
+
+        let query_finish_start = Instant::now();
+        let entry = client.query_finish(&query_resp);
+        let query_finish_us = query_finish_start.elapsed().as_micros();
+
+        println!("Query client finish: {} µs", query_finish_us);
+        println!("idx = {}; entry = {}", idx, entry);
     }
 
     /*

+ 2 - 2
src/server.rs

@@ -10,7 +10,7 @@ use spiral_rs::client::Query;
 use spiral_rs::params::Params;
 use spiral_rs::server::process_query;
 
-use crate::encdb_xor_keys;
+use crate::db_encrypt;
 use crate::load_db_from_slice_mt;
 use crate::ot::*;
 use crate::params;
@@ -98,7 +98,7 @@ impl Server {
                         let encdb = unsafe {
                             let dbslice =
                                 std::slice::from_raw_parts(db as *const DbEntry, num_records);
-                            encdb_xor_keys(
+                            db_encrypt(
                                 &dbslice,
                                 &nextstate.db_keys,
                                 r,