params.rs 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. use std::mem::size_of;
  2. use crate::{arith::*, ntt::*, number_theory::*, poly::*};
  3. pub const MAX_MODULI: usize = 4;
  4. pub static MIN_Q2_BITS: u64 = 14;
  5. pub static Q2_VALUES: [u64; 37] = [
  6. 0,
  7. 0,
  8. 0,
  9. 0,
  10. 0,
  11. 0,
  12. 0,
  13. 0,
  14. 0,
  15. 0,
  16. 0,
  17. 0,
  18. 0,
  19. 0,
  20. 12289,
  21. 12289,
  22. 61441,
  23. 65537,
  24. 65537,
  25. 520193,
  26. 786433,
  27. 786433,
  28. 3604481,
  29. 7340033,
  30. 16515073,
  31. 33292289,
  32. 67043329,
  33. 132120577,
  34. 268369921,
  35. 469762049,
  36. 1073479681,
  37. 2013265921,
  38. 4293918721,
  39. 8588886017,
  40. 17175674881,
  41. 34359214081,
  42. 68718428161,
  43. ];
  44. #[derive(Debug, PartialEq, Clone)]
  45. pub struct Params {
  46. pub poly_len: usize,
  47. pub poly_len_log2: usize,
  48. pub ntt_tables: Vec<Vec<Vec<u64>>>,
  49. pub scratch: Vec<u64>,
  50. pub crt_count: usize,
  51. pub barrett_cr_0: [u64; MAX_MODULI],
  52. pub barrett_cr_1: [u64; MAX_MODULI],
  53. pub barrett_cr_0_modulus: u64,
  54. pub barrett_cr_1_modulus: u64,
  55. pub mod0_inv_mod1: u64,
  56. pub mod1_inv_mod0: u64,
  57. pub moduli: [u64; MAX_MODULI],
  58. pub modulus: u64,
  59. pub modulus_log2: u64,
  60. pub noise_width: f64,
  61. pub n: usize,
  62. pub pt_modulus: u64,
  63. pub q2_bits: u64,
  64. pub t_conv: usize,
  65. pub t_exp_left: usize,
  66. pub t_exp_right: usize,
  67. pub t_gsw: usize,
  68. pub expand_queries: bool,
  69. pub db_dim_1: usize,
  70. pub db_dim_2: usize,
  71. pub instances: usize,
  72. pub db_item_size: usize,
  73. }
  74. impl Params {
  75. pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
  76. self.ntt_tables[i][0].as_slice()
  77. }
  78. pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
  79. self.ntt_tables[i][1].as_slice()
  80. }
  81. pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
  82. self.ntt_tables[i][2].as_slice()
  83. }
  84. pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
  85. self.ntt_tables[i][3].as_slice()
  86. }
  87. pub fn get_v_neg1(&self) -> Vec<PolyMatrixNTT> {
  88. let mut v_neg1 = Vec::new();
  89. for i in 0..self.poly_len_log2 {
  90. let idx = self.poly_len - (1 << i);
  91. let mut ng1 = PolyMatrixRaw::zero(&self, 1, 1);
  92. ng1.data[idx] = 1;
  93. v_neg1.push((-&ng1).ntt());
  94. }
  95. v_neg1
  96. }
  97. pub fn get_sk_gsw(&self) -> (usize, usize) {
  98. (self.n, 1)
  99. }
  100. pub fn get_sk_reg(&self) -> (usize, usize) {
  101. (1, 1)
  102. }
  103. pub fn num_expanded(&self) -> usize {
  104. 1 << self.db_dim_1
  105. }
  106. pub fn num_items(&self) -> usize {
  107. (1 << self.db_dim_1) * (1 << self.db_dim_2)
  108. }
  109. pub fn item_size(&self) -> usize {
  110. let logp = log2(self.pt_modulus) as usize;
  111. self.instances * self.n * self.n * self.poly_len * logp / 8
  112. }
  113. pub fn g(&self) -> usize {
  114. let num_bits_to_gen = self.t_gsw * self.db_dim_2 + self.num_expanded();
  115. log2_ceil_usize(num_bits_to_gen)
  116. }
  117. pub fn stop_round(&self) -> usize {
  118. log2_ceil_usize(self.t_gsw * self.db_dim_2)
  119. }
  120. pub fn factor_on_first_dim(&self) -> usize {
  121. if self.db_dim_2 == 0 {
  122. 1
  123. } else {
  124. 2
  125. }
  126. }
  127. pub fn setup_bytes(&self) -> usize {
  128. let mut sz_polys = 0;
  129. let packing_sz = ((self.n + 1) - 1) * self.t_conv;
  130. sz_polys += self.n * packing_sz;
  131. if self.expand_queries {
  132. let expansion_left_sz = self.g() * self.t_exp_left;
  133. let expansion_right_sz = (self.stop_round() + 1) * self.t_exp_right;
  134. let conversion_sz = 2 * self.t_conv;
  135. sz_polys += expansion_left_sz + expansion_right_sz + conversion_sz;
  136. }
  137. let sz_bytes = sz_polys * self.poly_len * size_of::<u64>();
  138. sz_bytes
  139. }
  140. pub fn query_bytes(&self) -> usize {
  141. let sz_polys;
  142. if self.expand_queries {
  143. sz_polys = 1;
  144. } else {
  145. let first_dimension_sz = self.num_expanded();
  146. let further_dimension_sz = self.db_dim_2 * (2 * self.t_gsw);
  147. sz_polys = first_dimension_sz + further_dimension_sz;
  148. }
  149. let sz_bytes = sz_polys * self.poly_len * size_of::<u64>();
  150. sz_bytes
  151. }
  152. pub fn query_v_buf_bytes(&self) -> usize {
  153. self.num_expanded() * self.poly_len * size_of::<u64>()
  154. }
  155. pub fn bytes_per_chunk(&self) -> usize {
  156. let trials = self.n * self.n;
  157. let chunks = self.instances * trials;
  158. let bytes_per_chunk = f64::ceil(self.db_item_size as f64 / chunks as f64) as usize;
  159. bytes_per_chunk
  160. }
  161. pub fn modp_words_per_chunk(&self) -> usize {
  162. let bytes_per_chunk = self.bytes_per_chunk();
  163. let logp = log2(self.pt_modulus);
  164. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  165. modp_words_per_chunk
  166. }
  167. pub fn crt_compose_1(&self, x: u64) -> u64 {
  168. assert_eq!(self.crt_count, 1);
  169. x
  170. }
  171. pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
  172. assert_eq!(self.crt_count, 2);
  173. let mut val = (x as u128) * (self.mod1_inv_mod0 as u128);
  174. val += (y as u128) * (self.mod0_inv_mod1 as u128);
  175. barrett_reduction_u128(self, val)
  176. }
  177. pub fn crt_compose(&self, a: &[u64], idx: usize) -> u64 {
  178. if self.crt_count == 1 {
  179. self.crt_compose_1(a[idx])
  180. } else {
  181. self.crt_compose_2(a[idx], a[idx + self.poly_len])
  182. }
  183. }
  184. pub fn init(
  185. poly_len: usize,
  186. moduli: &[u64],
  187. noise_width: f64,
  188. n: usize,
  189. pt_modulus: u64,
  190. q2_bits: u64,
  191. t_conv: usize,
  192. t_exp_left: usize,
  193. t_exp_right: usize,
  194. t_gsw: usize,
  195. expand_queries: bool,
  196. db_dim_1: usize,
  197. db_dim_2: usize,
  198. instances: usize,
  199. db_item_size: usize,
  200. ) -> Self {
  201. assert!(q2_bits >= MIN_Q2_BITS);
  202. let poly_len_log2 = log2(poly_len as u64) as usize;
  203. let crt_count = moduli.len();
  204. assert!(crt_count <= MAX_MODULI);
  205. let mut moduli_array = [0; MAX_MODULI];
  206. for i in 0..crt_count {
  207. moduli_array[i] = moduli[i];
  208. }
  209. let ntt_tables = build_ntt_tables(poly_len, moduli);
  210. let scratch = vec![0u64; crt_count * poly_len];
  211. let mut modulus = 1;
  212. for m in moduli {
  213. modulus *= m;
  214. }
  215. let modulus_log2 = log2_ceil(modulus);
  216. let (barrett_cr_0, barrett_cr_1) = get_barrett(moduli);
  217. let (barrett_cr_0_modulus, barrett_cr_1_modulus) = get_barrett_crs(modulus);
  218. let mut mod0_inv_mod1 = 0;
  219. let mut mod1_inv_mod0 = 0;
  220. if crt_count == 2 {
  221. mod0_inv_mod1 = moduli[0] * invert_uint_mod(moduli[0], moduli[1]).unwrap();
  222. mod1_inv_mod0 = moduli[1] * invert_uint_mod(moduli[1], moduli[0]).unwrap();
  223. }
  224. Self {
  225. poly_len,
  226. poly_len_log2,
  227. ntt_tables,
  228. scratch,
  229. crt_count,
  230. barrett_cr_0,
  231. barrett_cr_1,
  232. barrett_cr_0_modulus,
  233. barrett_cr_1_modulus,
  234. mod0_inv_mod1,
  235. mod1_inv_mod0,
  236. moduli: moduli_array,
  237. modulus,
  238. modulus_log2,
  239. noise_width,
  240. n,
  241. pt_modulus,
  242. q2_bits,
  243. t_conv,
  244. t_exp_left,
  245. t_exp_right,
  246. t_gsw,
  247. expand_queries,
  248. db_dim_1,
  249. db_dim_2,
  250. instances,
  251. db_item_size,
  252. }
  253. }
  254. }