Browse Source

spir_client_free, spir_server_new, spir_server_free APIs

Ian Goldberg 1 year ago
parent
commit
801a3203b6
7 changed files with 163 additions and 21 deletions
  1. 17 0
      cxx/spir.cpp
  2. 16 2
      cxx/spir.hpp
  3. 7 0
      cxx/spir_ffi.h
  4. 13 8
      cxx/spir_test.cpp
  5. 21 5
      src/client.rs
  6. 14 5
      src/main.rs
  7. 75 1
      src/server.rs

+ 17 - 0
cxx/spir.cpp

@@ -12,4 +12,21 @@ SPIR_Client::SPIR_Client(uint8_t r, string &pub_params)
     ClientNewRet ret = spir_client_new(r);
     pub_params.assign(ret.pub_params.data, ret.pub_params.len);
     spir_vecdata_free(ret.pub_params);
+    this->client = ret.client;
+}
+
+SPIR_Client::~SPIR_Client()
+{
+    spir_client_free(this->client);
+}
+
+SPIR_Server::SPIR_Server(uint8_t r, const string &pub_params)
+{
+    this->server = spir_server_new(r, pub_params.data(),
+        pub_params.length());
+}
+
+SPIR_Server::~SPIR_Server()
+{
+    spir_server_free(this->server);
 }

+ 16 - 2
cxx/spir.hpp

@@ -15,8 +15,9 @@ public:
 
 class SPIR_Client {
 public:
-    // constructor
+    // constructor and destructor
     SPIR_Client(uint8_t r, string &pub_params); // 2^r records in the database; pub_params will be _filled in_
+    ~SPIR_Client();
 
     // preprocessing
     string preproc_PIRs(uint32_t num_pirs); // returns the string to send to the server
@@ -30,12 +31,19 @@ public:
     // where N=2^r, idx is provided by the client above, and
     // db, rot, and blind are provided by the server below
     SPIR::DBEntry process_reply(const string &server_reply);
+
+private:
+    void *client;
+    SPIR_Client() = default;
+    SPIR_Client(const SPIR_Client&) = delete;
+    SPIR_Client& operator=(const SPIR_Client&) = delete;
 };
 
 class SPIR_Server {
 public:
-    // constructor
+    // constructor and destructor
     SPIR_Server(uint8_t r, const string &client_pub_params);
+    ~SPIR_Server();
 
     // preprocessing
     string preproc_PIR(const string &client_preproc); // returns the string to reply to the client
@@ -45,6 +53,12 @@ public:
     // returns the string to reply to the client
     string process_query(const string &client_query, const SPIR::DBEntry *db,
         size_t rot, SPIR::DBEntry blind);
+
+private:
+    void *server;
+    SPIR_Server() = default;
+    SPIR_Server(const SPIR_Server&) = delete;
+    SPIR_Server& operator=(const SPIR_Server&) = delete;
 };
 
 #endif

+ 7 - 0
cxx/spir_ffi.h

@@ -22,6 +22,13 @@ extern void spir_init(uint32_t num_threads);
 
 extern ClientNewRet spir_client_new(uint8_t r);
 
+extern void spir_client_free(void *client);
+
+extern void* spir_server_new(uint8_t r, const char *pub_params,
+    size_t pub_params_len);
+
+extern void spir_server_free(void *server);
+
 extern void spir_vecdata_free(VecData vecdata);
 
 #ifdef __cplusplus

+ 13 - 8
cxx/spir_test.cpp

@@ -1,6 +1,7 @@
 #include <iostream>
 #include <stdlib.h>
 #include <sys/time.h>
+#include <unistd.h>
 #include "spir.hpp"
 
 using std::cout;
@@ -15,18 +16,23 @@ static inline size_t elapsed_us(const struct timeval *start)
 
 int main(int argc, char **argv)
 {
-    if (argc < 2 || argc > 4) {
-        cerr << "Usage: " << argv[0] << " r [num_threads [num_pirs]]\n";
+    if (argc < 2 || argc > 5) {
+        cerr << "Usage: " << argv[0] << " r [num_threads [num_preproc [num_pirs]]]\n";
         cerr << "r = log_2(num_records)\n";
         exit(1);
     }
-    uint32_t r, num_threads = 1, num_pirs = 1;
+    uint32_t r, num_threads = 1, num_preproc = 1, num_pirs = 1;
     r = strtoul(argv[1], NULL, 10);
     if (argc > 2) {
         num_threads = strtoul(argv[2], NULL, 10);
     }
     if (argc > 3) {
-        num_pirs = strtoul(argv[3], NULL, 10);
+        num_preproc = strtoul(argv[3], NULL, 10);
+    }
+    if (argc > 4) {
+        num_pirs = strtoul(argv[4], NULL, 10);
+    } else {
+        num_pirs = num_preproc;
     }
 
     cout << "===== ONE-TIME SETUP =====\n\n";
@@ -37,12 +43,11 @@ int main(int argc, char **argv)
     SPIR::init(num_threads);
     string pub_params;
     SPIR_Client client(r, pub_params);
-    printf("%u %u %u %u\n", (unsigned char)pub_params[0], (unsigned
-    char)pub_params[1], (unsigned char)pub_params[2],
-    (unsigned char)pub_params[3]);
+    cout << "pub_params len = " << pub_params.length() << "\n";
+    SPIR_Server server(r, pub_params);
 
     size_t otsetup_us = elapsed_us(&otsetup_start);
-    cout << "OT one-time setup: " << otsetup_us << " µs\n";
+    cout << "One-time setup: " << otsetup_us << " µs\n";
 
     return 0;
 }

+ 21 - 5
src/client.rs

@@ -9,7 +9,7 @@ use crate::VecData;
 
 enum Command {
     PreProc(usize),
-    PreProcHandle(Vec<u8>),
+    PreProcResp(Vec<u8>),
 }
 
 enum Response {
@@ -35,13 +35,19 @@ impl Client {
 
             // The first communication is the pub_params
             let pub_params = spiral_client.generate_keys().serialize();
-            println!(
-                "{} {} {} {}",
-                pub_params[0], pub_params[1], pub_params[2], pub_params[3]
-            );
             outgoing_resp_send
                 .send(Response::PubParams(pub_params))
                 .unwrap();
+
+            // Wait for commands
+            loop {
+                println!("Client waiting");
+                match incoming_cmd_recv.recv() {
+                    Err(_) => break,
+                    _ => panic!("Received something unexpected"),
+                }
+            }
+            println!("Client ending");
         });
         let pub_params = match outgoing_resp.recv() {
             Ok(Response::PubParams(x)) => x,
@@ -80,3 +86,13 @@ pub extern "C" fn spir_client_new(r: u8) -> ClientNewRet {
         pub_params: vecdata,
     }
 }
+
+#[no_mangle]
+pub extern "C" fn spir_client_free(client: *mut Client) {
+    if client.is_null() {
+        return;
+    }
+    unsafe {
+        Box::from_raw(client);
+    }
+}

+ 14 - 5
src/main.rs

@@ -6,37 +6,46 @@ use std::env;
 use std::time::Instant;
 
 use spiral_spir::client::Client;
-use spiral_spir::*;
+use spiral_spir::init;
+use spiral_spir::server::Server;
 
 fn main() {
     let args: Vec<String> = env::args().collect();
-    if args.len() < 2 || args.len() > 4 {
+    if args.len() < 2 || args.len() > 5 {
         println!(
-            "Usage: {} r [num_threads [num_pirs]]\nr = log_2(num_records)",
+            "Usage: {} r [num_threads [num_preproc [num_pirs]]]\nr = log_2(num_records)",
             args[0]
         );
         return;
     }
     let r: usize = args[1].parse().unwrap();
     let mut num_threads = 1usize;
+    let mut num_preproc = 1usize;
     let mut num_pirs = 1usize;
     if args.len() > 2 {
         num_threads = args[2].parse().unwrap();
     }
     if args.len() > 3 {
-        num_pirs = args[3].parse().unwrap();
+        num_preproc = args[3].parse().unwrap();
+    }
+    if args.len() > 4 {
+        num_pirs = args[4].parse().unwrap();
+    } else {
+        num_pirs = num_preproc;
     }
     let num_records = 1 << r;
 
     println!("===== ONE-TIME SETUP =====\n");
 
     let otsetup_start = Instant::now();
+
     init(num_threads);
     let (client, pub_params) = Client::new(r);
     println!("pub_params len = {}", pub_params.len());
+    let server = Server::new(r, pub_params);
 
     let otsetup_us = otsetup_start.elapsed().as_micros();
-    println!("OT one-time setup: {} µs", otsetup_us);
+    println!("One-time setup: {} µs", otsetup_us);
 
     /*
         let spiral_params = params::get_spiral_params(r);

+ 75 - 1
src/server.rs

@@ -1 +1,75 @@
-pub struct Server {}
+use std::os::raw::c_uchar;
+use std::sync::mpsc::*;
+use std::thread::*;
+
+use spiral_rs::client::PublicParameters;
+use spiral_rs::params::Params;
+
+use crate::params;
+
+enum Command {
+    PubParams(Vec<u8>),
+    PreProcMsg(Vec<u8>),
+}
+
+enum Response {
+    PreProcResp(Vec<u8>),
+}
+
+pub struct Server {
+    r: usize,
+    thread_handle: JoinHandle<()>,
+    incoming_cmd: SyncSender<Command>,
+    outgoing_resp: Receiver<Response>,
+}
+
+impl Server {
+    pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
+        let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
+        let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
+        let thread_handle = spawn(move || {
+            let spiral_params = params::get_spiral_params(r);
+            let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
+            loop {
+                println!("Waiting");
+                match incoming_cmd_recv.recv() {
+                    Err(_) => break,
+                    _ => panic!("Received something unexpected"),
+                }
+            }
+            println!("Ending");
+        });
+        Server {
+            r,
+            thread_handle,
+            incoming_cmd,
+            outgoing_resp,
+        }
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn spir_server_new(
+    r: u8,
+    pub_params: *const c_uchar,
+    pub_params_len: usize,
+) -> *mut Server {
+    let pub_params_slice = unsafe {
+        assert!(!pub_params.is_null());
+        std::slice::from_raw_parts(pub_params, pub_params_len)
+    };
+    let mut pub_params_vec: Vec<u8> = Vec::new();
+    pub_params_vec.extend_from_slice(pub_params_slice);
+    let server = Server::new(r as usize, pub_params_vec);
+    Box::into_raw(Box::new(server))
+}
+
+#[no_mangle]
+pub extern "C" fn spir_server_free(server: *mut Server) {
+    if server.is_null() {
+        return;
+    }
+    unsafe {
+        Box::from_raw(server);
+    }
+}