spiral_mt.rs 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. use spiral_rs::aligned_memory::*;
  2. use spiral_rs::arith::*;
  3. use spiral_rs::params::*;
  4. use spiral_rs::poly::*;
  5. use spiral_rs::server::*;
  6. use spiral_rs::util::*;
  7. use rayon::scope;
  8. pub fn load_item_from_slice<'a>(
  9. params: &'a Params,
  10. slice: &[u8],
  11. instance: usize,
  12. trial: usize,
  13. item_idx: usize,
  14. ) -> PolyMatrixRaw<'a> {
  15. let db_item_size = params.db_item_size;
  16. let instances = params.instances;
  17. let trials = params.n * params.n;
  18. let chunks = instances * trials;
  19. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  20. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  21. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  22. assert!(modp_words_per_chunk <= params.poly_len);
  23. let idx_item_in_file = item_idx * db_item_size;
  24. let idx_chunk = instance * trials + trial;
  25. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  26. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  27. let modp_words_read = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  28. assert!(modp_words_read <= params.poly_len);
  29. for i in 0..modp_words_read {
  30. out.data[i] = read_arbitrary_bits(
  31. &slice[idx_poly_in_file..idx_poly_in_file + bytes_per_chunk],
  32. i * logp,
  33. logp,
  34. );
  35. assert!(out.data[i] <= params.pt_modulus);
  36. }
  37. out
  38. }
  39. pub fn load_db_from_slice_mt(params: &Params, slice: &[u8], num_threads: usize) -> AlignedMemory64 {
  40. let instances = params.instances;
  41. let trials = params.n * params.n;
  42. let dim0 = 1 << params.db_dim_1;
  43. let num_per = 1 << params.db_dim_2;
  44. let num_items = dim0 * num_per;
  45. let db_size_words = instances * trials * num_items * params.poly_len;
  46. let mut v: AlignedMemory64 = AlignedMemory64::new(db_size_words);
  47. // Get a pointer to the memory pool of the AlignedMemory64. We
  48. // treat it as a usize explicitly so we can pass the same pointer to
  49. // multiple threads, each of which will cast it to a *mut u64, in
  50. // order to *write* into the memory pool concurrently. There is a
  51. // caveat that the threads *must not* try to write into the same
  52. // memory location. In Spiral, each polynomial created from the
  53. // database ends up scattered into noncontiguous words of memory,
  54. // but any one word still only comes from one polynomial. So with
  55. // this mechanism, different threads can read different parts of the
  56. // database to produce different polynomials, and write those
  57. // polynomials into the same memory pool (but *not* the same memory
  58. // locations) at the same time.
  59. let vptrusize = unsafe { v.as_mut_ptr() as usize };
  60. for instance in 0..instances {
  61. for trial in 0..trials {
  62. scope(|s| {
  63. let mut item_thread_start = 0usize;
  64. let items_per_thread_base = num_items / num_threads;
  65. let items_per_thread_extra = num_items % num_threads;
  66. for thr in 0..num_threads {
  67. let items_this_thread =
  68. items_per_thread_base + if thr < items_per_thread_extra { 1 } else { 0 };
  69. let item_thread_end = item_thread_start + items_this_thread;
  70. s.spawn(move |_| {
  71. let vptr = vptrusize as *mut u64;
  72. for i in item_thread_start..item_thread_end {
  73. // Swap the halves of the item index so that
  74. // the polynomials based on the items are
  75. // written to the AlignedMemory64 more
  76. // sequentially
  77. let ii = i / dim0;
  78. let j = i % dim0;
  79. let db_idx = j * num_per + ii;
  80. let mut db_item =
  81. load_item_from_slice(params, slice, instance, trial, db_idx);
  82. // db_item.reduce_mod(params.pt_modulus);
  83. for z in 0..params.poly_len {
  84. db_item.data[z] = recenter_mod(
  85. db_item.data[z],
  86. params.pt_modulus,
  87. params.modulus,
  88. );
  89. }
  90. let db_item_ntt = db_item.ntt();
  91. for z in 0..params.poly_len {
  92. let idx_dst = calc_index(
  93. &[instance, trial, z, ii, j],
  94. &[instances, trials, params.poly_len, num_per, dim0],
  95. );
  96. unsafe {
  97. vptr.add(idx_dst).write(
  98. db_item_ntt.data[z]
  99. | (db_item_ntt.data[params.poly_len + z]
  100. << PACKED_OFFSET_2),
  101. );
  102. }
  103. }
  104. }
  105. });
  106. item_thread_start = item_thread_end;
  107. }
  108. });
  109. }
  110. }
  111. v
  112. }