Browse Source

Multithreaded database loading

Ian Goldberg 1 year ago
parent
commit
d870e1c058
3 changed files with 175 additions and 27 deletions
  1. 119 0
      src/aligned_memory_mt.rs
  2. 8 3
      src/main.rs
  3. 48 24
      src/spiral_mt.rs

+ 119 - 0
src/aligned_memory_mt.rs

@@ -0,0 +1,119 @@
+/* This file is almost identical to the aligned_memory.rs file in the
+   spiral-rs crate.  The name is modified from AlignedMemory to
+   AlignedMemoryMT, and there is one (unsafe!) change to the API:
+
+    pub unsafe fn as_mut_ptr(&mut self) -> *mut u64
+
+   has changed to:
+
+    pub unsafe fn as_mut_ptr(&self) -> *mut u64
+
+   The reason for this change is explicitly to allow multiple threads to
+   *write* into the memory pool concurrently, with the caveat that the
+   threads *must not* try to write into the same memory location.  In
+   Spiral, each polynomial created from the database ends up scattered
+   into noncontiguous words of memory, but any one word still only comes
+   from one polynomial.  So with this change, different threads can read
+   different parts of the database to produce different polynomials, and
+   write those polynomials into the same memory pool (but *not* the same
+   memory locations) at the same time.
+*/
+
+use std::{
+    alloc::{alloc_zeroed, dealloc, Layout},
+    mem::size_of,
+    ops::{Index, IndexMut},
+    slice::{from_raw_parts, from_raw_parts_mut},
+};
+
+const ALIGN_SIMD: usize = 64; // enough to support AVX-512
+pub type AlignedMemoryMT64 = AlignedMemoryMT<ALIGN_SIMD>;
+
+pub struct AlignedMemoryMT<const ALIGN: usize> {
+    p: *mut u64,
+    sz_u64: usize,
+    layout: Layout,
+}
+
+impl<const ALIGN: usize> AlignedMemoryMT<{ ALIGN }> {
+    pub fn new(sz_u64: usize) -> Self {
+        let sz_bytes = sz_u64 * size_of::<u64>();
+        let layout = Layout::from_size_align(sz_bytes, ALIGN).unwrap();
+
+        let ptr;
+        unsafe {
+            ptr = alloc_zeroed(layout);
+        }
+
+        Self {
+            p: ptr as *mut u64,
+            sz_u64,
+            layout,
+        }
+    }
+
+    // pub fn from(data: &[u8]) -> Self {
+    //     let sz_u64 = (data.len() + size_of::<u64>() - 1) / size_of::<u64>();
+    //     let mut out = Self::new(sz_u64);
+    //     let out_slice = out.as_mut_slice();
+    //     let mut i = 0;
+    //     for chunk in data.chunks(size_of::<u64>()) {
+    //         out_slice[i] = u64::from_ne_bytes(chunk);
+    //         i += 1;
+    //     }
+    //     out
+    // }
+
+    pub fn as_slice(&self) -> &[u64] {
+        unsafe { from_raw_parts(self.p, self.sz_u64) }
+    }
+
+    pub fn as_mut_slice(&mut self) -> &mut [u64] {
+        unsafe { from_raw_parts_mut(self.p, self.sz_u64) }
+    }
+
+    pub unsafe fn as_ptr(&self) -> *const u64 {
+        self.p
+    }
+
+    pub unsafe fn as_mut_ptr(&self) -> *mut u64 {
+        self.p
+    }
+
+    pub fn len(&self) -> usize {
+        self.sz_u64
+    }
+}
+
+unsafe impl<const ALIGN: usize> Send for AlignedMemoryMT<{ ALIGN }> {}
+unsafe impl<const ALIGN: usize> Sync for AlignedMemoryMT<{ ALIGN }> {}
+
+impl<const ALIGN: usize> Drop for AlignedMemoryMT<{ ALIGN }> {
+    fn drop(&mut self) {
+        unsafe {
+            dealloc(self.p as *mut u8, self.layout);
+        }
+    }
+}
+
+impl<const ALIGN: usize> Index<usize> for AlignedMemoryMT<{ ALIGN }> {
+    type Output = u64;
+
+    fn index(&self, index: usize) -> &Self::Output {
+        &self.as_slice()[index]
+    }
+}
+
+impl<const ALIGN: usize> IndexMut<usize> for AlignedMemoryMT<{ ALIGN }> {
+    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
+        &mut self.as_mut_slice()[index]
+    }
+}
+
+impl<const ALIGN: usize> Clone for AlignedMemoryMT<{ ALIGN }> {
+    fn clone(&self) -> Self {
+        let mut out = Self::new(self.sz_u64);
+        out.as_mut_slice().copy_from_slice(self.as_slice());
+        out
+    }
+}

+ 8 - 3
src/main.rs

@@ -2,6 +2,7 @@
 // lowercase letters
 #![allow(non_snake_case)]
 
+pub mod aligned_memory_mt;
 pub mod params;
 pub mod spiral_mt;
 
@@ -216,11 +217,15 @@ fn print_params_summary(params: &Params) {
 
 fn main() {
     let args: Vec<String> = env::args().collect();
-    if args.len() != 2 {
-        println!("Usage: {} r\nr = log_2(num_records)", args[0]);
+    if args.len() != 2 && args.len() != 3 {
+        println!("Usage: {} r [num_threads]\nr = log_2(num_records)", args[0]);
         return;
     }
     let r: usize = args[1].parse().unwrap();
+    let mut num_threads = 1usize;
+    if args.len() == 3 {
+        num_threads = args[2].parse().unwrap();
+    }
     let num_records = 1 << r;
 
     println!("===== ONE-TIME SETUP =====\n");
@@ -310,7 +315,7 @@ fn main() {
 
     // Load the encrypted database into Spiral
     let sps_loaddb_start = Instant::now();
-    let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, 1);
+    let sps_db = load_db_from_slice_mt(&spiral_params, &encdb, num_threads);
     let sps_loaddb_us = sps_loaddb_start.elapsed().as_micros();
     println!("Server load database {} µs", sps_loaddb_us);
 

+ 48 - 24
src/spiral_mt.rs

@@ -1,4 +1,3 @@
-use spiral_rs::aligned_memory::*;
 use spiral_rs::arith::*;
 use spiral_rs::params::*;
 use spiral_rs::poly::*;
@@ -7,6 +6,8 @@ use spiral_rs::util::*;
 
 use crossbeam::thread;
 
+use crate::aligned_memory_mt::*;
+
 pub fn load_item_from_slice<'a>(
     params: &'a Params,
     slice: &[u8],
@@ -45,44 +46,67 @@ pub fn load_item_from_slice<'a>(
     out
 }
 
-pub fn load_db_from_slice_mt(params: &Params, slice: &[u8], numthreads: usize) -> AlignedMemory64 {
+pub fn load_db_from_slice_mt(
+    params: &Params,
+    slice: &[u8],
+    num_threads: usize,
+) -> AlignedMemoryMT64 {
     let instances = params.instances;
     let trials = params.n * params.n;
     let dim0 = 1 << params.db_dim_1;
     let num_per = 1 << params.db_dim_2;
     let num_items = dim0 * num_per;
     let db_size_words = instances * trials * num_items * params.poly_len;
-    let mut v = AlignedMemory64::new(db_size_words);
+    let v: AlignedMemoryMT64 = AlignedMemoryMT64::new(db_size_words);
 
     for instance in 0..instances {
         for trial in 0..trials {
-            let vslice = v.as_mut_slice();
             thread::scope(|s| {
-                s.spawn(|_| {
-                    for i in 0..num_items {
-                        let ii = i % num_per;
-                        let j = i / num_per;
+                let mut item_thread_start = 0usize;
+                let items_per_thread_base = num_items / num_threads;
+                let items_per_thread_extra = num_items % num_threads;
+                for thr in 0..num_threads {
+                    let items_this_thread =
+                        items_per_thread_base + if thr < items_per_thread_extra { 1 } else { 0 };
+                    let item_thread_end = item_thread_start + items_this_thread;
+                    let v = &v;
+                    s.spawn(move |_| {
+                        let vptr = unsafe { v.as_mut_ptr() };
+                        for i in item_thread_start..item_thread_end {
+                            let ii = i % num_per;
+                            let j = i / num_per;
 
-                        let mut db_item = load_item_from_slice(&params, slice, instance, trial, i);
-                        // db_item.reduce_mod(params.pt_modulus);
+                            let mut db_item =
+                                load_item_from_slice(&params, slice, instance, trial, i);
+                            // db_item.reduce_mod(params.pt_modulus);
 
-                        for z in 0..params.poly_len {
-                            db_item.data[z] =
-                                recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
-                        }
+                            for z in 0..params.poly_len {
+                                db_item.data[z] = recenter_mod(
+                                    db_item.data[z],
+                                    params.pt_modulus,
+                                    params.modulus,
+                                );
+                            }
 
-                        let db_item_ntt = db_item.ntt();
-                        for z in 0..params.poly_len {
-                            let idx_dst = calc_index(
-                                &[instance, trial, z, ii, j],
-                                &[instances, trials, params.poly_len, num_per, dim0],
-                            );
+                            let db_item_ntt = db_item.ntt();
+                            for z in 0..params.poly_len {
+                                let idx_dst = calc_index(
+                                    &[instance, trial, z, ii, j],
+                                    &[instances, trials, params.poly_len, num_per, dim0],
+                                );
 
-                            vslice[idx_dst] = db_item_ntt.data[z]
-                                | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
+                                unsafe {
+                                    vptr.offset(idx_dst as isize).write(
+                                        db_item_ntt.data[z]
+                                            | (db_item_ntt.data[params.poly_len + z]
+                                                << PACKED_OFFSET_2),
+                                    );
+                                }
+                            }
                         }
-                    }
-                });
+                    });
+                    item_thread_start = item_thread_end;
+                }
             })
             .unwrap();
         }