Browse Source

spir_client_new API

Ian Goldberg 1 year ago
parent
commit
689d911823
8 changed files with 262 additions and 149 deletions
  1. 2 0
      cxx/Makefile
  2. 8 0
      cxx/spir.cpp
  3. 15 0
      cxx/spir_ffi.h
  4. 26 2
      cxx/spir_test.cpp
  5. 60 21
      src/client.rs
  6. 27 3
      src/lib.rs
  7. 123 121
      src/main.rs
  8. 1 2
      src/server.rs

+ 2 - 0
cxx/Makefile

@@ -1,3 +1,5 @@
+CXXFLAGS = -O3 -Wall
+
 spir_test: spir_test.o libspir_cxx.a
 	g++ -o $@ $^ -lpthread -ldl
 

+ 8 - 0
cxx/spir.cpp

@@ -1,3 +1,4 @@
+#include <stdio.h>
 #include "spir.hpp"
 #include "spir_ffi.h"
 
@@ -5,3 +6,10 @@ void SPIR::init(uint32_t num_threads)
 {
     spir_init(num_threads);
 }
+
+SPIR_Client::SPIR_Client(uint8_t r, string &pub_params)
+{
+    ClientNewRet ret = spir_client_new(r);
+    pub_params.assign(ret.pub_params.data, ret.pub_params.len);
+    spir_vecdata_free(ret.pub_params);
+}

+ 15 - 0
cxx/spir_ffi.h

@@ -7,8 +7,23 @@
 extern "C" {
 #endif
 
+typedef struct {
+    const char *data;
+    size_t len;
+    size_t capacity;
+} VecData;
+
+typedef struct {
+    void *client;
+    VecData pub_params;
+} ClientNewRet;
+
 extern void spir_init(uint32_t num_threads);
 
+extern ClientNewRet spir_client_new(uint8_t r);
+
+extern void spir_vecdata_free(VecData vecdata);
+
 #ifdef __cplusplus
 }
 #endif

+ 26 - 2
cxx/spir_test.cpp

@@ -1,12 +1,23 @@
 #include <iostream>
 #include <stdlib.h>
+#include <sys/time.h>
 #include "spir.hpp"
 
+using std::cout;
+using std::cerr;
+
+static inline size_t elapsed_us(const struct timeval *start)
+{
+    struct timeval end;
+    gettimeofday(&end, NULL);
+    return (end.tv_sec-start->tv_sec)*1000000 + end.tv_usec - start->tv_usec;
+}
+
 int main(int argc, char **argv)
 {
     if (argc < 2 || argc > 4) {
-        std::cerr << "Usage: " << argv[0] << " r [num_threads [num_pirs]]\n";
-        std::cerr << "r = log_2(num_records)\n";
+        cerr << "Usage: " << argv[0] << " r [num_threads [num_pirs]]\n";
+        cerr << "r = log_2(num_records)\n";
         exit(1);
     }
     uint32_t r, num_threads = 1, num_pirs = 1;
@@ -18,7 +29,20 @@ int main(int argc, char **argv)
         num_pirs = strtoul(argv[3], NULL, 10);
     }
 
+    cout << "===== ONE-TIME SETUP =====\n\n";
+
+    struct timeval otsetup_start;
+    gettimeofday(&otsetup_start, NULL);
+
     SPIR::init(num_threads);
+    string pub_params;
+    SPIR_Client client(r, pub_params);
+    printf("%u %u %u %u\n", (unsigned char)pub_params[0], (unsigned
+    char)pub_params[1], (unsigned char)pub_params[2],
+    (unsigned char)pub_params[3]);
+
+    size_t otsetup_us = elapsed_us(&otsetup_start);
+    cout << "OT one-time setup: " << otsetup_us << " µs\n";
 
     return 0;
 }

+ 60 - 21
src/client.rs

@@ -1,43 +1,82 @@
-use rand::RngCore;
 use rand::rngs::ThreadRng;
+use rand::RngCore;
 
-use std::thread::*;
 use std::sync::mpsc::*;
+use std::thread::*;
 
 use crate::params;
+use crate::VecData;
+
+enum Command {
+    PreProc(usize),
+    PreProcHandle(Vec<u8>),
+}
+
+enum Response {
+    PubParams(Vec<u8>),
+    PreProcMsg(Vec<u8>),
+}
 
 pub struct Client {
     r: usize,
     thread_handle: JoinHandle<()>,
-    incoming_query: SyncSender<usize>,
-    incoming_response: SyncSender<Vec<u8>>,
-    outgoing_data: Receiver<Vec<u8>>,
-    pub pub_params: Vec<u8>,
+    incoming_cmd: SyncSender<Command>,
+    outgoing_resp: Receiver<Response>,
 }
 
 impl Client {
-    pub fn new(r: usize) -> Self {
-        let (incoming_query, incoming_query_recv) = sync_channel(0);
-        let (incoming_response, incoming_response_recv) = sync_channel(0);
-        let (outgoing_data_send, outgoing_data) = sync_channel(0);
+    pub fn new(r: usize) -> (Self, Vec<u8>) {
+        let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
+        let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
         let thread_handle = spawn(move || {
             let spiral_params = params::get_spiral_params(r);
             let mut clientrng = rand::thread_rng();
-            let mut spiral_client =
-                spiral_rs::client::Client::init(&spiral_params, &mut clientrng);
+            let mut spiral_client = spiral_rs::client::Client::init(&spiral_params, &mut clientrng);
 
             // The first communication is the pub_params
             let pub_params = spiral_client.generate_keys().serialize();
-            outgoing_data_send.send(pub_params).unwrap();
+            println!(
+                "{} {} {} {}",
+                pub_params[0], pub_params[1], pub_params[2], pub_params[3]
+            );
+            outgoing_resp_send
+                .send(Response::PubParams(pub_params))
+                .unwrap();
         });
-        let pub_params = outgoing_data.recv().unwrap();
-        Client {
-            r,
-            thread_handle,
-            incoming_query,
-            incoming_response,
-            outgoing_data,
+        let pub_params = match outgoing_resp.recv() {
+            Ok(Response::PubParams(x)) => x,
+            _ => panic!("Received something unexpected"),
+        };
+
+        (
+            Client {
+                r,
+                thread_handle,
+                incoming_cmd,
+                outgoing_resp,
+            },
             pub_params,
-        }
+        )
+    }
+}
+
+#[repr(C)]
+pub struct ClientNewRet {
+    client: *mut Client,
+    pub_params: VecData,
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_new(r: u8) -> ClientNewRet {
+    let (client, pub_params) = Client::new(r as usize);
+    let vecdata = VecData {
+        data: pub_params.as_ptr(),
+        len: pub_params.len(),
+        cap: pub_params.capacity(),
+    };
+    std::mem::forget(pub_params);
+    ClientNewRet {
+        client: Box::into_raw(Box::new(client)),
+        pub_params: vecdata,
     }
 }

+ 27 - 3
src/lib.rs

@@ -3,10 +3,10 @@
 #![allow(non_snake_case)]
 
 mod aligned_memory_mt;
-mod params;
-mod spiral_mt;
 pub mod client;
+mod params;
 pub mod server;
+mod spiral_mt;
 
 use aes::cipher::{BlockEncrypt, KeyInit};
 use aes::Aes128Enc;
@@ -23,6 +23,8 @@ use sha2::Digest;
 use sha2::Sha256;
 use sha2::Sha512;
 
+use std::os::raw::c_uchar;
+
 use curve25519_dalek::constants as dalek_constants;
 use curve25519_dalek::ristretto::CompressedRistretto;
 use curve25519_dalek::ristretto::RistrettoBasepointTable;
@@ -241,7 +243,10 @@ pub fn init(num_threads: usize) {
     let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
 
     // Initialize the thread pool
-    ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
+    ThreadPoolBuilder::new()
+        .num_threads(num_threads)
+        .build_global()
+        .unwrap();
 }
 
 pub fn print_params_summary(params: &Params) {
@@ -259,3 +264,22 @@ pub fn print_params_summary(params: &Params) {
 pub extern "C" fn spir_init(num_threads: u32) {
     init(num_threads as usize);
 }
+
+#[repr(C)]
+pub struct VecData {
+    data: *const c_uchar,
+    len: usize,
+    cap: usize,
+}
+
+#[repr(C)]
+pub struct VecMutData {
+    data: *mut c_uchar,
+    len: usize,
+    cap: usize,
+}
+
+#[no_mangle]
+pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
+    unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };
+}

+ 123 - 121
src/main.rs

@@ -5,15 +5,16 @@
 use std::env;
 use std::time::Instant;
 
-use spiral_rs::client::*;
-use spiral_rs::server::*;
-
+use spiral_spir::client::Client;
 use spiral_spir::*;
 
 fn main() {
     let args: Vec<String> = env::args().collect();
     if args.len() < 2 || args.len() > 4 {
-        println!("Usage: {} r [num_threads [num_pirs]]\nr = log_2(num_records)", args[0]);
+        println!(
+            "Usage: {} r [num_threads [num_pirs]]\nr = log_2(num_records)",
+            args[0]
+        );
         return;
     }
     let r: usize = args[1].parse().unwrap();
@@ -31,124 +32,125 @@ fn main() {
 
     let otsetup_start = Instant::now();
     init(num_threads);
+    let (client, pub_params) = Client::new(r);
+    println!("pub_params len = {}", pub_params.len());
+
     let otsetup_us = otsetup_start.elapsed().as_micros();
     println!("OT one-time setup: {} µs", otsetup_us);
-/*
-    let otsetup_start = Instant::now();
-    let spiral_params = params::get_spiral_params(r);
-    let mut rng = rand::thread_rng();
-    init(num_threads);
-    let otsetup_us = otsetup_start.elapsed().as_micros();
-    print_params_summary(&spiral_params);
-    println!("OT one-time setup: {} µs", otsetup_us);
 
-    // One-time setup for the Spiral client
-    let spc_otsetup_start = Instant::now();
-    let mut clientrng = rand::thread_rng();
-    let mut client = Client::init(&spiral_params, &mut clientrng);
-    let pub_params = client.generate_keys();
-    let pub_params_buf = pub_params.serialize();
-    let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
-    let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
-    println!(
-        "Spiral client one-time setup: {} µs, {} bytes",
-        spc_otsetup_us,
-        pub_params_buf.len()
-    );
-
-    println!("\n===== PREPROCESSING =====\n");
-
-    // Spiral preprocessing: create a PIR lookup for an element at a
-    // random location
-    let spc_query_start = Instant::now();
-    let rand_idx = (rng.next_u64() as usize) % num_records;
-    let rand_pir_idx = rand_idx / spiral_blocking_factor;
-    println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
-    let spc_query = client.generate_query(rand_pir_idx);
-    let spc_query_buf = spc_query.serialize();
-    let spc_query_us = spc_query_start.elapsed().as_micros();
-    println!(
-        "Spiral query: {} µs, {} bytes",
-        spc_query_us,
-        spc_query_buf.len()
-    );
-
-    // Create the database encryption keys and do the OT to fetch the
-    // right one, but don't actually encrypt the database yet
-    let dbkeys = gen_db_enc_keys(r);
-    let otkeyreq_start = Instant::now();
-    let (keystate, keyquery) = otkey_request(rand_idx, r);
-    let keyquerysize = keyquery.len() * keyquery[0].len();
-    let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
-    let otkeysrv_start = Instant::now();
-    let keyresponse = otkey_serve(keyquery, &dbkeys);
-    let keyrespsize = keyresponse.len() * keyresponse[0].len();
-    let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
-    let otkeyrcv_start = Instant::now();
-    let otkey = otkey_receive(keystate, &keyresponse);
-    let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
-    println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
-    println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
-    println!("key OT receive in {} µs", otkeyrcv_us);
-
-    // Create a database with recognizable contents
-    let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
-        .map(|x| 10000001 * x)
-        .collect();
-
-    println!("\n===== RUNTIME =====\n");
-
-    // Pick the record we actually want to query
-    let q = (rng.next_u64() as usize) % num_records;
-
-    // Compute the offset from the record index we're actually looking
-    // for to the random one we picked earlier.  Tell it to the server,
-    // who will rotate right the database by that amount before
-    // encrypting it.
-    let idx_offset = (num_records + rand_idx - q) % num_records;
-
-    println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
-
-    // The server rotates, blinds, and encrypts the database
-    let blind: DbEntry = 20;
-    let encdb_start = Instant::now();
-    let encdb = encdb_xor_keys(&db, &dbkeys, r, idx_offset, blind, num_threads);
-    let encdb_us = encdb_start.elapsed().as_micros();
-    println!("Server encrypt database {} µs", encdb_us);
-
-    // Load the encrypted database into Spiral
-    let sps_loaddb_start = Instant::now();
-    let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
-    let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
-    println!("Server load database {} µs", sps_loaddb_us);
-
-    // Do the PIR query
-    let sps_query_start = Instant::now();
-    let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
-    let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
-    let sps_query_us = sps_query_start.elapsed().as_micros();
-    println!(
-        "Server compute response {} µs, {} bytes (*including* the above expansion time)",
-        sps_query_us,
-        sps_response.len()
-    );
-
-    // Decode the response to yield the whole Spiral block
-    let spc_recv_start = Instant::now();
-    let encdbblock = client.decode_response(sps_response.as_slice());
-    // Extract the one encrypted DbEntry we were looking for (and the
-    // only one we are able to decrypt)
-    let entry_in_block = rand_idx % spiral_blocking_factor;
-    let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
-    let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
-    let encdbentry = DbEntry::from_le_bytes(
-        encdbblock[loc_in_block..loc_in_block_end]
-            .try_into()
-            .unwrap(),
-    );
-    let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
-    let spc_recv_us = spc_recv_start.elapsed().as_micros();
-    println!("Client decode response {} µs", spc_recv_us);
-    println!("index = {}, Response = {}", q, decdbentry);
-*/
+    /*
+        let spiral_params = params::get_spiral_params(r);
+        let mut rng = rand::thread_rng();
+        print_params_summary(&spiral_params);
+        println!("OT one-time setup: {} µs", otsetup_us);
+
+        // One-time setup for the Spiral client
+        let spc_otsetup_start = Instant::now();
+        let mut clientrng = rand::thread_rng();
+        let mut client = Client::init(&spiral_params, &mut clientrng);
+        let pub_params = client.generate_keys();
+        let pub_params_buf = pub_params.serialize();
+        let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
+        let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
+        println!(
+            "Spiral client one-time setup: {} µs, {} bytes",
+            spc_otsetup_us,
+            pub_params_buf.len()
+        );
+
+        println!("\n===== PREPROCESSING =====\n");
+
+        // Spiral preprocessing: create a PIR lookup for an element at a
+        // random location
+        let spc_query_start = Instant::now();
+        let rand_idx = (rng.next_u64() as usize) % num_records;
+        let rand_pir_idx = rand_idx / spiral_blocking_factor;
+        println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
+        let spc_query = client.generate_query(rand_pir_idx);
+        let spc_query_buf = spc_query.serialize();
+        let spc_query_us = spc_query_start.elapsed().as_micros();
+        println!(
+            "Spiral query: {} µs, {} bytes",
+            spc_query_us,
+            spc_query_buf.len()
+        );
+
+        // Create the database encryption keys and do the OT to fetch the
+        // right one, but don't actually encrypt the database yet
+        let dbkeys = gen_db_enc_keys(r);
+        let otkeyreq_start = Instant::now();
+        let (keystate, keyquery) = otkey_request(rand_idx, r);
+        let keyquerysize = keyquery.len() * keyquery[0].len();
+        let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
+        let otkeysrv_start = Instant::now();
+        let keyresponse = otkey_serve(keyquery, &dbkeys);
+        let keyrespsize = keyresponse.len() * keyresponse[0].len();
+        let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
+        let otkeyrcv_start = Instant::now();
+        let otkey = otkey_receive(keystate, &keyresponse);
+        let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
+        println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
+        println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
+        println!("key OT receive in {} µs", otkeyrcv_us);
+
+        // Create a database with recognizable contents
+        let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
+            .map(|x| 10000001 * x)
+            .collect();
+
+        println!("\n===== RUNTIME =====\n");
+
+        // Pick the record we actually want to query
+        let q = (rng.next_u64() as usize) % num_records;
+
+        // Compute the offset from the record index we're actually looking
+        // for to the random one we picked earlier.  Tell it to the server,
+        // who will rotate right the database by that amount before
+        // encrypting it.
+        let idx_offset = (num_records + rand_idx - q) % num_records;
+
+        println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
+
+        // The server rotates, blinds, and encrypts the database
+        let blind: DbEntry = 20;
+        let encdb_start = Instant::now();
+        let encdb = encdb_xor_keys(&db, &dbkeys, r, idx_offset, blind, num_threads);
+        let encdb_us = encdb_start.elapsed().as_micros();
+        println!("Server encrypt database {} µs", encdb_us);
+
+        // Load the encrypted database into Spiral
+        let sps_loaddb_start = Instant::now();
+        let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
+        let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
+        println!("Server load database {} µs", sps_loaddb_us);
+
+        // Do the PIR query
+        let sps_query_start = Instant::now();
+        let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
+        let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
+        let sps_query_us = sps_query_start.elapsed().as_micros();
+        println!(
+            "Server compute response {} µs, {} bytes (*including* the above expansion time)",
+            sps_query_us,
+            sps_response.len()
+        );
+
+        // Decode the response to yield the whole Spiral block
+        let spc_recv_start = Instant::now();
+        let encdbblock = client.decode_response(sps_response.as_slice());
+        // Extract the one encrypted DbEntry we were looking for (and the
+        // only one we are able to decrypt)
+        let entry_in_block = rand_idx % spiral_blocking_factor;
+        let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
+        let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
+        let encdbentry = DbEntry::from_le_bytes(
+            encdbblock[loc_in_block..loc_in_block_end]
+                .try_into()
+                .unwrap(),
+        );
+        let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
+        let spc_recv_us = spc_recv_start.elapsed().as_micros();
+        println!("Client decode response {} µs", spc_recv_us);
+        println!("index = {}, Response = {}", q, decdbentry);
+    */
 }

+ 1 - 2
src/server.rs

@@ -1,2 +1 @@
-pub struct Server {
-}
+pub struct Server {}