Browse Source

server preproc_PIRs API

Ian Goldberg 1 year ago
parent
commit
217b08a574
9 changed files with 125 additions and 17 deletions
  1. 1 0
      Cargo.toml
  2. 9 0
      cxx/spir.cpp
  3. 1 1
      cxx/spir.hpp
  4. 3 0
      cxx/spir_ffi.h
  5. 8 0
      cxx/spir_test.cpp
  6. 10 15
      src/lib.rs
  7. 7 0
      src/main.rs
  8. 15 0
      src/ot.rs
  9. 71 1
      src/server.rs

+ 1 - 0
Cargo.toml

@@ -20,6 +20,7 @@ spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157
 rayon = "1.5"
 bincode = "1"
 serde = "1"
+serde_with = "2"
 
 [lib]
 crate_type = ["lib", "staticlib"]

+ 9 - 0
cxx/spir.cpp

@@ -38,3 +38,12 @@ SPIR_Server::~SPIR_Server()
 {
     spir_server_free(this->server);
 }
+
+string SPIR_Server::preproc_PIRs(const string &msg)
+{
+    VecData retmsg = spir_server_preproc_PIRs(this->server, msg.data(),
+        msg.length());
+    string ret(retmsg.data, retmsg.len);
+    spir_vecdata_free(retmsg);
+    return ret;
+}

+ 1 - 1
cxx/spir.hpp

@@ -46,7 +46,7 @@ public:
     ~SPIR_Server();
 
     // preprocessing
-    string preproc_PIR(const string &client_preproc); // returns the string to reply to the client
+    string preproc_PIRs(const string &client_preproc); // returns the string to reply to the client
   
     // SPIR query on the given database of N=2^r records, each of type DBEntry
     // rotate the database by rot, and blind each entry in the database additively with blind

+ 3 - 0
cxx/spir_ffi.h

@@ -31,6 +31,9 @@ extern void* spir_server_new(uint8_t r, const char *pub_params,
 
 extern void spir_server_free(void *server);
 
+extern VecData spir_server_preproc_PIRs(void *server,
+    const char *msgdata, size_t msglen);
+
 extern void spir_vecdata_free(VecData vecdata);
 
 #ifdef __cplusplus

+ 8 - 0
cxx/spir_test.cpp

@@ -59,5 +59,13 @@ int main(int argc, char **argv)
     cout << "Preprocessing client: " << preproc_client_us << " µs\n";
     cout << "preproc_msg len = " << preproc_msg.length() << "\n";
 
+    struct timeval preproc_server_start;
+    gettimeofday(&preproc_server_start, NULL);
+
+    string preproc_resp = server.preproc_PIRs(preproc_msg);
+    size_t preproc_server_us = elapsed_us(&preproc_server_start);
+    cout << "Preprocessing server: " << preproc_server_us << " µs\n";
+    cout << "preproc_response len = " << preproc_resp.length() << "\n";
+
     return 0;
 }

+ 10 - 15
src/lib.rs

@@ -16,8 +16,6 @@ use std::env;
 use std::mem;
 use std::time::Instant;
 
-use rand::RngCore;
-
 use serde::{Deserialize, Serialize};
 
 use std::os::raw::c_uchar;
@@ -25,6 +23,8 @@ use std::os::raw::c_uchar;
 use rayon::scope;
 use rayon::ThreadPoolBuilder;
 
+use serde_with::serde_as;
+
 use spiral_rs::client::*;
 use spiral_rs::params::*;
 use spiral_rs::server::*;
@@ -94,19 +94,6 @@ pub fn encdb_xor_keys(
     ret
 }
 
-// Generate the keys for encrypting the database
-pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
-    let mut keys: Vec<[u8; 16]> = Vec::new();
-
-    let mut rng = rand::thread_rng();
-    for _ in 0..2 * r {
-        let mut k: [u8; 16] = [0; 16];
-        rng.fill_bytes(&mut k);
-        keys.push(k);
-    }
-    keys
-}
-
 // Having received the key for element q with r parallel 1-out-of-2 OTs,
 // and having received the encrypted element with (non-symmetric) PIR,
 // use the key to decrypt the element.
@@ -148,6 +135,14 @@ struct PreProcSingleMsg {
     spc_query: Vec<u8>,
 }
 
+// The message format for a single preprocess response
+#[serde_as]
+#[derive(Serialize, Deserialize)]
+struct PreProcSingleRespMsg {
+    #[serde_as(as = "Vec<[_; 64]>")]
+    ot_resp: Vec<[u8; 64]>,
+}
+
 #[no_mangle]
 pub extern "C" fn spir_init(num_threads: u32) {
     init(num_threads as usize);

+ 7 - 0
src/main.rs

@@ -57,6 +57,13 @@ fn main() {
     println!("Preprocessing client: {} µs", preproc_client_us);
     println!("preproc_msg len = {}", preproc_msg.len());
 
+    let preproc_server_start = Instant::now();
+    let preproc_resp = server.preproc_PIRs(&preproc_msg);
+    let preproc_server_us = preproc_server_start.elapsed().as_micros();
+
+    println!("Preprocessing server: {} µs", preproc_server_us);
+    println!("preproc_response len = {}", preproc_resp.len());
+
     /*
         let spiral_params = params::get_spiral_params(r);
         let mut rng = rand::thread_rng();

+ 15 - 0
src/ot.rs

@@ -5,6 +5,8 @@ use subtle::ConditionallySelectable;
 
 use aes::Block;
 
+use rand::RngCore;
+
 use sha2::Digest;
 use sha2::Sha256;
 use sha2::Sha512;
@@ -129,3 +131,16 @@ pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) ->
     }
     key
 }
+
+// Generate the keys for encrypting the database
+pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
+    let mut keys: Vec<[u8; 16]> = Vec::new();
+
+    let mut rng = rand::thread_rng();
+    for _ in 0..2 * r {
+        let mut k: [u8; 16] = [0; 16];
+        rng.fill_bytes(&mut k);
+        keys.push(k);
+    }
+    keys
+}

+ 71 - 1
src/server.rs

@@ -2,20 +2,34 @@ use std::os::raw::c_uchar;
 use std::sync::mpsc::*;
 use std::thread::*;
 
+use rayon::prelude::*;
+
 use spiral_rs::client::PublicParameters;
+use spiral_rs::client::Query;
 use spiral_rs::params::Params;
 
+use crate::ot::*;
 use crate::params;
+use crate::to_vecdata;
+use crate::PreProcSingleMsg;
+use crate::PreProcSingleRespMsg;
+use crate::VecData;
 
 enum Command {
     PubParams(Vec<u8>),
-    PreProcMsg(Vec<u8>),
+    PreProcMsg(Vec<PreProcSingleMsg>),
 }
 
 enum Response {
     PreProcResp(Vec<u8>),
 }
 
+// The internal client state for a single preprocess query
+struct PreProcSingleState<'a> {
+    db_keys: Vec<[u8; 16]>,
+    query: Query<'a>,
+}
+
 pub struct Server {
     r: usize,
     thread_handle: JoinHandle<()>,
@@ -30,9 +44,36 @@ impl Server {
         let thread_handle = spawn(move || {
             let spiral_params = params::get_spiral_params(r);
             let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
+
+            // State for preprocessing queries
+            let mut preproc_state: Vec<PreProcSingleState> = Vec::new();
+
+            // Wait for commands
             loop {
                 match incoming_cmd_recv.recv() {
                     Err(_) => break,
+                    Ok(Command::PreProcMsg(cliquery)) => {
+                        let num_preproc = cliquery.len();
+                        let mut resp_state: Vec<PreProcSingleState> =
+                            Vec::with_capacity(num_preproc);
+                        let mut resp_msg: Vec<PreProcSingleRespMsg> =
+                            Vec::with_capacity(num_preproc);
+                        cliquery
+                            .into_par_iter()
+                            .map(|q| {
+                                let db_keys = gen_db_enc_keys(r);
+                                let query = Query::deserialize(&spiral_params, &q.spc_query);
+                                let ot_resp = otkey_serve(q.ot_query, &db_keys);
+                                (
+                                    PreProcSingleState { db_keys, query },
+                                    PreProcSingleRespMsg { ot_resp },
+                                )
+                            })
+                            .unzip_into_vecs(&mut resp_state, &mut resp_msg);
+                        preproc_state.append(&mut resp_state);
+                        let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
+                        outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
+                    }
                     _ => panic!("Received something unexpected"),
                 }
             }
@@ -44,6 +85,17 @@ impl Server {
             outgoing_resp,
         }
     }
+
+    pub fn preproc_PIRs(&self, msg: &[u8]) -> Vec<u8> {
+        self.incoming_cmd
+            .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
+            .unwrap();
+        let ret = match self.outgoing_resp.recv() {
+            Ok(Response::PreProcResp(x)) => x,
+            _ => panic!("Received something unexpected in preproc_PIRs"),
+        };
+        ret
+    }
 }
 
 #[no_mangle]
@@ -71,3 +123,21 @@ pub extern "C" fn spir_server_free(server: *mut Server) {
         Box::from_raw(server);
     }
 }
+
+#[no_mangle]
+pub extern "C" fn spir_server_preproc_PIRs(
+    serverptr: *mut Server,
+    msgdata: *const c_uchar,
+    msglen: usize,
+) -> VecData {
+    let server = unsafe {
+        assert!(!serverptr.is_null());
+        &mut *serverptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    let retvec = server.preproc_PIRs(&msg_slice);
+    to_vecdata(retvec)
+}