server.rs 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use std::fs::File;
  4. use std::io::BufReader;
  5. use std::io::Read;
  6. use std::io::Seek;
  7. use std::io::SeekFrom;
  8. use std::time::Instant;
  9. use crate::aligned_memory::*;
  10. use crate::arith::*;
  11. use crate::client::PublicParameters;
  12. use crate::client::Query;
  13. use crate::gadget::*;
  14. use crate::params::*;
  15. use crate::poly::*;
  16. use crate::util::*;
  17. use rayon::prelude::*;
  18. pub fn coefficient_expansion(
  19. v: &mut Vec<PolyMatrixNTT>,
  20. g: usize,
  21. stop_round: usize,
  22. params: &Params,
  23. v_w_left: &Vec<PolyMatrixNTT>,
  24. v_w_right: &Vec<PolyMatrixNTT>,
  25. v_neg1: &Vec<PolyMatrixNTT>,
  26. max_bits_to_gen_right: usize,
  27. ) {
  28. let poly_len = params.poly_len;
  29. let mut ct = PolyMatrixRaw::zero(params, 2, 1);
  30. let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
  31. let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
  32. let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
  33. let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
  34. let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
  35. let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
  36. let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
  37. let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
  38. for r in 0..g {
  39. let num_in = 1 << r;
  40. let num_out = 2 * num_in;
  41. let t = (poly_len / (1 << r)) + 1;
  42. let neg1 = &v_neg1[r];
  43. for i in 0..num_out {
  44. if (stop_round > 0 && r > stop_round && (i % 2) == 1)
  45. || (stop_round > 0 && r == stop_round && (i % 2) == 1 && (i / 2) >= max_bits_to_gen_right)
  46. {
  47. continue;
  48. }
  49. let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
  50. 0 => (
  51. &v_w_left[r],
  52. params.t_exp_left,
  53. &mut ginv_ct_left,
  54. &mut ginv_ct_left_ntt,
  55. ),
  56. 1 | _ => (
  57. &v_w_right[r],
  58. params.t_exp_right,
  59. &mut ginv_ct_right,
  60. &mut ginv_ct_right_ntt,
  61. ),
  62. };
  63. if i < num_in {
  64. let (src, dest) = v.split_at_mut(num_in);
  65. scalar_multiply(&mut dest[i], neg1, &src[i]);
  66. }
  67. from_ntt(&mut ct, &v[i]);
  68. automorph(&mut ct_auto, &ct, t);
  69. gadget_invert_rdim(gi_ct, &ct_auto, 1);
  70. to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
  71. ct_auto_1
  72. .data
  73. .as_mut_slice()
  74. .copy_from_slice(ct_auto.get_poly(1, 0));
  75. to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
  76. multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
  77. let mut idx = 0;
  78. for j in 0..2 {
  79. for n in 0..params.crt_count {
  80. for z in 0..poly_len {
  81. let sum = v[i].data[idx]
  82. + w_times_ginv_ct.data[idx]
  83. + j * ct_auto_1_ntt.data[n * poly_len + z];
  84. v[i].data[idx] = barrett_coeff_u64(params, sum, n);
  85. idx += 1;
  86. }
  87. }
  88. }
  89. }
  90. }
  91. }
  92. pub fn regev_to_gsw<'a>(
  93. v_gsw: &mut Vec<PolyMatrixNTT<'a>>,
  94. v_inp: &Vec<PolyMatrixNTT<'a>>,
  95. v: &PolyMatrixNTT<'a>,
  96. params: &'a Params,
  97. idx_factor: usize,
  98. idx_offset: usize,
  99. ) {
  100. assert!(v.rows == 2);
  101. assert!(v.cols == 2 * params.t_conv);
  102. let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
  103. let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
  104. let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
  105. let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
  106. for i in 0..params.db_dim_2 {
  107. let ct = &mut v_gsw[i];
  108. for j in 0..params.t_gsw {
  109. let idx_ct = i * params.t_gsw + j;
  110. let idx_inp = idx_factor * (idx_ct) + idx_offset;
  111. ct.copy_into(&v_inp[idx_inp], 0, 2 * j + 1);
  112. from_ntt(&mut tmp_ct_raw, &v_inp[idx_inp]);
  113. gadget_invert(&mut ginv_c_inp, &tmp_ct_raw);
  114. to_ntt(&mut ginv_c_inp_ntt, &ginv_c_inp);
  115. multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
  116. ct.copy_into(&tmp_ct, 0, 2 * j);
  117. }
  118. }
  119. }
  120. pub const MAX_SUMMED: usize = 1 << 6;
  121. pub const PACKED_OFFSET_2: i32 = 32;
  122. #[cfg(target_feature = "avx2")]
  123. pub fn multiply_reg_by_database(
  124. out: &mut Vec<PolyMatrixNTT>,
  125. db: &[u64],
  126. v_firstdim: &[u64],
  127. params: &Params,
  128. dim0: usize,
  129. num_per: usize,
  130. ) {
  131. let ct_rows = 2;
  132. let ct_cols = 1;
  133. let pt_rows = 1;
  134. let pt_cols = 1;
  135. assert!(dim0 * ct_rows >= MAX_SUMMED);
  136. let mut sums_out_n0_u64 = AlignedMemory64::new(4);
  137. let mut sums_out_n2_u64 = AlignedMemory64::new(4);
  138. for z in 0..params.poly_len {
  139. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  140. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  141. for i in 0..num_per {
  142. for c in 0..pt_cols {
  143. let inner_limit = MAX_SUMMED;
  144. let outer_limit = dim0 * ct_rows / inner_limit;
  145. let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
  146. let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
  147. for o_jm in 0..outer_limit {
  148. unsafe {
  149. let mut sums_out_n0 = _mm256_setzero_si256();
  150. let mut sums_out_n2 = _mm256_setzero_si256();
  151. for i_jm in 0..inner_limit / 4 {
  152. let jm = o_jm * inner_limit + (4 * i_jm);
  153. let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
  154. idx_b_base += 1;
  155. let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
  156. idx_b_base += 1;
  157. let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
  158. let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
  159. let a = _mm256_load_si256(v_a as *const __m256i);
  160. let a_lo = a;
  161. let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
  162. let b_lo = b;
  163. let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
  164. sums_out_n0 =
  165. _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
  166. sums_out_n2 =
  167. _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
  168. }
  169. // reduce here, otherwise we will overflow
  170. _mm256_store_si256(
  171. sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
  172. sums_out_n0,
  173. );
  174. _mm256_store_si256(
  175. sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
  176. sums_out_n2,
  177. );
  178. for idx in 0..4 {
  179. let val = sums_out_n0_u64[idx];
  180. sums_out_n0_u64_acc[idx] =
  181. barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
  182. }
  183. for idx in 0..4 {
  184. let val = sums_out_n2_u64[idx];
  185. sums_out_n2_u64_acc[idx] =
  186. barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
  187. }
  188. }
  189. }
  190. for idx in 0..4 {
  191. sums_out_n0_u64_acc[idx] =
  192. barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
  193. sums_out_n2_u64_acc[idx] =
  194. barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
  195. }
  196. // output n0
  197. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  198. let mut n = 0;
  199. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  200. out[i].data[idx_c] =
  201. barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
  202. idx_c += pt_cols * crt_count * poly_len;
  203. out[i].data[idx_c] =
  204. barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
  205. // output n1
  206. n = 1;
  207. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  208. out[i].data[idx_c] =
  209. barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
  210. idx_c += pt_cols * crt_count * poly_len;
  211. out[i].data[idx_c] =
  212. barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
  213. }
  214. }
  215. }
  216. }
  217. #[cfg(not(target_feature = "avx2"))]
  218. pub fn multiply_reg_by_database(
  219. out: &mut Vec<PolyMatrixNTT>,
  220. db: &[u64],
  221. v_firstdim: &[u64],
  222. params: &Params,
  223. dim0: usize,
  224. num_per: usize,
  225. ) {
  226. let ct_rows = 2;
  227. let ct_cols = 1;
  228. let pt_rows = 1;
  229. let pt_cols = 1;
  230. for z in 0..params.poly_len {
  231. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  232. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  233. for i in 0..num_per {
  234. for c in 0..pt_cols {
  235. let mut sums_out_n0_0 = 0u128;
  236. let mut sums_out_n0_1 = 0u128;
  237. let mut sums_out_n1_0 = 0u128;
  238. let mut sums_out_n1_1 = 0u128;
  239. for jm in 0..(dim0 * pt_rows) {
  240. let b = db[idx_b_base];
  241. idx_b_base += 1;
  242. let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
  243. let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
  244. let b_lo = b as u32;
  245. let b_hi = (b >> 32) as u32;
  246. let v_a0_lo = v_a0 as u32;
  247. let v_a0_hi = (v_a0 >> 32) as u32;
  248. let v_a1_lo = v_a1 as u32;
  249. let v_a1_hi = (v_a1 >> 32) as u32;
  250. // do n0
  251. sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
  252. sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
  253. // do n1
  254. sums_out_n1_0 += ((v_a0_hi as u64) * (b_hi as u64)) as u128;
  255. sums_out_n1_1 += ((v_a1_hi as u64) * (b_hi as u64)) as u128;
  256. }
  257. // output n0
  258. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  259. let mut n = 0;
  260. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  261. out[i].data[idx_c] = (sums_out_n0_0 % (params.moduli[0] as u128)) as u64;
  262. idx_c += pt_cols * crt_count * poly_len;
  263. out[i].data[idx_c] = (sums_out_n0_1 % (params.moduli[0] as u128)) as u64;
  264. // output n1
  265. n = 1;
  266. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  267. out[i].data[idx_c] = (sums_out_n1_0 % (params.moduli[1] as u128)) as u64;
  268. idx_c += pt_cols * crt_count * poly_len;
  269. out[i].data[idx_c] = (sums_out_n1_1 % (params.moduli[1] as u128)) as u64;
  270. }
  271. }
  272. }
  273. }
  274. pub fn generate_random_db_and_get_item<'a>(
  275. params: &'a Params,
  276. item_idx: usize,
  277. ) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
  278. let mut rng = get_seeded_rng();
  279. let instances = params.instances;
  280. let trials = params.n * params.n;
  281. let dim0 = 1 << params.db_dim_1;
  282. let num_per = 1 << params.db_dim_2;
  283. let num_items = dim0 * num_per;
  284. let db_size_words = instances * trials * num_items * params.poly_len;
  285. let mut v = AlignedMemory64::new(db_size_words);
  286. let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
  287. for instance in 0..instances {
  288. println!("Instance {:?}", instance);
  289. for trial in 0..trials {
  290. println!("Trial {:?}", trial);
  291. for i in 0..num_items {
  292. let ii = i % num_per;
  293. let j = i / num_per;
  294. let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
  295. db_item.reduce_mod(params.pt_modulus);
  296. if i == item_idx && instance == 0 {
  297. item.copy_into(&db_item, trial / params.n, trial % params.n);
  298. }
  299. for z in 0..params.poly_len {
  300. db_item.data[z] =
  301. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  302. }
  303. let db_item_ntt = db_item.ntt();
  304. for z in 0..params.poly_len {
  305. let idx_dst = calc_index(
  306. &[instance, trial, z, ii, j],
  307. &[instances, trials, params.poly_len, num_per, dim0],
  308. );
  309. v[idx_dst] = db_item_ntt.data[z]
  310. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  311. }
  312. }
  313. }
  314. }
  315. (item, v)
  316. }
  317. pub fn load_item_from_file<'a>(
  318. params: &'a Params,
  319. file: &mut File,
  320. instance: usize,
  321. trial: usize,
  322. item_idx: usize,
  323. ) -> PolyMatrixRaw<'a> {
  324. let db_item_size = params.db_item_size;
  325. let instances = params.instances;
  326. let trials = params.n * params.n;
  327. let chunks = instances * trials;
  328. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  329. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  330. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  331. assert!(modp_words_per_chunk <= params.poly_len);
  332. let idx_item_in_file = item_idx * db_item_size;
  333. let idx_chunk = instance * trials + trial;
  334. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  335. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  336. let seek_result = file.seek(SeekFrom::Start(idx_poly_in_file as u64));
  337. if seek_result.is_err() {
  338. return out;
  339. }
  340. let mut data = vec![0u8; 2 * bytes_per_chunk];
  341. let bytes_read = file
  342. .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
  343. .unwrap();
  344. let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
  345. assert!(modp_words_read <= params.poly_len);
  346. for i in 0..modp_words_read {
  347. out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
  348. assert!(out.data[i] <= params.pt_modulus);
  349. }
  350. out
  351. }
  352. pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  353. let instances = params.instances;
  354. let trials = params.n * params.n;
  355. let dim0 = 1 << params.db_dim_1;
  356. let num_per = 1 << params.db_dim_2;
  357. let num_items = dim0 * num_per;
  358. let db_size_words = instances * trials * num_items * params.poly_len;
  359. let mut v = AlignedMemory64::new(db_size_words);
  360. for instance in 0..instances {
  361. println!("Instance {:?}", instance);
  362. for trial in 0..trials {
  363. println!("Trial {:?}", trial);
  364. for i in 0..num_items {
  365. if i % 8192 == 0 {
  366. println!("item {:?}", i);
  367. }
  368. let ii = i % num_per;
  369. let j = i / num_per;
  370. let mut db_item = load_item_from_file(params, file, instance, trial, i);
  371. // db_item.reduce_mod(params.pt_modulus);
  372. for z in 0..params.poly_len {
  373. db_item.data[z] =
  374. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  375. }
  376. let db_item_ntt = db_item.ntt();
  377. for z in 0..params.poly_len {
  378. let idx_dst = calc_index(
  379. &[instance, trial, z, ii, j],
  380. &[instances, trials, params.poly_len, num_per, dim0],
  381. );
  382. v[idx_dst] = db_item_ntt.data[z]
  383. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  384. }
  385. }
  386. }
  387. }
  388. v
  389. }
  390. pub fn load_file_unsafe(data: &mut[u64], file: &mut File) {
  391. let data_as_u8_mut = unsafe {
  392. data.align_to_mut::<u8>().1
  393. };
  394. file.read_exact(data_as_u8_mut).unwrap();
  395. }
  396. pub fn load_file(data: &mut[u64], file: &mut File) {
  397. let mut reader = BufReader::with_capacity(1 << 24, file);
  398. let mut buf = [0u8; 8];
  399. for i in 0..data.len() {
  400. reader.read(&mut buf).unwrap();
  401. data[i] = u64::from_ne_bytes(buf);
  402. }
  403. }
  404. pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  405. let instances = params.instances;
  406. let trials = params.n * params.n;
  407. let dim0 = 1 << params.db_dim_1;
  408. let num_per = 1 << params.db_dim_2;
  409. let num_items = dim0 * num_per;
  410. let db_size_words = instances * trials * num_items * params.poly_len;
  411. let mut v = AlignedMemory64::new(db_size_words);
  412. let v_mut_slice = v.as_mut_slice();
  413. let now = Instant::now();
  414. load_file(v_mut_slice, file);
  415. println!("Done loading ({} ms).", now.elapsed().as_millis());
  416. v
  417. }
  418. pub fn fold_ciphertexts(
  419. params: &Params,
  420. v_cts: &mut Vec<PolyMatrixRaw>,
  421. v_folding: &Vec<PolyMatrixNTT>,
  422. v_folding_neg: &Vec<PolyMatrixNTT>,
  423. ) {
  424. let further_dims = log2(v_cts.len() as u64) as usize;
  425. let ell = v_folding[0].cols / 2;
  426. let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
  427. let mut ginv_c_ntt = PolyMatrixNTT::zero(&params, 2 * ell, 1);
  428. let mut prod = PolyMatrixNTT::zero(&params, 2, 1);
  429. let mut sum = PolyMatrixNTT::zero(&params, 2, 1);
  430. let mut num_per = v_cts.len();
  431. for cur_dim in 0..further_dims {
  432. num_per = num_per / 2;
  433. for i in 0..num_per {
  434. gadget_invert(&mut ginv_c, &v_cts[i]);
  435. to_ntt(&mut ginv_c_ntt, &ginv_c);
  436. multiply(
  437. &mut prod,
  438. &v_folding_neg[further_dims - 1 - cur_dim],
  439. &ginv_c_ntt,
  440. );
  441. gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
  442. to_ntt(&mut ginv_c_ntt, &ginv_c);
  443. multiply(
  444. &mut sum,
  445. &v_folding[further_dims - 1 - cur_dim],
  446. &ginv_c_ntt,
  447. );
  448. add_into(&mut sum, &prod);
  449. from_ntt(&mut v_cts[i], &sum);
  450. }
  451. }
  452. }
  453. pub fn pack<'a>(
  454. params: &'a Params,
  455. v_ct: &Vec<PolyMatrixRaw>,
  456. v_w: &Vec<PolyMatrixNTT>,
  457. ) -> PolyMatrixNTT<'a> {
  458. assert!(v_ct.len() >= params.n * params.n);
  459. assert!(v_w.len() == params.n);
  460. assert!(v_ct[0].rows == 2);
  461. assert!(v_ct[0].cols == 1);
  462. assert!(v_w[0].rows == (params.n + 1));
  463. assert!(v_w[0].cols == params.t_conv);
  464. let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
  465. let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
  466. let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
  467. let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
  468. let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
  469. let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
  470. let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
  471. for c in 0..params.n {
  472. let mut v_int = PolyMatrixNTT::zero(&params, params.n + 1, 1);
  473. for r in 0..params.n {
  474. let w = &v_w[r];
  475. let ct = &v_ct[r * params.n + c];
  476. ct_1.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(0, 0));
  477. ct_2.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(1, 0));
  478. to_ntt(&mut ct_2_ntt, &ct_2);
  479. gadget_invert(&mut ginv, &ct_1);
  480. to_ntt(&mut ginv_nttd, &ginv);
  481. multiply(&mut prod, &w, &ginv_nttd);
  482. add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
  483. add_into(&mut v_int, &prod);
  484. }
  485. result.copy_into(&v_int, 0, c);
  486. }
  487. result
  488. }
  489. pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
  490. let q1 = 4 * params.pt_modulus;
  491. let q1_bits = log2_ceil(q1) as usize;
  492. let q2 = Q2_VALUES[params.q2_bits as usize];
  493. let q2_bits = params.q2_bits as usize;
  494. let num_bits = params.instances
  495. * ((q2_bits * params.n * params.poly_len)
  496. + (q1_bits * params.n * params.n * params.poly_len));
  497. let round_to = 64;
  498. let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
  499. let mut result = vec![0u8; num_bytes_rounded_up];
  500. let mut bit_offs = 0;
  501. for instance in 0..params.instances {
  502. let packed_ct = &v_packed_ct[instance];
  503. let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
  504. let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
  505. first_row.apply_func(|x| rescale(x, params.modulus, q2));
  506. rest_rows.apply_func(|x| rescale(x, params.modulus, q1));
  507. let data = result.as_mut_slice();
  508. for i in 0..params.n * params.poly_len {
  509. write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
  510. bit_offs += q2_bits;
  511. }
  512. for i in 0..params.n * params.n * params.poly_len {
  513. write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
  514. bit_offs += q1_bits;
  515. }
  516. }
  517. result
  518. }
  519. pub fn get_v_folding_neg<'a>(
  520. params: &'a Params,
  521. v_folding: &Vec<PolyMatrixNTT>,
  522. ) -> Vec<PolyMatrixNTT<'a>> {
  523. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
  524. let mut v_folding_neg = Vec::new();
  525. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  526. for i in 0..params.db_dim_2 {
  527. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  528. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  529. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  530. v_folding_neg.push(ct_gsw_neg);
  531. }
  532. v_folding_neg
  533. }
  534. pub fn expand_query<'a>(
  535. params: &'a Params,
  536. public_params: &PublicParameters<'a>,
  537. query: &Query<'a>,
  538. ) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
  539. let dim0 = 1 << params.db_dim_1;
  540. let further_dims = params.db_dim_2;
  541. let mut v_reg_reoriented;
  542. let mut v_folding;
  543. let num_bits_to_gen = params.t_gsw * further_dims + dim0;
  544. let g = log2_ceil_usize(num_bits_to_gen);
  545. let right_expanded = params.t_gsw * further_dims;
  546. let stop_round = log2_ceil_usize(right_expanded);
  547. let mut v = Vec::new();
  548. for _ in 0..(1 << g) {
  549. v.push(PolyMatrixNTT::zero(params, 2, 1));
  550. }
  551. v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
  552. let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
  553. let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
  554. let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
  555. let v_neg1 = params.get_v_neg1();
  556. coefficient_expansion(
  557. &mut v,
  558. g,
  559. stop_round,
  560. params,
  561. &v_w_left,
  562. &v_w_right,
  563. &v_neg1,
  564. params.t_gsw * params.db_dim_2,
  565. );
  566. let mut v_reg_inp = Vec::with_capacity(dim0);
  567. for i in 0..dim0 {
  568. v_reg_inp.push(v[2 * i].clone());
  569. }
  570. let mut v_gsw_inp = Vec::with_capacity(right_expanded);
  571. for i in 0..right_expanded {
  572. v_gsw_inp.push(v[2 * i + 1].clone());
  573. }
  574. let v_reg_sz = dim0 * 2 * params.poly_len;
  575. v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  576. reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
  577. v_folding = Vec::new();
  578. for _ in 0..params.db_dim_2 {
  579. v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
  580. }
  581. regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
  582. (v_reg_reoriented, v_folding)
  583. }
  584. pub fn process_query(
  585. params: &Params,
  586. public_params: &PublicParameters,
  587. query: &Query,
  588. db: &[u64],
  589. ) -> Vec<u8> {
  590. let dim0 = 1 << params.db_dim_1;
  591. let num_per = 1 << params.db_dim_2;
  592. let db_slice_sz = dim0 * num_per * params.poly_len;
  593. let v_packing = public_params.v_packing.as_ref();
  594. let mut v_reg_reoriented;
  595. let v_folding;
  596. if params.expand_queries {
  597. (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
  598. } else {
  599. v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
  600. v_reg_reoriented
  601. .as_mut_slice()
  602. .copy_from_slice(query.v_buf.as_ref().unwrap());
  603. v_folding = query
  604. .v_ct
  605. .as_ref()
  606. .unwrap()
  607. .iter()
  608. .map(|x| x.ntt())
  609. .collect();
  610. }
  611. let v_folding_neg = get_v_folding_neg(params, &v_folding);
  612. let v_packed_ct = (0..params.instances).into_par_iter().map(|instance| {
  613. let mut intermediate = Vec::with_capacity(num_per);
  614. let mut intermediate_raw = Vec::with_capacity(num_per);
  615. for _ in 0..num_per {
  616. intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
  617. intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
  618. }
  619. let mut v_ct = Vec::new();
  620. for trial in 0..(params.n * params.n) {
  621. let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
  622. let cur_db = &db[idx..(idx + db_slice_sz)];
  623. multiply_reg_by_database(
  624. &mut intermediate,
  625. cur_db,
  626. v_reg_reoriented.as_slice(),
  627. params,
  628. dim0,
  629. num_per,
  630. );
  631. for i in 0..intermediate.len() {
  632. from_ntt(&mut intermediate_raw[i], &intermediate[i]);
  633. }
  634. fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
  635. v_ct.push(intermediate_raw[0].clone());
  636. }
  637. let packed_ct = pack(params, &v_ct, &v_packing);
  638. packed_ct.raw()
  639. }).collect();
  640. encode(params, &v_packed_ct)
  641. }
  642. #[cfg(test)]
  643. mod test {
  644. use super::*;
  645. use crate::client::*;
  646. use rand::{prelude::SmallRng, Rng};
  647. const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
  648. fn get_params() -> Params {
  649. get_fast_expansion_testing_params()
  650. }
  651. fn dec_reg<'a>(
  652. params: &'a Params,
  653. ct: &PolyMatrixNTT<'a>,
  654. client: &mut Client<'a, SmallRng>,
  655. scale_k: u64,
  656. ) -> u64 {
  657. let dec = client.decrypt_matrix_reg(ct).raw();
  658. let mut val = dec.data[0] as i64;
  659. if val >= (params.modulus / 2) as i64 {
  660. val -= params.modulus as i64;
  661. }
  662. let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
  663. if val_rounded == 0 {
  664. 0
  665. } else {
  666. 1
  667. }
  668. }
  669. fn dec_gsw<'a>(
  670. params: &'a Params,
  671. ct: &PolyMatrixNTT<'a>,
  672. client: &mut Client<'a, SmallRng>,
  673. ) -> u64 {
  674. let dec = client.decrypt_matrix_reg(ct).raw();
  675. let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
  676. let mut val = dec.data[idx] as i64;
  677. if val >= (params.modulus / 2) as i64 {
  678. val -= params.modulus as i64;
  679. }
  680. if i64::abs(val) < (1i64 << 10) {
  681. 0
  682. } else {
  683. 1
  684. }
  685. }
  686. #[test]
  687. fn coefficient_expansion_is_correct() {
  688. let params = get_params();
  689. let v_neg1 = params.get_v_neg1();
  690. let mut seeded_rng = get_seeded_rng();
  691. let mut client = Client::init(&params, &mut seeded_rng);
  692. let public_params = client.generate_keys();
  693. let mut v = Vec::new();
  694. for _ in 0..(1 << (params.db_dim_1 + 1)) {
  695. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  696. }
  697. let target = 7;
  698. let scale_k = params.modulus / params.pt_modulus;
  699. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  700. sigma.data[target] = scale_k;
  701. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  702. let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
  703. let v_w_left = public_params.v_expansion_left.unwrap();
  704. let v_w_right = public_params.v_expansion_right.unwrap();
  705. coefficient_expansion(
  706. &mut v,
  707. client.g,
  708. client.stop_round,
  709. &params,
  710. &v_w_left,
  711. &v_w_right,
  712. &v_neg1,
  713. params.t_gsw * params.db_dim_2,
  714. );
  715. assert_eq!(dec_reg(&params, &test_ct, &mut client, scale_k), 0);
  716. for i in 0..v.len() {
  717. if i == target {
  718. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 1);
  719. } else {
  720. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 0);
  721. }
  722. }
  723. }
  724. #[test]
  725. fn regev_to_gsw_is_correct() {
  726. let mut params = get_params();
  727. params.db_dim_2 = 1;
  728. let mut seeded_rng = get_seeded_rng();
  729. let mut client = Client::init(&params, &mut seeded_rng);
  730. let public_params = client.generate_keys();
  731. let mut enc_constant = |val| {
  732. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  733. sigma.data[0] = val;
  734. client.encrypt_matrix_reg(&sigma.ntt())
  735. };
  736. let v = &public_params.v_conversion.unwrap()[0];
  737. let bits_per = get_bits_per(&params, params.t_gsw);
  738. let mut v_inp_1 = Vec::new();
  739. let mut v_inp_0 = Vec::new();
  740. for i in 0..params.t_gsw {
  741. let val = 1u64 << (bits_per * i);
  742. v_inp_1.push(enc_constant(val));
  743. v_inp_0.push(enc_constant(0));
  744. }
  745. let mut v_gsw = Vec::new();
  746. v_gsw.push(PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw));
  747. regev_to_gsw(&mut v_gsw, &v_inp_1, v, &params, 1, 0);
  748. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 1);
  749. regev_to_gsw(&mut v_gsw, &v_inp_0, v, &params, 1, 0);
  750. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
  751. }
  752. #[test]
  753. fn multiply_reg_by_database_is_correct() {
  754. let params = get_params();
  755. let mut seeded_rng = get_seeded_rng();
  756. let dim0 = 1 << params.db_dim_1;
  757. let num_per = 1 << params.db_dim_2;
  758. let scale_k = params.modulus / params.pt_modulus;
  759. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  760. let target_idx_dim0 = target_idx / num_per;
  761. let target_idx_num_per = target_idx % num_per;
  762. let mut client = Client::init(&params, &mut seeded_rng);
  763. _ = client.generate_keys();
  764. let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
  765. let mut v_reg = Vec::new();
  766. for i in 0..dim0 {
  767. let val = if i == target_idx_dim0 { scale_k } else { 0 };
  768. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  769. v_reg.push(client.encrypt_matrix_reg(&sigma));
  770. }
  771. let v_reg_sz = dim0 * 2 * params.poly_len;
  772. let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  773. reorient_reg_ciphertexts(&params, v_reg_reoriented.as_mut_slice(), &v_reg);
  774. let mut out = Vec::with_capacity(num_per);
  775. for _ in 0..dim0 {
  776. out.push(PolyMatrixNTT::zero(&params, 2, 1));
  777. }
  778. multiply_reg_by_database(
  779. &mut out,
  780. db.as_slice(),
  781. v_reg_reoriented.as_slice(),
  782. &params,
  783. dim0,
  784. num_per,
  785. );
  786. // decrypt
  787. let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
  788. let mut dec_rescaled = PolyMatrixRaw::zero(&params, 1, 1);
  789. for z in 0..params.poly_len {
  790. dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
  791. }
  792. for z in 0..params.poly_len {
  793. // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
  794. assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
  795. }
  796. }
  797. #[test]
  798. fn fold_ciphertexts_is_correct() {
  799. let params = get_params();
  800. let mut seeded_rng = get_seeded_rng();
  801. let dim0 = 1 << params.db_dim_1;
  802. let num_per = 1 << params.db_dim_2;
  803. let scale_k = params.modulus / params.pt_modulus;
  804. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  805. let target_idx_num_per = target_idx % num_per;
  806. let mut client = Client::init(&params, &mut seeded_rng);
  807. _ = client.generate_keys();
  808. let mut v_reg = Vec::new();
  809. for i in 0..num_per {
  810. let val = if i == target_idx_num_per { scale_k } else { 0 };
  811. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  812. v_reg.push(client.encrypt_matrix_reg(&sigma));
  813. }
  814. let mut v_reg_raw = Vec::new();
  815. for i in 0..num_per {
  816. v_reg_raw.push(v_reg[i].raw());
  817. }
  818. let bits_per = get_bits_per(&params, params.t_gsw);
  819. let mut v_folding = Vec::new();
  820. for i in 0..params.db_dim_2 {
  821. let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
  822. let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  823. for j in 0..params.t_gsw {
  824. let value = (1u64 << (bits_per * j)) * bit;
  825. let sigma = PolyMatrixRaw::single_value(&params, value);
  826. let sigma_ntt = to_ntt_alloc(&sigma);
  827. let ct = client.encrypt_matrix_reg(&sigma_ntt);
  828. ct_gsw.copy_into(&ct, 0, 2 * j + 1);
  829. let prod = &to_ntt_alloc(&client.sk_reg) * &sigma_ntt;
  830. let ct = &client.encrypt_matrix_reg(&prod);
  831. ct_gsw.copy_into(&ct, 0, 2 * j);
  832. }
  833. v_folding.push(ct_gsw);
  834. }
  835. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
  836. let mut v_folding_neg = Vec::new();
  837. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  838. for i in 0..params.db_dim_2 {
  839. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  840. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  841. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  842. v_folding_neg.push(ct_gsw_neg);
  843. }
  844. fold_ciphertexts(&params, &mut v_reg_raw, &v_folding, &v_folding_neg);
  845. // decrypt
  846. assert_eq!(
  847. dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k),
  848. 1
  849. );
  850. }
  851. fn full_protocol_is_correct_for_params(params: &Params) {
  852. let mut seeded_rng = get_seeded_rng();
  853. let target_idx = 22456;//22456;//seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  854. let mut client = Client::init(params, &mut seeded_rng);
  855. let public_params = client.generate_keys();
  856. let query = client.generate_query(target_idx);
  857. let (corr_item, db) = generate_random_db_and_get_item(params, target_idx);
  858. let response = process_query(params, &public_params, &query, db.as_slice());
  859. let result = client.decode_response(response.as_slice());
  860. let p_bits = log2_ceil(params.pt_modulus) as usize;
  861. let corr_result = corr_item.to_vec(p_bits, params.modp_words_per_chunk());
  862. assert_eq!(result.len(), corr_result.len());
  863. for z in 0..corr_result.len() {
  864. assert_eq!(result[z], corr_result[z], "at {:?}", z);
  865. }
  866. }
  867. fn full_protocol_is_correct_for_params_real_db(params: &Params) {
  868. let mut seeded_rng = get_seeded_rng();
  869. let target_idx = 22456; //seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  870. let mut client = Client::init(params, &mut seeded_rng);
  871. let public_params = client.generate_keys();
  872. let query = client.generate_query(target_idx);
  873. let mut file = File::open(TEST_PREPROCESSED_DB_PATH).unwrap();
  874. let db = load_preprocessed_db_from_file(params, &mut file);
  875. let response = process_query(params, &public_params, &query, db.as_slice());
  876. let result = client.decode_response(response.as_slice());
  877. let corr_result = vec![0x42, 0x5a, 0x68];
  878. for z in 0..corr_result.len() {
  879. assert_eq!(result[z], corr_result[z]);
  880. }
  881. }
  882. #[test]
  883. fn full_protocol_is_correct() {
  884. full_protocol_is_correct_for_params(&get_params());
  885. }
  886. #[test]
  887. fn larger_full_protocol_is_correct() {
  888. let cfg_expand = r#"
  889. {
  890. 'n': 2,
  891. 'nu_1': 10,
  892. 'nu_2': 6,
  893. 'p': 512,
  894. 'q2_bits': 21,
  895. 's_e': 85.83255142749422,
  896. 't_gsw': 10,
  897. 't_conv': 4,
  898. 't_exp_left': 16,
  899. 't_exp_right': 56,
  900. 'instances': 1,
  901. 'db_item_size': 9000 }
  902. "#;
  903. let cfg = cfg_expand;
  904. let cfg = cfg.replace("'", "\"");
  905. let params = params_from_json(&cfg);
  906. full_protocol_is_correct_for_params(&params);
  907. full_protocol_is_correct_for_params_real_db(&params);
  908. }
  909. // #[test]
  910. // fn full_protocol_is_correct_20_256() {
  911. // full_protocol_is_correct_for_params(&params_from_json(&CFG_20_256.replace("'", "\"")));
  912. // }
  913. // #[test]
  914. // fn full_protocol_is_correct_16_100000() {
  915. // full_protocol_is_correct_for_params(&params_from_json(&CFG_16_100000.replace("'", "\"")));
  916. // }
  917. #[test]
  918. #[ignore]
  919. fn full_protocol_is_correct_real_db_16_100000() {
  920. full_protocol_is_correct_for_params_real_db(&params_from_json(
  921. &CFG_16_100000.replace("'", "\""),
  922. ));
  923. }
  924. }