Browse Source

Work towards multithreading the loading of the db

Ian Goldberg 1 year ago
parent
commit
f7d917ca8e
3 changed files with 96 additions and 2 deletions
  1. 1 0
      Cargo.toml
  2. 4 2
      src/main.rs
  3. 91 0
      src/spiral_mt.rs

+ 1 - 0
Cargo.toml

@@ -17,6 +17,7 @@ lazy_static = "1"
 sha2 = "0.9"
 subtle = { package = "subtle-ng", version = "2.4" }
 spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157" }
+crossbeam = "0.8"
 
 [features]
 default = ["u64_backend"]

+ 4 - 2
src/main.rs

@@ -3,12 +3,12 @@
 #![allow(non_snake_case)]
 
 pub mod params;
+pub mod spiral_mt;
 
 use aes::cipher::{BlockEncrypt, KeyInit};
 use aes::Aes128Enc;
 use aes::Block;
 use std::env;
-use std::io::Cursor;
 use std::mem;
 use std::time::Instant;
 use subtle::Choice;
@@ -30,6 +30,8 @@ use spiral_rs::client::*;
 use spiral_rs::params::*;
 use spiral_rs::server::*;
 
+use crate::spiral_mt::*;
+
 use lazy_static::lazy_static;
 
 type DbEntry = u64;
@@ -308,7 +310,7 @@ fn main() {
 
     // Load the encrypted database into Spiral
     let sps_loaddb_start = Instant::now();
-    let sps_db = load_db_from_seek(&spiral_params, &mut Cursor::new(encdb));
+    let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, 1);
     let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
     println!("Server load database {} µs", sps_loaddb_us);
 

+ 91 - 0
src/spiral_mt.rs

@@ -0,0 +1,91 @@
+use spiral_rs::aligned_memory::*;
+use spiral_rs::arith::*;
+use spiral_rs::params::*;
+use spiral_rs::poly::*;
+use spiral_rs::server::*;
+use spiral_rs::util::*;
+
+use crossbeam::thread;
+
+pub fn load_item_from_slice<'a>(
+    params: &'a Params,
+    slice: &[u8],
+    instance: usize,
+    trial: usize,
+    item_idx: usize,
+) -> PolyMatrixRaw<'a> {
+    let db_item_size = params.db_item_size;
+    let instances = params.instances;
+    let trials = params.n * params.n;
+
+    let chunks = instances * trials;
+    let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
+    let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
+    let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+    assert!(modp_words_per_chunk <= params.poly_len);
+
+    let idx_item_in_file = item_idx * db_item_size;
+    let idx_chunk = instance * trials + trial;
+    let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
+
+    let mut out = PolyMatrixRaw::zero(params, 1, 1);
+
+    let modp_words_read = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+    assert!(modp_words_read <= params.poly_len);
+
+    for i in 0..modp_words_read {
+        out.data[i] = read_arbitrary_bits(
+            &slice[idx_poly_in_file..idx_poly_in_file + bytes_per_chunk],
+            i * logp,
+            logp,
+        );
+        assert!(out.data[i] <= params.pt_modulus);
+    }
+
+    out
+}
+
+pub fn load_db_from_slice_mt(params: &Params, slice: &[u8], numthreads: usize) -> AlignedMemory64 {
+    let instances = params.instances;
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = instances * trials * num_items * params.poly_len;
+    let mut v = AlignedMemory64::new(db_size_words);
+
+    for instance in 0..instances {
+        for trial in 0..trials {
+            let vslice = v.as_mut_slice();
+            thread::scope(|s| {
+                s.spawn(|_| {
+                    for i in 0..num_items {
+                        let ii = i % num_per;
+                        let j = i / num_per;
+
+                        let mut db_item = load_item_from_slice(&params, slice, instance, trial, i);
+                        // db_item.reduce_mod(params.pt_modulus);
+
+                        for z in 0..params.poly_len {
+                            db_item.data[z] =
+                                recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
+                        }
+
+                        let db_item_ntt = db_item.ntt();
+                        for z in 0..params.poly_len {
+                            let idx_dst = calc_index(
+                                &[instance, trial, z, ii, j],
+                                &[instances, trials, params.poly_len, num_per, dim0],
+                            );
+
+                            vslice[idx_dst] = db_item_ntt.data[z]
+                                | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
+                        }
+                    }
+                });
+            })
+            .unwrap();
+        }
+    }
+    v
+}