params.rs 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. use crate::{arith::*, ntt::*, number_theory::*, poly::*};
  2. pub const MAX_MODULI: usize = 4;
  3. pub static Q2_VALUES: [u64; 37] = [
  4. 0,
  5. 0,
  6. 0,
  7. 0,
  8. 0,
  9. 0,
  10. 0,
  11. 0,
  12. 0,
  13. 0,
  14. 0,
  15. 0,
  16. 0,
  17. 0,
  18. 12289,
  19. 12289,
  20. 61441,
  21. 65537,
  22. 65537,
  23. 520193,
  24. 786433,
  25. 786433,
  26. 3604481,
  27. 7340033,
  28. 16515073,
  29. 33292289,
  30. 67043329,
  31. 132120577,
  32. 268369921,
  33. 469762049,
  34. 1073479681,
  35. 2013265921,
  36. 4293918721,
  37. 8588886017,
  38. 17175674881,
  39. 34359214081,
  40. 68718428161,
  41. ];
  42. #[derive(Debug, PartialEq)]
  43. pub struct Params {
  44. pub poly_len: usize,
  45. pub poly_len_log2: usize,
  46. pub ntt_tables: Vec<Vec<Vec<u64>>>,
  47. pub scratch: Vec<u64>,
  48. pub crt_count: usize,
  49. pub barrett_cr_0: [u64; MAX_MODULI],
  50. pub barrett_cr_1: [u64; MAX_MODULI],
  51. pub barrett_cr_0_modulus: u64,
  52. pub barrett_cr_1_modulus: u64,
  53. pub mod0_inv_mod1: u64,
  54. pub mod1_inv_mod0: u64,
  55. pub moduli: [u64; MAX_MODULI],
  56. pub modulus: u64,
  57. pub modulus_log2: u64,
  58. pub noise_width: f64,
  59. pub n: usize,
  60. pub pt_modulus: u64,
  61. pub q2_bits: u64,
  62. pub t_conv: usize,
  63. pub t_exp_left: usize,
  64. pub t_exp_right: usize,
  65. pub t_gsw: usize,
  66. pub expand_queries: bool,
  67. pub db_dim_1: usize,
  68. pub db_dim_2: usize,
  69. pub instances: usize,
  70. pub db_item_size: usize,
  71. }
  72. impl Params {
  73. pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
  74. self.ntt_tables[i][0].as_slice()
  75. }
  76. pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
  77. self.ntt_tables[i][1].as_slice()
  78. }
  79. pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
  80. self.ntt_tables[i][2].as_slice()
  81. }
  82. pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
  83. self.ntt_tables[i][3].as_slice()
  84. }
  85. pub fn get_v_neg1(&self) -> Vec<PolyMatrixNTT> {
  86. let mut v_neg1 = Vec::new();
  87. for i in 0..self.poly_len_log2 {
  88. let idx = self.poly_len - (1 << i);
  89. let mut ng1 = PolyMatrixRaw::zero(&self, 1, 1);
  90. ng1.data[idx] = 1;
  91. v_neg1.push((-&ng1).ntt());
  92. }
  93. v_neg1
  94. }
  95. pub fn get_sk_gsw(&self) -> (usize, usize) {
  96. (self.n, 1)
  97. }
  98. pub fn get_sk_reg(&self) -> (usize, usize) {
  99. (1, 1)
  100. }
  101. pub fn m_conv(&self) -> usize {
  102. self.t_conv
  103. }
  104. pub fn crt_compose_1(&self, x: u64) -> u64 {
  105. assert_eq!(self.crt_count, 1);
  106. x
  107. }
  108. pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
  109. assert_eq!(self.crt_count, 2);
  110. let mut val = (x as u128) * (self.mod1_inv_mod0 as u128);
  111. val += (y as u128) * (self.mod0_inv_mod1 as u128);
  112. barrett_reduction_u128(self, val)
  113. }
  114. pub fn crt_compose(&self, a: &[u64], idx: usize) -> u64 {
  115. if self.crt_count == 1 {
  116. self.crt_compose_1(a[idx])
  117. } else {
  118. self.crt_compose_2(a[idx], a[idx + self.poly_len])
  119. }
  120. }
  121. pub fn init(
  122. poly_len: usize,
  123. moduli: &[u64],
  124. noise_width: f64,
  125. n: usize,
  126. pt_modulus: u64,
  127. q2_bits: u64,
  128. t_conv: usize,
  129. t_exp_left: usize,
  130. t_exp_right: usize,
  131. t_gsw: usize,
  132. expand_queries: bool,
  133. db_dim_1: usize,
  134. db_dim_2: usize,
  135. instances: usize,
  136. db_item_size: usize,
  137. ) -> Self {
  138. let poly_len_log2 = log2(poly_len as u64) as usize;
  139. let crt_count = moduli.len();
  140. assert!(crt_count <= MAX_MODULI);
  141. let mut moduli_array = [0; MAX_MODULI];
  142. for i in 0..crt_count {
  143. moduli_array[i] = moduli[i];
  144. }
  145. let ntt_tables = build_ntt_tables(poly_len, moduli);
  146. let scratch = vec![0u64; crt_count * poly_len];
  147. let mut modulus = 1;
  148. for m in moduli {
  149. modulus *= m;
  150. }
  151. let modulus_log2 = log2_ceil(modulus);
  152. let (barrett_cr_0, barrett_cr_1) = get_barrett(moduli);
  153. let (barrett_cr_0_modulus, barrett_cr_1_modulus) = get_barrett_crs(modulus);
  154. let mut mod0_inv_mod1 = 0;
  155. let mut mod1_inv_mod0 = 0;
  156. if crt_count == 2 {
  157. mod0_inv_mod1 = moduli[0] * invert_uint_mod(moduli[0], moduli[1]).unwrap();
  158. mod1_inv_mod0 = moduli[1] * invert_uint_mod(moduli[1], moduli[0]).unwrap();
  159. }
  160. Self {
  161. poly_len,
  162. poly_len_log2,
  163. ntt_tables,
  164. scratch,
  165. crt_count,
  166. barrett_cr_0,
  167. barrett_cr_1,
  168. barrett_cr_0_modulus,
  169. barrett_cr_1_modulus,
  170. mod0_inv_mod1,
  171. mod1_inv_mod0,
  172. moduli: moduli_array,
  173. modulus,
  174. modulus_log2,
  175. noise_width,
  176. n,
  177. pt_modulus,
  178. q2_bits,
  179. t_conv,
  180. t_exp_left,
  181. t_exp_right,
  182. t_gsw,
  183. expand_queries,
  184. db_dim_1,
  185. db_dim_2,
  186. instances,
  187. db_item_size,
  188. }
  189. }
  190. }