params.rs 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. use crate::{arith::*, ntt::*, number_theory::*};
  2. pub struct Params {
  3. pub poly_len: usize,
  4. pub poly_len_log2: usize,
  5. pub ntt_tables: Vec<Vec<Vec<u64>>>,
  6. pub scratch: Vec<u64>,
  7. pub crt_count: usize,
  8. pub moduli: Vec<u64>,
  9. pub modulus: u64,
  10. pub modulus_log2: u64,
  11. pub noise_width: f64,
  12. pub n: usize,
  13. pub t_conv: usize,
  14. pub t_exp_left: usize,
  15. pub t_exp_right: usize,
  16. pub t_gsw: usize,
  17. pub expand_queries: bool,
  18. }
  19. impl Params {
  20. pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
  21. self.ntt_tables[i][0].as_slice()
  22. }
  23. pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
  24. self.ntt_tables[i][1].as_slice()
  25. }
  26. pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
  27. self.ntt_tables[i][2].as_slice()
  28. }
  29. pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
  30. self.ntt_tables[i][3].as_slice()
  31. }
  32. pub fn get_sk_gsw(&self) -> (usize, usize) {
  33. (self.n, 1)
  34. }
  35. pub fn get_sk_reg(&self) -> (usize, usize) {
  36. (1, 1)
  37. }
  38. pub fn m_conv(&self) -> usize {
  39. 2 * self.t_conv
  40. }
  41. pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
  42. assert_eq!(self.crt_count, 2);
  43. let a = self.moduli[0];
  44. let b = self.moduli[1];
  45. let a_inv_mod_b = invert_uint_mod(a, b).unwrap();
  46. let b_inv_mod_a = invert_uint_mod(b, a).unwrap();
  47. let mut val = (x as u128) * (b_inv_mod_a as u128) * (b as u128);
  48. val += (y as u128) * (a_inv_mod_b as u128) * (a as u128);
  49. (val % (self.modulus as u128)) as u64 // FIXME: use barrett
  50. }
  51. pub fn init(
  52. poly_len: usize,
  53. moduli: &Vec<u64>,
  54. noise_width: f64,
  55. n: usize,
  56. t_conv: usize,
  57. t_exp_left: usize,
  58. t_exp_right: usize,
  59. t_gsw: usize,
  60. expand_queries: bool,
  61. ) -> Self {
  62. let poly_len_log2 = log2(poly_len as u64) as usize;
  63. let crt_count = moduli.len();
  64. let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
  65. let scratch = vec![0u64; crt_count * poly_len];
  66. let mut modulus = 1;
  67. for m in moduli {
  68. modulus *= m;
  69. }
  70. let modulus_log2 = log2(modulus);
  71. Self {
  72. poly_len,
  73. poly_len_log2,
  74. ntt_tables,
  75. scratch,
  76. crt_count,
  77. moduli: moduli.clone(),
  78. modulus,
  79. modulus_log2,
  80. noise_width,
  81. n,
  82. t_conv,
  83. t_exp_left,
  84. t_exp_right,
  85. t_gsw,
  86. expand_queries,
  87. }
  88. }
  89. }