6 Commits 62601eb0e3 ... 4dfd2d9390

Author SHA1 Message Date
  Ian Goldberg 4dfd2d9390 Update the README 1 year ago
  Ian Goldberg 4bf304da4f Switch from crossbeam to rayon 1 year ago
  Ian Goldberg 4126f187ef Multithreaded db encryption 1 year ago
  Ian Goldberg 101d09589e Swap the halves of the item index in load_db_from_slice_mt so that the polynomials based on the items are written to the AlignedMemoryMT64 more sequentially 1 year ago
  Ian Goldberg d870e1c058 Multithreaded database loading 1 year ago
  Ian Goldberg f7d917ca8e Work towards multithreading the loading of the db 1 year ago
5 changed files with 317 additions and 36 deletions
  1. 1 0
      Cargo.toml
  2. 21 17
      README.md
  3. 119 0
      src/aligned_memory_mt.rs
  4. 57 19
      src/main.rs
  5. 119 0
      src/spiral_mt.rs

+ 1 - 0
Cargo.toml

@@ -17,6 +17,7 @@ lazy_static = "1"
 sha2 = "0.9"
 subtle = { package = "subtle-ng", version = "2.4" }
 spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157" }
+rayon = "1.5"
 
 [features]
 default = ["u64_backend"]

+ 21 - 17
README.md

@@ -57,38 +57,42 @@ We slightly optimize the protocol in that instead of the client sending both β0
 
 ## Running the code
 
+To build the code:
+
+`RUSTFLAGS="-C target-cpu=native" cargo build --release`
+
 To run the code:
 
-`cargo run --release 20`
+`./target/release/spiral-spir 20 4`
 
-Where `20` is the value of r (that is, the database will have N=2^20 entries).  Each entry is 8 bytes.  There are three phases of execution: a one-time Spiral public key generation (this only has to be done once, regardless of how many SPIR queries you do), a preprocessing phase per SPIR query (this can be done _before_ knowing the contents of the database on the server side, or the desired index on the client side), and the runtime phase per SPIR query (once those two things are known).
+Where `20` is the value of r (that is, the database will have N=2^20 entries), and 4 is the number of threads to use (defaults to 1).  Each entry is 8 bytes.  There are three phases of execution: a one-time Spiral public key generation (this only has to be done once, regardless of how many SPIR queries you do), a preprocessing phase per SPIR query (this can be done _before_ knowing the contents of the database on the server side, or the desired index on the client side), and the runtime phase per SPIR query (once those two things are known).
 
-A sample output (for r=20):
+A sample output (for r=20, 4 threads):
 
 ```
 ===== ONE-TIME SETUP =====
 
 Using a 2048 x 4096 byte database (8388608 bytes total)
-OT one-time setup: 3064 µs
-Spiral client one-time setup: 224935 µs, 10878976 bytes
+OT one-time setup: 3637 µs
+Spiral client one-time setup: 157144 µs, 10878976 bytes
 
 ===== PREPROCESSING =====
 
-rand_idx = 425489 rand_pir_idx = 831
-Spiral query: 691 µs, 32768 bytes
-key OT query in 365 µs, 640 bytes
-key OT serve in 1860 µs, 1280 bytes
-key OT receive in 1148 µs
+rand_idx = 146516 rand_pir_idx = 286
+Spiral query: 457 µs, 32768 bytes
+key OT query in 324 µs, 640 bytes
+key OT serve in 1653 µs, 1280 bytes
+key OT receive in 1029 µs
 
 ===== RUNTIME =====
 
 Send to server 8 bytes
-Server encrypt database 89036 µs
-Server load database 974249 µs
-expansion (took 166596 us).
-Server compute response 344613 µs, 14336 bytes (*including* the above expansion time)
-Client decode response 888 µs
-index = 948810, Response = 9488100948830
+Server encrypt database 29738 µs
+Server load database 248825 µs
+expansion (took 101920 us).
+Server compute response 181293 µs, 14336 bytes (*including* the above expansion time)
+Client decode response 790 µs
+index = 919657, Response = 9196570919677
 ```
 
-The various lines show the amount of compute time taken and the amount of data transferred between the client and the server.  The last line shows the random index that was looked up, and the database value the client retrieved.  The value for index i should be (10000001*i+20).
+The various lines show the amount of compute time taken and the amount of data transferred between the client and the server.  The last line shows the random index that was looked up, and the database value the client retrieved.  The value for index i should be (10000001*i+20).

+ 119 - 0
src/aligned_memory_mt.rs

@@ -0,0 +1,119 @@
+/* This file is almost identical to the aligned_memory.rs file in the
+   spiral-rs crate.  The name is modified from AlignedMemory to
+   AlignedMemoryMT, and there is one (unsafe!) change to the API:
+
+    pub unsafe fn as_mut_ptr(&mut self) -> *mut u64
+
+   has changed to:
+
+    pub unsafe fn as_mut_ptr(&self) -> *mut u64
+
+   The reason for this change is explicitly to allow multiple threads to
+   *write* into the memory pool concurrently, with the caveat that the
+   threads *must not* try to write into the same memory location.  In
+   Spiral, each polynomial created from the database ends up scattered
+   into noncontiguous words of memory, but any one word still only comes
+   from one polynomial.  So with this change, different threads can read
+   different parts of the database to produce different polynomials, and
+   write those polynomials into the same memory pool (but *not* the same
+   memory locations) at the same time.
+*/
+
+use std::{
+    alloc::{alloc_zeroed, dealloc, Layout},
+    mem::size_of,
+    ops::{Index, IndexMut},
+    slice::{from_raw_parts, from_raw_parts_mut},
+};
+
+const ALIGN_SIMD: usize = 64; // enough to support AVX-512
+pub type AlignedMemoryMT64 = AlignedMemoryMT<ALIGN_SIMD>;
+
+pub struct AlignedMemoryMT<const ALIGN: usize> {
+    p: *mut u64,
+    sz_u64: usize,
+    layout: Layout,
+}
+
+impl<const ALIGN: usize> AlignedMemoryMT<{ ALIGN }> {
+    pub fn new(sz_u64: usize) -> Self {
+        let sz_bytes = sz_u64 * size_of::<u64>();
+        let layout = Layout::from_size_align(sz_bytes, ALIGN).unwrap();
+
+        let ptr;
+        unsafe {
+            ptr = alloc_zeroed(layout);
+        }
+
+        Self {
+            p: ptr as *mut u64,
+            sz_u64,
+            layout,
+        }
+    }
+
+    // pub fn from(data: &[u8]) -> Self {
+    //     let sz_u64 = (data.len() + size_of::<u64>() - 1) / size_of::<u64>();
+    //     let mut out = Self::new(sz_u64);
+    //     let out_slice = out.as_mut_slice();
+    //     let mut i = 0;
+    //     for chunk in data.chunks(size_of::<u64>()) {
+    //         out_slice[i] = u64::from_ne_bytes(chunk);
+    //         i += 1;
+    //     }
+    //     out
+    // }
+
+    pub fn as_slice(&self) -> &[u64] {
+        unsafe { from_raw_parts(self.p, self.sz_u64) }
+    }
+
+    pub fn as_mut_slice(&mut self) -> &mut [u64] {
+        unsafe { from_raw_parts_mut(self.p, self.sz_u64) }
+    }
+
+    pub unsafe fn as_ptr(&self) -> *const u64 {
+        self.p
+    }
+
+    pub unsafe fn as_mut_ptr(&self) -> *mut u64 {
+        self.p
+    }
+
+    pub fn len(&self) -> usize {
+        self.sz_u64
+    }
+}
+
+unsafe impl<const ALIGN: usize> Send for AlignedMemoryMT<{ ALIGN }> {}
+unsafe impl<const ALIGN: usize> Sync for AlignedMemoryMT<{ ALIGN }> {}
+
+impl<const ALIGN: usize> Drop for AlignedMemoryMT<{ ALIGN }> {
+    fn drop(&mut self) {
+        unsafe {
+            dealloc(self.p as *mut u8, self.layout);
+        }
+    }
+}
+
+impl<const ALIGN: usize> Index<usize> for AlignedMemoryMT<{ ALIGN }> {
+    type Output = u64;
+
+    fn index(&self, index: usize) -> &Self::Output {
+        &self.as_slice()[index]
+    }
+}
+
+impl<const ALIGN: usize> IndexMut<usize> for AlignedMemoryMT<{ ALIGN }> {
+    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
+        &mut self.as_mut_slice()[index]
+    }
+}
+
+impl<const ALIGN: usize> Clone for AlignedMemoryMT<{ ALIGN }> {
+    fn clone(&self) -> Self {
+        let mut out = Self::new(self.sz_u64);
+        out.as_mut_slice().copy_from_slice(self.as_slice());
+        out
+    }
+}

+ 57 - 19
src/main.rs

@@ -2,13 +2,14 @@
 // lowercase letters
 #![allow(non_snake_case)]
 
+pub mod aligned_memory_mt;
 pub mod params;
+pub mod spiral_mt;
 
 use aes::cipher::{BlockEncrypt, KeyInit};
 use aes::Aes128Enc;
 use aes::Block;
 use std::env;
-use std::io::Cursor;
 use std::mem;
 use std::time::Instant;
 use subtle::Choice;
@@ -26,10 +27,15 @@ use curve25519_dalek::ristretto::RistrettoBasepointTable;
 use curve25519_dalek::ristretto::RistrettoPoint;
 use curve25519_dalek::scalar::Scalar;
 
+use rayon::scope;
+use rayon::ThreadPoolBuilder;
+
 use spiral_rs::client::*;
 use spiral_rs::params::*;
 use spiral_rs::server::*;
 
+use crate::spiral_mt::*;
+
 use lazy_static::lazy_static;
 
 type DbEntry = u64;
@@ -60,23 +66,50 @@ fn xor16(outar: &mut Block, inar: &[u8; 16]) {
 // as the XOR of r of the provided keys, one from each pair, according
 // to the bits of the element number.  Outputs a byte vector containing
 // the encrypted database.
-fn encdb_xor_keys(db: &[DbEntry], keys: &[[u8; 16]], r: usize, blind: DbEntry) -> Vec<u8> {
+fn encdb_xor_keys(
+    db: &[DbEntry],
+    keys: &[[u8; 16]],
+    r: usize,
+    blind: DbEntry,
+    num_threads: usize,
+) -> Vec<u8> {
     let num_records: usize = 1 << r;
     let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
-    for j in 0..num_records {
-        let mut key = Block::from([0u8; 16]);
-        for i in 0..r {
-            let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
-            xor16(&mut key, &keys[2 * i + bit]);
+    ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
+    scope(|s| {
+        let mut record_thread_start = 0usize;
+        let records_per_thread_base = num_records / num_threads;
+        let records_per_thread_extra = num_records % num_threads;
+        let mut retslice = ret.as_mut_slice();
+        for thr in 0..num_threads {
+            let records_this_thread =
+                records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
+            let record_thread_end = record_thread_start + records_this_thread;
+            let (thread_ret, retslice_) =
+                retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
+            retslice = retslice_;
+            s.spawn(move |_| {
+                let mut offset = 0usize;
+                for j in record_thread_start..record_thread_end {
+                    let mut key = Block::from([0u8; 16]);
+                    for i in 0..r {
+                        let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
+                        xor16(&mut key, &keys[2 * i + bit]);
+                    }
+                    let aes = Aes128Enc::new(&key);
+                    let mut block = Block::from([0u8; 16]);
+                    block[0..8].copy_from_slice(&j.to_le_bytes());
+                    aes.encrypt_block(&mut block);
+                    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
+                    let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
+                    thread_ret[offset..offset + mem::size_of::<DbEntry>()]
+                        .copy_from_slice(&encelem.to_le_bytes());
+                    offset += mem::size_of::<DbEntry>();
+                }
+            });
+            record_thread_start = record_thread_end;
         }
-        let aes = Aes128Enc::new(&key);
-        let mut block = Block::from([0u8; 16]);
-        block[0..8].copy_from_slice(&j.to_le_bytes());
-        aes.encrypt_block(&mut block);
-        let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
-        let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
-        ret.extend(encelem.to_le_bytes());
-    }
+    });
     ret
 }
 
@@ -214,11 +247,15 @@ fn print_params_summary(params: &Params) {
 
 fn main() {
     let args: Vec<String> = env::args().collect();
-    if args.len() != 2 {
-        println!("Usage: {} r\nr = log_2(num_records)", args[0]);
+    if args.len() != 2 && args.len() != 3 {
+        println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
         return;
     }
     let r: usize = args[1].parse().unwrap();
+    let mut num_threads = 1usize;
+    if args.len() == 3 {
+        num_threads = args[2].parse().unwrap();
+    }
     let num_records = 1 << r;
 
     println!("===== ONE-TIME SETUP =====\n");
@@ -226,6 +263,7 @@ fn main() {
     let otsetup_start = Instant::now();
     let spiral_params = params::get_spiral_params(r);
     let mut rng = rand::thread_rng();
+    ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
     one_time_setup();
     let otsetup_us = otsetup_start.elapsed().as_micros();
     print_params_summary(&spiral_params);
@@ -302,13 +340,13 @@ fn main() {
     let blind: DbEntry = 20;
     let encdb_start = Instant::now();
     db.rotate_right(idx_offset);
-    let encdb = encdb_xor_keys(&db, &dbkeys, r, blind);
+    let encdb = encdb_xor_keys(&db, &dbkeys, r, blind, num_threads);
     let encdb_us = encdb_start.elapsed().as_micros();
     println!("Server encrypt database {} µs", encdb_us);
 
     // Load the encrypted database into Spiral
     let sps_loaddb_start = Instant::now();
-    let sps_db = load_db_from_seek(&spiral_params, &mut Cursor::new(encdb));
+    let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
     let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
     println!("Server load database {} µs", sps_loaddb_us);
 

+ 119 - 0
src/spiral_mt.rs

@@ -0,0 +1,119 @@
+use spiral_rs::arith::*;
+use spiral_rs::params::*;
+use spiral_rs::poly::*;
+use spiral_rs::server::*;
+use spiral_rs::util::*;
+
+use rayon::scope;
+
+use crate::aligned_memory_mt::*;
+
+pub fn load_item_from_slice<'a>(
+    params: &'a Params,
+    slice: &[u8],
+    instance: usize,
+    trial: usize,
+    item_idx: usize,
+) -> PolyMatrixRaw<'a> {
+    let db_item_size = params.db_item_size;
+    let instances = params.instances;
+    let trials = params.n * params.n;
+
+    let chunks = instances * trials;
+    let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
+    let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
+    let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+    assert!(modp_words_per_chunk <= params.poly_len);
+
+    let idx_item_in_file = item_idx * db_item_size;
+    let idx_chunk = instance * trials + trial;
+    let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
+
+    let mut out = PolyMatrixRaw::zero(params, 1, 1);
+
+    let modp_words_read = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+    assert!(modp_words_read <= params.poly_len);
+
+    for i in 0..modp_words_read {
+        out.data[i] = read_arbitrary_bits(
+            &slice[idx_poly_in_file..idx_poly_in_file + bytes_per_chunk],
+            i * logp,
+            logp,
+        );
+        assert!(out.data[i] <= params.pt_modulus);
+    }
+
+    out
+}
+
+pub fn load_db_from_slice_mt(
+    params: &Params,
+    slice: &[u8],
+    num_threads: usize,
+) -> AlignedMemoryMT64 {
+    let instances = params.instances;
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = instances * trials * num_items * params.poly_len;
+    let v: AlignedMemoryMT64 = AlignedMemoryMT64::new(db_size_words);
+
+    for instance in 0..instances {
+        for trial in 0..trials {
+            scope(|s| {
+                let mut item_thread_start = 0usize;
+                let items_per_thread_base = num_items / num_threads;
+                let items_per_thread_extra = num_items % num_threads;
+                for thr in 0..num_threads {
+                    let items_this_thread =
+                        items_per_thread_base + if thr < items_per_thread_extra { 1 } else { 0 };
+                    let item_thread_end = item_thread_start + items_this_thread;
+                    let v = &v;
+                    s.spawn(move |_| {
+                        let vptr = unsafe { v.as_mut_ptr() };
+                        for i in item_thread_start..item_thread_end {
+                            // Swap the halves of the item index so that
+                            // the polynomials based on the items are
+                            // written to the AlignedMemoryMT64 more
+                            // sequentially
+                            let ii = i / dim0;
+                            let j = i % dim0;
+                            let db_idx = j * num_per + ii;
+
+                            let mut db_item =
+                                load_item_from_slice(&params, slice, instance, trial, db_idx);
+                            // db_item.reduce_mod(params.pt_modulus);
+
+                            for z in 0..params.poly_len {
+                                db_item.data[z] = recenter_mod(
+                                    db_item.data[z],
+                                    params.pt_modulus,
+                                    params.modulus,
+                                );
+                            }
+
+                            let db_item_ntt = db_item.ntt();
+                            for z in 0..params.poly_len {
+                                let idx_dst = calc_index(
+                                    &[instance, trial, z, ii, j],
+                                    &[instances, trials, params.poly_len, num_per, dim0],
+                                );
+
+                                unsafe {
+                                    vptr.offset(idx_dst as isize).write(
+                                        db_item_ntt.data[z]
+                                            | (db_item_ntt.data[params.poly_len + z]
+                                                << PACKED_OFFSET_2),
+                                    );
+                                }
+                            }
+                        }
+                    });
+                    item_thread_start = item_thread_end;
+                }
+            });
+        }
+    }
+    v
+}