|
@@ -2,13 +2,14 @@
|
|
|
// lowercase letters
|
|
|
#![allow(non_snake_case)]
|
|
|
|
|
|
+pub mod aligned_memory_mt;
|
|
|
pub mod params;
|
|
|
+pub mod spiral_mt;
|
|
|
|
|
|
use aes::cipher::{BlockEncrypt, KeyInit};
|
|
|
use aes::Aes128Enc;
|
|
|
use aes::Block;
|
|
|
use std::env;
|
|
|
-use std::io::Cursor;
|
|
|
use std::mem;
|
|
|
use std::time::Instant;
|
|
|
use subtle::Choice;
|
|
@@ -26,10 +27,15 @@ use curve25519_dalek::ristretto::RistrettoBasepointTable;
|
|
|
use curve25519_dalek::ristretto::RistrettoPoint;
|
|
|
use curve25519_dalek::scalar::Scalar;
|
|
|
|
|
|
+use rayon::scope;
|
|
|
+use rayon::ThreadPoolBuilder;
|
|
|
+
|
|
|
use spiral_rs::client::*;
|
|
|
use spiral_rs::params::*;
|
|
|
use spiral_rs::server::*;
|
|
|
|
|
|
+use crate::spiral_mt::*;
|
|
|
+
|
|
|
use lazy_static::lazy_static;
|
|
|
|
|
|
type DbEntry = u64;
|
|
@@ -60,23 +66,50 @@ fn xor16(outar: &mut Block, inar: &[u8; 16]) {
|
|
|
// as the XOR of r of the provided keys, one from each pair, according
|
|
|
// to the bits of the element number. Outputs a byte vector containing
|
|
|
// the encrypted database.
|
|
|
-fn encdb_xor_keys(db: &[DbEntry], keys: &[[u8; 16]], r: usize, blind: DbEntry) -> Vec<u8> {
|
|
|
+fn encdb_xor_keys(
|
|
|
+ db: &[DbEntry],
|
|
|
+ keys: &[[u8; 16]],
|
|
|
+ r: usize,
|
|
|
+ blind: DbEntry,
|
|
|
+ num_threads: usize,
|
|
|
+) -> Vec<u8> {
|
|
|
let num_records: usize = 1 << r;
|
|
|
let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
|
|
|
- for j in 0..num_records {
|
|
|
- let mut key = Block::from([0u8; 16]);
|
|
|
- for i in 0..r {
|
|
|
- let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
|
|
|
- xor16(&mut key, &keys[2 * i + bit]);
|
|
|
+ ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
|
|
|
+ scope(|s| {
|
|
|
+ let mut record_thread_start = 0usize;
|
|
|
+ let records_per_thread_base = num_records / num_threads;
|
|
|
+ let records_per_thread_extra = num_records % num_threads;
|
|
|
+ let mut retslice = ret.as_mut_slice();
|
|
|
+ for thr in 0..num_threads {
|
|
|
+ let records_this_thread =
|
|
|
+ records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
|
|
|
+ let record_thread_end = record_thread_start + records_this_thread;
|
|
|
+ let (thread_ret, retslice_) =
|
|
|
+ retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
|
|
|
+ retslice = retslice_;
|
|
|
+ s.spawn(move |_| {
|
|
|
+ let mut offset = 0usize;
|
|
|
+ for j in record_thread_start..record_thread_end {
|
|
|
+ let mut key = Block::from([0u8; 16]);
|
|
|
+ for i in 0..r {
|
|
|
+ let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
|
|
|
+ xor16(&mut key, &keys[2 * i + bit]);
|
|
|
+ }
|
|
|
+ let aes = Aes128Enc::new(&key);
|
|
|
+ let mut block = Block::from([0u8; 16]);
|
|
|
+ block[0..8].copy_from_slice(&j.to_le_bytes());
|
|
|
+ aes.encrypt_block(&mut block);
|
|
|
+ let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
|
|
|
+ let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
|
|
|
+ thread_ret[offset..offset + mem::size_of::<DbEntry>()]
|
|
|
+ .copy_from_slice(&encelem.to_le_bytes());
|
|
|
+ offset += mem::size_of::<DbEntry>();
|
|
|
+ }
|
|
|
+ });
|
|
|
+ record_thread_start = record_thread_end;
|
|
|
}
|
|
|
- let aes = Aes128Enc::new(&key);
|
|
|
- let mut block = Block::from([0u8; 16]);
|
|
|
- block[0..8].copy_from_slice(&j.to_le_bytes());
|
|
|
- aes.encrypt_block(&mut block);
|
|
|
- let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
|
|
|
- let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
|
|
|
- ret.extend(encelem.to_le_bytes());
|
|
|
- }
|
|
|
+ });
|
|
|
ret
|
|
|
}
|
|
|
|
|
@@ -214,11 +247,15 @@ fn print_params_summary(params: &Params) {
|
|
|
|
|
|
fn main() {
|
|
|
let args: Vec<String> = env::args().collect();
|
|
|
- if args.len() != 2 {
|
|
|
- println!("Usage: {} r\nr = log_2(num_records)", args[0]);
|
|
|
+ if args.len() != 2 && args.len() != 3 {
|
|
|
+ println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
|
|
|
return;
|
|
|
}
|
|
|
let r: usize = args[1].parse().unwrap();
|
|
|
+ let mut num_threads = 1usize;
|
|
|
+ if args.len() == 3 {
|
|
|
+ num_threads = args[2].parse().unwrap();
|
|
|
+ }
|
|
|
let num_records = 1 << r;
|
|
|
|
|
|
println!("===== ONE-TIME SETUP =====\n");
|
|
@@ -226,6 +263,7 @@ fn main() {
|
|
|
let otsetup_start = Instant::now();
|
|
|
let spiral_params = params::get_spiral_params(r);
|
|
|
let mut rng = rand::thread_rng();
|
|
|
+ ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
|
|
|
one_time_setup();
|
|
|
let otsetup_us = otsetup_start.elapsed().as_micros();
|
|
|
print_params_summary(&spiral_params);
|
|
@@ -302,13 +340,13 @@ fn main() {
|
|
|
let blind: DbEntry = 20;
|
|
|
let encdb_start = Instant::now();
|
|
|
db.rotate_right(idx_offset);
|
|
|
- let encdb = encdb_xor_keys(&db, &dbkeys, r, blind);
|
|
|
+ let encdb = encdb_xor_keys(&db, &dbkeys, r, blind, num_threads);
|
|
|
let encdb_us = encdb_start.elapsed().as_micros();
|
|
|
println!("Server encrypt database {} µs", encdb_us);
|
|
|
|
|
|
// Load the encrypted database into Spiral
|
|
|
let sps_loaddb_start = Instant::now();
|
|
|
- let sps_db = load_db_from_seek(&spiral_params, &mut Cursor::new(encdb));
|
|
|
+ let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
|
|
|
let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
|
|
|
println!("Server load database {} µs", sps_loaddb_us);
|
|
|
|