params.rs 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. use crate::{arith::*, ntt::*, number_theory::*, poly::*};
  2. pub static Q2_VALUES: [u64; 37] = [
  3. 0,
  4. 0,
  5. 0,
  6. 0,
  7. 0,
  8. 0,
  9. 0,
  10. 0,
  11. 0,
  12. 0,
  13. 0,
  14. 0,
  15. 0,
  16. 0,
  17. 12289,
  18. 12289,
  19. 61441,
  20. 65537,
  21. 65537,
  22. 520193,
  23. 786433,
  24. 786433,
  25. 3604481,
  26. 7340033,
  27. 16515073,
  28. 33292289,
  29. 67043329,
  30. 132120577,
  31. 268369921,
  32. 469762049,
  33. 1073479681,
  34. 2013265921,
  35. 4293918721,
  36. 8588886017,
  37. 17175674881,
  38. 34359214081,
  39. 68718428161,
  40. ];
  41. #[derive(Debug, PartialEq)]
  42. pub struct Params {
  43. pub poly_len: usize,
  44. pub poly_len_log2: usize,
  45. pub ntt_tables: Vec<Vec<Vec<u64>>>,
  46. pub scratch: Vec<u64>,
  47. pub crt_count: usize,
  48. pub moduli: Vec<u64>,
  49. pub modulus: u64,
  50. pub modulus_log2: u64,
  51. pub noise_width: f64,
  52. pub n: usize,
  53. pub pt_modulus: u64,
  54. pub q2_bits: u64,
  55. pub t_conv: usize,
  56. pub t_exp_left: usize,
  57. pub t_exp_right: usize,
  58. pub t_gsw: usize,
  59. pub expand_queries: bool,
  60. pub db_dim_1: usize,
  61. pub db_dim_2: usize,
  62. pub instances: usize,
  63. pub db_item_size: usize,
  64. }
  65. impl Params {
  66. pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
  67. self.ntt_tables[i][0].as_slice()
  68. }
  69. pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
  70. self.ntt_tables[i][1].as_slice()
  71. }
  72. pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
  73. self.ntt_tables[i][2].as_slice()
  74. }
  75. pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
  76. self.ntt_tables[i][3].as_slice()
  77. }
  78. pub fn get_v_neg1(&self) -> Vec<PolyMatrixNTT> {
  79. let mut v_neg1 = Vec::new();
  80. for i in 0..self.poly_len_log2 {
  81. let idx = self.poly_len - (1 << i);
  82. let mut ng1 = PolyMatrixRaw::zero(&self, 1, 1);
  83. ng1.data[idx] = 1;
  84. v_neg1.push((-&ng1).ntt());
  85. }
  86. v_neg1
  87. }
  88. pub fn get_sk_gsw(&self) -> (usize, usize) {
  89. (self.n, 1)
  90. }
  91. pub fn get_sk_reg(&self) -> (usize, usize) {
  92. (1, 1)
  93. }
  94. pub fn m_conv(&self) -> usize {
  95. self.t_conv
  96. }
  97. pub fn crt_compose_1(&self, x: u64) -> u64 {
  98. assert_eq!(self.crt_count, 1);
  99. x
  100. }
  101. pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
  102. assert_eq!(self.crt_count, 2);
  103. let a = self.moduli[0];
  104. let b = self.moduli[1];
  105. let a_inv_mod_b = invert_uint_mod(a, b).unwrap();
  106. let b_inv_mod_a = invert_uint_mod(b, a).unwrap();
  107. let mut val = (x as u128) * (b_inv_mod_a as u128) * (b as u128);
  108. val += (y as u128) * (a_inv_mod_b as u128) * (a as u128);
  109. (val % (self.modulus as u128)) as u64 // FIXME: use barrett
  110. }
  111. pub fn crt_compose(&self, a: &[u64], idx: usize) -> u64 {
  112. if self.crt_count == 1 {
  113. self.crt_compose_1(a[idx])
  114. } else {
  115. self.crt_compose_2(a[idx], a[idx + self.poly_len])
  116. }
  117. }
  118. pub fn init(
  119. poly_len: usize,
  120. moduli: &Vec<u64>,
  121. noise_width: f64,
  122. n: usize,
  123. pt_modulus: u64,
  124. q2_bits: u64,
  125. t_conv: usize,
  126. t_exp_left: usize,
  127. t_exp_right: usize,
  128. t_gsw: usize,
  129. expand_queries: bool,
  130. db_dim_1: usize,
  131. db_dim_2: usize,
  132. instances: usize,
  133. db_item_size: usize,
  134. ) -> Self {
  135. let poly_len_log2 = log2(poly_len as u64) as usize;
  136. let crt_count = moduli.len();
  137. let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
  138. let scratch = vec![0u64; crt_count * poly_len];
  139. let mut modulus = 1;
  140. for m in moduli {
  141. modulus *= m;
  142. }
  143. let modulus_log2 = log2_ceil(modulus);
  144. Self {
  145. poly_len,
  146. poly_len_log2,
  147. ntt_tables,
  148. scratch,
  149. crt_count,
  150. moduli: moduli.clone(),
  151. modulus,
  152. modulus_log2,
  153. noise_width,
  154. n,
  155. pt_modulus,
  156. q2_bits,
  157. t_conv,
  158. t_exp_left,
  159. t_exp_right,
  160. t_gsw,
  161. expand_queries,
  162. db_dim_1,
  163. db_dim_2,
  164. instances,
  165. db_item_size,
  166. }
  167. }
  168. }