spiral_mt.rs 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. use spiral_rs::aligned_memory::*;
  2. use spiral_rs::arith::*;
  3. use spiral_rs::params::*;
  4. use spiral_rs::poly::*;
  5. use spiral_rs::server::*;
  6. use spiral_rs::util::*;
  7. use crossbeam::thread;
  8. pub fn load_item_from_slice<'a>(
  9. params: &'a Params,
  10. slice: &[u8],
  11. instance: usize,
  12. trial: usize,
  13. item_idx: usize,
  14. ) -> PolyMatrixRaw<'a> {
  15. let db_item_size = params.db_item_size;
  16. let instances = params.instances;
  17. let trials = params.n * params.n;
  18. let chunks = instances * trials;
  19. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  20. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  21. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  22. assert!(modp_words_per_chunk <= params.poly_len);
  23. let idx_item_in_file = item_idx * db_item_size;
  24. let idx_chunk = instance * trials + trial;
  25. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  26. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  27. let modp_words_read = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  28. assert!(modp_words_read <= params.poly_len);
  29. for i in 0..modp_words_read {
  30. out.data[i] = read_arbitrary_bits(
  31. &slice[idx_poly_in_file..idx_poly_in_file + bytes_per_chunk],
  32. i * logp,
  33. logp,
  34. );
  35. assert!(out.data[i] <= params.pt_modulus);
  36. }
  37. out
  38. }
  39. pub fn load_db_from_slice_mt(params: &Params, slice: &[u8], numthreads: usize) -> AlignedMemory64 {
  40. let instances = params.instances;
  41. let trials = params.n * params.n;
  42. let dim0 = 1 << params.db_dim_1;
  43. let num_per = 1 << params.db_dim_2;
  44. let num_items = dim0 * num_per;
  45. let db_size_words = instances * trials * num_items * params.poly_len;
  46. let mut v = AlignedMemory64::new(db_size_words);
  47. for instance in 0..instances {
  48. for trial in 0..trials {
  49. let vslice = v.as_mut_slice();
  50. thread::scope(|s| {
  51. s.spawn(|_| {
  52. for i in 0..num_items {
  53. let ii = i % num_per;
  54. let j = i / num_per;
  55. let mut db_item = load_item_from_slice(&params, slice, instance, trial, i);
  56. // db_item.reduce_mod(params.pt_modulus);
  57. for z in 0..params.poly_len {
  58. db_item.data[z] =
  59. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  60. }
  61. let db_item_ntt = db_item.ntt();
  62. for z in 0..params.poly_len {
  63. let idx_dst = calc_index(
  64. &[instance, trial, z, ii, j],
  65. &[instances, trials, params.poly_len, num_per, dim0],
  66. );
  67. vslice[idx_dst] = db_item_ntt.data[z]
  68. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  69. }
  70. }
  71. });
  72. })
  73. .unwrap();
  74. }
  75. }
  76. v
  77. }