params.rs 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. use crate::{arith::*, ntt::*};
  2. pub struct Params {
  3. pub poly_len: usize,
  4. pub poly_len_log2: usize,
  5. pub ntt_tables: Vec<Vec<Vec<u64>>>,
  6. pub crt_count: usize,
  7. pub moduli: Vec<u64>,
  8. pub modulus: u64
  9. }
  10. impl Params {
  11. pub fn num_words(&self) -> usize {
  12. self.poly_len * self.crt_count
  13. }
  14. pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
  15. self.ntt_tables[i][0].as_slice()
  16. }
  17. pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
  18. self.ntt_tables[i][1].as_slice()
  19. }
  20. pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
  21. self.ntt_tables[i][2].as_slice()
  22. }
  23. pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
  24. self.ntt_tables[i][3].as_slice()
  25. }
  26. pub fn init(poly_len: usize, moduli: &Vec<u64>) -> Self {
  27. let poly_len_log2 = log2(poly_len as u64) as usize;
  28. let crt_count = moduli.len();
  29. let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
  30. let mut modulus = 1;
  31. for m in moduli {
  32. modulus *= m;
  33. }
  34. Self {
  35. poly_len,
  36. poly_len_log2,
  37. ntt_tables,
  38. crt_count,
  39. moduli: moduli.clone(),
  40. modulus
  41. }
  42. }
  43. }