123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- use spiral_rs::aligned_memory::*;
- use spiral_rs::arith::*;
- use spiral_rs::params::*;
- use spiral_rs::poly::*;
- use spiral_rs::server::*;
- use spiral_rs::util::*;
- use rayon::scope;
- pub fn load_item_from_slice<'a>(
- params: &'a Params,
- slice: &[u8],
- instance: usize,
- trial: usize,
- item_idx: usize,
- ) -> PolyMatrixRaw<'a> {
- let db_item_size = params.db_item_size;
- let instances = params.instances;
- let trials = params.n * params.n;
- let chunks = instances * trials;
- let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
- let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
- let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
- assert!(modp_words_per_chunk <= params.poly_len);
- let idx_item_in_file = item_idx * db_item_size;
- let idx_chunk = instance * trials + trial;
- let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
- let mut out = PolyMatrixRaw::zero(params, 1, 1);
- let modp_words_read = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
- assert!(modp_words_read <= params.poly_len);
- for i in 0..modp_words_read {
- out.data[i] = read_arbitrary_bits(
- &slice[idx_poly_in_file..idx_poly_in_file + bytes_per_chunk],
- i * logp,
- logp,
- );
- assert!(out.data[i] <= params.pt_modulus);
- }
- out
- }
- pub fn load_db_from_slice_mt(params: &Params, slice: &[u8], num_threads: usize) -> AlignedMemory64 {
- let instances = params.instances;
- let trials = params.n * params.n;
- let dim0 = 1 << params.db_dim_1;
- let num_per = 1 << params.db_dim_2;
- let num_items = dim0 * num_per;
- let db_size_words = instances * trials * num_items * params.poly_len;
- let mut v: AlignedMemory64 = AlignedMemory64::new(db_size_words);
- // Get a pointer to the memory pool of the AlignedMemory64. We
- // treat it as a usize explicitly so we can pass the same pointer to
- // multiple threads, each of which will cast it to a *mut u64, in
- // order to *write* into the memory pool concurrently. There is a
- // caveat that the threads *must not* try to write into the same
- // memory location. In Spiral, each polynomial created from the
- // database ends up scattered into noncontiguous words of memory,
- // but any one word still only comes from one polynomial. So with
- // this mechanism, different threads can read different parts of the
- // database to produce different polynomials, and write those
- // polynomials into the same memory pool (but *not* the same memory
- // locations) at the same time.
- let vptrusize = unsafe { v.as_mut_ptr() as usize };
- for instance in 0..instances {
- for trial in 0..trials {
- scope(|s| {
- let mut item_thread_start = 0usize;
- let items_per_thread_base = num_items / num_threads;
- let items_per_thread_extra = num_items % num_threads;
- for thr in 0..num_threads {
- let items_this_thread =
- items_per_thread_base + if thr < items_per_thread_extra { 1 } else { 0 };
- let item_thread_end = item_thread_start + items_this_thread;
- s.spawn(move |_| {
- let vptr = vptrusize as *mut u64;
- for i in item_thread_start..item_thread_end {
- // Swap the halves of the item index so that
- // the polynomials based on the items are
- // written to the AlignedMemory64 more
- // sequentially
- let ii = i / dim0;
- let j = i % dim0;
- let db_idx = j * num_per + ii;
- let mut db_item =
- load_item_from_slice(params, slice, instance, trial, db_idx);
- // db_item.reduce_mod(params.pt_modulus);
- for z in 0..params.poly_len {
- db_item.data[z] = recenter_mod(
- db_item.data[z],
- params.pt_modulus,
- params.modulus,
- );
- }
- let db_item_ntt = db_item.ntt();
- for z in 0..params.poly_len {
- let idx_dst = calc_index(
- &[instance, trial, z, ii, j],
- &[instances, trials, params.poly_len, num_per, dim0],
- );
- unsafe {
- vptr.add(idx_dst).write(
- db_item_ntt.data[z]
- | (db_item_ntt.data[params.poly_len + z]
- << PACKED_OFFSET_2),
- );
- }
- }
- }
- });
- item_thread_start = item_thread_end;
- }
- });
- }
- }
- v
- }
|