16 Commits c205c0121a ... 913aa2596b

Author SHA1 Message Date
  Ian Goldberg 913aa2596b Update README.md 1 year ago
  Ian Goldberg 60314dbab9 Typo in C++ test program output 1 year ago
  Ian Goldberg 6661c0f068 Take some cargo clippy advice 1 year ago
  Ian Goldberg 5ae476d943 Clean up some compiler warnings 1 year ago
  Ian Goldberg 23df7c068e client query_finish API 1 year ago
  Ian Goldberg b49605e91c server query_process API 1 year ago
  Ian Goldberg 99ac4e28e7 client query API 1 year ago
  Ian Goldberg 6211a89d03 Touch up some API function names 1 year ago
  Ian Goldberg 473d2cb478 preproc_finish API 1 year ago
  Ian Goldberg 217b08a574 server preproc_PIRs API 1 year ago
  Ian Goldberg 67ef5574d3 Build the rust code with target-cpu=native 1 year ago
  Ian Goldberg eba0f97b3b client preproc_PIRs API 1 year ago
  Ian Goldberg 2bd0149ea8 Move the oblivious transfer code into its own module 1 year ago
  Ian Goldberg 801a3203b6 spir_client_free, spir_server_new, spir_server_free APIs 1 year ago
  Ian Goldberg 689d911823 spir_client_new API 1 year ago
  Ian Goldberg eb53841429 Start converting the crate to a library crate with a C++ interface 1 year ago
14 changed files with 1458 additions and 378 deletions
  1. 11 0
      Cargo.toml
  2. 74 24
      README.md
  3. 14 0
      cxx/Makefile
  4. 79 0
      cxx/spir.cpp
  5. 64 0
      cxx/spir.hpp
  6. 57 0
      cxx/spir_ffi.h
  7. 126 0
      cxx/spir_test.cpp
  8. 1 0
      src/aligned_memory_mt.rs
  9. 296 0
      src/client.rs
  10. 168 0
      src/lib.rs
  11. 194 352
      src/main.rs
  12. 150 0
      src/ot.rs
  13. 222 0
      src/server.rs
  14. 2 2
      src/spiral_mt.rs

+ 11 - 0
Cargo.toml

@@ -18,6 +18,17 @@ sha2 = "0.9"
 subtle = { package = "subtle-ng", version = "2.4" }
 spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157" }
 rayon = "1.5"
+bincode = "1"
+serde = "1"
+serde_with = "2"
+
+[lib]
+crate_type = ["lib", "staticlib"]
+path = "src/lib.rs"
+
+[[bin]]
+name = "spiral-spir"
+path = "src/main.rs"
 
 [features]
 default = ["u64_backend"]

+ 74 - 24
README.md

@@ -2,6 +2,8 @@
 
 *Ian Goldberg (iang@uwaterloo.ca), July 2022*
 
+Last update August 2022
+
 This code implements Symmetric Private Information Retrieval, building
 on the Spiral PIR library (this code is not written by the Spiral
 authors).
@@ -57,42 +59,90 @@ We slightly optimize the protocol in that instead of the client sending both β0
 
 ## Running the code
 
-To build the code:
+In the August 2022 version, this code is now built as a Rust library
+that can be called from Rust or from C++.
+
+To build the Rust library and a Rust test program:
 
 `RUSTFLAGS="-C target-cpu=native" cargo build --release`
 
-To run the code:
+To build the C++ library that wraps the Rust library and a C++ test program:
+
+`make -C cxx`
+
+To run the Rust test program:
+
+`./target/release/spiral-spir 20 4 100 2`
+
+TO run the C++ test program:
 
-`./target/release/spiral-spir 20 4`
+`./cxx/spir_test 20 4 100 2`
 
-Where `20` is the value of r (that is, the database will have N=2^20 entries), and 4 is the number of threads to use (defaults to 1).  Each entry is 8 bytes.  There are three phases of execution: a one-time Spiral public key generation (this only has to be done once, regardless of how many SPIR queries you do), a preprocessing phase per SPIR query (this can be done _before_ knowing the contents of the database on the server side, or the desired index on the client side), and the runtime phase per SPIR query (once those two things are known).
+Here:
 
-A sample output (for r=20, 4 threads):
+  * `20` is the value of r (that is, the database will have N=2^20 entries)
+  * `4` is the number of threads to use (defaults to 1)
+  * `100` is the number of SPIR queries to prepare in the preprocessing
+    step; there is only one round trip from the client to the server
+    regardless of this number, but the message and response are larger
+    (defaults to 1)
+  * `2` is the number of SPIR queries to do, one at a time. Must be at
+    most the number of preprocessed queries, and defaults to that
+    number
+  
+Each entry is 8 bytes.  There are three phases of execution: a one-time Spiral public key generation (this only has to be done once, regardless of how many SPIR queries you do), a preprocessing phase per SPIR query (this can be done _before_ knowing the contents of the database on the server side, or the desired index on the client side), and the runtime phase per SPIR query (once those two things are known).
+
+A sample output (for r=20, 4 threads, 100 preprocessed SPIR queries, 2
+SPIR queries performed):
 
 ```
 ===== ONE-TIME SETUP =====
 
-Using a 2048 x 4096 byte database (8388608 bytes total)
-OT one-time setup: 3637 µs
-Spiral client one-time setup: 157144 µs, 10878976 bytes
+One-time setup: 201893 µs
+pub_params len = 10878976
 
 ===== PREPROCESSING =====
 
-rand_idx = 146516 rand_pir_idx = 286
-Spiral query: 457 µs, 32768 bytes
-key OT query in 324 µs, 640 bytes
-key OT serve in 1653 µs, 1280 bytes
-key OT receive in 1029 µs
-
-===== RUNTIME =====
-
-Send to server 8 bytes
-Server encrypt database 29738 µs
-Server load database 248825 µs
-expansion (took 101920 us).
-Server compute response 181293 µs, 14336 bytes (*including* the above expansion time)
-Client decode response 790 µs
-index = 919657, Response = 9196570919677
+num_preproc = 100
+Preprocessing client: 76869 µs
+preproc_msg len = 3342408
+Preprocessing server: 53409 µs
+preproc_resp len = 128808
+Preprocessing client finish: 26688 µs
+
+===== SPIR QUERY 1 =====
+
+Query client: 107 µs
+query_msg len = 8
+expansion (took 99798 us).
+Query server: 411422 µs
+query_resp len = 14336
+Query client finish: 832 µs
+idx = 778275; entry = 7783750778395
+
+===== SPIR QUERY 2 =====
+
+Query client: 73 µs
+query_msg len = 8
+expansion (took 90281 us).
+Query server: 406014 µs
+query_resp len = 14336
+Query client finish: 810 µs
+idx = 675158; entry = 6752580675278
 ```
 
-The various lines show the amount of wall-clock time taken for various parts of the computation and the amount of data transferred between the client and the server.  The last line shows the random index that was looked up, and the database value the client retrieved.  The value for index i should be (10000001*i+20).
+The various lines show the amount of wall-clock time taken for various
+parts of the computation and the amount of data transferred between the
+client and the server.  The last line of each SPIR query shows the
+random index that was looked up, and the database value the client
+retrieved.  The value for index i should be (10000001\*(i+100)+20).
+
+## Using the C++ library yourself
+
+Build the C++ library as above:
+
+`make -C cxx`
+
+Then `cxx/spir.hpp` and `cxx/libspir_cxx.a` are the files you'll need.
+\#include the former in your code, and link your program with the latter
+as well as `-lpthread -ldl`.

+ 14 - 0
cxx/Makefile

@@ -0,0 +1,14 @@
+CXXFLAGS = -O3 -Wall
+
+spir_test: spir_test.o libspir_cxx.a
+	g++ -o $@ $^ -lpthread -ldl
+
+libspir_cxx.a: spir.o ../target/release/libspiral_spir.a
+	cp ../target/release/libspiral_spir.a $@
+	ar r $@ $<
+
+../target/release/libspiral_spir.a: $(wildcard ../src/*.rs)
+	RUSTFLAGS="-C target-cpu=native" cargo build --release
+
+clean:
+	-rm -f libspir_cxx.a spir.o spir_test.o spir_test

+ 79 - 0
cxx/spir.cpp

@@ -0,0 +1,79 @@
+#include <stdio.h>
+#include "spir.hpp"
+#include "spir_ffi.h"
+
+void SPIR::init(uint32_t num_threads)
+{
+    spir_init(num_threads);
+}
+
+SPIR_Client::SPIR_Client(uint8_t r, string &pub_params)
+{
+    ClientNewRet ret = spir_client_new(r);
+    pub_params.assign(ret.pub_params.data, ret.pub_params.len);
+    spir_vecdata_free(ret.pub_params);
+    this->client = ret.client;
+}
+
+SPIR_Client::~SPIR_Client()
+{
+    spir_client_free(this->client);
+}
+
+string SPIR_Client::preproc(uint32_t num_preproc)
+{
+    VecData msg = spir_client_preproc(this->client, num_preproc);
+    string ret(msg.data, msg.len);
+    spir_vecdata_free(msg);
+    return ret;
+}
+
+void SPIR_Client::preproc_finish(const string &server_preproc)
+{
+    spir_client_preproc_finish(this->client, server_preproc.data(),
+        server_preproc.length());
+}
+
+string SPIR_Client::query(size_t idx)
+{
+    VecData msg = spir_client_query(this->client, idx);
+    string ret(msg.data, msg.len);
+    spir_vecdata_free(msg);
+    return ret;
+}
+
+SPIR::DBEntry SPIR_Client::query_finish(const string &server_resp)
+{
+    return spir_client_query_finish(this->client,
+        server_resp.data(), server_resp.length());
+}
+
+SPIR_Server::SPIR_Server(uint8_t r, const string &pub_params)
+{
+    this->server = spir_server_new(r, pub_params.data(),
+        pub_params.length());
+}
+
+SPIR_Server::~SPIR_Server()
+{
+    spir_server_free(this->server);
+}
+
+string SPIR_Server::preproc_process(const string &msg)
+{
+    VecData retmsg = spir_server_preproc_process(this->server, msg.data(),
+        msg.length());
+    string ret(retmsg.data, retmsg.len);
+    spir_vecdata_free(retmsg);
+    return ret;
+}
+
+string SPIR_Server::query_process(const string &client_query,
+        const SPIR::DBEntry *db, size_t rot, SPIR::DBEntry blind)
+{
+    VecData retmsg = spir_server_query_process(this->server,
+        client_query.data(), client_query.length(), db, rot, blind);
+    string ret(retmsg.data, retmsg.len);
+    spir_vecdata_free(retmsg);
+    return ret;
+}

+ 64 - 0
cxx/spir.hpp

@@ -0,0 +1,64 @@
+#ifndef __SPIR_HPP__
+#define __SPIR_HPP__
+
+#include <string>
+#include <stdint.h>
+
+using std::string;
+
+class SPIR {
+public:
+    typedef uint64_t DBEntry;  // The type of each DB entry (64 bits)
+
+    static void init(uint32_t nthreads);  // Call this once at startup
+};
+
+class SPIR_Client {
+public:
+    // constructor and destructor
+    SPIR_Client(uint8_t r, string &pub_params); // 2^r records in the database; pub_params will be _filled in_
+    ~SPIR_Client();
+
+    // preprocessing
+    string preproc(uint32_t num_pirs); // returns the string to send to the server
+
+    void preproc_finish(const string &server_preproc);
+
+    // SPIR query for index idx
+    string query(size_t idx); // returns the string to send to the server
+
+    // process the server's response to yield the server's db[(idx + rot)%N] + blind
+    // where N=2^r, idx is provided by the client above, and
+    // db, rot, and blind are provided by the server below
+    SPIR::DBEntry query_finish(const string &server_reply);
+
+private:
+    void *client;
+    SPIR_Client() = default;
+    SPIR_Client(const SPIR_Client&) = delete;
+    SPIR_Client& operator=(const SPIR_Client&) = delete;
+};
+
+class SPIR_Server {
+public:
+    // constructor and destructor
+    SPIR_Server(uint8_t r, const string &client_pub_params);
+    ~SPIR_Server();
+
+    // preprocessing
+    string preproc_process(const string &client_preproc); // returns the string to reply to the client
+  
+    // SPIR query on the given database of N=2^r records, each of type DBEntry
+    // rotate the database by rot, and blind each entry in the database additively with blind
+    // returns the string to reply to the client
+    string query_process(const string &client_query, const SPIR::DBEntry *db,
+        size_t rot, SPIR::DBEntry blind);
+
+private:
+    void *server;
+    SPIR_Server() = default;
+    SPIR_Server(const SPIR_Server&) = delete;
+    SPIR_Server& operator=(const SPIR_Server&) = delete;
+};
+
+#endif

+ 57 - 0
cxx/spir_ffi.h

@@ -0,0 +1,57 @@
+#ifndef __SPIR_FFI_H__
+#define __SPIR_FFI_H__
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef size_t DBEntry;
+
+typedef struct {
+    const char *data;
+    size_t len;
+    size_t capacity;
+} VecData;
+
+typedef struct {
+    void *client;
+    VecData pub_params;
+} ClientNewRet;
+
+extern void spir_init(uint32_t num_threads);
+
+extern ClientNewRet spir_client_new(uint8_t r);
+
+extern void spir_client_free(void *client);
+
+extern VecData spir_client_preproc(void *client, uint32_t num_preproc);
+
+extern void spir_client_preproc_finish(void *client,
+    const char *msgdata, size_t msglen);
+
+extern VecData spir_client_query(void *client, size_t idx);
+
+extern DBEntry spir_client_query_finish(void *client,
+    const char *msgdata, size_t msglen);
+
+extern void* spir_server_new(uint8_t r, const char *pub_params,
+    size_t pub_params_len);
+
+extern void spir_server_free(void *server);
+
+extern VecData spir_server_preproc_process(void *server,
+    const char *msgdata, size_t msglen);
+
+extern VecData spir_server_query_process(void *server,
+    const char *msgdata, size_t msglen, const DBEntry *db,
+    size_t rot, DBEntry blind);
+
+extern void spir_vecdata_free(VecData vecdata);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif

+ 126 - 0
cxx/spir_test.cpp

@@ -0,0 +1,126 @@
+#include <iostream>
+#include <stdlib.h>
+#include <sys/random.h>
+#include <sys/time.h>
+#include <unistd.h>
+#include "spir.hpp"
+
+using std::cout;
+using std::cerr;
+
+static inline size_t elapsed_us(const struct timeval *start)
+{
+    struct timeval end;
+    gettimeofday(&end, NULL);
+    return (end.tv_sec-start->tv_sec)*1000000 + end.tv_usec - start->tv_usec;
+}
+
+int main(int argc, char **argv)
+{
+    if (argc < 2 || argc > 5) {
+        cerr << "Usage: " << argv[0] << " r [num_threads [num_preproc [num_pirs]]]\n";
+        cerr << "r = log_2(num_records)\n";
+        exit(1);
+    }
+    uint32_t r, num_threads = 1, num_preproc = 1, num_pirs = 1;
+    r = strtoul(argv[1], NULL, 10);
+    size_t num_records = ((size_t) 1)<<r;
+    size_t num_records_mask = num_records - 1;
+    if (argc > 2) {
+        num_threads = strtoul(argv[2], NULL, 10);
+    }
+    if (argc > 3) {
+        num_preproc = strtoul(argv[3], NULL, 10);
+    }
+    if (argc > 4) {
+        num_pirs = strtoul(argv[4], NULL, 10);
+    } else {
+        num_pirs = num_preproc;
+    }
+
+    cout << "===== ONE-TIME SETUP =====\n\n";
+
+    struct timeval otsetup_start;
+    gettimeofday(&otsetup_start, NULL);
+
+    SPIR::init(num_threads);
+    string pub_params;
+    SPIR_Client client(r, pub_params);
+    SPIR_Server server(r, pub_params);
+
+    size_t otsetup_us = elapsed_us(&otsetup_start);
+    cout << "One-time setup: " << otsetup_us << " µs\n";
+    cout << "pub_params len = " << pub_params.length() << "\n";
+
+    cout << "\n===== PREPROCESSING =====\n\n";
+
+    cout << "num_preproc = " << num_preproc << "\n";
+
+    struct timeval preproc_client_start;
+    gettimeofday(&preproc_client_start, NULL);
+
+    string preproc_msg = client.preproc(num_preproc);
+    size_t preproc_client_us = elapsed_us(&preproc_client_start);
+    cout << "Preprocessing client: " << preproc_client_us << " µs\n";
+    cout << "preproc_msg len = " << preproc_msg.length() << "\n";
+
+    struct timeval preproc_server_start;
+    gettimeofday(&preproc_server_start, NULL);
+
+    string preproc_resp = server.preproc_process(preproc_msg);
+    size_t preproc_server_us = elapsed_us(&preproc_server_start);
+    cout << "Preprocessing server: " << preproc_server_us << " µs\n";
+    cout << "preproc_resp len = " << preproc_resp.length() << "\n";
+
+    struct timeval preproc_finish_start;
+    gettimeofday(&preproc_finish_start, NULL);
+
+    client.preproc_finish(preproc_resp);
+    size_t preproc_finish_us = elapsed_us(&preproc_finish_start);
+    cout << "Preprocessing client finish: " << preproc_finish_us << " µs\n";
+
+    // Create the database
+    SPIR::DBEntry *db = new SPIR::DBEntry[num_records];
+    for (size_t i=0; i<num_records; ++i) {
+        db[i] = i * 10000001;
+    }
+
+    for (size_t i=0; i<num_pirs; ++i) {
+        cout << "\n===== SPIR QUERY " << i+1 << " =====\n\n";
+
+        size_t idx;
+        if (getrandom(&idx, sizeof(idx), 0) != sizeof(idx)) {
+            cerr << "Failure in getrandom\n";
+            exit(1);
+        }
+        idx &= num_records_mask;
+
+        struct timeval query_client_start;
+        gettimeofday(&query_client_start, NULL);
+
+        string query_msg = client.query(idx);
+        size_t query_client_us = elapsed_us(&query_client_start);
+        cout << "Query client: " << query_client_us << " µs\n";
+        cout << "query_msg len = " << query_msg.length() << "\n";
+
+        struct timeval query_server_start;
+        gettimeofday(&query_server_start, NULL);
+
+        string query_resp = server.query_process(query_msg, db, 100, 20);
+        size_t query_server_us = elapsed_us(&query_server_start);
+        cout << "Query server: " << query_server_us << " µs\n";
+        cout << "query_resp len = " << query_resp.length() << "\n";
+
+        struct timeval query_finish_start;
+        gettimeofday(&query_finish_start, NULL);
+
+        SPIR::DBEntry entry = client.query_finish(query_resp);
+        size_t query_finish_us = elapsed_us(&query_finish_start);
+        cout << "Query client finish: " << query_finish_us << " µs\n";
+        cout << "idx = " << idx << "; entry = " << entry << "\n";
+    }
+
+    delete[] db;
+
+    return 0;
+}

+ 1 - 0
src/aligned_memory_mt.rs

@@ -1,3 +1,4 @@
+#![allow(dead_code)]
 /* This file is almost identical to the aligned_memory.rs file in the
    spiral-rs crate.  The name is modified from AlignedMemory to
    AlignedMemoryMT, and there is one (unsafe!) change to the API:

+ 296 - 0
src/client.rs

@@ -0,0 +1,296 @@
+use rand::RngCore;
+
+use std::collections::VecDeque;
+use std::mem;
+use std::os::raw::c_uchar;
+use std::sync::mpsc::*;
+use std::thread::*;
+
+use aes::Block;
+
+use rayon::prelude::*;
+
+use subtle::Choice;
+
+use curve25519_dalek::scalar::Scalar;
+
+use crate::dbentry_decrypt;
+use crate::ot::*;
+use crate::params;
+use crate::to_vecdata;
+use crate::DbEntry;
+use crate::PreProcSingleMsg;
+use crate::PreProcSingleRespMsg;
+use crate::VecData;
+
+enum Command {
+    PreProc(usize),
+    PreProcFinish(Vec<PreProcSingleRespMsg>),
+    Query(usize),
+    QueryFinish(Vec<u8>),
+}
+
+enum Response {
+    PubParams(Vec<u8>),
+    PreProcMsg(Vec<u8>),
+    PreProcDone,
+    QueryMsg(Vec<u8>),
+    QueryDone(DbEntry),
+}
+
+// The internal client state for a single outstanding preprocess query
+struct PreProcOutSingleState {
+    rand_idx: usize,
+    ot_state: Vec<(Choice, Scalar)>,
+}
+
+// The internal client state for a single preprocess ready to be used
+struct PreProcSingleState {
+    rand_idx: usize,
+    ot_key: Block,
+}
+
+pub struct Client {
+    incoming_cmd: SyncSender<Command>,
+    outgoing_resp: Receiver<Response>,
+}
+
+impl Client {
+    pub fn new(r: usize) -> (Self, Vec<u8>) {
+        let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
+        let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
+        spawn(move || {
+            let spiral_params = params::get_spiral_params(r);
+            let mut clientrng = rand::thread_rng();
+            let mut rng = rand::thread_rng();
+            let mut spiral_client = spiral_rs::client::Client::init(&spiral_params, &mut clientrng);
+            let num_records = 1 << r;
+            let num_records_mask = num_records - 1;
+            let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
+
+            // The first communication is the pub_params
+            let pub_params = spiral_client.generate_keys().serialize();
+            outgoing_resp_send
+                .send(Response::PubParams(pub_params))
+                .unwrap();
+
+            // State for outstanding preprocessing queries
+            let mut preproc_out_state: Vec<PreProcOutSingleState> = Vec::new();
+
+            // State for preprocessing queries ready to be used
+            let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
+
+            // State for outstanding active queries
+            let mut query_state: VecDeque<PreProcSingleState> = VecDeque::new();
+
+            // Wait for commands
+            loop {
+                match incoming_cmd_recv.recv() {
+                    Err(_) => break,
+                    Ok(Command::PreProc(num_preproc)) => {
+                        // Ensure we don't already have outstanding
+                        // preprocessing state
+                        assert!(preproc_out_state.is_empty());
+                        let mut preproc_msg: Vec<PreProcSingleMsg> = Vec::new();
+                        for _ in 0..num_preproc {
+                            let rand_idx = (rng.next_u64() as usize) & num_records_mask;
+                            let rand_pir_idx = rand_idx / spiral_blocking_factor;
+                            let spc_query = spiral_client.generate_query(rand_pir_idx).serialize();
+                            let (ot_state, ot_query) = otkey_request(rand_idx, r);
+                            preproc_out_state.push(PreProcOutSingleState { rand_idx, ot_state });
+                            preproc_msg.push(PreProcSingleMsg {
+                                ot_query,
+                                spc_query,
+                            });
+                        }
+                        let ret: Vec<u8> = bincode::serialize(&preproc_msg).unwrap();
+                        outgoing_resp_send.send(Response::PreProcMsg(ret)).unwrap();
+                    }
+                    Ok(Command::PreProcFinish(srvresp)) => {
+                        let num_preproc = srvresp.len();
+                        assert!(preproc_out_state.len() == num_preproc);
+                        let mut newstate: VecDeque<PreProcSingleState> = preproc_out_state
+                            .into_par_iter()
+                            .zip(srvresp)
+                            .map(|(c, s)| {
+                                let ot_key = otkey_receive(c.ot_state, &s.ot_resp);
+                                PreProcSingleState {
+                                    rand_idx: c.rand_idx,
+                                    ot_key,
+                                }
+                            })
+                            .collect();
+                        preproc_state.append(&mut newstate);
+                        preproc_out_state = Vec::new();
+                        outgoing_resp_send.send(Response::PreProcDone).unwrap();
+                    }
+                    Ok(Command::Query(idx)) => {
+                        // panic if there are no preproc states
+                        // available
+                        let nextstate = preproc_state.pop_front().unwrap();
+                        let offset = (num_records + idx - nextstate.rand_idx) & num_records_mask;
+                        let mut querymsg: Vec<u8> = Vec::new();
+                        querymsg.extend(offset.to_le_bytes());
+                        query_state.push_back(nextstate);
+                        outgoing_resp_send
+                            .send(Response::QueryMsg(querymsg))
+                            .unwrap();
+                    }
+                    Ok(Command::QueryFinish(msg)) => {
+                        // panic if there is no outstanding state
+                        let nextstate = query_state.pop_front().unwrap();
+                        let encdbblock = spiral_client.decode_response(msg.as_slice());
+                        // Extract the one encrypted DbEntry we were
+                        // looking for (and the only one we are able to
+                        // decrypt)
+                        let entry_in_block = nextstate.rand_idx % spiral_blocking_factor;
+                        let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
+                        let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
+                        let encdbentry = DbEntry::from_le_bytes(
+                            encdbblock[loc_in_block..loc_in_block_end]
+                                .try_into()
+                                .unwrap(),
+                        );
+                        let decdbentry =
+                            dbentry_decrypt(&nextstate.ot_key, nextstate.rand_idx, encdbentry);
+                        outgoing_resp_send
+                            .send(Response::QueryDone(decdbentry))
+                            .unwrap();
+                    }
+                    // When adding new messages, the following line is
+                    // useful during development
+                    // _ => panic!("Received something unexpected in client loop"),
+                }
+            }
+        });
+        let pub_params = match outgoing_resp.recv() {
+            Ok(Response::PubParams(x)) => x,
+            _ => panic!("Received something unexpected in client new"),
+        };
+
+        (
+            Client {
+                incoming_cmd,
+                outgoing_resp,
+            },
+            pub_params,
+        )
+    }
+
+    pub fn preproc(&self, num_preproc: usize) -> Vec<u8> {
+        self.incoming_cmd
+            .send(Command::PreProc(num_preproc))
+            .unwrap();
+        match self.outgoing_resp.recv() {
+            Ok(Response::PreProcMsg(x)) => x,
+            _ => panic!("Received something unexpected in preproc"),
+        }
+    }
+
+    pub fn preproc_finish(&self, msg: &[u8]) {
+        self.incoming_cmd
+            .send(Command::PreProcFinish(bincode::deserialize(msg).unwrap()))
+            .unwrap();
+        match self.outgoing_resp.recv() {
+            Ok(Response::PreProcDone) => (),
+            _ => panic!("Received something unexpected in preproc_finish"),
+        }
+    }
+
+    pub fn query(&self, idx: usize) -> Vec<u8> {
+        self.incoming_cmd.send(Command::Query(idx)).unwrap();
+        match self.outgoing_resp.recv() {
+            Ok(Response::QueryMsg(x)) => x,
+            _ => panic!("Received something unexpected in preproc"),
+        }
+    }
+
+    pub fn query_finish(&self, msg: &[u8]) -> DbEntry {
+        self.incoming_cmd
+            .send(Command::QueryFinish(msg.to_vec()))
+            .unwrap();
+        match self.outgoing_resp.recv() {
+            Ok(Response::QueryDone(entry)) => entry,
+            _ => panic!("Received something unexpected in preproc_finish"),
+        }
+    }
+}
+
+#[repr(C)]
+pub struct ClientNewRet {
+    client: *mut Client,
+    pub_params: VecData,
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_new(r: u8) -> ClientNewRet {
+    let (client, pub_params) = Client::new(r as usize);
+    ClientNewRet {
+        client: Box::into_raw(Box::new(client)),
+        pub_params: to_vecdata(pub_params),
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_free(client: *mut Client) {
+    if client.is_null() {
+        return;
+    }
+    unsafe {
+        Box::from_raw(client);
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_preproc(clientptr: *mut Client, num_preproc: u32) -> VecData {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let retvec = client.preproc(num_preproc as usize);
+    to_vecdata(retvec)
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_preproc_finish(
+    clientptr: *mut Client,
+    msgdata: *const c_uchar,
+    msglen: usize,
+) {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    client.preproc_finish(msg_slice);
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_query(clientptr: *mut Client, idx: usize) -> VecData {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let retvec = client.query(idx);
+    to_vecdata(retvec)
+}
+
+#[no_mangle]
+pub extern "C" fn spir_client_query_finish(
+    clientptr: *mut Client,
+    msgdata: *const c_uchar,
+    msglen: usize,
+) -> DbEntry {
+    let client = unsafe {
+        assert!(!clientptr.is_null());
+        &mut *clientptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    client.query_finish(msg_slice)
+}

+ 168 - 0
src/lib.rs

@@ -0,0 +1,168 @@
+mod aligned_memory_mt;
+pub mod client;
+mod ot;
+mod params;
+pub mod server;
+mod spiral_mt;
+
+use aes::cipher::{BlockEncrypt, KeyInit};
+use aes::Aes128Enc;
+use aes::Block;
+use std::mem;
+
+use serde::{Deserialize, Serialize};
+
+use std::os::raw::c_uchar;
+
+use rayon::scope;
+use rayon::ThreadPoolBuilder;
+
+use serde_with::serde_as;
+
+use spiral_rs::params::*;
+
+use crate::ot::{otkey_init, xor16};
+use crate::spiral_mt::*;
+
+pub type DbEntry = u64;
+
+// Encrypt a database of 2^r elements, where each element is a DbEntry,
+// using the 2*r provided keys (r pairs of keys).  Also rotate the
+// database by rot positions, and add the provided blinding factor to
+// each element before encryption (the same blinding factor for all
+// elements).  Each element is encrypted in AES counter mode, with the
+// counter being the element number and the key computed as the XOR of r
+// of the provided keys, one from each pair, according to the bits of
+// the element number.  Outputs a byte vector containing the encrypted
+// database.
+fn db_encrypt(
+    db: &[DbEntry],
+    keys: &[[u8; 16]],
+    r: usize,
+    rot: usize,
+    blind: DbEntry,
+    num_threads: usize,
+) -> Vec<u8> {
+    let num_records: usize = 1 << r;
+    let num_record_mask: usize = num_records - 1;
+    let mut ret = vec![0; num_records * mem::size_of::<DbEntry>()];
+    scope(|s| {
+        let mut record_thread_start = 0usize;
+        let records_per_thread_base = num_records / num_threads;
+        let records_per_thread_extra = num_records % num_threads;
+        let mut retslice = ret.as_mut_slice();
+        for thr in 0..num_threads {
+            let records_this_thread =
+                records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
+            let record_thread_end = record_thread_start + records_this_thread;
+            let (thread_ret, retslice_) =
+                retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
+            retslice = retslice_;
+            s.spawn(move |_| {
+                let mut offset = 0usize;
+                for j in record_thread_start..record_thread_end {
+                    let rec = (j + rot) & num_record_mask;
+                    let mut key = Block::from([0u8; 16]);
+                    for i in 0..r {
+                        let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
+                        xor16(&mut key, &keys[2 * i + bit]);
+                    }
+                    let aes = Aes128Enc::new(&key);
+                    let mut block = Block::from([0u8; 16]);
+                    block[0..8].copy_from_slice(&j.to_le_bytes());
+                    aes.encrypt_block(&mut block);
+                    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
+                    let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream;
+                    thread_ret[offset..offset + mem::size_of::<DbEntry>()]
+                        .copy_from_slice(&encelem.to_le_bytes());
+                    offset += mem::size_of::<DbEntry>();
+                }
+            });
+            record_thread_start = record_thread_end;
+        }
+    });
+    ret
+}
+
+// Having received the key for element q with r parallel 1-out-of-2 OTs,
+// and having received the encrypted element with (non-symmetric) PIR,
+// use the key to decrypt the element.
+fn dbentry_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
+    let aes = Aes128Enc::new(key);
+    let mut block = Block::from([0u8; 16]);
+    block[0..8].copy_from_slice(&q.to_le_bytes());
+    aes.encrypt_block(&mut block);
+    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
+    encelement ^ aeskeystream
+}
+
+// Things that are only done once total, not once for each SPIR
+pub fn init(num_threads: usize) {
+    otkey_init();
+
+    // Initialize the thread pool
+    ThreadPoolBuilder::new()
+        .num_threads(num_threads)
+        .build_global()
+        .unwrap();
+}
+
+pub fn print_params_summary(params: &Params) {
+    let db_elem_size = params.item_size();
+    let total_size = params.num_items() * db_elem_size;
+    println!(
+        "Using a {} x {} byte database ({} bytes total)",
+        params.num_items(),
+        db_elem_size,
+        total_size
+    );
+}
+
+// The message format for a single preprocess query
+#[derive(Serialize, Deserialize)]
+struct PreProcSingleMsg {
+    ot_query: Vec<[u8; 32]>,
+    spc_query: Vec<u8>,
+}
+
+// The message format for a single preprocess response
+#[serde_as]
+#[derive(Serialize, Deserialize)]
+struct PreProcSingleRespMsg {
+    #[serde_as(as = "Vec<[_; 64]>")]
+    ot_resp: Vec<[u8; 64]>,
+}
+
+#[no_mangle]
+pub extern "C" fn spir_init(num_threads: u32) {
+    init(num_threads as usize);
+}
+
+#[repr(C)]
+pub struct VecData {
+    data: *const c_uchar,
+    len: usize,
+    cap: usize,
+}
+
+#[repr(C)]
+pub struct VecMutData {
+    data: *mut c_uchar,
+    len: usize,
+    cap: usize,
+}
+
+pub fn to_vecdata(v: Vec<u8>) -> VecData {
+    let vecdata = VecData {
+        data: v.as_ptr(),
+        len: v.len(),
+        cap: v.capacity(),
+    };
+    std::mem::forget(v);
+    vecdata
+}
+
+#[no_mangle]
+pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
+    unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };
+}

+ 194 - 352
src/main.rs

@@ -1,381 +1,223 @@
-// We really want points to be capital letters and scalars to be
-// lowercase letters
-#![allow(non_snake_case)]
-
-pub mod aligned_memory_mt;
-pub mod params;
-pub mod spiral_mt;
-
-use aes::cipher::{BlockEncrypt, KeyInit};
-use aes::Aes128Enc;
-use aes::Block;
 use std::env;
-use std::mem;
 use std::time::Instant;
-use subtle::Choice;
-use subtle::ConditionallySelectable;
 
 use rand::RngCore;
 
-use sha2::Digest;
-use sha2::Sha256;
-use sha2::Sha512;
-
-use curve25519_dalek::constants as dalek_constants;
-use curve25519_dalek::ristretto::CompressedRistretto;
-use curve25519_dalek::ristretto::RistrettoBasepointTable;
-use curve25519_dalek::ristretto::RistrettoPoint;
-use curve25519_dalek::scalar::Scalar;
-
-use rayon::scope;
-use rayon::ThreadPoolBuilder;
-
-use spiral_rs::client::*;
-use spiral_rs::params::*;
-use spiral_rs::server::*;
-
-use crate::spiral_mt::*;
-
-use lazy_static::lazy_static;
-
-type DbEntry = u64;
-
-// Generators of the Ristretto group (the standard B and another one C,
-// for which the DL relationship is unknown), and their precomputed
-// multiplication tables.  Used for the Oblivious Transfer protocol
-lazy_static! {
-    pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
-    pub static ref OT_C: RistrettoPoint =
-        RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
-    pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
-    pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
-}
-
-// XOR a 16-byte slice into a Block (which will be used as an AES key)
-fn xor16(outar: &mut Block, inar: &[u8; 16]) {
-    for i in 0..16 {
-        outar[i] ^= inar[i];
-    }
-}
-
-// Encrypt a database of 2^r elements, where each element is a DbEntry,
-// using the 2*r provided keys (r pairs of keys).  Also add the provided
-// blinding factor to each element before encryption (the same blinding
-// factor for all elements).  Each element is encrypted in AES counter
-// mode, with the counter being the element number and the key computed
-// as the XOR of r of the provided keys, one from each pair, according
-// to the bits of the element number.  Outputs a byte vector containing
-// the encrypted database.
-fn encdb_xor_keys(
-    db: &[DbEntry],
-    keys: &[[u8; 16]],
-    r: usize,
-    blind: DbEntry,
-    num_threads: usize,
-) -> Vec<u8> {
-    let num_records: usize = 1 << r;
-    let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
-    ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
-    scope(|s| {
-        let mut record_thread_start = 0usize;
-        let records_per_thread_base = num_records / num_threads;
-        let records_per_thread_extra = num_records % num_threads;
-        let mut retslice = ret.as_mut_slice();
-        for thr in 0..num_threads {
-            let records_this_thread =
-                records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
-            let record_thread_end = record_thread_start + records_this_thread;
-            let (thread_ret, retslice_) =
-                retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
-            retslice = retslice_;
-            s.spawn(move |_| {
-                let mut offset = 0usize;
-                for j in record_thread_start..record_thread_end {
-                    let mut key = Block::from([0u8; 16]);
-                    for i in 0..r {
-                        let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
-                        xor16(&mut key, &keys[2 * i + bit]);
-                    }
-                    let aes = Aes128Enc::new(&key);
-                    let mut block = Block::from([0u8; 16]);
-                    block[0..8].copy_from_slice(&j.to_le_bytes());
-                    aes.encrypt_block(&mut block);
-                    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
-                    let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
-                    thread_ret[offset..offset + mem::size_of::<DbEntry>()]
-                        .copy_from_slice(&encelem.to_le_bytes());
-                    offset += mem::size_of::<DbEntry>();
-                }
-            });
-            record_thread_start = record_thread_end;
-        }
-    });
-    ret
-}
-
-// Generate the keys for encrypting the database
-fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
-    let mut keys: Vec<[u8; 16]> = Vec::new();
-
-    let mut rng = rand::thread_rng();
-    for _ in 0..2 * r {
-        let mut k: [u8; 16] = [0; 16];
-        rng.fill_bytes(&mut k);
-        keys.push(k);
-    }
-    keys
-}
-
-// 1-out-of-2 Oblivious Transfer (OT)
-
-fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
-    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let C: &RistrettoPoint = &OT_C;
-    let mut rng = rand07::thread_rng();
-    let x = Scalar::random(&mut rng);
-    let xB = &x * Btable;
-    let CmxB = C - xB;
-    let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
-    ((sel, x), P.compress().to_bytes())
-}
-
-fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
-    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
-    let mut rng = rand07::thread_rng();
-    let y = Scalar::random(&mut rng);
-    let yB = &y * Btable;
-    let yC = &y * Ctable;
-    let P = CompressedRistretto::from_slice(query).decompress().unwrap();
-    let yP0 = y * P;
-    let yP1 = yC - yP0;
-    let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
-    for i in 0..16 {
-        HyP0[i] ^= m0[i];
-    }
-    let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
-    for i in 0..16 {
-        HyP1[i] ^= m1[i];
-    }
-    let mut ret = [0u8; 64];
-    ret[0..32].copy_from_slice(yB.compress().as_bytes());
-    ret[32..48].copy_from_slice(&HyP0[0..16]);
-    ret[48..64].copy_from_slice(&HyP1[0..16]);
-    ret
-}
-
-fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
-    let yB = CompressedRistretto::from_slice(&response[0..32])
-        .decompress()
-        .unwrap();
-    let yP = state.1 * yB;
-    let mut HyP = Sha256::digest(yP.compress().as_bytes());
-    for i in 0..16 {
-        HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
-    }
-    HyP[0..16].try_into().unwrap()
-}
-
-// Obliviously fetch the key for element q of the database (which has
-// 2^r elements total).  Each bit of q is used in a 1-out-of-2 OT to get
-// one of the keys in each of the r pairs of keys on the server side.
-// The resulting r keys are XORed together.
-
-fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
-    let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
-    let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
-    for i in 0..r {
-        let bit = ((q >> i) & 1) as u8;
-        let (si, qi) = ot12_request(bit.into());
-        state.push(si);
-        query.push(qi);
-    }
-    (state, query)
-}
-
-fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
-    let r = query.len();
-    assert!(keys.len() == 2 * r);
-    let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
-    for i in 0..r {
-        response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
-    }
-    response
-}
-
-fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
-    let r = state.len();
-    assert!(response.len() == r);
-    let mut key = Block::from([0u8; 16]);
-    for i in 0..r {
-        xor16(&mut key, &ot12_receive(state[i], &response[i]));
-    }
-    key
-}
-
-// Having received the key for element q with r parallel 1-out-of-2 OTs,
-// and having received the encrypted element with (non-symmetric) PIR,
-// use the key to decrypt the element.
-fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
-    let aes = Aes128Enc::new(key);
-    let mut block = Block::from([0u8; 16]);
-    block[0..8].copy_from_slice(&q.to_le_bytes());
-    aes.encrypt_block(&mut block);
-    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
-    encelement ^ aeskeystream
-}
-
-// Things that are only done once total, not once for each SPIR
-fn one_time_setup() {
-    // Resolve the lazy statics
-    let _B: &RistrettoPoint = &OT_B;
-    let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let _C: &RistrettoPoint = &OT_C;
-    let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
-}
-
-fn print_params_summary(params: &Params) {
-    let db_elem_size = params.item_size();
-    let total_size = params.num_items() * db_elem_size;
-    println!(
-        "Using a {} x {} byte database ({} bytes total)",
-        params.num_items(),
-        db_elem_size,
-        total_size
-    );
-}
+use spiral_spir::client::Client;
+use spiral_spir::init;
+use spiral_spir::server::Server;
+use spiral_spir::DbEntry;
 
 fn main() {
     let args: Vec<String> = env::args().collect();
-    if args.len() != 2 && args.len() != 3 {
-        println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
+    if args.len() < 2 || args.len() > 5 {
+        println!(
+            "Usage: {} r [num_threads [num_preproc [num_pirs]]]\nr = log_2(num_records)",
+            args[0]
+        );
         return;
     }
     let r: usize = args[1].parse().unwrap();
     let mut num_threads = 1usize;
-    if args.len() == 3 {
+    let mut num_preproc = 1usize;
+    let num_pirs: usize;
+    if args.len() > 2 {
         num_threads = args[2].parse().unwrap();
     }
+    if args.len() > 3 {
+        num_preproc = args[3].parse().unwrap();
+    }
+    if args.len() > 4 {
+        num_pirs = args[4].parse().unwrap();
+    } else {
+        num_pirs = num_preproc;
+    }
     let num_records = 1 << r;
+    let num_records_mask = num_records - 1;
 
     println!("===== ONE-TIME SETUP =====\n");
 
     let otsetup_start = Instant::now();
-    let spiral_params = params::get_spiral_params(r);
-    let mut rng = rand::thread_rng();
-    ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
-    one_time_setup();
+
+    init(num_threads);
+    let (client, pub_params) = Client::new(r);
+    let pub_params_len = pub_params.len();
+    let server = Server::new(r, pub_params);
+
     let otsetup_us = otsetup_start.elapsed().as_micros();
-    print_params_summary(&spiral_params);
-    println!("OT one-time setup: {} µs", otsetup_us);
-
-    // One-time setup for the Spiral client
-    let spc_otsetup_start = Instant::now();
-    let mut clientrng = rand::thread_rng();
-    let mut client = Client::init(&spiral_params, &mut clientrng);
-    let pub_params = client.generate_keys();
-    let pub_params_buf = pub_params.serialize();
-    let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
-    let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
-    println!(
-        "Spiral client one-time setup: {} µs, {} bytes",
-        spc_otsetup_us,
-        pub_params_buf.len()
-    );
+    println!("One-time setup: {} µs", otsetup_us);
+    println!("pub_params len = {}", pub_params_len);
 
     println!("\n===== PREPROCESSING =====\n");
 
-    // Spiral preprocessing: create a PIR lookup for an element at a
-    // random location
-    let spc_query_start = Instant::now();
-    let rand_idx = (rng.next_u64() as usize) % num_records;
-    let rand_pir_idx = rand_idx / spiral_blocking_factor;
-    println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
-    let spc_query = client.generate_query(rand_pir_idx);
-    let spc_query_buf = spc_query.serialize();
-    let spc_query_us = spc_query_start.elapsed().as_micros();
-    println!(
-        "Spiral query: {} µs, {} bytes",
-        spc_query_us,
-        spc_query_buf.len()
-    );
-
-    // Create the database encryption keys and do the OT to fetch the
-    // right one, but don't actually encrypt the database yet
-    let dbkeys = gen_db_enc_keys(r);
-    let otkeyreq_start = Instant::now();
-    let (keystate, keyquery) = otkey_request(rand_idx, r);
-    let keyquerysize = keyquery.len() * keyquery[0].len();
-    let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
-    let otkeysrv_start = Instant::now();
-    let keyresponse = otkey_serve(keyquery, &dbkeys);
-    let keyrespsize = keyresponse.len() * keyresponse[0].len();
-    let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
-    let otkeyrcv_start = Instant::now();
-    let otkey = otkey_receive(keystate, &keyresponse);
-    let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
-    println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
-    println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
-    println!("key OT receive in {} µs", otkeyrcv_us);
+    println!("num_preproc = {}", num_preproc);
+
+    let preproc_client_start = Instant::now();
+    let preproc_msg = client.preproc(num_preproc);
+    let preproc_client_us = preproc_client_start.elapsed().as_micros();
+
+    println!("Preprocessing client: {} µs", preproc_client_us);
+    println!("preproc_msg len = {}", preproc_msg.len());
+
+    let preproc_server_start = Instant::now();
+    let preproc_resp = server.preproc_process(&preproc_msg);
+    let preproc_server_us = preproc_server_start.elapsed().as_micros();
+
+    println!("Preprocessing server: {} µs", preproc_server_us);
+    println!("preproc_resp len = {}", preproc_resp.len());
+
+    let preproc_finish_start = Instant::now();
+    client.preproc_finish(&preproc_resp);
+    let preproc_finish_us = preproc_finish_start.elapsed().as_micros();
+
+    println!("Preprocessing client finish: {} µs", preproc_finish_us);
 
     // Create a database with recognizable contents
-    let mut db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
+    let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
         .map(|x| 10000001 * x)
         .collect();
+    let dbptr = db.as_ptr();
+
+    let mut rng = rand::thread_rng();
+    for i in 1..num_pirs + 1 {
+        println!("\n===== SPIR QUERY {} =====\n", i);
+
+        let idx = (rng.next_u64() as usize) & num_records_mask;
+        let query_client_start = Instant::now();
+        let query_msg = client.query(idx);
+        let query_client_us = query_client_start.elapsed().as_micros();
+
+        println!("Query client: {} µs", query_client_us);
+        println!("query_msg len = {}", query_msg.len());
+
+        let query_server_start = Instant::now();
+        let query_resp = server.query_process(&query_msg, dbptr, 100, 20);
+        let query_server_us = query_server_start.elapsed().as_micros();
+
+        println!("Query server: {} µs", query_server_us);
+        println!("query_resp len = {}", query_resp.len());
+
+        let query_finish_start = Instant::now();
+        let entry = client.query_finish(&query_resp);
+        let query_finish_us = query_finish_start.elapsed().as_micros();
+
+        println!("Query client finish: {} µs", query_finish_us);
+        println!("idx = {}; entry = {}", idx, entry);
+    }
 
-    println!("\n===== RUNTIME =====\n");
-
-    // Pick the record we actually want to query
-    let q = (rng.next_u64() as usize) % num_records;
-
-    // Compute the offset from the record index we're actually looking
-    // for to the random one we picked earlier.  Tell it to the server,
-    // who will rotate right the database by that amount before
-    // encrypting it.
-    let idx_offset = (num_records + rand_idx - q) % num_records;
-
-    println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
-
-    // The server rotates, blinds, and encrypts the database
-    let blind: DbEntry = 20;
-    let encdb_start = Instant::now();
-    db.rotate_right(idx_offset);
-    let encdb = encdb_xor_keys(&db, &dbkeys, r, blind, num_threads);
-    let encdb_us = encdb_start.elapsed().as_micros();
-    println!("Server encrypt database {} µs", encdb_us);
-
-    // Load the encrypted database into Spiral
-    let sps_loaddb_start = Instant::now();
-    let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
-    let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
-    println!("Server load database {} µs", sps_loaddb_us);
-
-    // Do the PIR query
-    let sps_query_start = Instant::now();
-    let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
-    let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
-    let sps_query_us = sps_query_start.elapsed().as_micros();
-    println!(
-        "Server compute response {} µs, {} bytes (*including* the above expansion time)",
-        sps_query_us,
-        sps_response.len()
-    );
-
-    // Decode the response to yield the whole Spiral block
-    let spc_recv_start = Instant::now();
-    let encdbblock = client.decode_response(sps_response.as_slice());
-    // Extract the one encrypted DbEntry we were looking for (and the
-    // only one we are able to decrypt)
-    let entry_in_block = rand_idx % spiral_blocking_factor;
-    let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
-    let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
-    let encdbentry = DbEntry::from_le_bytes(
-        encdbblock[loc_in_block..loc_in_block_end]
-            .try_into()
-            .unwrap(),
-    );
-    let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
-    let spc_recv_us = spc_recv_start.elapsed().as_micros();
-    println!("Client decode response {} µs", spc_recv_us);
-    println!("index = {}, Response = {}", q, decdbentry);
+    /*
+        let spiral_params = params::get_spiral_params(r);
+        let mut rng = rand::thread_rng();
+        print_params_summary(&spiral_params);
+        println!("OT one-time setup: {} µs", otsetup_us);
+
+        // One-time setup for the Spiral client
+        let spc_otsetup_start = Instant::now();
+        let mut clientrng = rand::thread_rng();
+        let mut client = Client::init(&spiral_params, &mut clientrng);
+        let pub_params = client.generate_keys();
+        let pub_params_buf = pub_params.serialize();
+        let spc_otsetup_us = spc_otsetup_start.elapsed().as_micros();
+        let spiral_blocking_factor = spiral_params.db_item_size / mem::size_of::<DbEntry>();
+        println!(
+            "Spiral client one-time setup: {} µs, {} bytes",
+            spc_otsetup_us,
+            pub_params_buf.len()
+        );
+
+        println!("\n===== PREPROCESSING =====\n");
+
+        // Spiral preprocessing: create a PIR lookup for an element at a
+        // random location
+        let spc_query_start = Instant::now();
+        let rand_idx = (rng.next_u64() as usize) % num_records;
+        let rand_pir_idx = rand_idx / spiral_blocking_factor;
+        println!("rand_idx = {} rand_pir_idx = {}", rand_idx, rand_pir_idx);
+        let spc_query = client.generate_query(rand_pir_idx);
+        let spc_query_buf = spc_query.serialize();
+        let spc_query_us = spc_query_start.elapsed().as_micros();
+        println!(
+            "Spiral query: {} µs, {} bytes",
+            spc_query_us,
+            spc_query_buf.len()
+        );
+
+        // Create the database encryption keys and do the OT to fetch the
+        // right one, but don't actually encrypt the database yet
+        let dbkeys = gen_db_enc_keys(r);
+        let otkeyreq_start = Instant::now();
+        let (keystate, keyquery) = otkey_request(rand_idx, r);
+        let keyquerysize = keyquery.len() * keyquery[0].len();
+        let otkeyreq_us = otkeyreq_start.elapsed().as_micros();
+        let otkeysrv_start = Instant::now();
+        let keyresponse = otkey_serve(keyquery, &dbkeys);
+        let keyrespsize = keyresponse.len() * keyresponse[0].len();
+        let otkeysrv_us = otkeysrv_start.elapsed().as_micros();
+        let otkeyrcv_start = Instant::now();
+        let otkey = otkey_receive(keystate, &keyresponse);
+        let otkeyrcv_us = otkeyrcv_start.elapsed().as_micros();
+        println!("key OT query in {} µs, {} bytes", otkeyreq_us, keyquerysize);
+        println!("key OT serve in {} µs, {} bytes", otkeysrv_us, keyrespsize);
+        println!("key OT receive in {} µs", otkeyrcv_us);
+
+        // Create a database with recognizable contents
+        let db: Vec<DbEntry> = ((0 as DbEntry)..(num_records as DbEntry))
+            .map(|x| 10000001 * x)
+            .collect();
+
+        println!("\n===== RUNTIME =====\n");
+
+        // Pick the record we actually want to query
+        let q = (rng.next_u64() as usize) % num_records;
+
+        // Compute the offset from the record index we're actually looking
+        // for to the random one we picked earlier.  Tell it to the server,
+        // who will rotate right the database by that amount before
+        // encrypting it.
+        let idx_offset = (num_records + rand_idx - q) % num_records;
+
+        println!("Send to server {} bytes", 8 /* sizeof(idx_offset) */);
+
+        // The server rotates, blinds, and encrypts the database
+        let blind: DbEntry = 20;
+        let encdb_start = Instant::now();
+        let encdb = encdb_xor_keys(&db, &dbkeys, r, idx_offset, blind, num_threads);
+        let encdb_us = encdb_start.elapsed().as_micros();
+        println!("Server encrypt database {} µs", encdb_us);
+
+        // Load the encrypted database into Spiral
+        let sps_loaddb_start = Instant::now();
+        let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
+        let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
+        println!("Server load database {} µs", sps_loaddb_us);
+
+        // Do the PIR query
+        let sps_query_start = Instant::now();
+        let sps_query = Query::deserialize(&spiral_params, &spc_query_buf);
+        let sps_response = process_query(&spiral_params, &pub_params, &sps_query, sps_db.as_slice());
+        let sps_query_us = sps_query_start.elapsed().as_micros();
+        println!(
+            "Server compute response {} µs, {} bytes (*including* the above expansion time)",
+            sps_query_us,
+            sps_response.len()
+        );
+
+        // Decode the response to yield the whole Spiral block
+        let spc_recv_start = Instant::now();
+        let encdbblock = client.decode_response(sps_response.as_slice());
+        // Extract the one encrypted DbEntry we were looking for (and the
+        // only one we are able to decrypt)
+        let entry_in_block = rand_idx % spiral_blocking_factor;
+        let loc_in_block = entry_in_block * mem::size_of::<DbEntry>();
+        let loc_in_block_end = (entry_in_block + 1) * mem::size_of::<DbEntry>();
+        let encdbentry = DbEntry::from_le_bytes(
+            encdbblock[loc_in_block..loc_in_block_end]
+                .try_into()
+                .unwrap(),
+        );
+        let decdbentry = otkey_decrypt(&otkey, rand_idx, encdbentry);
+        let spc_recv_us = spc_recv_start.elapsed().as_micros();
+        println!("Client decode response {} µs", spc_recv_us);
+        println!("index = {}, Response = {}", q, decdbentry);
+    */
 }

+ 150 - 0
src/ot.rs

@@ -0,0 +1,150 @@
+// We really want points to be capital letters and scalars to be
+// lowercase letters
+#![allow(non_snake_case)]
+
+// Oblivious transfer
+
+use subtle::Choice;
+use subtle::ConditionallySelectable;
+
+use aes::Block;
+
+use rand::RngCore;
+
+use sha2::Digest;
+use sha2::Sha256;
+use sha2::Sha512;
+
+use curve25519_dalek::constants as dalek_constants;
+use curve25519_dalek::ristretto::CompressedRistretto;
+use curve25519_dalek::ristretto::RistrettoBasepointTable;
+use curve25519_dalek::ristretto::RistrettoPoint;
+use curve25519_dalek::scalar::Scalar;
+
+use lazy_static::lazy_static;
+
+// Generators of the Ristretto group (the standard B and another one C,
+// for which the DL relationship is unknown), and their precomputed
+// multiplication tables.  Used for the Oblivious Transfer protocol
+lazy_static! {
+    pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
+    pub static ref OT_C: RistrettoPoint =
+        RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
+    pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
+    pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
+}
+
+// 1-out-of-2 Oblivious Transfer (OT)
+
+fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
+    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let C: &RistrettoPoint = &OT_C;
+    let mut rng = rand07::thread_rng();
+    let x = Scalar::random(&mut rng);
+    let xB = &x * Btable;
+    let CmxB = C - xB;
+    let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
+    ((sel, x), P.compress().to_bytes())
+}
+
+fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
+    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+    let mut rng = rand07::thread_rng();
+    let y = Scalar::random(&mut rng);
+    let yB = &y * Btable;
+    let yC = &y * Ctable;
+    let P = CompressedRistretto::from_slice(query).decompress().unwrap();
+    let yP0 = y * P;
+    let yP1 = yC - yP0;
+    let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
+    for i in 0..16 {
+        HyP0[i] ^= m0[i];
+    }
+    let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
+    for i in 0..16 {
+        HyP1[i] ^= m1[i];
+    }
+    let mut ret = [0u8; 64];
+    ret[0..32].copy_from_slice(yB.compress().as_bytes());
+    ret[32..48].copy_from_slice(&HyP0[0..16]);
+    ret[48..64].copy_from_slice(&HyP1[0..16]);
+    ret
+}
+
+fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
+    let yB = CompressedRistretto::from_slice(&response[0..32])
+        .decompress()
+        .unwrap();
+    let yP = state.1 * yB;
+    let mut HyP = Sha256::digest(yP.compress().as_bytes());
+    for i in 0..16 {
+        HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
+    }
+    HyP[0..16].try_into().unwrap()
+}
+
+// Obliviously fetch the key for element q of the database (which has
+// 2^r elements total).  Each bit of q is used in a 1-out-of-2 OT to get
+// one of the keys in each of the r pairs of keys on the server side.
+// The resulting r keys are XORed together.
+
+pub fn otkey_init() {
+    // Resolve the lazy statics
+    let _B: &RistrettoPoint = &OT_B;
+    let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let _C: &RistrettoPoint = &OT_C;
+    let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+}
+
+pub fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
+    let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
+    let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
+    for i in 0..r {
+        let bit = ((q >> i) & 1) as u8;
+        let (si, qi) = ot12_request(bit.into());
+        state.push(si);
+        query.push(qi);
+    }
+    (state, query)
+}
+
+pub fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
+    let r = query.len();
+    assert!(keys.len() == 2 * r);
+    let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
+    for i in 0..r {
+        response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
+    }
+    response
+}
+
+// XOR a 16-byte slice into a Block (which will be used as an AES key)
+pub fn xor16(outar: &mut Block, inar: &[u8; 16]) {
+    for i in 0..16 {
+        outar[i] ^= inar[i];
+    }
+}
+
+pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
+    let r = state.len();
+    assert!(response.len() == r);
+    let mut key = Block::from([0u8; 16]);
+    for i in 0..r {
+        xor16(&mut key, &ot12_receive(state[i], &response[i]));
+    }
+    key
+}
+
+// Generate the keys for encrypting the database
+pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
+    let mut keys: Vec<[u8; 16]> = Vec::new();
+
+    let mut rng = rand::thread_rng();
+    for _ in 0..2 * r {
+        let mut k: [u8; 16] = [0; 16];
+        rng.fill_bytes(&mut k);
+        keys.push(k);
+    }
+    keys
+}

+ 222 - 0
src/server.rs

@@ -0,0 +1,222 @@
+use std::collections::VecDeque;
+use std::os::raw::c_uchar;
+use std::sync::mpsc::*;
+use std::thread::*;
+
+use rayon::prelude::*;
+
+use spiral_rs::client::PublicParameters;
+use spiral_rs::client::Query;
+use spiral_rs::server::process_query;
+
+use crate::db_encrypt;
+use crate::load_db_from_slice_mt;
+use crate::ot::*;
+use crate::params;
+use crate::to_vecdata;
+use crate::DbEntry;
+use crate::PreProcSingleMsg;
+use crate::PreProcSingleRespMsg;
+use crate::VecData;
+
+enum Command {
+    PreProcMsg(Vec<PreProcSingleMsg>),
+    QueryMsg(usize, usize, usize, DbEntry),
+}
+
+enum Response {
+    PreProcResp(Vec<u8>),
+    QueryResp(Vec<u8>),
+}
+
+// The internal client state for a single preprocess query
+struct PreProcSingleState<'a> {
+    db_keys: Vec<[u8; 16]>,
+    query: Query<'a>,
+}
+
+pub struct Server {
+    incoming_cmd: SyncSender<Command>,
+    outgoing_resp: Receiver<Response>,
+}
+
+impl Server {
+    pub fn new(r: usize, pub_params: Vec<u8>) -> Self {
+        let (incoming_cmd, incoming_cmd_recv) = sync_channel(0);
+        let (outgoing_resp_send, outgoing_resp) = sync_channel(0);
+        spawn(move || {
+            let spiral_params = params::get_spiral_params(r);
+            let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
+            let num_records = 1 << r;
+            let num_records_mask = num_records - 1;
+
+            // State for preprocessing queries
+            let mut preproc_state: VecDeque<PreProcSingleState> = VecDeque::new();
+
+            // Wait for commands
+            loop {
+                match incoming_cmd_recv.recv() {
+                    Err(_) => break,
+                    Ok(Command::PreProcMsg(cliquery)) => {
+                        let num_preproc = cliquery.len();
+                        let mut resp_state: Vec<PreProcSingleState> =
+                            Vec::with_capacity(num_preproc);
+                        let mut resp_msg: Vec<PreProcSingleRespMsg> =
+                            Vec::with_capacity(num_preproc);
+                        cliquery
+                            .into_par_iter()
+                            .map(|q| {
+                                let db_keys = gen_db_enc_keys(r);
+                                let query = Query::deserialize(&spiral_params, &q.spc_query);
+                                let ot_resp = otkey_serve(q.ot_query, &db_keys);
+                                (
+                                    PreProcSingleState { db_keys, query },
+                                    PreProcSingleRespMsg { ot_resp },
+                                )
+                            })
+                            .unzip_into_vecs(&mut resp_state, &mut resp_msg);
+                        preproc_state.append(&mut VecDeque::from(resp_state));
+                        let ret: Vec<u8> = bincode::serialize(&resp_msg).unwrap();
+                        outgoing_resp_send.send(Response::PreProcResp(ret)).unwrap();
+                    }
+                    Ok(Command::QueryMsg(offset, db, rot, blind)) => {
+                        // Panic if there's no preprocess state
+                        // available
+                        let nextstate = preproc_state.pop_front().unwrap();
+                        // Encrypt the database with the keys, rotating
+                        // and blinding in the process.  It is safe to
+                        // construct a slice out of the const pointer we
+                        // were handed because that pointer will stay
+                        // valid until we return something back to the
+                        // caller.
+                        let totoffset = (offset + rot) & num_records_mask;
+                        let num_threads = rayon::current_num_threads();
+                        let encdb = unsafe {
+                            let dbslice =
+                                std::slice::from_raw_parts(db as *const DbEntry, num_records);
+                            db_encrypt(
+                                dbslice,
+                                &nextstate.db_keys,
+                                r,
+                                totoffset,
+                                blind,
+                                num_threads,
+                            )
+                        };
+                        // Load the encrypted db into Spiral
+                        let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
+                        // Process the query
+                        let resp = process_query(
+                            &spiral_params,
+                            &pub_params,
+                            &nextstate.query,
+                            sps_db.as_slice(),
+                        );
+                        outgoing_resp_send.send(Response::QueryResp(resp)).unwrap();
+                    }
+                    // When adding new messages, the following line is
+                    // useful during development
+                    // _ => panic!("Received something unexpected in server loop"),
+                }
+            }
+        });
+        Server {
+            incoming_cmd,
+            outgoing_resp,
+        }
+    }
+
+    pub fn preproc_process(&self, msg: &[u8]) -> Vec<u8> {
+        self.incoming_cmd
+            .send(Command::PreProcMsg(bincode::deserialize(msg).unwrap()))
+            .unwrap();
+        let ret = match self.outgoing_resp.recv() {
+            Ok(Response::PreProcResp(x)) => x,
+            _ => panic!("Received something unexpected in preproc_process"),
+        };
+        ret
+    }
+
+    pub fn query_process(
+        &self,
+        msg: &[u8],
+        db: *const DbEntry,
+        rot: usize,
+        blind: DbEntry,
+    ) -> Vec<u8> {
+        let offset = usize::from_le_bytes(msg.try_into().unwrap());
+        self.incoming_cmd
+            .send(Command::QueryMsg(offset, db as usize, rot, blind))
+            .unwrap();
+        let ret = match self.outgoing_resp.recv() {
+            Ok(Response::QueryResp(x)) => x,
+            _ => panic!("Received something unexpected in query_process"),
+        };
+        ret
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn spir_server_new(
+    r: u8,
+    pub_params: *const c_uchar,
+    pub_params_len: usize,
+) -> *mut Server {
+    let pub_params_slice = unsafe {
+        assert!(!pub_params.is_null());
+        std::slice::from_raw_parts(pub_params, pub_params_len)
+    };
+    let mut pub_params_vec: Vec<u8> = Vec::new();
+    pub_params_vec.extend_from_slice(pub_params_slice);
+    let server = Server::new(r as usize, pub_params_vec);
+    Box::into_raw(Box::new(server))
+}
+
+#[no_mangle]
+pub extern "C" fn spir_server_free(server: *mut Server) {
+    if server.is_null() {
+        return;
+    }
+    unsafe {
+        Box::from_raw(server);
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn spir_server_preproc_process(
+    serverptr: *mut Server,
+    msgdata: *const c_uchar,
+    msglen: usize,
+) -> VecData {
+    let server = unsafe {
+        assert!(!serverptr.is_null());
+        &mut *serverptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    let retvec = server.preproc_process(msg_slice);
+    to_vecdata(retvec)
+}
+
+#[no_mangle]
+pub extern "C" fn spir_server_query_process(
+    serverptr: *mut Server,
+    msgdata: *const c_uchar,
+    msglen: usize,
+    db: *const DbEntry,
+    rot: usize,
+    blind: DbEntry,
+) -> VecData {
+    let server = unsafe {
+        assert!(!serverptr.is_null());
+        &mut *serverptr
+    };
+    let msg_slice = unsafe {
+        assert!(!msgdata.is_null());
+        std::slice::from_raw_parts(msgdata, msglen)
+    };
+    let retvec = server.query_process(msg_slice, db, rot, blind);
+    to_vecdata(retvec)
+}

+ 2 - 2
src/spiral_mt.rs

@@ -82,7 +82,7 @@ pub fn load_db_from_slice_mt(
                             let db_idx = j * num_per + ii;
 
                             let mut db_item =
-                                load_item_from_slice(&params, slice, instance, trial, db_idx);
+                                load_item_from_slice(params, slice, instance, trial, db_idx);
                             // db_item.reduce_mod(params.pt_modulus);
 
                             for z in 0..params.poly_len {
@@ -101,7 +101,7 @@ pub fn load_db_from_slice_mt(
                                 );
 
                                 unsafe {
-                                    vptr.offset(idx_dst as isize).write(
+                                    vptr.add(idx_dst).write(
                                         db_item_ntt.data[z]
                                             | (db_item_ntt.data[params.poly_len + z]
                                                 << PACKED_OFFSET_2),