|
@@ -1,4 +1,3 @@
|
|
|
-use spiral_rs::aligned_memory::*;
|
|
|
use spiral_rs::arith::*;
|
|
|
use spiral_rs::params::*;
|
|
|
use spiral_rs::poly::*;
|
|
@@ -7,6 +6,8 @@ use spiral_rs::util::*;
|
|
|
|
|
|
use crossbeam::thread;
|
|
|
|
|
|
+use crate::aligned_memory_mt::*;
|
|
|
+
|
|
|
pub fn load_item_from_slice<'a>(
|
|
|
params: &'a Params,
|
|
|
slice: &[u8],
|
|
@@ -45,44 +46,67 @@ pub fn load_item_from_slice<'a>(
|
|
|
out
|
|
|
}
|
|
|
|
|
|
-pub fn load_db_from_slice_mt(params: &Params, slice: &[u8], numthreads: usize) -> AlignedMemory64 {
|
|
|
+pub fn load_db_from_slice_mt(
|
|
|
+ params: &Params,
|
|
|
+ slice: &[u8],
|
|
|
+ num_threads: usize,
|
|
|
+) -> AlignedMemoryMT64 {
|
|
|
let instances = params.instances;
|
|
|
let trials = params.n * params.n;
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
|
let num_per = 1 << params.db_dim_2;
|
|
|
let num_items = dim0 * num_per;
|
|
|
let db_size_words = instances * trials * num_items * params.poly_len;
|
|
|
- let mut v = AlignedMemory64::new(db_size_words);
|
|
|
+ let v: AlignedMemoryMT64 = AlignedMemoryMT64::new(db_size_words);
|
|
|
|
|
|
for instance in 0..instances {
|
|
|
for trial in 0..trials {
|
|
|
- let vslice = v.as_mut_slice();
|
|
|
thread::scope(|s| {
|
|
|
- s.spawn(|_| {
|
|
|
- for i in 0..num_items {
|
|
|
- let ii = i % num_per;
|
|
|
- let j = i / num_per;
|
|
|
+ let mut item_thread_start = 0usize;
|
|
|
+ let items_per_thread_base = num_items / num_threads;
|
|
|
+ let items_per_thread_extra = num_items % num_threads;
|
|
|
+ for thr in 0..num_threads {
|
|
|
+ let items_this_thread =
|
|
|
+ items_per_thread_base + if thr < items_per_thread_extra { 1 } else { 0 };
|
|
|
+ let item_thread_end = item_thread_start + items_this_thread;
|
|
|
+ let v = &v;
|
|
|
+ s.spawn(move |_| {
|
|
|
+ let vptr = unsafe { v.as_mut_ptr() };
|
|
|
+ for i in item_thread_start..item_thread_end {
|
|
|
+ let ii = i % num_per;
|
|
|
+ let j = i / num_per;
|
|
|
|
|
|
- let mut db_item = load_item_from_slice(¶ms, slice, instance, trial, i);
|
|
|
- // db_item.reduce_mod(params.pt_modulus);
|
|
|
+ let mut db_item =
|
|
|
+ load_item_from_slice(¶ms, slice, instance, trial, i);
|
|
|
+ // db_item.reduce_mod(params.pt_modulus);
|
|
|
|
|
|
- for z in 0..params.poly_len {
|
|
|
- db_item.data[z] =
|
|
|
- recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
- }
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ db_item.data[z] = recenter_mod(
|
|
|
+ db_item.data[z],
|
|
|
+ params.pt_modulus,
|
|
|
+ params.modulus,
|
|
|
+ );
|
|
|
+ }
|
|
|
|
|
|
- let db_item_ntt = db_item.ntt();
|
|
|
- for z in 0..params.poly_len {
|
|
|
- let idx_dst = calc_index(
|
|
|
- &[instance, trial, z, ii, j],
|
|
|
- &[instances, trials, params.poly_len, num_per, dim0],
|
|
|
- );
|
|
|
+ let db_item_ntt = db_item.ntt();
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ let idx_dst = calc_index(
|
|
|
+ &[instance, trial, z, ii, j],
|
|
|
+ &[instances, trials, params.poly_len, num_per, dim0],
|
|
|
+ );
|
|
|
|
|
|
- vslice[idx_dst] = db_item_ntt.data[z]
|
|
|
- | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
|
|
|
+ unsafe {
|
|
|
+ vptr.offset(idx_dst as isize).write(
|
|
|
+ db_item_ntt.data[z]
|
|
|
+ | (db_item_ntt.data[params.poly_len + z]
|
|
|
+ << PACKED_OFFSET_2),
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
- });
|
|
|
+ });
|
|
|
+ item_thread_start = item_thread_end;
|
|
|
+ }
|
|
|
})
|
|
|
.unwrap();
|
|
|
}
|