params.rs 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. use crate::{arith::*, ntt::*, number_theory::*};
  2. pub static Q2_VALUES: [u64; 37] = [
  3. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12289, 12289, 61441, 65537, 65537, 520193, 786433, 786433, 3604481, 7340033, 16515073, 33292289, 67043329, 132120577, 268369921, 469762049, 1073479681, 2013265921, 4293918721, 8588886017, 17175674881, 34359214081, 68718428161
  4. ];
  5. #[derive(Debug)]
  6. #[derive(PartialEq)]
  7. pub struct Params {
  8. pub poly_len: usize,
  9. pub poly_len_log2: usize,
  10. pub ntt_tables: Vec<Vec<Vec<u64>>>,
  11. pub scratch: Vec<u64>,
  12. pub crt_count: usize,
  13. pub moduli: Vec<u64>,
  14. pub modulus: u64,
  15. pub modulus_log2: u64,
  16. pub noise_width: f64,
  17. pub n: usize,
  18. pub pt_modulus: u64,
  19. pub q2_bits: u64,
  20. pub t_conv: usize,
  21. pub t_exp_left: usize,
  22. pub t_exp_right: usize,
  23. pub t_gsw: usize,
  24. pub expand_queries: bool,
  25. pub db_dim_1: usize,
  26. pub db_dim_2: usize,
  27. pub instances: usize,
  28. pub db_item_size: usize,
  29. }
  30. impl Params {
  31. pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
  32. self.ntt_tables[i][0].as_slice()
  33. }
  34. pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
  35. self.ntt_tables[i][1].as_slice()
  36. }
  37. pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
  38. self.ntt_tables[i][2].as_slice()
  39. }
  40. pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
  41. self.ntt_tables[i][3].as_slice()
  42. }
  43. pub fn get_sk_gsw(&self) -> (usize, usize) {
  44. (self.n, 1)
  45. }
  46. pub fn get_sk_reg(&self) -> (usize, usize) {
  47. (1, 1)
  48. }
  49. pub fn m_conv(&self) -> usize {
  50. self.t_conv
  51. }
  52. pub fn crt_compose_1(&self, x: u64) -> u64 {
  53. assert_eq!(self.crt_count, 1);
  54. x
  55. }
  56. pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
  57. assert_eq!(self.crt_count, 2);
  58. let a = self.moduli[0];
  59. let b = self.moduli[1];
  60. let a_inv_mod_b = invert_uint_mod(a, b).unwrap();
  61. let b_inv_mod_a = invert_uint_mod(b, a).unwrap();
  62. let mut val = (x as u128) * (b_inv_mod_a as u128) * (b as u128);
  63. val += (y as u128) * (a_inv_mod_b as u128) * (a as u128);
  64. (val % (self.modulus as u128)) as u64 // FIXME: use barrett
  65. }
  66. pub fn crt_compose(&self, a: &[u64], idx: usize) -> u64 {
  67. if self.crt_count == 1 {
  68. self.crt_compose_1(a[idx])
  69. } else {
  70. self.crt_compose_2(a[idx], a[idx + self.poly_len])
  71. }
  72. }
  73. pub fn init(
  74. poly_len: usize,
  75. moduli: &Vec<u64>,
  76. noise_width: f64,
  77. n: usize,
  78. pt_modulus: u64,
  79. q2_bits: u64,
  80. t_conv: usize,
  81. t_exp_left: usize,
  82. t_exp_right: usize,
  83. t_gsw: usize,
  84. expand_queries: bool,
  85. db_dim_1: usize,
  86. db_dim_2: usize,
  87. instances: usize,
  88. db_item_size: usize,
  89. ) -> Self {
  90. let poly_len_log2 = log2(poly_len as u64) as usize;
  91. let crt_count = moduli.len();
  92. let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
  93. let scratch = vec![0u64; crt_count * poly_len];
  94. let mut modulus = 1;
  95. for m in moduli {
  96. modulus *= m;
  97. }
  98. let modulus_log2 = log2_ceil(modulus);
  99. Self {
  100. poly_len,
  101. poly_len_log2,
  102. ntt_tables,
  103. scratch,
  104. crt_count,
  105. moduli: moduli.clone(),
  106. modulus,
  107. modulus_log2,
  108. noise_width,
  109. n,
  110. pt_modulus,
  111. q2_bits,
  112. t_conv,
  113. t_exp_left,
  114. t_exp_right,
  115. t_gsw,
  116. expand_queries,
  117. db_dim_1,
  118. db_dim_2,
  119. instances,
  120. db_item_size,
  121. }
  122. }
  123. }