Browse Source

client preproc_PIRs API

Ian Goldberg 1 year ago
parent
commit
eba0f97b3b
8 changed files with 120 additions and 13 deletions
  1. 2 0
      Cargo.toml
  2. 8 0
      cxx/spir.cpp
  3. 2 0
      cxx/spir_ffi.h
  4. 11 1
      cxx/spir_test.cpp
  5. 66 9
      src/client.rs
  6. 20 1
      src/lib.rs
  7. 11 1
      src/main.rs
  8. 0 1
      src/ot.rs

+ 2 - 0
Cargo.toml

@@ -18,6 +18,8 @@ sha2 = "0.9"
 subtle = { package = "subtle-ng", version = "2.4" }
 spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157" }
 rayon = "1.5"
+bincode = "1"
+serde = "1"
 
 [lib]
 crate_type = ["lib", "staticlib"]

+ 8 - 0
cxx/spir.cpp

@@ -20,6 +20,14 @@ SPIR_Client::~SPIR_Client()
     spir_client_free(this->client);
 }
 
+string SPIR_Client::preproc_PIRs(uint32_t num_preproc)
+{
+    VecData msg = spir_client_preproc_PIRs(this->client, num_preproc);
+    string ret(msg.data, msg.len);
+    spir_vecdata_free(msg);
+    return ret;
+}
+
 SPIR_Server::SPIR_Server(uint8_t r, const string &pub_params)
 {
     this->server = spir_server_new(r, pub_params.data(),

+ 2 - 0
cxx/spir_ffi.h

@@ -24,6 +24,8 @@ extern ClientNewRet spir_client_new(uint8_t r);
 
 extern void spir_client_free(void *client);
 
+extern VecData spir_client_preproc_PIRs(void *client, uint32_t num_preproc);
+
 extern void* spir_server_new(uint8_t r, const char *pub_params,
     size_t pub_params_len);
 

+ 11 - 1
cxx/spir_test.cpp

@@ -43,11 +43,21 @@ int main(int argc, char **argv)
     SPIR::init(num_threads);
     string pub_params;
     SPIR_Client client(r, pub_params);
-    cout << "pub_params len = " << pub_params.length() << "\n";
     SPIR_Server server(r, pub_params);
 
     size_t otsetup_us = elapsed_us(&otsetup_start);
     cout << "One-time setup: " << otsetup_us << " µs\n";
+    cout << "pub_params len = " << pub_params.length() << "\n";
+
+    cout << "\n===== PREPROCESSING =====\n\n";
+
+    struct timeval preproc_client_start;
+    gettimeofday(&preproc_client_start, NULL);
+
+    string preproc_msg = client.preproc_PIRs(num_preproc);
+    size_t preproc_client_us = elapsed_us(&preproc_client_start);
+    cout << "Preprocessing client: " << preproc_client_us << " µs\n";
+    cout << "preproc_msg len = " << preproc_msg.length() << "\n";
 
     return 0;
 }

+ 66 - 9
src/client.rs

@@ -1,10 +1,20 @@
 use rand::rngs::ThreadRng;
 use rand::RngCore;
 
+use std::mem;
 use std::sync::mpsc::*;
 use std::thread::*;
 
+use subtle::Choice;
+
+use curve25519_dalek::ristretto::RistrettoPoint;
+use curve25519_dalek::scalar::Scalar;
+
+use crate::ot::*;
 use crate::params;
+use crate::to_vecdata;
+use crate::DbEntry;
+use crate::PreProcSingleMsg;
 use crate::VecData;
 
 enum Command {
@@ -17,6 +27,12 @@ enum Response {
     PreProcMsg(Vec<u8>),
 }
 
+// The internal client state for a single preprocess query
+struct PreProcSingleState {
+    rand_idx: usize,
+    ot_state: Vec<(Choice, Scalar)>,
+}
+
 pub struct Client {
     r: usize,
     thread_handle: JoinHandle<()>,
@@ -31,7 +47,11 @@ impl Client {
         let thread_handle = spawn(move || {
             let spiral_params = params::get_spiral_params(r);
             let mut clientrng = rand::thread_rng();
+            let mut rng = rand::thread_rng();
             let mut spiral_client = spiral_rs::client::Client::init(&spiral_params, &mut clientrng);
+            let num_records = 1 << r;
+            let num_records_mask = num_records - 1;
+            let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
 
             // The first communication is the pub_params
             let pub_params = spiral_client.generate_keys().serialize();
@@ -39,17 +59,39 @@ impl Client {
                 .send(Response::PubParams(pub_params))
                 .unwrap();
 
+            // State for preprocessing queries
+            let mut preproc_state: Vec<PreProcSingleState> = Vec::new();
+
             // Wait for commands
             loop {
                 match incoming_cmd_recv.recv() {
                     Err(_) => break,
-                    _ => panic!("Received something unexpected"),
+                    Ok(Command::PreProc(num_preproc)) => {
+                        // Ensure we don't already have outstanding
+                        // preprocessing state
+                        assert!(preproc_state.len() == 0);
+                        let mut preproc_msg: Vec<PreProcSingleMsg> = Vec::new();
+                        for _ in 0..num_preproc {
+                            let rand_idx = (rng.next_u64() as usize) & num_records_mask;
+                            let rand_pir_idx = rand_idx / spiral_blocking_factor;
+                            let spc_query = spiral_client.generate_query(rand_pir_idx).serialize();
+                            let (ot_state, ot_query) = otkey_request(rand_idx, r);
+                            preproc_state.push(PreProcSingleState { rand_idx, ot_state });
+                            preproc_msg.push(PreProcSingleMsg {
+                                ot_query,
+                                spc_query,
+                            });
+                        }
+                        let ret: Vec<u8> = bincode::serialize(&preproc_msg).unwrap();
+                        outgoing_resp_send.send(Response::PreProcMsg(ret)).unwrap();
+                    }
+                    _ => panic!("Received something unexpected in client loop"),
                 }
             }
         });
         let pub_params = match outgoing_resp.recv() {
             Ok(Response::PubParams(x)) => x,
-            _ => panic!("Received something unexpected"),
+            _ => panic!("Received something unexpected in client new"),
         };
 
         (
@@ -62,6 +104,17 @@ impl Client {
             pub_params,
         )
     }
+
+    pub fn preproc_PIRs(&self, num_preproc: usize) -> Vec<u8> {
+        self.incoming_cmd
+            .send(Command::PreProc(num_preproc))
+            .unwrap();
+        let ret = match self.outgoing_resp.recv() {
+            Ok(Response::PreProcMsg(x)) => x,
+            _ => panic!("Received something unexpected in preproc_PIRs"),
+        };
+        ret
+    }
 }
 
 #[repr(C)]
@@ -73,15 +126,9 @@ pub struct ClientNewRet {
 #[no_mangle]
 pub extern "C" fn spir_client_new(r: u8) -> ClientNewRet {
     let (client, pub_params) = Client::new(r as usize);
-    let vecdata = VecData {
-        data: pub_params.as_ptr(),
-        len: pub_params.len(),
-        cap: pub_params.capacity(),
-    };
-    std::mem::forget(pub_params);
     ClientNewRet {
         client: Box::into_raw(Box::new(client)),
-        pub_params: vecdata,
+        pub_params: to_vecdata(pub_params),
     }
 }
 
@@ -94,3 +141,13 @@ pub extern "C" fn spir_client_free(client: *mut Client) {
         Box::from_raw(client);
     }
 }
+
+#[no_mangle]
+pub extern "C" fn spir_client_preproc_PIRs(clientptr: *mut Client, num_preproc: u32) -> VecData {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let retvec = client.preproc_PIRs(num_preproc as usize);
+    to_vecdata(retvec)
+}

+ 20 - 1
src/lib.rs

@@ -18,6 +18,8 @@ use std::time::Instant;
 
 use rand::RngCore;
 
+use serde::{Deserialize, Serialize};
+
 use std::os::raw::c_uchar;
 
 use rayon::scope;
@@ -27,8 +29,8 @@ use spiral_rs::client::*;
 use spiral_rs::params::*;
 use spiral_rs::server::*;
 
-use crate::spiral_mt::*;
 use crate::ot::{otkey_init, xor16};
+use crate::spiral_mt::*;
 
 pub type DbEntry = u64;
 
@@ -139,6 +141,13 @@ pub fn print_params_summary(params: &Params) {
     );
 }
 
+// The message format for a single preprocess query
+#[derive(Serialize, Deserialize)]
+struct PreProcSingleMsg {
+    ot_query: Vec<[u8; 32]>,
+    spc_query: Vec<u8>,
+}
+
 #[no_mangle]
 pub extern "C" fn spir_init(num_threads: u32) {
     init(num_threads as usize);
@@ -158,6 +167,16 @@ pub struct VecMutData {
     cap: usize,
 }
 
+pub fn to_vecdata(v: Vec<u8>) -> VecData {
+    let vecdata = VecData {
+        data: v.as_ptr(),
+        len: v.len(),
+        cap: v.capacity(),
+    };
+    std::mem::forget(v);
+    vecdata
+}
+
 #[no_mangle]
 pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
     unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };

+ 11 - 1
src/main.rs

@@ -41,11 +41,21 @@ fn main() {
 
     init(num_threads);
     let (client, pub_params) = Client::new(r);
-    println!("pub_params len = {}", pub_params.len());
+    let pub_params_len = pub_params.len();
     let server = Server::new(r, pub_params);
 
     let otsetup_us = otsetup_start.elapsed().as_micros();
     println!("One-time setup: {} µs", otsetup_us);
+    println!("pub_params len = {}", pub_params_len);
+
+    println!("\n===== PREPROCESSING =====\n");
+
+    let preproc_client_start = Instant::now();
+    let preproc_msg = client.preproc_PIRs(num_preproc);
+    let preproc_client_us = preproc_client_start.elapsed().as_micros();
+
+    println!("Preprocessing client: {} µs", preproc_client_us);
+    println!("preproc_msg len = {}", preproc_msg.len());
 
     /*
         let spiral_params = params::get_spiral_params(r);

+ 0 - 1
src/ot.rs

@@ -129,4 +129,3 @@ pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) ->
     }
     key
 }
-