Browse Source

client query API

Ian Goldberg 1 year ago
parent
commit
99ac4e28e7
6 changed files with 92 additions and 7 deletions
  1. 8 0
      cxx/spir.cpp
  2. 2 0
      cxx/spir_ffi.h
  3. 24 1
      cxx/spir_test.cpp
  4. 36 0
      src/client.rs
  5. 18 6
      src/main.rs
  6. 4 0
      src/ot.rs

+ 8 - 0
cxx/spir.cpp

@@ -34,6 +34,14 @@ void SPIR_Client::preproc_finish(const string &server_preproc)
         server_preproc.length());
 }
 
+string SPIR_Client::query(size_t idx)
+{
+    VecData msg = spir_client_query(this->client, idx);
+    string ret(msg.data, msg.len);
+    spir_vecdata_free(msg);
+    return ret;
+}
+
 SPIR_Server::SPIR_Server(uint8_t r, const string &pub_params)
 {
     this->server = spir_server_new(r, pub_params.data(),

+ 2 - 0
cxx/spir_ffi.h

@@ -29,6 +29,8 @@ extern VecData spir_client_preproc(void *client, uint32_t num_preproc);
 extern void spir_client_preproc_finish(void *client,
     const char *msgdata, size_t msglen);
 
+extern VecData spir_client_query(void *client, size_t idx);
+
 extern void* spir_server_new(uint8_t r, const char *pub_params,
     size_t pub_params_len);
 

+ 24 - 1
cxx/spir_test.cpp

@@ -1,5 +1,6 @@
 #include <iostream>
 #include <stdlib.h>
+#include <sys/random.h>
 #include <sys/time.h>
 #include <unistd.h>
 #include "spir.hpp"
@@ -23,6 +24,8 @@ int main(int argc, char **argv)
     }
     uint32_t r, num_threads = 1, num_preproc = 1, num_pirs = 1;
     r = strtoul(argv[1], NULL, 10);
+    size_t num_records = ((size_t) 1)<<r;
+    size_t num_records_mask = num_records - 1;
     if (argc > 2) {
         num_threads = strtoul(argv[2], NULL, 10);
     }
@@ -67,7 +70,7 @@ int main(int argc, char **argv)
     string preproc_resp = server.preproc_process(preproc_msg);
     size_t preproc_server_us = elapsed_us(&preproc_server_start);
     cout << "Preprocessing server: " << preproc_server_us << " µs\n";
-    cout << "preproc_response len = " << preproc_resp.length() << "\n";
+    cout << "preproc_resp len = " << preproc_resp.length() << "\n";
 
     struct timeval preproc_finish_start;
     gettimeofday(&preproc_finish_start, NULL);
@@ -76,5 +79,25 @@ int main(int argc, char **argv)
     size_t preproc_finish_us = elapsed_us(&preproc_finish_start);
     cout << "Preprocessing client finish: " << preproc_finish_us << " µs\n";
 
+    for (size_t i=0; i<num_pirs; ++i) {
+        cout << "\n===== SPIR QUERY " << i+1 << " =====\n\n";
+
+        size_t idx;
+        if (getrandom(&idx, sizeof(idx), 0) != sizeof(idx)) {
+            cerr << "Failure in getrandom\n";
+            exit(1);
+        }
+        idx &= num_records_mask;
+
+        struct timeval query_client_start;
+        gettimeofday(&query_client_start, NULL);
+
+        string query_msg = client.query(idx);
+        size_t query_client_us = elapsed_us(&query_client_start);
+        cout << "Preprocessing client: " << query_client_us << " µs\n";
+        cout << "query_msg len = " << query_msg.length() << "\n";
+
+    }
+
     return 0;
 }

+ 36 - 0
src/client.rs

@@ -27,12 +27,14 @@ use crate::VecData;
 enum Command {
     PreProc(usize),
     PreProcFinish(Vec<PreProcSingleRespMsg>),
+    Query(usize),
 }
 
 enum Response {
     PubParams(Vec<u8>),
     PreProcMsg(Vec<u8>),
     PreProcDone,
+    QueryMsg(Vec<u8>),
 }
 
 // The internal client state for a single outstanding preprocess query
@@ -79,6 +81,9 @@ impl Client {
             // State for preprocessing queries ready to be used
             let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
 
+            // State for outstanding active queries
+            let mut query_state: VecDeque<PreProcSingleState> = VecDeque::new();
+
             // Wait for commands
             loop {
                 match incoming_cmd_recv.recv() {
@@ -120,6 +125,18 @@ impl Client {
                         preproc_out_state = Vec::new();
                         outgoing_resp_send.send(Response::PreProcDone).unwrap();
                     }
+                    Ok(Command::Query(idx)) => {
+                        // panic if there are no preproc states
+                        // available
+                        let nextstate = preproc_state.pop_front().unwrap();
+                        let offset = (num_records + nextstate.rand_idx - idx) & num_records_mask;
+                        let mut querymsg: Vec<u8> = Vec::new();
+                        querymsg.extend(idx.to_le_bytes());
+                        query_state.push_back(nextstate);
+                        outgoing_resp_send
+                            .send(Response::QueryMsg(querymsg))
+                            .unwrap();
+                    }
                     _ => panic!("Received something unexpected in client loop"),
                 }
             }
@@ -160,6 +177,15 @@ impl Client {
             _ => panic!("Received something unexpected in preproc_finish"),
         }
     }
+
+    pub fn query(&self, idx: usize) -> Vec<u8> {
+        self.incoming_cmd.send(Command::Query(idx)).unwrap();
+        let ret = match self.outgoing_resp.recv() {
+            Ok(Response::QueryMsg(x)) => x,
+            _ => panic!("Received something unexpected in preproc"),
+        };
+        ret
+    }
 }
 
 #[repr(C)]
@@ -213,3 +239,13 @@ pub extern "C" fn spir_client_preproc_finish(
     };
     client.preproc_finish(&msg_slice);
 }
+
+#[no_mangle]
+pub extern "C" fn spir_client_query(clientptr: *mut Client, idx: usize) -> VecData {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let retvec = client.query(idx);
+    to_vecdata(retvec)
+}

+ 18 - 6
src/main.rs

@@ -1,10 +1,8 @@
-// We really want points to be capital letters and scalars to be
-// lowercase letters
-#![allow(non_snake_case)]
-
 use std::env;
 use std::time::Instant;
 
+use rand::RngCore;
+
 use spiral_spir::client::Client;
 use spiral_spir::init;
 use spiral_spir::server::Server;
@@ -21,7 +19,7 @@ fn main() {
     let r: usize = args[1].parse().unwrap();
     let mut num_threads = 1usize;
     let mut num_preproc = 1usize;
-    let mut num_pirs = 1usize;
+    let num_pirs: usize;
     if args.len() > 2 {
         num_threads = args[2].parse().unwrap();
     }
@@ -34,6 +32,7 @@ fn main() {
         num_pirs = num_preproc;
     }
     let num_records = 1 << r;
+    let num_records_mask = num_records - 1;
 
     println!("===== ONE-TIME SETUP =====\n");
 
@@ -64,7 +63,7 @@ fn main() {
     let preproc_server_us = preproc_server_start.elapsed().as_micros();
 
     println!("Preprocessing server: {} µs", preproc_server_us);
-    println!("preproc_response len = {}", preproc_resp.len());
+    println!("preproc_resp len = {}", preproc_resp.len());
 
     let preproc_finish_start = Instant::now();
     client.preproc_finish(&preproc_resp);
@@ -72,6 +71,19 @@ fn main() {
 
     println!("Preprocessing client finish: {} µs", preproc_finish_us);
 
+    let mut rng = rand::thread_rng();
+    for i in 1..num_pirs + 1 {
+        println!("\n===== SPIR QUERY {} =====\n", i);
+
+        let idx = (rng.next_u64() as usize) & num_records_mask;
+        let query_client_start = Instant::now();
+        let query_msg = client.query(idx);
+        let query_client_us = query_client_start.elapsed().as_micros();
+
+        println!("Query client: {} µs", query_client_us);
+        println!("query_msg len = {}", query_msg.len());
+    }
+
     /*
         let spiral_params = params::get_spiral_params(r);
         let mut rng = rand::thread_rng();

+ 4 - 0
src/ot.rs

@@ -1,3 +1,7 @@
+// We really want points to be capital letters and scalars to be
+// lowercase letters
+#![allow(non_snake_case)]
+
 // Oblivious transfer
 
 use subtle::Choice;