Browse Source

Start converting the crate to a library crate with a C++ interface

Ian Goldberg 1 year ago
parent
commit
eb53841429
10 changed files with 440 additions and 244 deletions
  1. 8 0
      Cargo.toml
  2. 12 0
      cxx/Makefile
  3. 7 0
      cxx/spir.cpp
  4. 50 0
      cxx/spir.hpp
  5. 16 0
      cxx/spir_ffi.h
  6. 24 0
      cxx/spir_test.cpp
  7. 43 0
      src/client.rs
  8. 261 0
      src/lib.rs
  9. 17 244
      src/main.rs
  10. 2 0
      src/server.rs

+ 8 - 0
Cargo.toml

@@ -19,6 +19,14 @@ subtle = { package = "subtle-ng", version = "2.4" }
 spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157" }
 rayon = "1.5"
 
+[lib]
+crate_type = ["lib", "staticlib"]
+path = "src/lib.rs"
+
+[[bin]]
+name = "spiral-spir"
+path = "src/main.rs"
+
 [features]
 default = ["u64_backend"]
 u32_backend = ["curve25519-dalek/u32_backend"]

+ 12 - 0
cxx/Makefile

@@ -0,0 +1,12 @@
+spir_test: spir_test.o libspir_cxx.a
+	g++ -o $@ $^ -lpthread -ldl
+
+libspir_cxx.a: spir.o ../target/release/libspiral_spir.a
+	cp ../target/release/libspiral_spir.a $@
+	ar r $@ $<
+
+../target/release/libspiral_spir.a: $(wildcard ../src/*.rs)
+	cargo build --release
+
+clean:
+	-rm -f libspir_cxx.a spir.o spir_test.o spir_test

+ 7 - 0
cxx/spir.cpp

@@ -0,0 +1,7 @@
+#include "spir.hpp"
+#include "spir_ffi.h"
+
+void SPIR::init(uint32_t num_threads)
+{
+    spir_init(num_threads);
+}

+ 50 - 0
cxx/spir.hpp

@@ -0,0 +1,50 @@
+#ifndef __SPIR_HPP__
+#define __SPIR_HPP__
+
+#include <string>
+#include <stdint.h>
+
+using std::string;
+
+class SPIR {
+public:
+    typedef uint64_t DBEntry;  // The type of each DB entry (64 bits)
+
+    static void init(uint32_t nthreads);  // Call this once at startup
+};
+
+class SPIR_Client {
+public:
+    // constructor
+    SPIR_Client(uint8_t r, string &pub_params); // 2^r records in the database; pub_params will be _filled in_
+
+    // preprocessing
+    string preproc_PIRs(uint32_t num_pirs); // returns the string to send to the server
+
+    void preproc_handle(const string &server_preproc);
+
+    // SPIR query for index idx
+    string query(size_t idx); // returns the string to send to the server
+
+    // process the server's response to yield the server's db[(idx + rot)%N] + blind
+    // where N=2^r, idx is provided by the client above, and
+    // db, rot, and blind are provided by the server below
+    SPIR::DBEntry process_reply(const string &server_reply);
+};
+
+class SPIR_Server {
+public:
+    // constructor
+    SPIR_Server(uint8_t r, const string &client_pub_params);
+
+    // preprocessing
+    string preproc_PIR(const string &client_preproc); // returns the string to reply to the client
+  
+    // SPIR query on the given database of N=2^r records, each of type DBEntry
+    // rotate the database by rot, and blind each entry in the database additively with blind
+    // returns the string to reply to the client
+    string process_query(const string &client_query, const SPIR::DBEntry *db,
+        size_t rot, SPIR::DBEntry blind);
+};
+
+#endif

+ 16 - 0
cxx/spir_ffi.h

@@ -0,0 +1,16 @@
+#ifndef __SPIR_FFI_H__
+#define __SPIR_FFI_H__
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern void spir_init(uint32_t num_threads);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif

+ 24 - 0
cxx/spir_test.cpp

@@ -0,0 +1,24 @@
+#include <iostream>
+#include <stdlib.h>
+#include "spir.hpp"
+
+int main(int argc, char **argv)
+{
+    if (argc < 2 || argc > 4) {
+        std::cerr << "Usage: " << argv[0] << " r [num_threads [num_pirs]]\n";
+        std::cerr << "r = log_2(num_records)\n";
+        exit(1);
+    }
+    uint32_t r, num_threads = 1, num_pirs = 1;
+    r = strtoul(argv[1], NULL, 10);
+    if (argc > 2) {
+        num_threads = strtoul(argv[2], NULL, 10);
+    }
+    if (argc > 3) {
+        num_pirs = strtoul(argv[3], NULL, 10);
+    }
+
+    SPIR::init(num_threads);
+
+    return 0;
+}

+ 43 - 0
src/client.rs

@@ -0,0 +1,43 @@
+use rand::RngCore;
+use rand::rngs::ThreadRng;
+
+use std::thread::*;
+use std::sync::mpsc::*;
+
+use crate::params;
+
+pub struct Client {
+    r: usize,
+    thread_handle: JoinHandle<()>,
+    incoming_query: SyncSender<usize>,
+    incoming_response: SyncSender<Vec<u8>>,
+    outgoing_data: Receiver<Vec<u8>>,
+    pub pub_params: Vec<u8>,
+}
+
+impl Client {
+    pub fn new(r: usize) -> Self {
+        let (incoming_query, incoming_query_recv) = sync_channel(0);
+        let (incoming_response, incoming_response_recv) = sync_channel(0);
+        let (outgoing_data_send, outgoing_data) = sync_channel(0);
+        let thread_handle = spawn(move || {
+            let spiral_params = params::get_spiral_params(r);
+            let mut clientrng = rand::thread_rng();
+            let mut spiral_client =
+                spiral_rs::client::Client::init(&spiral_params, &mut clientrng);
+
+            // The first communication is the pub_params
+            let pub_params = spiral_client.generate_keys().serialize();
+            outgoing_data_send.send(pub_params).unwrap();
+        });
+        let pub_params = outgoing_data.recv().unwrap();
+        Client {
+            r,
+            thread_handle,
+            incoming_query,
+            incoming_response,
+            outgoing_data,
+            pub_params,
+        }
+    }
+}

+ 261 - 0
src/lib.rs

@@ -0,0 +1,261 @@
+// We really want points to be capital letters and scalars to be
+// lowercase letters
+#![allow(non_snake_case)]
+
+mod aligned_memory_mt;
+mod params;
+mod spiral_mt;
+pub mod client;
+pub mod server;
+
+use aes::cipher::{BlockEncrypt, KeyInit};
+use aes::Aes128Enc;
+use aes::Block;
+use std::env;
+use std::mem;
+use std::time::Instant;
+use subtle::Choice;
+use subtle::ConditionallySelectable;
+
+use rand::RngCore;
+
+use sha2::Digest;
+use sha2::Sha256;
+use sha2::Sha512;
+
+use curve25519_dalek::constants as dalek_constants;
+use curve25519_dalek::ristretto::CompressedRistretto;
+use curve25519_dalek::ristretto::RistrettoBasepointTable;
+use curve25519_dalek::ristretto::RistrettoPoint;
+use curve25519_dalek::scalar::Scalar;
+
+use rayon::scope;
+use rayon::ThreadPoolBuilder;
+
+use spiral_rs::client::*;
+use spiral_rs::params::*;
+use spiral_rs::server::*;
+
+use crate::spiral_mt::*;
+
+use lazy_static::lazy_static;
+
+pub type DbEntry = u64;
+
+// Generators of the Ristretto group (the standard B and another one C,
+// for which the DL relationship is unknown), and their precomputed
+// multiplication tables.  Used for the Oblivious Transfer protocol
+lazy_static! {
+    pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
+    pub static ref OT_C: RistrettoPoint =
+        RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
+    pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
+    pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
+}
+
+// XOR a 16-byte slice into a Block (which will be used as an AES key)
+fn xor16(outar: &mut Block, inar: &[u8; 16]) {
+    for i in 0..16 {
+        outar[i] ^= inar[i];
+    }
+}
+
+// Encrypt a database of 2^r elements, where each element is a DbEntry,
+// using the 2*r provided keys (r pairs of keys).  Also rotate the
+// database by rot positions, and add the provided blinding factor to
+// each element before encryption (the same blinding factor for all
+// elements).  Each element is encrypted in AES counter mode, with the
+// counter being the element number and the key computed as the XOR of r
+// of the provided keys, one from each pair, according to the bits of
+// the element number.  Outputs a byte vector containing the encrypted
+// database.
+pub fn encdb_xor_keys(
+    db: &[DbEntry],
+    keys: &[[u8; 16]],
+    r: usize,
+    rot: usize,
+    blind: DbEntry,
+    num_threads: usize,
+) -> Vec<u8> {
+    let num_records: usize = 1 << r;
+    let num_record_mask: usize = num_records - 1;
+    let negrot = num_records - rot;
+    let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
+    ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
+    scope(|s| {
+        let mut record_thread_start = 0usize;
+        let records_per_thread_base = num_records / num_threads;
+        let records_per_thread_extra = num_records % num_threads;
+        let mut retslice = ret.as_mut_slice();
+        for thr in 0..num_threads {
+            let records_this_thread =
+                records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
+            let record_thread_end = record_thread_start + records_this_thread;
+            let (thread_ret, retslice_) =
+                retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
+            retslice = retslice_;
+            s.spawn(move |_| {
+                let mut offset = 0usize;
+                for j in record_thread_start..record_thread_end {
+                    let rec = (j + negrot) & num_record_mask;
+                    let mut key = Block::from([0u8; 16]);
+                    for i in 0..r {
+                        let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
+                        xor16(&mut key, &keys[2 * i + bit]);
+                    }
+                    let aes = Aes128Enc::new(&key);
+                    let mut block = Block::from([0u8; 16]);
+                    block[0..8].copy_from_slice(&j.to_le_bytes());
+                    aes.encrypt_block(&mut block);
+                    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
+                    let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream;
+                    thread_ret[offset..offset + mem::size_of::<DbEntry>()]
+                        .copy_from_slice(&encelem.to_le_bytes());
+                    offset += mem::size_of::<DbEntry>();
+                }
+            });
+            record_thread_start = record_thread_end;
+        }
+    });
+    ret
+}
+
+// Generate the keys for encrypting the database
+pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
+    let mut keys: Vec<[u8; 16]> = Vec::new();
+
+    let mut rng = rand::thread_rng();
+    for _ in 0..2 * r {
+        let mut k: [u8; 16] = [0; 16];
+        rng.fill_bytes(&mut k);
+        keys.push(k);
+    }
+    keys
+}
+
+// 1-out-of-2 Oblivious Transfer (OT)
+
+fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
+    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let C: &RistrettoPoint = &OT_C;
+    let mut rng = rand07::thread_rng();
+    let x = Scalar::random(&mut rng);
+    let xB = &x * Btable;
+    let CmxB = C - xB;
+    let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
+    ((sel, x), P.compress().to_bytes())
+}
+
+fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
+    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+    let mut rng = rand07::thread_rng();
+    let y = Scalar::random(&mut rng);
+    let yB = &y * Btable;
+    let yC = &y * Ctable;
+    let P = CompressedRistretto::from_slice(query).decompress().unwrap();
+    let yP0 = y * P;
+    let yP1 = yC - yP0;
+    let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
+    for i in 0..16 {
+        HyP0[i] ^= m0[i];
+    }
+    let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
+    for i in 0..16 {
+        HyP1[i] ^= m1[i];
+    }
+    let mut ret = [0u8; 64];
+    ret[0..32].copy_from_slice(yB.compress().as_bytes());
+    ret[32..48].copy_from_slice(&HyP0[0..16]);
+    ret[48..64].copy_from_slice(&HyP1[0..16]);
+    ret
+}
+
+fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
+    let yB = CompressedRistretto::from_slice(&response[0..32])
+        .decompress()
+        .unwrap();
+    let yP = state.1 * yB;
+    let mut HyP = Sha256::digest(yP.compress().as_bytes());
+    for i in 0..16 {
+        HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
+    }
+    HyP[0..16].try_into().unwrap()
+}
+
+// Obliviously fetch the key for element q of the database (which has
+// 2^r elements total).  Each bit of q is used in a 1-out-of-2 OT to get
+// one of the keys in each of the r pairs of keys on the server side.
+// The resulting r keys are XORed together.
+
+pub fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
+    let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
+    let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
+    for i in 0..r {
+        let bit = ((q >> i) & 1) as u8;
+        let (si, qi) = ot12_request(bit.into());
+        state.push(si);
+        query.push(qi);
+    }
+    (state, query)
+}
+
+pub fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
+    let r = query.len();
+    assert!(keys.len() == 2 * r);
+    let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
+    for i in 0..r {
+        response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
+    }
+    response
+}
+
+pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
+    let r = state.len();
+    assert!(response.len() == r);
+    let mut key = Block::from([0u8; 16]);
+    for i in 0..r {
+        xor16(&mut key, &ot12_receive(state[i], &response[i]));
+    }
+    key
+}
+
+// Having received the key for element q with r parallel 1-out-of-2 OTs,
+// and having received the encrypted element with (non-symmetric) PIR,
+// use the key to decrypt the element.
+pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
+    let aes = Aes128Enc::new(key);
+    let mut block = Block::from([0u8; 16]);
+    block[0..8].copy_from_slice(&q.to_le_bytes());
+    aes.encrypt_block(&mut block);
+    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
+    encelement ^ aeskeystream
+}
+
+// Things that are only done once total, not once for each SPIR
+pub fn init(num_threads: usize) {
+    // Resolve the lazy statics
+    let _B: &RistrettoPoint = &OT_B;
+    let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let _C: &RistrettoPoint = &OT_C;
+    let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+
+    // Initialize the thread pool
+    ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
+}
+
+pub fn print_params_summary(params: &Params) {
+    let db_elem_size = params.item_size();
+    let total_size = params.num_items() * db_elem_size;
+    println!(
+        "Using a {} x {} byte database ({} bytes total)",
+        params.num_items(),
+        db_elem_size,
+        total_size
+    );
+}
+
+#[no_mangle]
+pub extern "C" fn spir_init(num_threads: u32) {
+    init(num_threads as usize);
+}

+ 17 - 244
src/main.rs

@@ -2,269 +2,42 @@
 // lowercase letters
 #![allow(non_snake_case)]
 
-pub mod aligned_memory_mt;
-pub mod params;
-pub mod spiral_mt;
-
-use aes::cipher::{BlockEncrypt, KeyInit};
-use aes::Aes128Enc;
-use aes::Block;
 use std::env;
-use std::mem;
 use std::time::Instant;
-use subtle::Choice;
-use subtle::ConditionallySelectable;
-
-use rand::RngCore;
-
-use sha2::Digest;
-use sha2::Sha256;
-use sha2::Sha512;
-
-use curve25519_dalek::constants as dalek_constants;
-use curve25519_dalek::ristretto::CompressedRistretto;
-use curve25519_dalek::ristretto::RistrettoBasepointTable;
-use curve25519_dalek::ristretto::RistrettoPoint;
-use curve25519_dalek::scalar::Scalar;
-
-use rayon::scope;
-use rayon::ThreadPoolBuilder;
 
 use spiral_rs::client::*;
-use spiral_rs::params::*;
 use spiral_rs::server::*;
 
-use crate::spiral_mt::*;
-
-use lazy_static::lazy_static;
-
-type DbEntry = u64;
-
-// Generators of the Ristretto group (the standard B and another one C,
-// for which the DL relationship is unknown), and their precomputed
-// multiplication tables.  Used for the Oblivious Transfer protocol
-lazy_static! {
-    pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
-    pub static ref OT_C: RistrettoPoint =
-        RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
-    pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
-    pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
-}
-
-// XOR a 16-byte slice into a Block (which will be used as an AES key)
-fn xor16(outar: &mut Block, inar: &[u8; 16]) {
-    for i in 0..16 {
-        outar[i] ^= inar[i];
-    }
-}
-
-// Encrypt a database of 2^r elements, where each element is a DbEntry,
-// using the 2*r provided keys (r pairs of keys).  Also add the provided
-// blinding factor to each element before encryption (the same blinding
-// factor for all elements).  Each element is encrypted in AES counter
-// mode, with the counter being the element number and the key computed
-// as the XOR of r of the provided keys, one from each pair, according
-// to the bits of the element number.  Outputs a byte vector containing
-// the encrypted database.
-fn encdb_xor_keys(
-    db: &[DbEntry],
-    keys: &[[u8; 16]],
-    r: usize,
-    blind: DbEntry,
-    num_threads: usize,
-) -> Vec<u8> {
-    let num_records: usize = 1 << r;
-    let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
-    ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
-    scope(|s| {
-        let mut record_thread_start = 0usize;
-        let records_per_thread_base = num_records / num_threads;
-        let records_per_thread_extra = num_records % num_threads;
-        let mut retslice = ret.as_mut_slice();
-        for thr in 0..num_threads {
-            let records_this_thread =
-                records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
-            let record_thread_end = record_thread_start + records_this_thread;
-            let (thread_ret, retslice_) =
-                retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
-            retslice = retslice_;
-            s.spawn(move |_| {
-                let mut offset = 0usize;
-                for j in record_thread_start..record_thread_end {
-                    let mut key = Block::from([0u8; 16]);
-                    for i in 0..r {
-                        let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
-                        xor16(&mut key, &keys[2 * i + bit]);
-                    }
-                    let aes = Aes128Enc::new(&key);
-                    let mut block = Block::from([0u8; 16]);
-                    block[0..8].copy_from_slice(&j.to_le_bytes());
-                    aes.encrypt_block(&mut block);
-                    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
-                    let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
-                    thread_ret[offset..offset + mem::size_of::<DbEntry>()]
-                        .copy_from_slice(&encelem.to_le_bytes());
-                    offset += mem::size_of::<DbEntry>();
-                }
-            });
-            record_thread_start = record_thread_end;
-        }
-    });
-    ret
-}
-
-// Generate the keys for encrypting the database
-fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
-    let mut keys: Vec<[u8; 16]> = Vec::new();
-
-    let mut rng = rand::thread_rng();
-    for _ in 0..2 * r {
-        let mut k: [u8; 16] = [0; 16];
-        rng.fill_bytes(&mut k);
-        keys.push(k);
-    }
-    keys
-}
-
-// 1-out-of-2 Oblivious Transfer (OT)
-
-fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
-    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let C: &RistrettoPoint = &OT_C;
-    let mut rng = rand07::thread_rng();
-    let x = Scalar::random(&mut rng);
-    let xB = &x * Btable;
-    let CmxB = C - xB;
-    let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
-    ((sel, x), P.compress().to_bytes())
-}
-
-fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
-    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
-    let mut rng = rand07::thread_rng();
-    let y = Scalar::random(&mut rng);
-    let yB = &y * Btable;
-    let yC = &y * Ctable;
-    let P = CompressedRistretto::from_slice(query).decompress().unwrap();
-    let yP0 = y * P;
-    let yP1 = yC - yP0;
-    let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
-    for i in 0..16 {
-        HyP0[i] ^= m0[i];
-    }
-    let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
-    for i in 0..16 {
-        HyP1[i] ^= m1[i];
-    }
-    let mut ret = [0u8; 64];
-    ret[0..32].copy_from_slice(yB.compress().as_bytes());
-    ret[32..48].copy_from_slice(&HyP0[0..16]);
-    ret[48..64].copy_from_slice(&HyP1[0..16]);
-    ret
-}
-
-fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
-    let yB = CompressedRistretto::from_slice(&response[0..32])
-        .decompress()
-        .unwrap();
-    let yP = state.1 * yB;
-    let mut HyP = Sha256::digest(yP.compress().as_bytes());
-    for i in 0..16 {
-        HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
-    }
-    HyP[0..16].try_into().unwrap()
-}
-
-// Obliviously fetch the key for element q of the database (which has
-// 2^r elements total).  Each bit of q is used in a 1-out-of-2 OT to get
-// one of the keys in each of the r pairs of keys on the server side.
-// The resulting r keys are XORed together.
-
-fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
-    let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
-    let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
-    for i in 0..r {
-        let bit = ((q >> i) & 1) as u8;
-        let (si, qi) = ot12_request(bit.into());
-        state.push(si);
-        query.push(qi);
-    }
-    (state, query)
-}
-
-fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
-    let r = query.len();
-    assert!(keys.len() == 2 * r);
-    let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
-    for i in 0..r {
-        response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
-    }
-    response
-}
-
-fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
-    let r = state.len();
-    assert!(response.len() == r);
-    let mut key = Block::from([0u8; 16]);
-    for i in 0..r {
-        xor16(&mut key, &ot12_receive(state[i], &response[i]));
-    }
-    key
-}
-
-// Having received the key for element q with r parallel 1-out-of-2 OTs,
-// and having received the encrypted element with (non-symmetric) PIR,
-// use the key to decrypt the element.
-fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
-    let aes = Aes128Enc::new(key);
-    let mut block = Block::from([0u8; 16]);
-    block[0..8].copy_from_slice(&q.to_le_bytes());
-    aes.encrypt_block(&mut block);
-    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
-    encelement ^ aeskeystream
-}
-
-// Things that are only done once total, not once for each SPIR
-fn one_time_setup() {
-    // Resolve the lazy statics
-    let _B: &RistrettoPoint = &OT_B;
-    let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let _C: &RistrettoPoint = &OT_C;
-    let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
-}
-
-fn print_params_summary(params: &Params) {
-    let db_elem_size = params.item_size();
-    let total_size = params.num_items() * db_elem_size;
-    println!(
-        "Using a {} x {} byte database ({} bytes total)",
-        params.num_items(),
-        db_elem_size,
-        total_size
-    );
-}
+use spiral_spir::*;
 
 fn main() {
     let args: Vec<String> = env::args().collect();
-    if args.len() != 2 && args.len() != 3 {
-        println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
+    if args.len() < 2 || args.len() > 4 {
+        println!("Usage: {} r [num_threads [num_pirs]]\nr = log_2(num_records)", args[0]);
         return;
     }
     let r: usize = args[1].parse().unwrap();
     let mut num_threads = 1usize;
-    if args.len() == 3 {
+    let mut num_pirs = 1usize;
+    if args.len() > 2 {
         num_threads = args[2].parse().unwrap();
     }
+    if args.len() > 3 {
+        num_pirs = args[3].parse().unwrap();
+    }
     let num_records = 1 << r;
 
     println!("===== ONE-TIME SETUP =====\n");
 
+    let otsetup_start = Instant::now();
+    init(num_threads);
+    let otsetup_us = otsetup_start.elapsed().as_micros();
+    println!("OT one-time setup: {} µs", otsetup_us);
+/*
     let otsetup_start = Instant::now();
     let spiral_params = params::get_spiral_params(r);
     let mut rng = rand::thread_rng();
-    ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
-    one_time_setup();
+    init(num_threads);
     let otsetup_us = otsetup_start.elapsed().as_micros();
     print_params_summary(&spiral_params);
     println!("OT one-time setup: {} µs", otsetup_us);
@@ -319,7 +92,7 @@ fn main() {
     println!("key OT receive in {} µs", otkeyrcv_us);
 
     // Create a database with recognizable contents
-    let mut db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
+    let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
         .map(|x| 10000001 * x)
         .collect();
 
@@ -339,8 +112,7 @@ fn main() {
     // The server rotates, blinds, and encrypts the database
     let blind: DbEntry = 20;
     let encdb_start = Instant::now();
-    db.rotate_right(idx_offset);
-    let encdb = encdb_xor_keys(&db, &dbkeys, r, blind, num_threads);
+    let encdb = encdb_xor_keys(&db, &dbkeys, r, idx_offset, blind, num_threads);
     let encdb_us = encdb_start.elapsed().as_micros();
     println!("Server encrypt database {} µs", encdb_us);
 
@@ -378,4 +150,5 @@ fn main() {
     let spc_recv_us = spc_recv_start.elapsed().as_micros();
     println!("Client decode response {} µs", spc_recv_us);
     println!("index = {}, Response = {}", q, decdbentry);
+*/
 }

+ 2 - 0
src/server.rs

@@ -0,0 +1,2 @@
+pub struct Server {
+}