server.rs 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. #[cfg(target_feature = "avx2")]
  4. use crate::aligned_memory::*;
  5. use crate::arith::*;
  6. use crate::aligned_memory::*;
  7. use crate::client::PublicParameters;
  8. use crate::client::Query;
  9. use crate::gadget::*;
  10. use crate::params::*;
  11. use crate::poly::*;
  12. use crate::util::*;
  13. pub fn coefficient_expansion(
  14. v: &mut Vec<PolyMatrixNTT>,
  15. g: usize,
  16. stop_round: usize,
  17. params: &Params,
  18. v_w_left: &Vec<PolyMatrixNTT>,
  19. v_w_right: &Vec<PolyMatrixNTT>,
  20. v_neg1: &Vec<PolyMatrixNTT>,
  21. max_bits_to_gen_right: usize,
  22. ) {
  23. let poly_len = params.poly_len;
  24. let mut ct = PolyMatrixRaw::zero(params, 2, 1);
  25. let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
  26. let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
  27. let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
  28. let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
  29. let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
  30. let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
  31. let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
  32. let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
  33. for r in 0..g {
  34. let num_in = 1 << r;
  35. let num_out = 2 * num_in;
  36. let t = (poly_len / (1 << r)) + 1;
  37. let neg1 = &v_neg1[r];
  38. for i in 0..num_out {
  39. if stop_round > 0 && i % 2 == 1 && r > stop_round
  40. || (r == stop_round && i / 2 >= max_bits_to_gen_right)
  41. {
  42. continue;
  43. }
  44. let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
  45. 0 => (
  46. &v_w_left[r],
  47. params.t_exp_left,
  48. &mut ginv_ct_left,
  49. &mut ginv_ct_left_ntt,
  50. ),
  51. 1 | _ => (
  52. &v_w_right[r],
  53. params.t_exp_right,
  54. &mut ginv_ct_right,
  55. &mut ginv_ct_right_ntt,
  56. ),
  57. };
  58. if i < num_in {
  59. let (src, dest) = v.split_at_mut(num_in);
  60. scalar_multiply(&mut dest[i], neg1, &src[i]);
  61. }
  62. from_ntt(&mut ct, &v[i]);
  63. automorph(&mut ct_auto, &ct, t);
  64. gadget_invert_rdim(gi_ct, &ct_auto, 1);
  65. to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
  66. ct_auto_1
  67. .data
  68. .as_mut_slice()
  69. .copy_from_slice(ct_auto.get_poly(1, 0));
  70. to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
  71. multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
  72. let mut idx = 0;
  73. for j in 0..2 {
  74. for n in 0..params.crt_count {
  75. for z in 0..poly_len {
  76. let sum = v[i].data[idx]
  77. + w_times_ginv_ct.data[idx]
  78. + j * ct_auto_1_ntt.data[n * poly_len + z];
  79. v[i].data[idx] = barrett_coeff_u64(params, sum, n);
  80. idx += 1;
  81. }
  82. }
  83. }
  84. }
  85. }
  86. }
  87. pub fn regev_to_gsw<'a>(
  88. v_gsw: &mut Vec<PolyMatrixNTT<'a>>,
  89. v_inp: &Vec<PolyMatrixNTT<'a>>,
  90. v: &PolyMatrixNTT<'a>,
  91. params: &'a Params,
  92. idx_factor: usize,
  93. idx_offset: usize,
  94. ) {
  95. assert!(v.rows == 2);
  96. assert!(v.cols == 2 * params.t_conv);
  97. let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
  98. let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
  99. let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
  100. let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
  101. for i in 0..params.db_dim_2 {
  102. let ct = &mut v_gsw[i];
  103. for j in 0..params.t_gsw {
  104. let idx_ct = i * params.t_gsw + j;
  105. let idx_inp = idx_factor * (idx_ct) + idx_offset;
  106. ct.copy_into(&v_inp[idx_inp], 0, 2 * j + 1);
  107. from_ntt(&mut tmp_ct_raw, &v_inp[idx_inp]);
  108. gadget_invert(&mut ginv_c_inp, &tmp_ct_raw);
  109. to_ntt(&mut ginv_c_inp_ntt, &ginv_c_inp);
  110. multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
  111. ct.copy_into(&tmp_ct, 0, 2 * j);
  112. }
  113. }
  114. }
  115. pub const MAX_SUMMED: usize = 1 << 6;
  116. pub const PACKED_OFFSET_2: i32 = 32;
  117. #[cfg(target_feature = "avx2")]
  118. pub fn multiply_reg_by_database(
  119. out: &mut Vec<PolyMatrixNTT>,
  120. db: &[u64],
  121. v_firstdim: &[u64],
  122. params: &Params,
  123. dim0: usize,
  124. num_per: usize,
  125. ) {
  126. let ct_rows = 2;
  127. let ct_cols = 1;
  128. let pt_rows = 1;
  129. let pt_cols = 1;
  130. assert!(dim0 * ct_rows >= MAX_SUMMED);
  131. let mut sums_out_n0_u64 = AlignedMemory64::new(4);
  132. let mut sums_out_n2_u64 = AlignedMemory64::new(4);
  133. for z in 0..params.poly_len {
  134. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  135. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  136. for i in 0..num_per {
  137. for c in 0..pt_cols {
  138. let inner_limit = MAX_SUMMED;
  139. let outer_limit = dim0 * ct_rows / inner_limit;
  140. let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
  141. let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
  142. for o_jm in 0..outer_limit {
  143. unsafe {
  144. let mut sums_out_n0 = _mm256_setzero_si256();
  145. let mut sums_out_n2 = _mm256_setzero_si256();
  146. for i_jm in 0..inner_limit / 4 {
  147. let jm = o_jm * inner_limit + (4 * i_jm);
  148. let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
  149. idx_b_base += 1;
  150. let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
  151. idx_b_base += 1;
  152. let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
  153. let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
  154. let a = _mm256_load_si256(v_a as *const __m256i);
  155. let a_lo = a;
  156. let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
  157. let b_lo = b;
  158. let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
  159. sums_out_n0 =
  160. _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
  161. sums_out_n2 =
  162. _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
  163. }
  164. // reduce here, otherwise we will overflow
  165. _mm256_store_si256(
  166. sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
  167. sums_out_n0,
  168. );
  169. _mm256_store_si256(
  170. sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
  171. sums_out_n2,
  172. );
  173. for idx in 0..4 {
  174. let val = sums_out_n0_u64[idx];
  175. sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
  176. }
  177. for idx in 0..4 {
  178. let val = sums_out_n2_u64[idx];
  179. sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
  180. }
  181. }
  182. }
  183. for idx in 0..4 {
  184. sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
  185. sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
  186. }
  187. // output n0
  188. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  189. let mut n = 0;
  190. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  191. out[i].data[idx_c] =
  192. barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
  193. idx_c += pt_cols * crt_count * poly_len;
  194. out[i].data[idx_c] =
  195. barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
  196. // output n1
  197. n = 1;
  198. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  199. out[i].data[idx_c] =
  200. barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
  201. idx_c += pt_cols * crt_count * poly_len;
  202. out[i].data[idx_c] =
  203. barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
  204. }
  205. }
  206. }
  207. }
  208. pub fn generate_random_db_and_get_item<'a>(
  209. params: &'a Params,
  210. item_idx: usize,
  211. ) -> (PolyMatrixRaw<'a>, Vec<u64>) {
  212. let mut rng = get_seeded_rng();
  213. let trials = params.n * params.n;
  214. let dim0 = 1 << params.db_dim_1;
  215. let num_per = 1 << params.db_dim_2;
  216. let num_items = dim0 * num_per;
  217. let db_size_words = trials * num_items * params.poly_len;
  218. let mut v = vec![0u64; db_size_words];
  219. let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
  220. for trial in 0..trials {
  221. for i in 0..num_items {
  222. let ii = i % num_per;
  223. let j = i / num_per;
  224. let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
  225. db_item.reduce_mod(params.pt_modulus);
  226. if i == item_idx {
  227. item.copy_into(&db_item, trial / params.n, trial % params.n);
  228. }
  229. for z in 0..params.poly_len {
  230. db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  231. }
  232. let db_item_ntt = db_item.ntt();
  233. for z in 0..params.poly_len {
  234. let idx_dst = calc_index(
  235. &[trial, z, ii, j],
  236. &[trials, params.poly_len, num_per, dim0],
  237. );
  238. v[idx_dst] = db_item_ntt.data[z]
  239. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  240. }
  241. }
  242. }
  243. (item, v)
  244. }
  245. pub fn fold_ciphertexts(
  246. params: &Params,
  247. v_cts: &mut Vec<PolyMatrixRaw>,
  248. v_folding: &Vec<PolyMatrixNTT>,
  249. v_folding_neg: &Vec<PolyMatrixNTT>
  250. ) {
  251. let further_dims = log2(v_cts.len() as u64) as usize;
  252. let ell = v_folding[0].cols / 2;
  253. let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
  254. let mut ginv_c_ntt = PolyMatrixNTT::zero(&params, 2 * ell, 1);
  255. let mut prod = PolyMatrixNTT::zero(&params, 2, 1);
  256. let mut sum = PolyMatrixNTT::zero(&params, 2, 1);
  257. let mut num_per = v_cts.len();
  258. for cur_dim in 0..further_dims {
  259. num_per = num_per / 2;
  260. for i in 0..num_per {
  261. gadget_invert(&mut ginv_c, &v_cts[i]);
  262. to_ntt(&mut ginv_c_ntt, &ginv_c);
  263. multiply(&mut prod, &v_folding_neg[further_dims - 1 - cur_dim], &ginv_c_ntt);
  264. gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
  265. to_ntt(&mut ginv_c_ntt, &ginv_c);
  266. multiply(&mut sum, &v_folding[further_dims - 1 - cur_dim], &ginv_c_ntt);
  267. add_into(&mut sum, &prod);
  268. from_ntt(&mut v_cts[i], &sum);
  269. }
  270. }
  271. }
  272. pub fn pack<'a>(
  273. params: &'a Params,
  274. v_ct: &Vec<PolyMatrixRaw>,
  275. v_w: &Vec<PolyMatrixNTT>
  276. ) -> PolyMatrixNTT<'a> {
  277. assert!(v_ct.len() >= params.n * params.n);
  278. assert!(v_w.len() == params.n);
  279. assert!(v_ct[0].rows == 2);
  280. assert!(v_ct[0].cols == 1);
  281. assert!(v_w[0].rows == (params.n + 1));
  282. assert!(v_w[0].cols == params.t_conv);
  283. let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
  284. let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
  285. let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
  286. let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
  287. let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
  288. let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
  289. let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
  290. for c in 0..params.n {
  291. let mut v_int = PolyMatrixNTT::zero(&params, params.n + 1, 1);
  292. for r in 0..params.n {
  293. let w = &v_w[r];
  294. let ct = &v_ct[r * params.n + c];
  295. ct_1.copy_into(ct, 0, 0);
  296. ct_2.copy_into(ct, 1, 0);
  297. to_ntt(&mut ct_2_ntt, &ct_2);
  298. gadget_invert(&mut ginv, &ct_1);
  299. to_ntt(&mut ginv_nttd, &ginv);
  300. multiply(&mut prod, &w, &ginv_nttd);
  301. add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
  302. add_into(&mut v_int, &prod);
  303. }
  304. result.copy_into(&v_int, 0, c);
  305. }
  306. result
  307. }
  308. pub fn encode(
  309. params: &Params,
  310. v_packed_ct: &Vec<PolyMatrixRaw>
  311. ) -> Vec<u8> {
  312. let q1 = 4 * params.pt_modulus;
  313. let q1_bits = log2_ceil(q1) as usize;
  314. let q2 = Q2_VALUES[params.q2_bits as usize];
  315. let q2_bits = params.q2_bits as usize;
  316. let num_bits = params.instances *
  317. (
  318. (q2_bits * params.n * params.poly_len) +
  319. (q1_bits * params.n * params.n * params.poly_len)
  320. );
  321. let round_to = 64;
  322. let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
  323. let mut result = vec![0u8; num_bytes_rounded_up];
  324. let mut bit_offs = 0;
  325. for instance in 0..params.instances {
  326. let packed_ct = &v_packed_ct[instance];
  327. let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
  328. let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
  329. first_row.apply_func(|x| { rescale(x, params.modulus, q2) });
  330. rest_rows.apply_func(|x| { rescale(x, params.modulus, q1) });
  331. let data = result.as_mut_slice();
  332. for i in 0..params.n * params.poly_len {
  333. write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
  334. bit_offs += q2_bits;
  335. }
  336. for i in 0..params.n * params.n * params.poly_len {
  337. write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
  338. bit_offs += q1_bits;
  339. }
  340. }
  341. result
  342. }
  343. pub fn expand_query<'a>(
  344. params: &'a Params,
  345. public_params: &PublicParameters<'a>,
  346. query: &Query<'a>,
  347. ) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>, Vec<PolyMatrixNTT<'a>>) {
  348. let dim0 = 1 << params.db_dim_1;
  349. let further_dims = params.db_dim_2;
  350. let mut v_reg_reoriented;
  351. let mut v_folding;
  352. let mut v_folding_neg;
  353. let num_bits_to_gen = params.t_gsw * further_dims + dim0;
  354. let g = log2_ceil_usize(num_bits_to_gen);
  355. let right_expanded = params.t_gsw * further_dims;
  356. let stop_round = log2_ceil_usize(right_expanded);
  357. let mut v = Vec::new();
  358. for _ in 0..(1 << g) {
  359. v.push(PolyMatrixNTT::zero(params, 2, 1));
  360. }
  361. v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
  362. let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
  363. let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
  364. let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
  365. let v_neg1 = params.get_v_neg1();
  366. coefficient_expansion(
  367. &mut v,
  368. g,
  369. stop_round,
  370. params,
  371. &v_w_left,
  372. &v_w_right,
  373. &v_neg1,
  374. params.t_gsw * params.db_dim_2,
  375. );
  376. let mut v_reg_inp = Vec::with_capacity(dim0);
  377. for i in 0..dim0 {
  378. v_reg_inp.push(v[2 * i].clone());
  379. }
  380. let mut v_gsw_inp = Vec::with_capacity(right_expanded);
  381. for i in 0..right_expanded {
  382. v_gsw_inp.push(v[2 * i + 1].clone());
  383. }
  384. let v_reg_sz = dim0 * 2 * params.poly_len;
  385. v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  386. reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
  387. v_folding = Vec::new();
  388. for _ in 0..params.db_dim_2 {
  389. v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
  390. }
  391. regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
  392. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
  393. v_folding_neg = Vec::new();
  394. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  395. for i in 0..params.db_dim_2 {
  396. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  397. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  398. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  399. v_folding_neg.push(ct_gsw_neg);
  400. }
  401. (v_reg_reoriented, v_folding, v_folding_neg)
  402. }
  403. #[cfg(target_feature = "avx2")]
  404. pub fn process_query(
  405. params: &Params,
  406. public_params: &PublicParameters,
  407. query: &Query,
  408. db: &[u64],
  409. ) -> Vec<u8> {
  410. let dim0 = 1 << params.db_dim_1;
  411. let num_per = 1 << params.db_dim_2;
  412. let further_dims = params.db_dim_2;
  413. let db_slice_sz = dim0 * num_per * params.poly_len;
  414. let v_packing = public_params.v_packing.as_ref();
  415. if params.expand_queries {
  416. }
  417. let mut intermediate = Vec::with_capacity(num_per);
  418. let mut intermediate_raw = Vec::with_capacity(num_per);
  419. for _ in 0..dim0 {
  420. intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
  421. intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
  422. }
  423. let mut v_ct = Vec::new();
  424. for trial in 0..(params.n * params.n) {
  425. let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
  426. multiply_reg_by_database(&mut intermediate, db, v_reg_reoriented.as_slice(), params, dim0, num_per);
  427. for i in 0..intermediate.len() {
  428. from_ntt(&mut intermediate_raw[i], &intermediate[i]);
  429. }
  430. fold_ciphertexts(
  431. params,
  432. &mut intermediate_raw,
  433. &v_folding,
  434. &v_folding_neg
  435. );
  436. v_ct.push(intermediate_raw[0]);
  437. }
  438. let packed_ct = pack(
  439. params,
  440. &v_ct,
  441. &v_packing,
  442. );
  443. let mut v_packed_ct = Vec::new();
  444. v_packed_ct.push(packed_ct.raw());
  445. encode(params, &v_packed_ct)
  446. }
  447. #[cfg(test)]
  448. mod test {
  449. use super::*;
  450. use crate::{client::*};
  451. use rand::{prelude::StdRng, Rng};
  452. fn get_params() -> Params {
  453. let mut params = get_expansion_testing_params();
  454. params.db_dim_1 = 6;
  455. params.db_dim_2 = 2;
  456. params.t_exp_right = 8;
  457. params
  458. }
  459. fn dec_reg<'a>(
  460. params: &'a Params,
  461. ct: &PolyMatrixNTT<'a>,
  462. client: &mut Client<'a, StdRng>,
  463. scale_k: u64,
  464. ) -> u64 {
  465. let dec = client.decrypt_matrix_reg(ct).raw();
  466. let mut val = dec.data[0] as i64;
  467. if val >= (params.modulus / 2) as i64 {
  468. val -= params.modulus as i64;
  469. }
  470. let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
  471. if val_rounded == 0 {
  472. 0
  473. } else {
  474. 1
  475. }
  476. }
  477. fn dec_gsw<'a>(
  478. params: &'a Params,
  479. ct: &PolyMatrixNTT<'a>,
  480. client: &mut Client<'a, StdRng>,
  481. ) -> u64 {
  482. let dec = client.decrypt_matrix_reg(ct).raw();
  483. let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
  484. let mut val = dec.data[idx] as i64;
  485. if val >= (params.modulus / 2) as i64 {
  486. val -= params.modulus as i64;
  487. }
  488. if i64::abs(val) < (1i64 << 10) {
  489. 0
  490. } else {
  491. 1
  492. }
  493. }
  494. #[test]
  495. fn coefficient_expansion_is_correct() {
  496. let params = get_params();
  497. let v_neg1 = params.get_v_neg1();
  498. let mut seeded_rng = get_seeded_rng();
  499. let mut client = Client::init(&params, &mut seeded_rng);
  500. let public_params = client.generate_keys();
  501. let mut v = Vec::new();
  502. for _ in 0..(1 << (params.db_dim_1 + 1)) {
  503. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  504. }
  505. let target = 7;
  506. let scale_k = params.modulus / params.pt_modulus;
  507. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  508. sigma.data[target] = scale_k;
  509. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  510. let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
  511. let v_w_left = public_params.v_expansion_left.unwrap();
  512. let v_w_right = public_params.v_expansion_right.unwrap();
  513. coefficient_expansion(
  514. &mut v,
  515. client.g,
  516. client.stop_round,
  517. &params,
  518. &v_w_left,
  519. &v_w_right,
  520. &v_neg1,
  521. params.t_gsw * params.db_dim_2,
  522. );
  523. assert_eq!(dec_reg(&params, &test_ct, &mut client, scale_k), 0);
  524. for i in 0..v.len() {
  525. if i == target {
  526. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 1);
  527. } else {
  528. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 0);
  529. }
  530. }
  531. }
  532. #[test]
  533. fn regev_to_gsw_is_correct() {
  534. let mut params = get_params();
  535. params.db_dim_2 = 1;
  536. let mut seeded_rng = get_seeded_rng();
  537. let mut client = Client::init(&params, &mut seeded_rng);
  538. let public_params = client.generate_keys();
  539. let mut enc_constant = |val| {
  540. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  541. sigma.data[0] = val;
  542. client.encrypt_matrix_reg(&sigma.ntt())
  543. };
  544. let v = &public_params.v_conversion.unwrap()[0];
  545. let bits_per = get_bits_per(&params, params.t_gsw);
  546. let mut v_inp_1 = Vec::new();
  547. let mut v_inp_0 = Vec::new();
  548. for i in 0..params.t_gsw {
  549. let val = 1u64 << (bits_per * i);
  550. v_inp_1.push(enc_constant(val));
  551. v_inp_0.push(enc_constant(0));
  552. }
  553. let mut v_gsw = Vec::new();
  554. v_gsw.push(PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw));
  555. regev_to_gsw(&mut v_gsw, &v_inp_1, v, &params, 1, 0);
  556. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 1);
  557. regev_to_gsw(&mut v_gsw, &v_inp_0, v, &params, 1, 0);
  558. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
  559. }
  560. #[test]
  561. fn multiply_reg_by_database_is_correct() {
  562. let params = get_params();
  563. let mut seeded_rng = get_seeded_rng();
  564. let dim0 = 1 << params.db_dim_1;
  565. let num_per = 1 << params.db_dim_2;
  566. let scale_k = params.modulus / params.pt_modulus;
  567. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  568. let target_idx_dim0 = target_idx / num_per;
  569. let target_idx_num_per = target_idx % num_per;
  570. let mut client = Client::init(&params, &mut seeded_rng);
  571. _ = client.generate_keys();
  572. let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
  573. let mut v_reg = Vec::new();
  574. for i in 0..dim0 {
  575. let val = if i == target_idx_dim0 { scale_k } else { 0 };
  576. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  577. v_reg.push(client.encrypt_matrix_reg(&sigma));
  578. }
  579. let v_reg_sz = dim0 * 2 * params.poly_len;
  580. let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  581. reorient_reg_ciphertexts(&params, v_reg_reoriented.as_mut_slice(), &v_reg);
  582. let mut out = Vec::with_capacity(num_per);
  583. for _ in 0..dim0 {
  584. out.push(PolyMatrixNTT::zero(&params, 2, 1));
  585. }
  586. multiply_reg_by_database(&mut out, db.as_slice(), v_reg_reoriented.as_slice(), &params, dim0, num_per);
  587. // decrypt
  588. let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
  589. let mut dec_rescaled = PolyMatrixRaw::zero(&params, 1, 1);
  590. for z in 0..params.poly_len {
  591. dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
  592. }
  593. for z in 0..params.poly_len {
  594. // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
  595. assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
  596. }
  597. }
  598. #[test]
  599. fn fold_ciphertexts_is_correct() {
  600. let params = get_params();
  601. let mut seeded_rng = get_seeded_rng();
  602. let dim0 = 1 << params.db_dim_1;
  603. let num_per = 1 << params.db_dim_2;
  604. let scale_k = params.modulus / params.pt_modulus;
  605. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  606. let target_idx_num_per = target_idx % num_per;
  607. let mut client = Client::init(&params, &mut seeded_rng);
  608. _ = client.generate_keys();
  609. let mut v_reg = Vec::new();
  610. for i in 0..num_per {
  611. let val = if i == target_idx_num_per { scale_k } else { 0 };
  612. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  613. v_reg.push(client.encrypt_matrix_reg(&sigma));
  614. }
  615. let mut v_reg_raw = Vec::new();
  616. for i in 0..num_per {
  617. v_reg_raw.push(v_reg[i].raw());
  618. }
  619. let bits_per = get_bits_per(&params, params.t_gsw);
  620. let mut v_folding = Vec::new();
  621. for i in 0..params.db_dim_2 {
  622. let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
  623. let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  624. for j in 0..params.t_gsw {
  625. let value = (1u64 << (bits_per * j)) * bit;
  626. let sigma = PolyMatrixRaw::single_value(&params, value);
  627. let sigma_ntt = to_ntt_alloc(&sigma);
  628. let ct = client.encrypt_matrix_reg(&sigma_ntt);
  629. ct_gsw.copy_into(&ct, 0, 2 * j + 1);
  630. let prod = &to_ntt_alloc(&client.sk_reg) * &sigma_ntt;
  631. let ct = &client.encrypt_matrix_reg(&prod);
  632. ct_gsw.copy_into(&ct, 0, 2 * j);
  633. }
  634. v_folding.push(ct_gsw);
  635. }
  636. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
  637. let mut v_folding_neg = Vec::new();
  638. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  639. for i in 0..params.db_dim_2 {
  640. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  641. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  642. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  643. v_folding_neg.push(ct_gsw_neg);
  644. }
  645. fold_ciphertexts(
  646. &params,
  647. &mut v_reg_raw,
  648. &v_folding,
  649. &v_folding_neg
  650. );
  651. // decrypt
  652. assert_eq!(dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
  653. }
  654. #[test]
  655. fn full_protocol_is_correct() {
  656. let params = get_params();
  657. let mut seeded_rng = get_seeded_rng();
  658. let dim0 = 1 << params.db_dim_1;
  659. let num_per = 1 << params.db_dim_2;
  660. let scale_k = params.modulus / params.pt_modulus;
  661. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  662. let target_idx_dim0 = target_idx / num_per;
  663. let target_idx_num_per = target_idx % num_per;
  664. let mut client = Client::init(&params, &mut seeded_rng);
  665. let public_parameters = client.generate_keys();
  666. let query = client.generate_query(target_idx);
  667. let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
  668. let mut v_reg = Vec::new();
  669. for i in 0..dim0 {
  670. let val = if i == target_idx_dim0 { scale_k } else { 0 };
  671. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  672. v_reg.push(client.encrypt_matrix_reg(&sigma));
  673. }
  674. }
  675. }