server.rs 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use std::fs::File;
  4. use std::io::BufReader;
  5. use std::io::Read;
  6. use std::io::Seek;
  7. use std::io::SeekFrom;
  8. use std::mem::size_of;
  9. use crate::aligned_memory::*;
  10. use crate::arith::*;
  11. use crate::client::PublicParameters;
  12. use crate::client::Query;
  13. use crate::gadget::*;
  14. use crate::params::*;
  15. use crate::poly::*;
  16. use crate::util::*;
  17. pub fn coefficient_expansion(
  18. v: &mut Vec<PolyMatrixNTT>,
  19. g: usize,
  20. stop_round: usize,
  21. params: &Params,
  22. v_w_left: &Vec<PolyMatrixNTT>,
  23. v_w_right: &Vec<PolyMatrixNTT>,
  24. v_neg1: &Vec<PolyMatrixNTT>,
  25. max_bits_to_gen_right: usize,
  26. ) {
  27. let poly_len = params.poly_len;
  28. let mut ct = PolyMatrixRaw::zero(params, 2, 1);
  29. let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
  30. let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
  31. let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
  32. let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
  33. let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
  34. let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
  35. let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
  36. let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
  37. for r in 0..g {
  38. let num_in = 1 << r;
  39. let num_out = 2 * num_in;
  40. let t = (poly_len / (1 << r)) + 1;
  41. let neg1 = &v_neg1[r];
  42. for i in 0..num_out {
  43. if stop_round > 0 && i % 2 == 1 && r > stop_round
  44. || (r == stop_round && i / 2 >= max_bits_to_gen_right)
  45. {
  46. continue;
  47. }
  48. let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
  49. 0 => (
  50. &v_w_left[r],
  51. params.t_exp_left,
  52. &mut ginv_ct_left,
  53. &mut ginv_ct_left_ntt,
  54. ),
  55. 1 | _ => (
  56. &v_w_right[r],
  57. params.t_exp_right,
  58. &mut ginv_ct_right,
  59. &mut ginv_ct_right_ntt,
  60. ),
  61. };
  62. if i < num_in {
  63. let (src, dest) = v.split_at_mut(num_in);
  64. scalar_multiply(&mut dest[i], neg1, &src[i]);
  65. }
  66. from_ntt(&mut ct, &v[i]);
  67. automorph(&mut ct_auto, &ct, t);
  68. gadget_invert_rdim(gi_ct, &ct_auto, 1);
  69. to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
  70. ct_auto_1
  71. .data
  72. .as_mut_slice()
  73. .copy_from_slice(ct_auto.get_poly(1, 0));
  74. to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
  75. multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
  76. let mut idx = 0;
  77. for j in 0..2 {
  78. for n in 0..params.crt_count {
  79. for z in 0..poly_len {
  80. let sum = v[i].data[idx]
  81. + w_times_ginv_ct.data[idx]
  82. + j * ct_auto_1_ntt.data[n * poly_len + z];
  83. v[i].data[idx] = barrett_coeff_u64(params, sum, n);
  84. idx += 1;
  85. }
  86. }
  87. }
  88. }
  89. }
  90. }
  91. pub fn regev_to_gsw<'a>(
  92. v_gsw: &mut Vec<PolyMatrixNTT<'a>>,
  93. v_inp: &Vec<PolyMatrixNTT<'a>>,
  94. v: &PolyMatrixNTT<'a>,
  95. params: &'a Params,
  96. idx_factor: usize,
  97. idx_offset: usize,
  98. ) {
  99. assert!(v.rows == 2);
  100. assert!(v.cols == 2 * params.t_conv);
  101. let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
  102. let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
  103. let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
  104. let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
  105. for i in 0..params.db_dim_2 {
  106. let ct = &mut v_gsw[i];
  107. for j in 0..params.t_gsw {
  108. let idx_ct = i * params.t_gsw + j;
  109. let idx_inp = idx_factor * (idx_ct) + idx_offset;
  110. ct.copy_into(&v_inp[idx_inp], 0, 2 * j + 1);
  111. from_ntt(&mut tmp_ct_raw, &v_inp[idx_inp]);
  112. gadget_invert(&mut ginv_c_inp, &tmp_ct_raw);
  113. to_ntt(&mut ginv_c_inp_ntt, &ginv_c_inp);
  114. multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
  115. ct.copy_into(&tmp_ct, 0, 2 * j);
  116. }
  117. }
  118. }
  119. pub const MAX_SUMMED: usize = 1 << 6;
  120. pub const PACKED_OFFSET_2: i32 = 32;
  121. #[cfg(target_feature = "avx2")]
  122. pub fn multiply_reg_by_database(
  123. out: &mut Vec<PolyMatrixNTT>,
  124. db: &[u64],
  125. v_firstdim: &[u64],
  126. params: &Params,
  127. dim0: usize,
  128. num_per: usize,
  129. ) {
  130. let ct_rows = 2;
  131. let ct_cols = 1;
  132. let pt_rows = 1;
  133. let pt_cols = 1;
  134. assert!(dim0 * ct_rows >= MAX_SUMMED);
  135. let mut sums_out_n0_u64 = AlignedMemory64::new(4);
  136. let mut sums_out_n2_u64 = AlignedMemory64::new(4);
  137. for z in 0..params.poly_len {
  138. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  139. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  140. for i in 0..num_per {
  141. for c in 0..pt_cols {
  142. let inner_limit = MAX_SUMMED;
  143. let outer_limit = dim0 * ct_rows / inner_limit;
  144. let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
  145. let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
  146. for o_jm in 0..outer_limit {
  147. unsafe {
  148. let mut sums_out_n0 = _mm256_setzero_si256();
  149. let mut sums_out_n2 = _mm256_setzero_si256();
  150. for i_jm in 0..inner_limit / 4 {
  151. let jm = o_jm * inner_limit + (4 * i_jm);
  152. let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
  153. idx_b_base += 1;
  154. let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
  155. idx_b_base += 1;
  156. let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
  157. let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
  158. let a = _mm256_load_si256(v_a as *const __m256i);
  159. let a_lo = a;
  160. let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
  161. let b_lo = b;
  162. let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
  163. sums_out_n0 =
  164. _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
  165. sums_out_n2 =
  166. _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
  167. }
  168. // reduce here, otherwise we will overflow
  169. _mm256_store_si256(
  170. sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
  171. sums_out_n0,
  172. );
  173. _mm256_store_si256(
  174. sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
  175. sums_out_n2,
  176. );
  177. for idx in 0..4 {
  178. let val = sums_out_n0_u64[idx];
  179. sums_out_n0_u64_acc[idx] =
  180. barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
  181. }
  182. for idx in 0..4 {
  183. let val = sums_out_n2_u64[idx];
  184. sums_out_n2_u64_acc[idx] =
  185. barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
  186. }
  187. }
  188. }
  189. for idx in 0..4 {
  190. sums_out_n0_u64_acc[idx] =
  191. barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
  192. sums_out_n2_u64_acc[idx] =
  193. barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
  194. }
  195. // output n0
  196. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  197. let mut n = 0;
  198. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  199. out[i].data[idx_c] =
  200. barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
  201. idx_c += pt_cols * crt_count * poly_len;
  202. out[i].data[idx_c] =
  203. barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
  204. // output n1
  205. n = 1;
  206. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  207. out[i].data[idx_c] =
  208. barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
  209. idx_c += pt_cols * crt_count * poly_len;
  210. out[i].data[idx_c] =
  211. barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
  212. }
  213. }
  214. }
  215. }
  216. #[cfg(not(target_feature = "avx2"))]
  217. pub fn multiply_reg_by_database(
  218. out: &mut Vec<PolyMatrixNTT>,
  219. db: &[u64],
  220. v_firstdim: &[u64],
  221. params: &Params,
  222. dim0: usize,
  223. num_per: usize,
  224. ) {
  225. let ct_rows = 2;
  226. let ct_cols = 1;
  227. let pt_rows = 1;
  228. let pt_cols = 1;
  229. for z in 0..params.poly_len {
  230. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  231. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  232. for i in 0..num_per {
  233. for c in 0..pt_cols {
  234. let mut sums_out_n0_0 = 0u128;
  235. let mut sums_out_n0_1 = 0u128;
  236. let mut sums_out_n1_0 = 0u128;
  237. let mut sums_out_n1_1 = 0u128;
  238. for jm in 0..(dim0 * pt_rows) {
  239. let b = db[idx_b_base];
  240. idx_b_base += 1;
  241. let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
  242. let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
  243. let b_lo = b as u32;
  244. let b_hi = (b >> 32) as u32;
  245. let v_a0_lo = v_a0 as u32;
  246. let v_a0_hi = (v_a0 >> 32) as u32;
  247. let v_a1_lo = v_a1 as u32;
  248. let v_a1_hi = (v_a1 >> 32) as u32;
  249. // do n0
  250. sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
  251. sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
  252. // do n1
  253. sums_out_n1_0 += ((v_a0_hi as u64) * (b_hi as u64)) as u128;
  254. sums_out_n1_1 += ((v_a1_hi as u64) * (b_hi as u64)) as u128;
  255. }
  256. // output n0
  257. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  258. let mut n = 0;
  259. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  260. out[i].data[idx_c] = (sums_out_n0_0 % (params.moduli[0] as u128)) as u64;
  261. idx_c += pt_cols * crt_count * poly_len;
  262. out[i].data[idx_c] = (sums_out_n0_1 % (params.moduli[0] as u128)) as u64;
  263. // output n1
  264. n = 1;
  265. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  266. out[i].data[idx_c] = (sums_out_n1_0 % (params.moduli[1] as u128)) as u64;
  267. idx_c += pt_cols * crt_count * poly_len;
  268. out[i].data[idx_c] = (sums_out_n1_1 % (params.moduli[1] as u128)) as u64;
  269. }
  270. }
  271. }
  272. }
  273. pub fn generate_random_db_and_get_item<'a>(
  274. params: &'a Params,
  275. item_idx: usize,
  276. ) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
  277. let mut rng = get_seeded_rng();
  278. let instances = params.instances;
  279. let trials = params.n * params.n;
  280. let dim0 = 1 << params.db_dim_1;
  281. let num_per = 1 << params.db_dim_2;
  282. let num_items = dim0 * num_per;
  283. let db_size_words = instances * trials * num_items * params.poly_len;
  284. let mut v = AlignedMemory64::new(db_size_words);
  285. let mut tmp_item_ntt = PolyMatrixNTT::zero(params, 1, 1);
  286. let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
  287. for instance in 0..instances {
  288. println!("Instance {:?}", instance);
  289. for trial in 0..trials {
  290. println!("Trial {:?}", trial);
  291. for i in 0..num_items {
  292. let ii = i % num_per;
  293. let j = i / num_per;
  294. let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
  295. db_item.reduce_mod(params.pt_modulus);
  296. if i == item_idx && instance == 0 {
  297. item.copy_into(&db_item, trial / params.n, trial % params.n);
  298. }
  299. for z in 0..params.poly_len {
  300. db_item.data[z] =
  301. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  302. }
  303. let db_item_ntt = db_item.ntt();
  304. for z in 0..params.poly_len {
  305. let idx_dst = calc_index(
  306. &[instance, trial, z, ii, j],
  307. &[instances, trials, params.poly_len, num_per, dim0],
  308. );
  309. v[idx_dst] = db_item_ntt.data[z]
  310. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  311. }
  312. }
  313. }
  314. }
  315. (item, v)
  316. }
  317. pub fn load_item_from_file<'a>(
  318. params: &'a Params,
  319. file: &mut File,
  320. instance: usize,
  321. trial: usize,
  322. item_idx: usize,
  323. ) -> PolyMatrixRaw<'a> {
  324. let db_item_size = params.db_item_size;
  325. let instances = params.instances;
  326. let trials = params.n * params.n;
  327. let dim0 = 1 << params.db_dim_1;
  328. let num_per = 1 << params.db_dim_2;
  329. let num_items = dim0 * num_per;
  330. let chunks = instances * trials;
  331. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  332. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  333. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  334. assert!(modp_words_per_chunk <= params.poly_len);
  335. let idx_item_in_file = item_idx * db_item_size;
  336. let idx_chunk = instance * trials + trial;
  337. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  338. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  339. let seek_result = file.seek(SeekFrom::Start(idx_poly_in_file as u64));
  340. if seek_result.is_err() {
  341. return out;
  342. }
  343. let mut data = vec![0u8; 2 * bytes_per_chunk];
  344. let bytes_read = file
  345. .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
  346. .unwrap();
  347. let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
  348. assert!(modp_words_read <= params.poly_len);
  349. for i in 0..modp_words_read {
  350. out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
  351. assert!(out.data[i] <= params.pt_modulus);
  352. }
  353. out
  354. }
  355. pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  356. let instances = params.instances;
  357. let trials = params.n * params.n;
  358. let dim0 = 1 << params.db_dim_1;
  359. let num_per = 1 << params.db_dim_2;
  360. let num_items = dim0 * num_per;
  361. let db_size_words = instances * trials * num_items * params.poly_len;
  362. let mut v = AlignedMemory64::new(db_size_words);
  363. for instance in 0..instances {
  364. println!("Instance {:?}", instance);
  365. for trial in 0..trials {
  366. println!("Trial {:?}", trial);
  367. for i in 0..num_items {
  368. if i % 8192 == 0 {
  369. println!("item {:?}", i);
  370. }
  371. let ii = i % num_per;
  372. let j = i / num_per;
  373. let mut db_item = load_item_from_file(params, file, instance, trial, i);
  374. // db_item.reduce_mod(params.pt_modulus);
  375. for z in 0..params.poly_len {
  376. db_item.data[z] =
  377. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  378. }
  379. let db_item_ntt = db_item.ntt();
  380. for z in 0..params.poly_len {
  381. let idx_dst = calc_index(
  382. &[instance, trial, z, ii, j],
  383. &[instances, trials, params.poly_len, num_per, dim0],
  384. );
  385. v[idx_dst] = db_item_ntt.data[z]
  386. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  387. }
  388. }
  389. }
  390. }
  391. v
  392. }
  393. pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  394. let instances = params.instances;
  395. let trials = params.n * params.n;
  396. let dim0 = 1 << params.db_dim_1;
  397. let num_per = 1 << params.db_dim_2;
  398. let num_items = dim0 * num_per;
  399. let db_size_words = instances * trials * num_items * params.poly_len;
  400. let mut v = AlignedMemory64::new(db_size_words);
  401. let v_mut_slice = v.as_mut_slice();
  402. let mut reader = BufReader::new(file);
  403. let mut buf = [0u8; 8];
  404. for i in 0..db_size_words {
  405. if i % 1000000000 == 0 {
  406. println!("{} GB loaded", i);
  407. }
  408. reader.read(&mut buf).unwrap();
  409. v_mut_slice[i] = u64::from_ne_bytes(buf);
  410. }
  411. v
  412. }
  413. pub fn fold_ciphertexts(
  414. params: &Params,
  415. v_cts: &mut Vec<PolyMatrixRaw>,
  416. v_folding: &Vec<PolyMatrixNTT>,
  417. v_folding_neg: &Vec<PolyMatrixNTT>,
  418. ) {
  419. let further_dims = log2(v_cts.len() as u64) as usize;
  420. let ell = v_folding[0].cols / 2;
  421. let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
  422. let mut ginv_c_ntt = PolyMatrixNTT::zero(&params, 2 * ell, 1);
  423. let mut prod = PolyMatrixNTT::zero(&params, 2, 1);
  424. let mut sum = PolyMatrixNTT::zero(&params, 2, 1);
  425. let mut num_per = v_cts.len();
  426. for cur_dim in 0..further_dims {
  427. num_per = num_per / 2;
  428. for i in 0..num_per {
  429. gadget_invert(&mut ginv_c, &v_cts[i]);
  430. to_ntt(&mut ginv_c_ntt, &ginv_c);
  431. multiply(
  432. &mut prod,
  433. &v_folding_neg[further_dims - 1 - cur_dim],
  434. &ginv_c_ntt,
  435. );
  436. gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
  437. to_ntt(&mut ginv_c_ntt, &ginv_c);
  438. multiply(
  439. &mut sum,
  440. &v_folding[further_dims - 1 - cur_dim],
  441. &ginv_c_ntt,
  442. );
  443. add_into(&mut sum, &prod);
  444. from_ntt(&mut v_cts[i], &sum);
  445. }
  446. }
  447. }
  448. pub fn pack<'a>(
  449. params: &'a Params,
  450. v_ct: &Vec<PolyMatrixRaw>,
  451. v_w: &Vec<PolyMatrixNTT>,
  452. ) -> PolyMatrixNTT<'a> {
  453. assert!(v_ct.len() >= params.n * params.n);
  454. assert!(v_w.len() == params.n);
  455. assert!(v_ct[0].rows == 2);
  456. assert!(v_ct[0].cols == 1);
  457. assert!(v_w[0].rows == (params.n + 1));
  458. assert!(v_w[0].cols == params.t_conv);
  459. let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
  460. let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
  461. let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
  462. let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
  463. let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
  464. let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
  465. let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
  466. for c in 0..params.n {
  467. let mut v_int = PolyMatrixNTT::zero(&params, params.n + 1, 1);
  468. for r in 0..params.n {
  469. let w = &v_w[r];
  470. let ct = &v_ct[r * params.n + c];
  471. ct_1.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(0, 0));
  472. ct_2.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(1, 0));
  473. to_ntt(&mut ct_2_ntt, &ct_2);
  474. gadget_invert(&mut ginv, &ct_1);
  475. to_ntt(&mut ginv_nttd, &ginv);
  476. multiply(&mut prod, &w, &ginv_nttd);
  477. add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
  478. add_into(&mut v_int, &prod);
  479. }
  480. result.copy_into(&v_int, 0, c);
  481. }
  482. result
  483. }
  484. pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
  485. let q1 = 4 * params.pt_modulus;
  486. let q1_bits = log2_ceil(q1) as usize;
  487. let q2 = Q2_VALUES[params.q2_bits as usize];
  488. let q2_bits = params.q2_bits as usize;
  489. let num_bits = params.instances
  490. * ((q2_bits * params.n * params.poly_len)
  491. + (q1_bits * params.n * params.n * params.poly_len));
  492. let round_to = 64;
  493. let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
  494. let mut result = vec![0u8; num_bytes_rounded_up];
  495. let mut bit_offs = 0;
  496. for instance in 0..params.instances {
  497. let packed_ct = &v_packed_ct[instance];
  498. let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
  499. let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
  500. first_row.apply_func(|x| rescale(x, params.modulus, q2));
  501. rest_rows.apply_func(|x| rescale(x, params.modulus, q1));
  502. let data = result.as_mut_slice();
  503. for i in 0..params.n * params.poly_len {
  504. write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
  505. bit_offs += q2_bits;
  506. }
  507. for i in 0..params.n * params.n * params.poly_len {
  508. write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
  509. bit_offs += q1_bits;
  510. }
  511. }
  512. result
  513. }
  514. pub fn get_v_folding_neg<'a>(
  515. params: &'a Params,
  516. v_folding: &Vec<PolyMatrixNTT>,
  517. ) -> Vec<PolyMatrixNTT<'a>> {
  518. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
  519. let mut v_folding_neg = Vec::new();
  520. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  521. for i in 0..params.db_dim_2 {
  522. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  523. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  524. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  525. v_folding_neg.push(ct_gsw_neg);
  526. }
  527. v_folding_neg
  528. }
  529. pub fn expand_query<'a>(
  530. params: &'a Params,
  531. public_params: &PublicParameters<'a>,
  532. query: &Query<'a>,
  533. ) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
  534. let dim0 = 1 << params.db_dim_1;
  535. let further_dims = params.db_dim_2;
  536. let mut v_reg_reoriented;
  537. let mut v_folding;
  538. let num_bits_to_gen = params.t_gsw * further_dims + dim0;
  539. let g = log2_ceil_usize(num_bits_to_gen);
  540. let right_expanded = params.t_gsw * further_dims;
  541. let stop_round = log2_ceil_usize(right_expanded);
  542. let mut v = Vec::new();
  543. for _ in 0..(1 << g) {
  544. v.push(PolyMatrixNTT::zero(params, 2, 1));
  545. }
  546. v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
  547. let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
  548. let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
  549. let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
  550. let v_neg1 = params.get_v_neg1();
  551. coefficient_expansion(
  552. &mut v,
  553. g,
  554. stop_round,
  555. params,
  556. &v_w_left,
  557. &v_w_right,
  558. &v_neg1,
  559. params.t_gsw * params.db_dim_2,
  560. );
  561. let mut v_reg_inp = Vec::with_capacity(dim0);
  562. for i in 0..dim0 {
  563. v_reg_inp.push(v[2 * i].clone());
  564. }
  565. let mut v_gsw_inp = Vec::with_capacity(right_expanded);
  566. for i in 0..right_expanded {
  567. v_gsw_inp.push(v[2 * i + 1].clone());
  568. }
  569. let v_reg_sz = dim0 * 2 * params.poly_len;
  570. v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  571. reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
  572. v_folding = Vec::new();
  573. for _ in 0..params.db_dim_2 {
  574. v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
  575. }
  576. regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
  577. (v_reg_reoriented, v_folding)
  578. }
  579. pub fn process_query(
  580. params: &Params,
  581. public_params: &PublicParameters,
  582. query: &Query,
  583. db: &[u64],
  584. ) -> Vec<u8> {
  585. let dim0 = 1 << params.db_dim_1;
  586. let num_per = 1 << params.db_dim_2;
  587. let db_slice_sz = dim0 * num_per * params.poly_len;
  588. let v_packing = public_params.v_packing.as_ref();
  589. let mut v_reg_reoriented;
  590. let v_folding;
  591. if params.expand_queries {
  592. (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
  593. } else {
  594. v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
  595. v_reg_reoriented
  596. .as_mut_slice()
  597. .copy_from_slice(query.v_buf.as_ref().unwrap());
  598. v_folding = query
  599. .v_ct
  600. .as_ref()
  601. .unwrap()
  602. .clone()
  603. .iter()
  604. .map(|x| x.ntt())
  605. .collect();
  606. }
  607. let v_folding_neg = get_v_folding_neg(params, &v_folding);
  608. let mut intermediate = Vec::with_capacity(num_per);
  609. let mut intermediate_raw = Vec::with_capacity(num_per);
  610. for _ in 0..num_per {
  611. intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
  612. intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
  613. }
  614. let mut v_packed_ct = Vec::new();
  615. for instance in 0..params.instances {
  616. let mut v_ct = Vec::new();
  617. for trial in 0..(params.n * params.n) {
  618. let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
  619. let cur_db = &db[idx..(idx + db_slice_sz)];
  620. multiply_reg_by_database(
  621. &mut intermediate,
  622. cur_db,
  623. v_reg_reoriented.as_slice(),
  624. params,
  625. dim0,
  626. num_per,
  627. );
  628. for i in 0..intermediate.len() {
  629. from_ntt(&mut intermediate_raw[i], &intermediate[i]);
  630. }
  631. fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
  632. v_ct.push(intermediate_raw[0].clone());
  633. }
  634. let packed_ct = pack(params, &v_ct, &v_packing);
  635. v_packed_ct.push(packed_ct.raw());
  636. }
  637. encode(params, &v_packed_ct)
  638. }
  639. #[cfg(test)]
  640. mod test {
  641. use super::*;
  642. use crate::client::*;
  643. use rand::{prelude::SmallRng, Rng};
  644. const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
  645. fn get_params() -> Params {
  646. let mut params = get_expansion_testing_params();
  647. params.db_dim_1 = 6;
  648. params.db_dim_2 = 2;
  649. params.t_exp_right = 8;
  650. params
  651. }
  652. fn dec_reg<'a>(
  653. params: &'a Params,
  654. ct: &PolyMatrixNTT<'a>,
  655. client: &mut Client<'a, SmallRng>,
  656. scale_k: u64,
  657. ) -> u64 {
  658. let dec = client.decrypt_matrix_reg(ct).raw();
  659. let mut val = dec.data[0] as i64;
  660. if val >= (params.modulus / 2) as i64 {
  661. val -= params.modulus as i64;
  662. }
  663. let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
  664. if val_rounded == 0 {
  665. 0
  666. } else {
  667. 1
  668. }
  669. }
  670. fn dec_gsw<'a>(
  671. params: &'a Params,
  672. ct: &PolyMatrixNTT<'a>,
  673. client: &mut Client<'a, SmallRng>,
  674. ) -> u64 {
  675. let dec = client.decrypt_matrix_reg(ct).raw();
  676. let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
  677. let mut val = dec.data[idx] as i64;
  678. if val >= (params.modulus / 2) as i64 {
  679. val -= params.modulus as i64;
  680. }
  681. if i64::abs(val) < (1i64 << 10) {
  682. 0
  683. } else {
  684. 1
  685. }
  686. }
  687. #[test]
  688. fn coefficient_expansion_is_correct() {
  689. let params = get_params();
  690. let v_neg1 = params.get_v_neg1();
  691. let mut seeded_rng = get_seeded_rng();
  692. let mut client = Client::init(&params, &mut seeded_rng);
  693. let public_params = client.generate_keys();
  694. let mut v = Vec::new();
  695. for _ in 0..(1 << (params.db_dim_1 + 1)) {
  696. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  697. }
  698. let target = 7;
  699. let scale_k = params.modulus / params.pt_modulus;
  700. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  701. sigma.data[target] = scale_k;
  702. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  703. let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
  704. let v_w_left = public_params.v_expansion_left.unwrap();
  705. let v_w_right = public_params.v_expansion_right.unwrap();
  706. coefficient_expansion(
  707. &mut v,
  708. client.g,
  709. client.stop_round,
  710. &params,
  711. &v_w_left,
  712. &v_w_right,
  713. &v_neg1,
  714. params.t_gsw * params.db_dim_2,
  715. );
  716. assert_eq!(dec_reg(&params, &test_ct, &mut client, scale_k), 0);
  717. for i in 0..v.len() {
  718. if i == target {
  719. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 1);
  720. } else {
  721. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 0);
  722. }
  723. }
  724. }
  725. #[test]
  726. fn regev_to_gsw_is_correct() {
  727. let mut params = get_params();
  728. params.db_dim_2 = 1;
  729. let mut seeded_rng = get_seeded_rng();
  730. let mut client = Client::init(&params, &mut seeded_rng);
  731. let public_params = client.generate_keys();
  732. let mut enc_constant = |val| {
  733. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  734. sigma.data[0] = val;
  735. client.encrypt_matrix_reg(&sigma.ntt())
  736. };
  737. let v = &public_params.v_conversion.unwrap()[0];
  738. let bits_per = get_bits_per(&params, params.t_gsw);
  739. let mut v_inp_1 = Vec::new();
  740. let mut v_inp_0 = Vec::new();
  741. for i in 0..params.t_gsw {
  742. let val = 1u64 << (bits_per * i);
  743. v_inp_1.push(enc_constant(val));
  744. v_inp_0.push(enc_constant(0));
  745. }
  746. let mut v_gsw = Vec::new();
  747. v_gsw.push(PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw));
  748. regev_to_gsw(&mut v_gsw, &v_inp_1, v, &params, 1, 0);
  749. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 1);
  750. regev_to_gsw(&mut v_gsw, &v_inp_0, v, &params, 1, 0);
  751. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
  752. }
  753. #[test]
  754. fn multiply_reg_by_database_is_correct() {
  755. let params = get_params();
  756. let mut seeded_rng = get_seeded_rng();
  757. let dim0 = 1 << params.db_dim_1;
  758. let num_per = 1 << params.db_dim_2;
  759. let scale_k = params.modulus / params.pt_modulus;
  760. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  761. let target_idx_dim0 = target_idx / num_per;
  762. let target_idx_num_per = target_idx % num_per;
  763. let mut client = Client::init(&params, &mut seeded_rng);
  764. _ = client.generate_keys();
  765. let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
  766. let mut v_reg = Vec::new();
  767. for i in 0..dim0 {
  768. let val = if i == target_idx_dim0 { scale_k } else { 0 };
  769. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  770. v_reg.push(client.encrypt_matrix_reg(&sigma));
  771. }
  772. let v_reg_sz = dim0 * 2 * params.poly_len;
  773. let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  774. reorient_reg_ciphertexts(&params, v_reg_reoriented.as_mut_slice(), &v_reg);
  775. let mut out = Vec::with_capacity(num_per);
  776. for _ in 0..dim0 {
  777. out.push(PolyMatrixNTT::zero(&params, 2, 1));
  778. }
  779. multiply_reg_by_database(
  780. &mut out,
  781. db.as_slice(),
  782. v_reg_reoriented.as_slice(),
  783. &params,
  784. dim0,
  785. num_per,
  786. );
  787. // decrypt
  788. let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
  789. let mut dec_rescaled = PolyMatrixRaw::zero(&params, 1, 1);
  790. for z in 0..params.poly_len {
  791. dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
  792. }
  793. for z in 0..params.poly_len {
  794. // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
  795. assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
  796. }
  797. }
  798. #[test]
  799. fn fold_ciphertexts_is_correct() {
  800. let params = get_params();
  801. let mut seeded_rng = get_seeded_rng();
  802. let dim0 = 1 << params.db_dim_1;
  803. let num_per = 1 << params.db_dim_2;
  804. let scale_k = params.modulus / params.pt_modulus;
  805. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  806. let target_idx_num_per = target_idx % num_per;
  807. let mut client = Client::init(&params, &mut seeded_rng);
  808. _ = client.generate_keys();
  809. let mut v_reg = Vec::new();
  810. for i in 0..num_per {
  811. let val = if i == target_idx_num_per { scale_k } else { 0 };
  812. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  813. v_reg.push(client.encrypt_matrix_reg(&sigma));
  814. }
  815. let mut v_reg_raw = Vec::new();
  816. for i in 0..num_per {
  817. v_reg_raw.push(v_reg[i].raw());
  818. }
  819. let bits_per = get_bits_per(&params, params.t_gsw);
  820. let mut v_folding = Vec::new();
  821. for i in 0..params.db_dim_2 {
  822. let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
  823. let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  824. for j in 0..params.t_gsw {
  825. let value = (1u64 << (bits_per * j)) * bit;
  826. let sigma = PolyMatrixRaw::single_value(&params, value);
  827. let sigma_ntt = to_ntt_alloc(&sigma);
  828. let ct = client.encrypt_matrix_reg(&sigma_ntt);
  829. ct_gsw.copy_into(&ct, 0, 2 * j + 1);
  830. let prod = &to_ntt_alloc(&client.sk_reg) * &sigma_ntt;
  831. let ct = &client.encrypt_matrix_reg(&prod);
  832. ct_gsw.copy_into(&ct, 0, 2 * j);
  833. }
  834. v_folding.push(ct_gsw);
  835. }
  836. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
  837. let mut v_folding_neg = Vec::new();
  838. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  839. for i in 0..params.db_dim_2 {
  840. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  841. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  842. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  843. v_folding_neg.push(ct_gsw_neg);
  844. }
  845. fold_ciphertexts(&params, &mut v_reg_raw, &v_folding, &v_folding_neg);
  846. // decrypt
  847. assert_eq!(
  848. dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k),
  849. 1
  850. );
  851. }
  852. fn full_protocol_is_correct_for_params(params: &Params) {
  853. let mut seeded_rng = get_seeded_rng();
  854. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  855. let mut client = Client::init(params, &mut seeded_rng);
  856. let public_params = client.generate_keys();
  857. let query = client.generate_query(target_idx);
  858. let (corr_item, db) = generate_random_db_and_get_item(params, target_idx);
  859. let response = process_query(params, &public_params, &query, db.as_slice());
  860. let result = client.decode_response(response.as_slice());
  861. let p_bits = log2_ceil(params.pt_modulus) as usize;
  862. let corr_result = corr_item.to_vec(p_bits, params.poly_len);
  863. for z in 0..corr_result.len() {
  864. assert_eq!(result[z], corr_result[z]);
  865. }
  866. }
  867. fn full_protocol_is_correct_for_params_real_db(params: &Params) {
  868. let mut seeded_rng = get_seeded_rng();
  869. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  870. let mut client = Client::init(params, &mut seeded_rng);
  871. let public_params = client.generate_keys();
  872. let query = client.generate_query(target_idx);
  873. let mut file = File::open(TEST_PREPROCESSED_DB_PATH).unwrap();
  874. let db = load_preprocessed_db_from_file(params, &mut file);
  875. let response = process_query(params, &public_params, &query, db.as_slice());
  876. let result = client.decode_response(response.as_slice());
  877. let corr_result = vec![0x42, 0x5a, 0x68];
  878. for z in 0..corr_result.len() {
  879. assert_eq!(result[z], corr_result[z]);
  880. }
  881. }
  882. #[test]
  883. fn full_protocol_is_correct() {
  884. full_protocol_is_correct_for_params(&get_params());
  885. }
  886. // #[test]
  887. // fn full_protocol_is_correct_20_256() {
  888. // full_protocol_is_correct_for_params(&params_from_json(&CFG_20_256.replace("'", "\"")));
  889. // }
  890. // #[test]
  891. // fn full_protocol_is_correct_16_100000() {
  892. // full_protocol_is_correct_for_params(&params_from_json(&CFG_16_100000.replace("'", "\"")));
  893. // }
  894. #[test]
  895. #[ignore]
  896. fn full_protocol_is_correct_real_db_16_100000() {
  897. full_protocol_is_correct_for_params_real_db(&params_from_json(
  898. &CFG_16_100000.replace("'", "\""),
  899. ));
  900. }
  901. }