server.rs 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use std::fs::File;
  4. use std::io::BufReader;
  5. use std::io::Read;
  6. use std::io::Seek;
  7. use std::io::SeekFrom;
  8. use crate::aligned_memory::*;
  9. use crate::arith::*;
  10. use crate::client::PublicParameters;
  11. use crate::client::Query;
  12. use crate::gadget::*;
  13. use crate::params::*;
  14. use crate::poly::*;
  15. use crate::util::*;
  16. use rayon::prelude::*;
  17. pub fn coefficient_expansion(
  18. v: &mut Vec<PolyMatrixNTT>,
  19. g: usize,
  20. stop_round: usize,
  21. params: &Params,
  22. v_w_left: &Vec<PolyMatrixNTT>,
  23. v_w_right: &Vec<PolyMatrixNTT>,
  24. v_neg1: &Vec<PolyMatrixNTT>,
  25. max_bits_to_gen_right: usize,
  26. ) {
  27. let poly_len = params.poly_len;
  28. for r in 0..g {
  29. let num_in = 1 << r;
  30. let num_out = 2 * num_in;
  31. let t = (poly_len / (1 << r)) + 1;
  32. let neg1 = &v_neg1[r];
  33. let action_expand = |(i, v_i): (usize, &mut PolyMatrixNTT)| {
  34. if (stop_round > 0 && r > stop_round && (i % 2) == 1)
  35. || (stop_round > 0
  36. && r == stop_round
  37. && (i % 2) == 1
  38. && (i / 2) >= max_bits_to_gen_right)
  39. {
  40. return;
  41. }
  42. let mut ct = PolyMatrixRaw::zero(params, 2, 1);
  43. let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
  44. let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
  45. let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
  46. let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
  47. let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
  48. let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
  49. let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
  50. let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
  51. let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match (r != 0) && (i % 2 == 0) {
  52. true => (
  53. &v_w_left[r],
  54. params.t_exp_left,
  55. &mut ginv_ct_left,
  56. &mut ginv_ct_left_ntt,
  57. ),
  58. false => (
  59. &v_w_right[r],
  60. params.t_exp_right,
  61. &mut ginv_ct_right,
  62. &mut ginv_ct_right_ntt,
  63. ),
  64. };
  65. // if i < num_in {
  66. // let (src, dest) = v.split_at_mut(num_in);
  67. // scalar_multiply(&mut dest[i], neg1, &src[i]);
  68. // }
  69. from_ntt(&mut ct, &v_i);
  70. automorph(&mut ct_auto, &ct, t);
  71. gadget_invert_rdim(gi_ct, &ct_auto, 1);
  72. to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
  73. ct_auto_1
  74. .data
  75. .as_mut_slice()
  76. .copy_from_slice(ct_auto.get_poly(1, 0));
  77. to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
  78. multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
  79. let mut idx = 0;
  80. for j in 0..2 {
  81. for n in 0..params.crt_count {
  82. for z in 0..poly_len {
  83. let sum = (*v_i).data[idx]
  84. + w_times_ginv_ct.data[idx]
  85. + j * ct_auto_1_ntt.data[n * poly_len + z];
  86. (*v_i).data[idx] = barrett_coeff_u64(params, sum, n);
  87. idx += 1;
  88. }
  89. }
  90. }
  91. };
  92. let (src, dest) = v.split_at_mut(num_in);
  93. src.par_iter_mut()
  94. .zip(dest.par_iter_mut())
  95. .for_each(|(s, d)| {
  96. scalar_multiply(d, neg1, s);
  97. });
  98. v[0..num_in]
  99. .par_iter_mut()
  100. .enumerate()
  101. .for_each(action_expand);
  102. v[num_in..num_out]
  103. .par_iter_mut()
  104. .enumerate()
  105. .for_each(action_expand);
  106. }
  107. }
  108. pub fn regev_to_gsw<'a>(
  109. v_gsw: &mut Vec<PolyMatrixNTT<'a>>,
  110. v_inp: &Vec<PolyMatrixNTT<'a>>,
  111. v: &PolyMatrixNTT<'a>,
  112. params: &'a Params,
  113. idx_factor: usize,
  114. idx_offset: usize,
  115. ) {
  116. assert!(v.rows == 2);
  117. assert!(v.cols == 2 * params.t_conv);
  118. v_gsw.par_iter_mut().enumerate().for_each(|(i, ct)| {
  119. let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
  120. let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
  121. let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
  122. let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
  123. for j in 0..params.t_gsw {
  124. let idx_ct = i * params.t_gsw + j;
  125. let idx_inp = idx_factor * (idx_ct) + idx_offset;
  126. ct.copy_into(&v_inp[idx_inp], 0, 2 * j + 1);
  127. from_ntt(&mut tmp_ct_raw, &v_inp[idx_inp]);
  128. gadget_invert(&mut ginv_c_inp, &tmp_ct_raw);
  129. to_ntt(&mut ginv_c_inp_ntt, &ginv_c_inp);
  130. multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
  131. ct.copy_into(&tmp_ct, 0, 2 * j);
  132. }
  133. });
  134. }
  135. pub const MAX_SUMMED: usize = 1 << 6;
  136. pub const PACKED_OFFSET_2: i32 = 32;
  137. #[cfg(target_feature = "avx2")]
  138. pub fn multiply_reg_by_database(
  139. out: &mut Vec<PolyMatrixNTT>,
  140. db: &[u64],
  141. v_firstdim: &[u64],
  142. params: &Params,
  143. dim0: usize,
  144. num_per: usize,
  145. ) {
  146. let ct_rows = 2;
  147. let ct_cols = 1;
  148. let pt_rows = 1;
  149. let pt_cols = 1;
  150. // assert!(dim0 * ct_rows >= MAX_SUMMED);
  151. let mut sums_out_n0_u64 = AlignedMemory64::new(4);
  152. let mut sums_out_n2_u64 = AlignedMemory64::new(4);
  153. for z in 0..params.poly_len {
  154. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  155. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  156. for i in 0..num_per {
  157. for c in 0..pt_cols {
  158. let mut inner_limit = MAX_SUMMED;
  159. let mut outer_limit = dim0 * ct_rows / inner_limit;
  160. if MAX_SUMMED > dim0 * ct_rows {
  161. inner_limit = dim0 * ct_rows;
  162. outer_limit = 1;
  163. }
  164. let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
  165. let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
  166. for o_jm in 0..outer_limit {
  167. unsafe {
  168. let mut sums_out_n0 = _mm256_setzero_si256();
  169. let mut sums_out_n2 = _mm256_setzero_si256();
  170. for i_jm in 0..inner_limit / 4 {
  171. let jm = o_jm * inner_limit + (4 * i_jm);
  172. let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
  173. idx_b_base += 1;
  174. let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
  175. idx_b_base += 1;
  176. let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
  177. let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
  178. let a = _mm256_load_si256(v_a as *const __m256i);
  179. let a_lo = a;
  180. let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
  181. let b_lo = b;
  182. let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
  183. sums_out_n0 =
  184. _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
  185. sums_out_n2 =
  186. _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
  187. }
  188. // reduce here, otherwise we will overflow
  189. _mm256_store_si256(
  190. sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
  191. sums_out_n0,
  192. );
  193. _mm256_store_si256(
  194. sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
  195. sums_out_n2,
  196. );
  197. for idx in 0..4 {
  198. let val = sums_out_n0_u64[idx];
  199. sums_out_n0_u64_acc[idx] =
  200. barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
  201. }
  202. for idx in 0..4 {
  203. let val = sums_out_n2_u64[idx];
  204. sums_out_n2_u64_acc[idx] =
  205. barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
  206. }
  207. }
  208. }
  209. for idx in 0..4 {
  210. sums_out_n0_u64_acc[idx] =
  211. barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
  212. sums_out_n2_u64_acc[idx] =
  213. barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
  214. }
  215. // output n0
  216. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  217. let mut n = 0;
  218. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  219. out[i].data[idx_c] =
  220. barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
  221. idx_c += pt_cols * crt_count * poly_len;
  222. out[i].data[idx_c] =
  223. barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
  224. // output n1
  225. n = 1;
  226. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  227. out[i].data[idx_c] =
  228. barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
  229. idx_c += pt_cols * crt_count * poly_len;
  230. out[i].data[idx_c] =
  231. barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
  232. }
  233. }
  234. }
  235. }
  236. #[cfg(not(target_feature = "avx2"))]
  237. pub fn multiply_reg_by_database(
  238. out: &mut Vec<PolyMatrixNTT>,
  239. db: &[u64],
  240. v_firstdim: &[u64],
  241. params: &Params,
  242. dim0: usize,
  243. num_per: usize,
  244. ) {
  245. let ct_rows = 2;
  246. let ct_cols = 1;
  247. let pt_rows = 1;
  248. let pt_cols = 1;
  249. for z in 0..params.poly_len {
  250. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  251. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  252. for i in 0..num_per {
  253. for c in 0..pt_cols {
  254. let mut sums_out_n0_0 = 0u128;
  255. let mut sums_out_n0_1 = 0u128;
  256. let mut sums_out_n1_0 = 0u128;
  257. let mut sums_out_n1_1 = 0u128;
  258. for jm in 0..(dim0 * pt_rows) {
  259. let b = db[idx_b_base];
  260. idx_b_base += 1;
  261. let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
  262. let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
  263. let b_lo = b as u32;
  264. let b_hi = (b >> 32) as u32;
  265. let v_a0_lo = v_a0 as u32;
  266. let v_a0_hi = (v_a0 >> 32) as u32;
  267. let v_a1_lo = v_a1 as u32;
  268. let v_a1_hi = (v_a1 >> 32) as u32;
  269. // do n0
  270. sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
  271. sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
  272. // do n1
  273. sums_out_n1_0 += ((v_a0_hi as u64) * (b_hi as u64)) as u128;
  274. sums_out_n1_1 += ((v_a1_hi as u64) * (b_hi as u64)) as u128;
  275. }
  276. // output n0
  277. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  278. let mut n = 0;
  279. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  280. out[i].data[idx_c] = (sums_out_n0_0 % (params.moduli[0] as u128)) as u64;
  281. idx_c += pt_cols * crt_count * poly_len;
  282. out[i].data[idx_c] = (sums_out_n0_1 % (params.moduli[0] as u128)) as u64;
  283. // output n1
  284. n = 1;
  285. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  286. out[i].data[idx_c] = (sums_out_n1_0 % (params.moduli[1] as u128)) as u64;
  287. idx_c += pt_cols * crt_count * poly_len;
  288. out[i].data[idx_c] = (sums_out_n1_1 % (params.moduli[1] as u128)) as u64;
  289. }
  290. }
  291. }
  292. }
  293. pub fn generate_random_db_and_get_item<'a>(
  294. params: &'a Params,
  295. item_idx: usize,
  296. ) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
  297. let mut rng = get_seeded_rng();
  298. let instances = params.instances;
  299. let trials = params.n * params.n;
  300. let dim0 = 1 << params.db_dim_1;
  301. let num_per = 1 << params.db_dim_2;
  302. let num_items = dim0 * num_per;
  303. let db_size_words = instances * trials * num_items * params.poly_len;
  304. let mut v = AlignedMemory64::new(db_size_words);
  305. let mut item = PolyMatrixRaw::zero(params, params.instances * params.n, params.n);
  306. for instance in 0..instances {
  307. for trial in 0..trials {
  308. for i in 0..num_items {
  309. let ii = i % num_per;
  310. let j = i / num_per;
  311. let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
  312. db_item.reduce_mod(params.pt_modulus);
  313. if i == item_idx {
  314. item.copy_into(&db_item, instance * params.n + trial / params.n, trial % params.n);
  315. }
  316. for z in 0..params.poly_len {
  317. db_item.data[z] =
  318. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  319. }
  320. let db_item_ntt = db_item.ntt();
  321. for z in 0..params.poly_len {
  322. let idx_dst = calc_index(
  323. &[instance, trial, z, ii, j],
  324. &[instances, trials, params.poly_len, num_per, dim0],
  325. );
  326. v[idx_dst] = db_item_ntt.data[z]
  327. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  328. }
  329. }
  330. }
  331. }
  332. (item, v)
  333. }
  334. pub fn load_item_from_seek<'a, T: Seek + Read>(
  335. params: &'a Params,
  336. seekable: &mut T,
  337. instance: usize,
  338. trial: usize,
  339. item_idx: usize,
  340. ) -> PolyMatrixRaw<'a> {
  341. let db_item_size = params.db_item_size;
  342. let instances = params.instances;
  343. let trials = params.n * params.n;
  344. let chunks = instances * trials;
  345. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  346. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  347. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  348. assert!(modp_words_per_chunk <= params.poly_len);
  349. let idx_item_in_file = item_idx * db_item_size;
  350. let idx_chunk = instance * trials + trial;
  351. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  352. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  353. let seek_result = seekable.seek(SeekFrom::Start(idx_poly_in_file as u64));
  354. if seek_result.is_err() {
  355. return out;
  356. }
  357. let mut data = vec![0u8; 2 * bytes_per_chunk];
  358. let bytes_read = seekable
  359. .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
  360. .unwrap();
  361. let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
  362. assert!(modp_words_read <= params.poly_len);
  363. for i in 0..modp_words_read {
  364. out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
  365. assert!(out.data[i] <= params.pt_modulus);
  366. }
  367. out
  368. }
  369. pub fn load_db_from_seek<T: Seek + Read>(params: &Params, seekable: &mut T) -> AlignedMemory64 {
  370. let instances = params.instances;
  371. let trials = params.n * params.n;
  372. let dim0 = 1 << params.db_dim_1;
  373. let num_per = 1 << params.db_dim_2;
  374. let num_items = dim0 * num_per;
  375. let db_size_words = instances * trials * num_items * params.poly_len;
  376. let mut v = AlignedMemory64::new(db_size_words);
  377. for instance in 0..instances {
  378. for trial in 0..trials {
  379. for i in 0..num_items {
  380. let ii = i % num_per;
  381. let j = i / num_per;
  382. let mut db_item = load_item_from_seek(params, seekable, instance, trial, i);
  383. // db_item.reduce_mod(params.pt_modulus);
  384. for z in 0..params.poly_len {
  385. db_item.data[z] =
  386. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  387. }
  388. let db_item_ntt = db_item.ntt();
  389. for z in 0..params.poly_len {
  390. let idx_dst = calc_index(
  391. &[instance, trial, z, ii, j],
  392. &[instances, trials, params.poly_len, num_per, dim0],
  393. );
  394. v[idx_dst] = db_item_ntt.data[z]
  395. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  396. }
  397. }
  398. }
  399. }
  400. v
  401. }
  402. pub fn load_file_unsafe(data: &mut [u64], file: &mut File) {
  403. let data_as_u8_mut = unsafe { data.align_to_mut::<u8>().1 };
  404. file.read_exact(data_as_u8_mut).unwrap();
  405. }
  406. pub fn load_file(data: &mut [u64], file: &mut File) {
  407. let mut reader = BufReader::with_capacity(1 << 24, file);
  408. let mut buf = [0u8; 8];
  409. for i in 0..data.len() {
  410. reader.read(&mut buf).unwrap();
  411. data[i] = u64::from_ne_bytes(buf);
  412. }
  413. }
  414. pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  415. let instances = params.instances;
  416. let trials = params.n * params.n;
  417. let dim0 = 1 << params.db_dim_1;
  418. let num_per = 1 << params.db_dim_2;
  419. let num_items = dim0 * num_per;
  420. let db_size_words = instances * trials * num_items * params.poly_len;
  421. let mut v = AlignedMemory64::new(db_size_words);
  422. let v_mut_slice = v.as_mut_slice();
  423. load_file(v_mut_slice, file);
  424. v
  425. }
  426. pub fn fold_ciphertexts(
  427. params: &Params,
  428. v_cts: &mut Vec<PolyMatrixRaw>,
  429. v_folding: &Vec<PolyMatrixNTT>,
  430. v_folding_neg: &Vec<PolyMatrixNTT>,
  431. ) {
  432. if v_cts.len() == 1 {
  433. return;
  434. }
  435. let further_dims = log2(v_cts.len() as u64) as usize;
  436. let ell = v_folding[0].cols / 2;
  437. let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
  438. let mut ginv_c_ntt = PolyMatrixNTT::zero(&params, 2 * ell, 1);
  439. let mut prod = PolyMatrixNTT::zero(&params, 2, 1);
  440. let mut sum = PolyMatrixNTT::zero(&params, 2, 1);
  441. let mut num_per = v_cts.len();
  442. for cur_dim in 0..further_dims {
  443. num_per = num_per / 2;
  444. for i in 0..num_per {
  445. gadget_invert(&mut ginv_c, &v_cts[i]);
  446. to_ntt(&mut ginv_c_ntt, &ginv_c);
  447. multiply(
  448. &mut prod,
  449. &v_folding_neg[further_dims - 1 - cur_dim],
  450. &ginv_c_ntt,
  451. );
  452. gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
  453. to_ntt(&mut ginv_c_ntt, &ginv_c);
  454. multiply(
  455. &mut sum,
  456. &v_folding[further_dims - 1 - cur_dim],
  457. &ginv_c_ntt,
  458. );
  459. add_into(&mut sum, &prod);
  460. from_ntt(&mut v_cts[i], &sum);
  461. }
  462. }
  463. }
  464. pub fn pack<'a>(
  465. params: &'a Params,
  466. v_ct: &Vec<PolyMatrixRaw>,
  467. v_w: &Vec<PolyMatrixNTT>,
  468. ) -> PolyMatrixNTT<'a> {
  469. assert!(v_ct.len() >= params.n * params.n);
  470. assert!(v_w.len() == params.n);
  471. assert!(v_ct[0].rows == 2);
  472. assert!(v_ct[0].cols == 1);
  473. assert!(v_w[0].rows == (params.n + 1));
  474. assert!(v_w[0].cols == params.t_conv);
  475. let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
  476. let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
  477. let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
  478. let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
  479. let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
  480. let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
  481. let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
  482. for c in 0..params.n {
  483. let mut v_int = PolyMatrixNTT::zero(&params, params.n + 1, 1);
  484. for r in 0..params.n {
  485. let w = &v_w[r];
  486. let ct = &v_ct[r * params.n + c];
  487. ct_1.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(0, 0));
  488. ct_2.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(1, 0));
  489. to_ntt(&mut ct_2_ntt, &ct_2);
  490. gadget_invert(&mut ginv, &ct_1);
  491. to_ntt(&mut ginv_nttd, &ginv);
  492. multiply(&mut prod, &w, &ginv_nttd);
  493. add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
  494. add_into(&mut v_int, &prod);
  495. }
  496. result.copy_into(&v_int, 0, c);
  497. }
  498. result
  499. }
  500. pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
  501. let q1 = 4 * params.pt_modulus;
  502. let q1_bits = log2_ceil(q1) as usize;
  503. let q2 = Q2_VALUES[params.q2_bits as usize];
  504. let q2_bits = params.q2_bits as usize;
  505. let num_bits = params.instances
  506. * ((q2_bits * params.n * params.poly_len)
  507. + (q1_bits * params.n * params.n * params.poly_len));
  508. let round_to = 64;
  509. let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
  510. let mut result = vec![0u8; num_bytes_rounded_up];
  511. let mut bit_offs = 0;
  512. for instance in 0..params.instances {
  513. let packed_ct = &v_packed_ct[instance];
  514. let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
  515. let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
  516. first_row.apply_func(|x| rescale(x, params.modulus, q2));
  517. rest_rows.apply_func(|x| rescale(x, params.modulus, q1));
  518. let data = result.as_mut_slice();
  519. for i in 0..params.n * params.poly_len {
  520. write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
  521. bit_offs += q2_bits;
  522. }
  523. for i in 0..params.n * params.n * params.poly_len {
  524. write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
  525. bit_offs += q1_bits;
  526. }
  527. }
  528. result
  529. }
  530. pub fn get_v_folding_neg<'a>(
  531. params: &'a Params,
  532. v_folding: &Vec<PolyMatrixNTT<'a>>,
  533. ) -> Vec<PolyMatrixNTT<'a>> {
  534. let gadget_ntt = build_gadget(params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
  535. let v_folding_neg = (0..params.db_dim_2)
  536. .into_par_iter()
  537. .map(|i| {
  538. let mut ct_gsw_inv = PolyMatrixRaw::zero(params, 2, 2 * params.t_gsw);
  539. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  540. let mut ct_gsw_neg = PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw);
  541. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  542. ct_gsw_neg
  543. })
  544. .collect();
  545. v_folding_neg
  546. }
  547. pub fn expand_query<'a>(
  548. params: &'a Params,
  549. public_params: &PublicParameters<'a>,
  550. query: &Query<'a>,
  551. ) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
  552. let dim0 = 1 << params.db_dim_1;
  553. let further_dims = params.db_dim_2;
  554. let mut v_reg_reoriented;
  555. let mut v_folding;
  556. let num_bits_to_gen = params.t_gsw * further_dims + dim0;
  557. let g = log2_ceil_usize(num_bits_to_gen);
  558. let right_expanded = params.t_gsw * further_dims;
  559. let stop_round = log2_ceil_usize(right_expanded);
  560. let mut v = Vec::new();
  561. for _ in 0..(1 << g) {
  562. v.push(PolyMatrixNTT::zero(params, 2, 1));
  563. }
  564. v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
  565. let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
  566. let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
  567. let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
  568. let v_neg1 = params.get_v_neg1();
  569. let mut v_reg_inp = Vec::with_capacity(dim0);
  570. let mut v_gsw_inp = Vec::with_capacity(right_expanded);
  571. if further_dims > 0 {
  572. coefficient_expansion(
  573. &mut v,
  574. g,
  575. stop_round,
  576. params,
  577. &v_w_left,
  578. &v_w_right,
  579. &v_neg1,
  580. params.t_gsw * params.db_dim_2,
  581. );
  582. for i in 0..dim0 {
  583. v_reg_inp.push(v[2 * i].clone());
  584. }
  585. for i in 0..right_expanded {
  586. v_gsw_inp.push(v[2 * i + 1].clone());
  587. }
  588. } else {
  589. coefficient_expansion(
  590. &mut v,
  591. g,
  592. 0,
  593. params,
  594. &v_w_left,
  595. &v_w_left,
  596. &v_neg1,
  597. 0,
  598. );
  599. for i in 0..dim0 {
  600. v_reg_inp.push(v[i].clone());
  601. }
  602. }
  603. let v_reg_sz = dim0 * 2 * params.poly_len;
  604. v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  605. reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
  606. v_folding = Vec::new();
  607. for _ in 0..params.db_dim_2 {
  608. v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
  609. }
  610. regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
  611. (v_reg_reoriented, v_folding)
  612. }
  613. pub fn process_query(
  614. params: &Params,
  615. public_params: &PublicParameters,
  616. query: &Query,
  617. db: &[u64],
  618. ) -> Vec<u8> {
  619. let dim0 = 1 << params.db_dim_1;
  620. let num_per = 1 << params.db_dim_2;
  621. let db_slice_sz = dim0 * num_per * params.poly_len;
  622. let v_packing = public_params.v_packing.as_ref();
  623. let mut v_reg_reoriented;
  624. let v_folding;
  625. if params.expand_queries {
  626. (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
  627. } else {
  628. v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
  629. v_reg_reoriented
  630. .as_mut_slice()
  631. .copy_from_slice(query.v_buf.as_ref().unwrap());
  632. v_folding = query
  633. .v_ct
  634. .as_ref()
  635. .unwrap()
  636. .iter()
  637. .map(|x| x.ntt())
  638. .collect();
  639. }
  640. let v_folding_neg = get_v_folding_neg(params, &v_folding);
  641. let v_packed_ct = (0..params.instances)
  642. .into_par_iter()
  643. .map(|instance| {
  644. let mut intermediate = Vec::with_capacity(num_per);
  645. let mut intermediate_raw = Vec::with_capacity(num_per);
  646. for _ in 0..num_per {
  647. intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
  648. intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
  649. }
  650. let mut v_ct = Vec::new();
  651. for trial in 0..(params.n * params.n) {
  652. let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
  653. let cur_db = &db[idx..(idx + db_slice_sz)];
  654. multiply_reg_by_database(
  655. &mut intermediate,
  656. cur_db,
  657. v_reg_reoriented.as_slice(),
  658. params,
  659. dim0,
  660. num_per,
  661. );
  662. for i in 0..intermediate.len() {
  663. from_ntt(&mut intermediate_raw[i], &intermediate[i]);
  664. }
  665. fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
  666. v_ct.push(intermediate_raw[0].clone());
  667. }
  668. let packed_ct = pack(params, &v_ct, &v_packing);
  669. packed_ct.raw()
  670. })
  671. .collect();
  672. encode(params, &v_packed_ct)
  673. }
  674. #[cfg(test)]
  675. mod test {
  676. use super::*;
  677. use crate::client::*;
  678. use rand::{prelude::SmallRng, Rng};
  679. const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
  680. fn get_params() -> Params {
  681. get_fast_expansion_testing_params()
  682. }
  683. fn dec_reg<'a>(
  684. params: &'a Params,
  685. ct: &PolyMatrixNTT<'a>,
  686. client: &mut Client<'a, SmallRng>,
  687. scale_k: u64,
  688. ) -> u64 {
  689. let dec = client.decrypt_matrix_reg(ct).raw();
  690. let mut val = dec.data[0] as i64;
  691. if val >= (params.modulus / 2) as i64 {
  692. val -= params.modulus as i64;
  693. }
  694. let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
  695. if val_rounded == 0 {
  696. 0
  697. } else {
  698. 1
  699. }
  700. }
  701. fn dec_gsw<'a>(
  702. params: &'a Params,
  703. ct: &PolyMatrixNTT<'a>,
  704. client: &mut Client<'a, SmallRng>,
  705. ) -> u64 {
  706. let dec = client.decrypt_matrix_reg(ct).raw();
  707. let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
  708. let mut val = dec.data[idx] as i64;
  709. if val >= (params.modulus / 2) as i64 {
  710. val -= params.modulus as i64;
  711. }
  712. if i64::abs(val) < (1i64 << 10) {
  713. 0
  714. } else {
  715. 1
  716. }
  717. }
  718. #[test]
  719. fn coefficient_expansion_is_correct() {
  720. let params = get_params();
  721. let v_neg1 = params.get_v_neg1();
  722. let mut seeded_rng = get_seeded_rng();
  723. let mut client = Client::init(&params, &mut seeded_rng);
  724. let public_params = client.generate_keys();
  725. let mut v = Vec::new();
  726. for _ in 0..(1 << (params.db_dim_1 + 1)) {
  727. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  728. }
  729. let target = 7;
  730. let scale_k = params.modulus / params.pt_modulus;
  731. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  732. sigma.data[target] = scale_k;
  733. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  734. let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
  735. let v_w_left = public_params.v_expansion_left.unwrap();
  736. let v_w_right = public_params.v_expansion_right.unwrap();
  737. coefficient_expansion(
  738. &mut v,
  739. params.g(),
  740. params.stop_round(),
  741. &params,
  742. &v_w_left,
  743. &v_w_right,
  744. &v_neg1,
  745. params.t_gsw * params.db_dim_2,
  746. );
  747. assert_eq!(dec_reg(&params, &test_ct, &mut client, scale_k), 0);
  748. for i in 0..v.len() {
  749. if i == target {
  750. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 1);
  751. } else {
  752. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 0);
  753. }
  754. }
  755. }
  756. #[test]
  757. fn regev_to_gsw_is_correct() {
  758. let mut params = get_params();
  759. params.db_dim_2 = 1;
  760. let mut seeded_rng = get_seeded_rng();
  761. let mut client = Client::init(&params, &mut seeded_rng);
  762. let public_params = client.generate_keys();
  763. let mut enc_constant = |val| {
  764. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  765. sigma.data[0] = val;
  766. client.encrypt_matrix_reg(&sigma.ntt())
  767. };
  768. let v = &public_params.v_conversion.unwrap()[0];
  769. let bits_per = get_bits_per(&params, params.t_gsw);
  770. let mut v_inp_1 = Vec::new();
  771. let mut v_inp_0 = Vec::new();
  772. for i in 0..params.t_gsw {
  773. let val = 1u64 << (bits_per * i);
  774. v_inp_1.push(enc_constant(val));
  775. v_inp_0.push(enc_constant(0));
  776. }
  777. let mut v_gsw = Vec::new();
  778. v_gsw.push(PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw));
  779. regev_to_gsw(&mut v_gsw, &v_inp_1, v, &params, 1, 0);
  780. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 1);
  781. regev_to_gsw(&mut v_gsw, &v_inp_0, v, &params, 1, 0);
  782. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
  783. }
  784. #[test]
  785. fn multiply_reg_by_database_is_correct() {
  786. let params = get_params();
  787. let mut seeded_rng = get_seeded_rng();
  788. let dim0 = 1 << params.db_dim_1;
  789. let num_per = 1 << params.db_dim_2;
  790. let scale_k = params.modulus / params.pt_modulus;
  791. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  792. let target_idx_dim0 = target_idx / num_per;
  793. let target_idx_num_per = target_idx % num_per;
  794. let mut client = Client::init(&params, &mut seeded_rng);
  795. _ = client.generate_keys();
  796. let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
  797. let mut v_reg = Vec::new();
  798. for i in 0..dim0 {
  799. let val = if i == target_idx_dim0 { scale_k } else { 0 };
  800. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  801. v_reg.push(client.encrypt_matrix_reg(&sigma));
  802. }
  803. let v_reg_sz = dim0 * 2 * params.poly_len;
  804. let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  805. reorient_reg_ciphertexts(&params, v_reg_reoriented.as_mut_slice(), &v_reg);
  806. let mut out = Vec::with_capacity(num_per);
  807. for _ in 0..dim0 {
  808. out.push(PolyMatrixNTT::zero(&params, 2, 1));
  809. }
  810. multiply_reg_by_database(
  811. &mut out,
  812. db.as_slice(),
  813. v_reg_reoriented.as_slice(),
  814. &params,
  815. dim0,
  816. num_per,
  817. );
  818. // decrypt
  819. let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
  820. let mut dec_rescaled = PolyMatrixRaw::zero(&params, 1, 1);
  821. for z in 0..params.poly_len {
  822. dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
  823. }
  824. for z in 0..params.poly_len {
  825. // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
  826. assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
  827. }
  828. }
  829. #[test]
  830. fn fold_ciphertexts_is_correct() {
  831. let params = get_params();
  832. let mut seeded_rng = get_seeded_rng();
  833. let dim0 = 1 << params.db_dim_1;
  834. let num_per = 1 << params.db_dim_2;
  835. let scale_k = params.modulus / params.pt_modulus;
  836. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  837. let target_idx_num_per = target_idx % num_per;
  838. let mut client = Client::init(&params, &mut seeded_rng);
  839. _ = client.generate_keys();
  840. let mut v_reg = Vec::new();
  841. for i in 0..num_per {
  842. let val = if i == target_idx_num_per { scale_k } else { 0 };
  843. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  844. v_reg.push(client.encrypt_matrix_reg(&sigma));
  845. }
  846. let mut v_reg_raw = Vec::new();
  847. for i in 0..num_per {
  848. v_reg_raw.push(v_reg[i].raw());
  849. }
  850. let bits_per = get_bits_per(&params, params.t_gsw);
  851. let mut v_folding = Vec::new();
  852. for i in 0..params.db_dim_2 {
  853. let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
  854. let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  855. for j in 0..params.t_gsw {
  856. let value = (1u64 << (bits_per * j)) * bit;
  857. let sigma = PolyMatrixRaw::single_value(&params, value);
  858. let sigma_ntt = to_ntt_alloc(&sigma);
  859. let ct = client.encrypt_matrix_reg(&sigma_ntt);
  860. ct_gsw.copy_into(&ct, 0, 2 * j + 1);
  861. let prod = &to_ntt_alloc(client.get_sk_reg()) * &sigma_ntt;
  862. let ct = &client.encrypt_matrix_reg(&prod);
  863. ct_gsw.copy_into(&ct, 0, 2 * j);
  864. }
  865. v_folding.push(ct_gsw);
  866. }
  867. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
  868. let mut v_folding_neg = Vec::new();
  869. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  870. for i in 0..params.db_dim_2 {
  871. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  872. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  873. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  874. v_folding_neg.push(ct_gsw_neg);
  875. }
  876. fold_ciphertexts(&params, &mut v_reg_raw, &v_folding, &v_folding_neg);
  877. // decrypt
  878. assert_eq!(
  879. dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k),
  880. 1
  881. );
  882. }
  883. fn full_protocol_is_correct_for_params(params: &Params) {
  884. let mut seeded_rng = get_seeded_rng();
  885. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  886. let mut client = Client::init(params, &mut seeded_rng);
  887. let public_params = client.generate_keys();
  888. let query = client.generate_query(target_idx);
  889. let (corr_item, db) = generate_random_db_and_get_item(params, target_idx);
  890. let response = process_query(params, &public_params, &query, db.as_slice());
  891. let result = client.decode_response(response.as_slice());
  892. let p_bits = log2_ceil(params.pt_modulus) as usize;
  893. let corr_result = corr_item.to_vec(p_bits, params.modp_words_per_chunk());
  894. assert_eq!(result.len(), corr_result.len());
  895. for z in 0..corr_result.len() {
  896. assert_eq!(result[z], corr_result[z], "at {:?}", z);
  897. }
  898. }
  899. fn full_protocol_is_correct_for_params_real_db(params: &Params) {
  900. let mut seeded_rng = get_seeded_rng();
  901. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  902. let mut client = Client::init(params, &mut seeded_rng);
  903. let public_params = client.generate_keys();
  904. let query = client.generate_query(target_idx);
  905. let mut file = File::open(TEST_PREPROCESSED_DB_PATH).unwrap();
  906. let db = load_preprocessed_db_from_file(params, &mut file);
  907. let response = process_query(params, &public_params, &query, db.as_slice());
  908. let result = client.decode_response(response.as_slice());
  909. let corr_result = vec![0x42, 0x5a, 0x68];
  910. for z in 0..corr_result.len() {
  911. assert_eq!(result[z], corr_result[z]);
  912. }
  913. }
  914. #[test]
  915. fn full_protocol_is_correct() {
  916. full_protocol_is_correct_for_params(&get_params());
  917. }
  918. #[test]
  919. #[ignore]
  920. fn larger_full_protocol_is_correct() {
  921. let cfg_expand = r#"
  922. {
  923. 'n': 2,
  924. 'nu_1': 10,
  925. 'nu_2': 6,
  926. 'p': 512,
  927. 'q2_bits': 21,
  928. 's_e': 85.83255142749422,
  929. 't_gsw': 10,
  930. 't_conv': 4,
  931. 't_exp_left': 16,
  932. 't_exp_right': 56,
  933. 'instances': 1,
  934. 'db_item_size': 9000 }
  935. "#;
  936. let cfg = cfg_expand;
  937. let cfg = cfg.replace("'", "\"");
  938. let params = params_from_json(&cfg);
  939. full_protocol_is_correct_for_params(&params);
  940. full_protocol_is_correct_for_params_real_db(&params);
  941. }
  942. // #[test]
  943. // fn full_protocol_is_correct_20_256() {
  944. // full_protocol_is_correct_for_params(&params_from_json(&CFG_20_256.replace("'", "\"")));
  945. // }
  946. // #[test]
  947. // fn full_protocol_is_correct_16_100000() {
  948. // full_protocol_is_correct_for_params(&params_from_json(&CFG_16_100000.replace("'", "\"")));
  949. // }
  950. #[test]
  951. #[ignore]
  952. fn full_protocol_is_correct_real_db_16_100000() {
  953. full_protocol_is_correct_for_params_real_db(&params_from_json(
  954. &CFG_16_100000.replace("'", "\""),
  955. ));
  956. }
  957. }