spiral_mt.rs 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. use spiral_rs::arith::*;
  2. use spiral_rs::params::*;
  3. use spiral_rs::poly::*;
  4. use spiral_rs::server::*;
  5. use spiral_rs::util::*;
  6. use rayon::scope;
  7. use crate::aligned_memory_mt::*;
  8. pub fn load_item_from_slice<'a>(
  9. params: &'a Params,
  10. slice: &[u8],
  11. instance: usize,
  12. trial: usize,
  13. item_idx: usize,
  14. ) -> PolyMatrixRaw<'a> {
  15. let db_item_size = params.db_item_size;
  16. let instances = params.instances;
  17. let trials = params.n * params.n;
  18. let chunks = instances * trials;
  19. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  20. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  21. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  22. assert!(modp_words_per_chunk <= params.poly_len);
  23. let idx_item_in_file = item_idx * db_item_size;
  24. let idx_chunk = instance * trials + trial;
  25. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  26. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  27. let modp_words_read = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  28. assert!(modp_words_read <= params.poly_len);
  29. for i in 0..modp_words_read {
  30. out.data[i] = read_arbitrary_bits(
  31. &slice[idx_poly_in_file..idx_poly_in_file + bytes_per_chunk],
  32. i * logp,
  33. logp,
  34. );
  35. assert!(out.data[i] <= params.pt_modulus);
  36. }
  37. out
  38. }
  39. pub fn load_db_from_slice_mt(
  40. params: &Params,
  41. slice: &[u8],
  42. num_threads: usize,
  43. ) -> AlignedMemoryMT64 {
  44. let instances = params.instances;
  45. let trials = params.n * params.n;
  46. let dim0 = 1 << params.db_dim_1;
  47. let num_per = 1 << params.db_dim_2;
  48. let num_items = dim0 * num_per;
  49. let db_size_words = instances * trials * num_items * params.poly_len;
  50. let v: AlignedMemoryMT64 = AlignedMemoryMT64::new(db_size_words);
  51. for instance in 0..instances {
  52. for trial in 0..trials {
  53. scope(|s| {
  54. let mut item_thread_start = 0usize;
  55. let items_per_thread_base = num_items / num_threads;
  56. let items_per_thread_extra = num_items % num_threads;
  57. for thr in 0..num_threads {
  58. let items_this_thread =
  59. items_per_thread_base + if thr < items_per_thread_extra { 1 } else { 0 };
  60. let item_thread_end = item_thread_start + items_this_thread;
  61. let v = &v;
  62. s.spawn(move |_| {
  63. let vptr = unsafe { v.as_mut_ptr() };
  64. for i in item_thread_start..item_thread_end {
  65. // Swap the halves of the item index so that
  66. // the polynomials based on the items are
  67. // written to the AlignedMemoryMT64 more
  68. // sequentially
  69. let ii = i / dim0;
  70. let j = i % dim0;
  71. let db_idx = j * num_per + ii;
  72. let mut db_item =
  73. load_item_from_slice(params, slice, instance, trial, db_idx);
  74. // db_item.reduce_mod(params.pt_modulus);
  75. for z in 0..params.poly_len {
  76. db_item.data[z] = recenter_mod(
  77. db_item.data[z],
  78. params.pt_modulus,
  79. params.modulus,
  80. );
  81. }
  82. let db_item_ntt = db_item.ntt();
  83. for z in 0..params.poly_len {
  84. let idx_dst = calc_index(
  85. &[instance, trial, z, ii, j],
  86. &[instances, trials, params.poly_len, num_per, dim0],
  87. );
  88. unsafe {
  89. vptr.add(idx_dst).write(
  90. db_item_ntt.data[z]
  91. | (db_item_ntt.data[params.poly_len + z]
  92. << PACKED_OFFSET_2),
  93. );
  94. }
  95. }
  96. }
  97. });
  98. item_thread_start = item_thread_end;
  99. }
  100. });
  101. }
  102. }
  103. v
  104. }