server.rs 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use std::fs::File;
  4. use std::io::BufReader;
  5. use std::io::Read;
  6. use std::io::Seek;
  7. use std::io::SeekFrom;
  8. use std::time::Instant;
  9. use crate::aligned_memory::*;
  10. use crate::arith::*;
  11. use crate::client::PublicParameters;
  12. use crate::client::Query;
  13. use crate::gadget::*;
  14. use crate::params::*;
  15. use crate::poly::*;
  16. use crate::util::*;
  17. use rayon::prelude::*;
  18. pub fn coefficient_expansion(
  19. v: &mut Vec<PolyMatrixNTT>,
  20. g: usize,
  21. stop_round: usize,
  22. params: &Params,
  23. v_w_left: &Vec<PolyMatrixNTT>,
  24. v_w_right: &Vec<PolyMatrixNTT>,
  25. v_neg1: &Vec<PolyMatrixNTT>,
  26. max_bits_to_gen_right: usize,
  27. ) {
  28. let poly_len = params.poly_len;
  29. let mut ct = PolyMatrixRaw::zero(params, 2, 1);
  30. let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
  31. let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
  32. let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
  33. let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
  34. let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
  35. let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
  36. let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
  37. let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
  38. for r in 0..g {
  39. let num_in = 1 << r;
  40. let num_out = 2 * num_in;
  41. let t = (poly_len / (1 << r)) + 1;
  42. let neg1 = &v_neg1[r];
  43. for i in 0..num_out {
  44. if (stop_round > 0 && r > stop_round && (i % 2) == 1)
  45. || (stop_round > 0
  46. && r == stop_round
  47. && (i % 2) == 1
  48. && (i / 2) >= max_bits_to_gen_right)
  49. {
  50. continue;
  51. }
  52. let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
  53. 0 => (
  54. &v_w_left[r],
  55. params.t_exp_left,
  56. &mut ginv_ct_left,
  57. &mut ginv_ct_left_ntt,
  58. ),
  59. 1 | _ => (
  60. &v_w_right[r],
  61. params.t_exp_right,
  62. &mut ginv_ct_right,
  63. &mut ginv_ct_right_ntt,
  64. ),
  65. };
  66. if i < num_in {
  67. let (src, dest) = v.split_at_mut(num_in);
  68. scalar_multiply(&mut dest[i], neg1, &src[i]);
  69. }
  70. from_ntt(&mut ct, &v[i]);
  71. automorph(&mut ct_auto, &ct, t);
  72. gadget_invert_rdim(gi_ct, &ct_auto, 1);
  73. to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
  74. ct_auto_1
  75. .data
  76. .as_mut_slice()
  77. .copy_from_slice(ct_auto.get_poly(1, 0));
  78. to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
  79. multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
  80. let mut idx = 0;
  81. for j in 0..2 {
  82. for n in 0..params.crt_count {
  83. for z in 0..poly_len {
  84. let sum = v[i].data[idx]
  85. + w_times_ginv_ct.data[idx]
  86. + j * ct_auto_1_ntt.data[n * poly_len + z];
  87. v[i].data[idx] = barrett_coeff_u64(params, sum, n);
  88. idx += 1;
  89. }
  90. }
  91. }
  92. }
  93. }
  94. }
  95. pub fn regev_to_gsw<'a>(
  96. v_gsw: &mut Vec<PolyMatrixNTT<'a>>,
  97. v_inp: &Vec<PolyMatrixNTT<'a>>,
  98. v: &PolyMatrixNTT<'a>,
  99. params: &'a Params,
  100. idx_factor: usize,
  101. idx_offset: usize,
  102. ) {
  103. assert!(v.rows == 2);
  104. assert!(v.cols == 2 * params.t_conv);
  105. let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
  106. let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
  107. let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
  108. let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
  109. for i in 0..params.db_dim_2 {
  110. let ct = &mut v_gsw[i];
  111. for j in 0..params.t_gsw {
  112. let idx_ct = i * params.t_gsw + j;
  113. let idx_inp = idx_factor * (idx_ct) + idx_offset;
  114. ct.copy_into(&v_inp[idx_inp], 0, 2 * j + 1);
  115. from_ntt(&mut tmp_ct_raw, &v_inp[idx_inp]);
  116. gadget_invert(&mut ginv_c_inp, &tmp_ct_raw);
  117. to_ntt(&mut ginv_c_inp_ntt, &ginv_c_inp);
  118. multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
  119. ct.copy_into(&tmp_ct, 0, 2 * j);
  120. }
  121. }
  122. }
  123. pub const MAX_SUMMED: usize = 1 << 6;
  124. pub const PACKED_OFFSET_2: i32 = 32;
  125. #[cfg(target_feature = "avx2")]
  126. pub fn multiply_reg_by_database(
  127. out: &mut Vec<PolyMatrixNTT>,
  128. db: &[u64],
  129. v_firstdim: &[u64],
  130. params: &Params,
  131. dim0: usize,
  132. num_per: usize,
  133. ) {
  134. let ct_rows = 2;
  135. let ct_cols = 1;
  136. let pt_rows = 1;
  137. let pt_cols = 1;
  138. assert!(dim0 * ct_rows >= MAX_SUMMED);
  139. let mut sums_out_n0_u64 = AlignedMemory64::new(4);
  140. let mut sums_out_n2_u64 = AlignedMemory64::new(4);
  141. for z in 0..params.poly_len {
  142. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  143. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  144. for i in 0..num_per {
  145. for c in 0..pt_cols {
  146. let inner_limit = MAX_SUMMED;
  147. let outer_limit = dim0 * ct_rows / inner_limit;
  148. let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
  149. let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
  150. for o_jm in 0..outer_limit {
  151. unsafe {
  152. let mut sums_out_n0 = _mm256_setzero_si256();
  153. let mut sums_out_n2 = _mm256_setzero_si256();
  154. for i_jm in 0..inner_limit / 4 {
  155. let jm = o_jm * inner_limit + (4 * i_jm);
  156. let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
  157. idx_b_base += 1;
  158. let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
  159. idx_b_base += 1;
  160. let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
  161. let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
  162. let a = _mm256_load_si256(v_a as *const __m256i);
  163. let a_lo = a;
  164. let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
  165. let b_lo = b;
  166. let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
  167. sums_out_n0 =
  168. _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
  169. sums_out_n2 =
  170. _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
  171. }
  172. // reduce here, otherwise we will overflow
  173. _mm256_store_si256(
  174. sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
  175. sums_out_n0,
  176. );
  177. _mm256_store_si256(
  178. sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
  179. sums_out_n2,
  180. );
  181. for idx in 0..4 {
  182. let val = sums_out_n0_u64[idx];
  183. sums_out_n0_u64_acc[idx] =
  184. barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
  185. }
  186. for idx in 0..4 {
  187. let val = sums_out_n2_u64[idx];
  188. sums_out_n2_u64_acc[idx] =
  189. barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
  190. }
  191. }
  192. }
  193. for idx in 0..4 {
  194. sums_out_n0_u64_acc[idx] =
  195. barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
  196. sums_out_n2_u64_acc[idx] =
  197. barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
  198. }
  199. // output n0
  200. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  201. let mut n = 0;
  202. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  203. out[i].data[idx_c] =
  204. barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
  205. idx_c += pt_cols * crt_count * poly_len;
  206. out[i].data[idx_c] =
  207. barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
  208. // output n1
  209. n = 1;
  210. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  211. out[i].data[idx_c] =
  212. barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
  213. idx_c += pt_cols * crt_count * poly_len;
  214. out[i].data[idx_c] =
  215. barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
  216. }
  217. }
  218. }
  219. }
  220. #[cfg(not(target_feature = "avx2"))]
  221. pub fn multiply_reg_by_database(
  222. out: &mut Vec<PolyMatrixNTT>,
  223. db: &[u64],
  224. v_firstdim: &[u64],
  225. params: &Params,
  226. dim0: usize,
  227. num_per: usize,
  228. ) {
  229. let ct_rows = 2;
  230. let ct_cols = 1;
  231. let pt_rows = 1;
  232. let pt_cols = 1;
  233. for z in 0..params.poly_len {
  234. let idx_a_base = z * (ct_cols * dim0 * ct_rows);
  235. let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
  236. for i in 0..num_per {
  237. for c in 0..pt_cols {
  238. let mut sums_out_n0_0 = 0u128;
  239. let mut sums_out_n0_1 = 0u128;
  240. let mut sums_out_n1_0 = 0u128;
  241. let mut sums_out_n1_1 = 0u128;
  242. for jm in 0..(dim0 * pt_rows) {
  243. let b = db[idx_b_base];
  244. idx_b_base += 1;
  245. let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
  246. let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
  247. let b_lo = b as u32;
  248. let b_hi = (b >> 32) as u32;
  249. let v_a0_lo = v_a0 as u32;
  250. let v_a0_hi = (v_a0 >> 32) as u32;
  251. let v_a1_lo = v_a1 as u32;
  252. let v_a1_hi = (v_a1 >> 32) as u32;
  253. // do n0
  254. sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
  255. sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
  256. // do n1
  257. sums_out_n1_0 += ((v_a0_hi as u64) * (b_hi as u64)) as u128;
  258. sums_out_n1_1 += ((v_a1_hi as u64) * (b_hi as u64)) as u128;
  259. }
  260. // output n0
  261. let (crt_count, poly_len) = (params.crt_count, params.poly_len);
  262. let mut n = 0;
  263. let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  264. out[i].data[idx_c] = (sums_out_n0_0 % (params.moduli[0] as u128)) as u64;
  265. idx_c += pt_cols * crt_count * poly_len;
  266. out[i].data[idx_c] = (sums_out_n0_1 % (params.moduli[0] as u128)) as u64;
  267. // output n1
  268. n = 1;
  269. idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
  270. out[i].data[idx_c] = (sums_out_n1_0 % (params.moduli[1] as u128)) as u64;
  271. idx_c += pt_cols * crt_count * poly_len;
  272. out[i].data[idx_c] = (sums_out_n1_1 % (params.moduli[1] as u128)) as u64;
  273. }
  274. }
  275. }
  276. }
  277. pub fn generate_random_db_and_get_item<'a>(
  278. params: &'a Params,
  279. item_idx: usize,
  280. ) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
  281. let mut rng = get_seeded_rng();
  282. let instances = params.instances;
  283. let trials = params.n * params.n;
  284. let dim0 = 1 << params.db_dim_1;
  285. let num_per = 1 << params.db_dim_2;
  286. let num_items = dim0 * num_per;
  287. let db_size_words = instances * trials * num_items * params.poly_len;
  288. let mut v = AlignedMemory64::new(db_size_words);
  289. let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
  290. for instance in 0..instances {
  291. println!("Instance {:?}", instance);
  292. for trial in 0..trials {
  293. println!("Trial {:?}", trial);
  294. for i in 0..num_items {
  295. let ii = i % num_per;
  296. let j = i / num_per;
  297. let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
  298. db_item.reduce_mod(params.pt_modulus);
  299. if i == item_idx && instance == 0 {
  300. item.copy_into(&db_item, trial / params.n, trial % params.n);
  301. }
  302. for z in 0..params.poly_len {
  303. db_item.data[z] =
  304. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  305. }
  306. let db_item_ntt = db_item.ntt();
  307. for z in 0..params.poly_len {
  308. let idx_dst = calc_index(
  309. &[instance, trial, z, ii, j],
  310. &[instances, trials, params.poly_len, num_per, dim0],
  311. );
  312. v[idx_dst] = db_item_ntt.data[z]
  313. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  314. }
  315. }
  316. }
  317. }
  318. (item, v)
  319. }
  320. pub fn load_item_from_file<'a>(
  321. params: &'a Params,
  322. file: &mut File,
  323. instance: usize,
  324. trial: usize,
  325. item_idx: usize,
  326. ) -> PolyMatrixRaw<'a> {
  327. let db_item_size = params.db_item_size;
  328. let instances = params.instances;
  329. let trials = params.n * params.n;
  330. let chunks = instances * trials;
  331. let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
  332. let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
  333. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  334. assert!(modp_words_per_chunk <= params.poly_len);
  335. let idx_item_in_file = item_idx * db_item_size;
  336. let idx_chunk = instance * trials + trial;
  337. let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
  338. let mut out = PolyMatrixRaw::zero(params, 1, 1);
  339. let seek_result = file.seek(SeekFrom::Start(idx_poly_in_file as u64));
  340. if seek_result.is_err() {
  341. return out;
  342. }
  343. let mut data = vec![0u8; 2 * bytes_per_chunk];
  344. let bytes_read = file
  345. .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
  346. .unwrap();
  347. let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
  348. assert!(modp_words_read <= params.poly_len);
  349. for i in 0..modp_words_read {
  350. out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
  351. assert!(out.data[i] <= params.pt_modulus);
  352. }
  353. out
  354. }
  355. pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  356. let instances = params.instances;
  357. let trials = params.n * params.n;
  358. let dim0 = 1 << params.db_dim_1;
  359. let num_per = 1 << params.db_dim_2;
  360. let num_items = dim0 * num_per;
  361. let db_size_words = instances * trials * num_items * params.poly_len;
  362. let mut v = AlignedMemory64::new(db_size_words);
  363. for instance in 0..instances {
  364. println!("Instance {:?}", instance);
  365. for trial in 0..trials {
  366. println!("Trial {:?}", trial);
  367. for i in 0..num_items {
  368. if i % 8192 == 0 {
  369. println!("item {:?}", i);
  370. }
  371. let ii = i % num_per;
  372. let j = i / num_per;
  373. let mut db_item = load_item_from_file(params, file, instance, trial, i);
  374. // db_item.reduce_mod(params.pt_modulus);
  375. for z in 0..params.poly_len {
  376. db_item.data[z] =
  377. recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
  378. }
  379. let db_item_ntt = db_item.ntt();
  380. for z in 0..params.poly_len {
  381. let idx_dst = calc_index(
  382. &[instance, trial, z, ii, j],
  383. &[instances, trials, params.poly_len, num_per, dim0],
  384. );
  385. v[idx_dst] = db_item_ntt.data[z]
  386. | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
  387. }
  388. }
  389. }
  390. }
  391. v
  392. }
  393. pub fn load_file_unsafe(data: &mut [u64], file: &mut File) {
  394. let data_as_u8_mut = unsafe { data.align_to_mut::<u8>().1 };
  395. file.read_exact(data_as_u8_mut).unwrap();
  396. }
  397. pub fn load_file(data: &mut [u64], file: &mut File) {
  398. let mut reader = BufReader::with_capacity(1 << 24, file);
  399. let mut buf = [0u8; 8];
  400. for i in 0..data.len() {
  401. reader.read(&mut buf).unwrap();
  402. data[i] = u64::from_ne_bytes(buf);
  403. }
  404. }
  405. pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
  406. let instances = params.instances;
  407. let trials = params.n * params.n;
  408. let dim0 = 1 << params.db_dim_1;
  409. let num_per = 1 << params.db_dim_2;
  410. let num_items = dim0 * num_per;
  411. let db_size_words = instances * trials * num_items * params.poly_len;
  412. let mut v = AlignedMemory64::new(db_size_words);
  413. let v_mut_slice = v.as_mut_slice();
  414. let now = Instant::now();
  415. load_file(v_mut_slice, file);
  416. println!("Done loading ({} ms).", now.elapsed().as_millis());
  417. v
  418. }
  419. pub fn fold_ciphertexts(
  420. params: &Params,
  421. v_cts: &mut Vec<PolyMatrixRaw>,
  422. v_folding: &Vec<PolyMatrixNTT>,
  423. v_folding_neg: &Vec<PolyMatrixNTT>,
  424. ) {
  425. let further_dims = log2(v_cts.len() as u64) as usize;
  426. let ell = v_folding[0].cols / 2;
  427. let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
  428. let mut ginv_c_ntt = PolyMatrixNTT::zero(&params, 2 * ell, 1);
  429. let mut prod = PolyMatrixNTT::zero(&params, 2, 1);
  430. let mut sum = PolyMatrixNTT::zero(&params, 2, 1);
  431. let mut num_per = v_cts.len();
  432. for cur_dim in 0..further_dims {
  433. num_per = num_per / 2;
  434. for i in 0..num_per {
  435. gadget_invert(&mut ginv_c, &v_cts[i]);
  436. to_ntt(&mut ginv_c_ntt, &ginv_c);
  437. multiply(
  438. &mut prod,
  439. &v_folding_neg[further_dims - 1 - cur_dim],
  440. &ginv_c_ntt,
  441. );
  442. gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
  443. to_ntt(&mut ginv_c_ntt, &ginv_c);
  444. multiply(
  445. &mut sum,
  446. &v_folding[further_dims - 1 - cur_dim],
  447. &ginv_c_ntt,
  448. );
  449. add_into(&mut sum, &prod);
  450. from_ntt(&mut v_cts[i], &sum);
  451. }
  452. }
  453. }
  454. pub fn pack<'a>(
  455. params: &'a Params,
  456. v_ct: &Vec<PolyMatrixRaw>,
  457. v_w: &Vec<PolyMatrixNTT>,
  458. ) -> PolyMatrixNTT<'a> {
  459. assert!(v_ct.len() >= params.n * params.n);
  460. assert!(v_w.len() == params.n);
  461. assert!(v_ct[0].rows == 2);
  462. assert!(v_ct[0].cols == 1);
  463. assert!(v_w[0].rows == (params.n + 1));
  464. assert!(v_w[0].cols == params.t_conv);
  465. let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
  466. let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
  467. let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
  468. let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
  469. let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
  470. let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
  471. let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
  472. for c in 0..params.n {
  473. let mut v_int = PolyMatrixNTT::zero(&params, params.n + 1, 1);
  474. for r in 0..params.n {
  475. let w = &v_w[r];
  476. let ct = &v_ct[r * params.n + c];
  477. ct_1.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(0, 0));
  478. ct_2.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(1, 0));
  479. to_ntt(&mut ct_2_ntt, &ct_2);
  480. gadget_invert(&mut ginv, &ct_1);
  481. to_ntt(&mut ginv_nttd, &ginv);
  482. multiply(&mut prod, &w, &ginv_nttd);
  483. add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
  484. add_into(&mut v_int, &prod);
  485. }
  486. result.copy_into(&v_int, 0, c);
  487. }
  488. result
  489. }
  490. pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
  491. let q1 = 4 * params.pt_modulus;
  492. let q1_bits = log2_ceil(q1) as usize;
  493. let q2 = Q2_VALUES[params.q2_bits as usize];
  494. let q2_bits = params.q2_bits as usize;
  495. let num_bits = params.instances
  496. * ((q2_bits * params.n * params.poly_len)
  497. + (q1_bits * params.n * params.n * params.poly_len));
  498. let round_to = 64;
  499. let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
  500. let mut result = vec![0u8; num_bytes_rounded_up];
  501. let mut bit_offs = 0;
  502. for instance in 0..params.instances {
  503. let packed_ct = &v_packed_ct[instance];
  504. let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
  505. let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
  506. first_row.apply_func(|x| rescale(x, params.modulus, q2));
  507. rest_rows.apply_func(|x| rescale(x, params.modulus, q1));
  508. let data = result.as_mut_slice();
  509. for i in 0..params.n * params.poly_len {
  510. write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
  511. bit_offs += q2_bits;
  512. }
  513. for i in 0..params.n * params.n * params.poly_len {
  514. write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
  515. bit_offs += q1_bits;
  516. }
  517. }
  518. result
  519. }
  520. pub fn get_v_folding_neg<'a>(
  521. params: &'a Params,
  522. v_folding: &Vec<PolyMatrixNTT>,
  523. ) -> Vec<PolyMatrixNTT<'a>> {
  524. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
  525. let mut v_folding_neg = Vec::new();
  526. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  527. for i in 0..params.db_dim_2 {
  528. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  529. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  530. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  531. v_folding_neg.push(ct_gsw_neg);
  532. }
  533. v_folding_neg
  534. }
  535. pub fn expand_query<'a>(
  536. params: &'a Params,
  537. public_params: &PublicParameters<'a>,
  538. query: &Query<'a>,
  539. ) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
  540. let dim0 = 1 << params.db_dim_1;
  541. let further_dims = params.db_dim_2;
  542. let mut v_reg_reoriented;
  543. let mut v_folding;
  544. let num_bits_to_gen = params.t_gsw * further_dims + dim0;
  545. let g = log2_ceil_usize(num_bits_to_gen);
  546. let right_expanded = params.t_gsw * further_dims;
  547. let stop_round = log2_ceil_usize(right_expanded);
  548. let mut v = Vec::new();
  549. for _ in 0..(1 << g) {
  550. v.push(PolyMatrixNTT::zero(params, 2, 1));
  551. }
  552. v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
  553. let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
  554. let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
  555. let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
  556. let v_neg1 = params.get_v_neg1();
  557. coefficient_expansion(
  558. &mut v,
  559. g,
  560. stop_round,
  561. params,
  562. &v_w_left,
  563. &v_w_right,
  564. &v_neg1,
  565. params.t_gsw * params.db_dim_2,
  566. );
  567. let mut v_reg_inp = Vec::with_capacity(dim0);
  568. for i in 0..dim0 {
  569. v_reg_inp.push(v[2 * i].clone());
  570. }
  571. let mut v_gsw_inp = Vec::with_capacity(right_expanded);
  572. for i in 0..right_expanded {
  573. v_gsw_inp.push(v[2 * i + 1].clone());
  574. }
  575. let v_reg_sz = dim0 * 2 * params.poly_len;
  576. v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  577. reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
  578. v_folding = Vec::new();
  579. for _ in 0..params.db_dim_2 {
  580. v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
  581. }
  582. regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
  583. (v_reg_reoriented, v_folding)
  584. }
  585. pub fn process_query(
  586. params: &Params,
  587. public_params: &PublicParameters,
  588. query: &Query,
  589. db: &[u64],
  590. ) -> Vec<u8> {
  591. let dim0 = 1 << params.db_dim_1;
  592. let num_per = 1 << params.db_dim_2;
  593. let db_slice_sz = dim0 * num_per * params.poly_len;
  594. let v_packing = public_params.v_packing.as_ref();
  595. let mut v_reg_reoriented;
  596. let v_folding;
  597. if params.expand_queries {
  598. (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
  599. } else {
  600. v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
  601. v_reg_reoriented
  602. .as_mut_slice()
  603. .copy_from_slice(query.v_buf.as_ref().unwrap());
  604. v_folding = query
  605. .v_ct
  606. .as_ref()
  607. .unwrap()
  608. .iter()
  609. .map(|x| x.ntt())
  610. .collect();
  611. }
  612. let v_folding_neg = get_v_folding_neg(params, &v_folding);
  613. let v_packed_ct = (0..params.instances)
  614. .into_par_iter()
  615. .map(|instance| {
  616. let mut intermediate = Vec::with_capacity(num_per);
  617. let mut intermediate_raw = Vec::with_capacity(num_per);
  618. for _ in 0..num_per {
  619. intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
  620. intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
  621. }
  622. let mut v_ct = Vec::new();
  623. for trial in 0..(params.n * params.n) {
  624. let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
  625. let cur_db = &db[idx..(idx + db_slice_sz)];
  626. multiply_reg_by_database(
  627. &mut intermediate,
  628. cur_db,
  629. v_reg_reoriented.as_slice(),
  630. params,
  631. dim0,
  632. num_per,
  633. );
  634. for i in 0..intermediate.len() {
  635. from_ntt(&mut intermediate_raw[i], &intermediate[i]);
  636. }
  637. fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
  638. v_ct.push(intermediate_raw[0].clone());
  639. }
  640. let packed_ct = pack(params, &v_ct, &v_packing);
  641. packed_ct.raw()
  642. })
  643. .collect();
  644. encode(params, &v_packed_ct)
  645. }
  646. #[cfg(test)]
  647. mod test {
  648. use super::*;
  649. use crate::client::*;
  650. use rand::{prelude::SmallRng, Rng};
  651. const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
  652. fn get_params() -> Params {
  653. get_fast_expansion_testing_params()
  654. }
  655. fn dec_reg<'a>(
  656. params: &'a Params,
  657. ct: &PolyMatrixNTT<'a>,
  658. client: &mut Client<'a, SmallRng>,
  659. scale_k: u64,
  660. ) -> u64 {
  661. let dec = client.decrypt_matrix_reg(ct).raw();
  662. let mut val = dec.data[0] as i64;
  663. if val >= (params.modulus / 2) as i64 {
  664. val -= params.modulus as i64;
  665. }
  666. let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
  667. if val_rounded == 0 {
  668. 0
  669. } else {
  670. 1
  671. }
  672. }
  673. fn dec_gsw<'a>(
  674. params: &'a Params,
  675. ct: &PolyMatrixNTT<'a>,
  676. client: &mut Client<'a, SmallRng>,
  677. ) -> u64 {
  678. let dec = client.decrypt_matrix_reg(ct).raw();
  679. let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
  680. let mut val = dec.data[idx] as i64;
  681. if val >= (params.modulus / 2) as i64 {
  682. val -= params.modulus as i64;
  683. }
  684. if i64::abs(val) < (1i64 << 10) {
  685. 0
  686. } else {
  687. 1
  688. }
  689. }
  690. #[test]
  691. fn coefficient_expansion_is_correct() {
  692. let params = get_params();
  693. let v_neg1 = params.get_v_neg1();
  694. let mut seeded_rng = get_seeded_rng();
  695. let mut client = Client::init(&params, &mut seeded_rng);
  696. let public_params = client.generate_keys();
  697. let mut v = Vec::new();
  698. for _ in 0..(1 << (params.db_dim_1 + 1)) {
  699. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  700. }
  701. let target = 7;
  702. let scale_k = params.modulus / params.pt_modulus;
  703. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  704. sigma.data[target] = scale_k;
  705. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  706. let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
  707. let v_w_left = public_params.v_expansion_left.unwrap();
  708. let v_w_right = public_params.v_expansion_right.unwrap();
  709. coefficient_expansion(
  710. &mut v,
  711. params.g(),
  712. params.stop_round(),
  713. &params,
  714. &v_w_left,
  715. &v_w_right,
  716. &v_neg1,
  717. params.t_gsw * params.db_dim_2,
  718. );
  719. assert_eq!(dec_reg(&params, &test_ct, &mut client, scale_k), 0);
  720. for i in 0..v.len() {
  721. if i == target {
  722. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 1);
  723. } else {
  724. assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 0);
  725. }
  726. }
  727. }
  728. #[test]
  729. fn regev_to_gsw_is_correct() {
  730. let mut params = get_params();
  731. params.db_dim_2 = 1;
  732. let mut seeded_rng = get_seeded_rng();
  733. let mut client = Client::init(&params, &mut seeded_rng);
  734. let public_params = client.generate_keys();
  735. let mut enc_constant = |val| {
  736. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  737. sigma.data[0] = val;
  738. client.encrypt_matrix_reg(&sigma.ntt())
  739. };
  740. let v = &public_params.v_conversion.unwrap()[0];
  741. let bits_per = get_bits_per(&params, params.t_gsw);
  742. let mut v_inp_1 = Vec::new();
  743. let mut v_inp_0 = Vec::new();
  744. for i in 0..params.t_gsw {
  745. let val = 1u64 << (bits_per * i);
  746. v_inp_1.push(enc_constant(val));
  747. v_inp_0.push(enc_constant(0));
  748. }
  749. let mut v_gsw = Vec::new();
  750. v_gsw.push(PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw));
  751. regev_to_gsw(&mut v_gsw, &v_inp_1, v, &params, 1, 0);
  752. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 1);
  753. regev_to_gsw(&mut v_gsw, &v_inp_0, v, &params, 1, 0);
  754. assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
  755. }
  756. #[test]
  757. fn multiply_reg_by_database_is_correct() {
  758. let params = get_params();
  759. let mut seeded_rng = get_seeded_rng();
  760. let dim0 = 1 << params.db_dim_1;
  761. let num_per = 1 << params.db_dim_2;
  762. let scale_k = params.modulus / params.pt_modulus;
  763. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  764. let target_idx_dim0 = target_idx / num_per;
  765. let target_idx_num_per = target_idx % num_per;
  766. let mut client = Client::init(&params, &mut seeded_rng);
  767. _ = client.generate_keys();
  768. let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
  769. let mut v_reg = Vec::new();
  770. for i in 0..dim0 {
  771. let val = if i == target_idx_dim0 { scale_k } else { 0 };
  772. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  773. v_reg.push(client.encrypt_matrix_reg(&sigma));
  774. }
  775. let v_reg_sz = dim0 * 2 * params.poly_len;
  776. let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  777. reorient_reg_ciphertexts(&params, v_reg_reoriented.as_mut_slice(), &v_reg);
  778. let mut out = Vec::with_capacity(num_per);
  779. for _ in 0..dim0 {
  780. out.push(PolyMatrixNTT::zero(&params, 2, 1));
  781. }
  782. multiply_reg_by_database(
  783. &mut out,
  784. db.as_slice(),
  785. v_reg_reoriented.as_slice(),
  786. &params,
  787. dim0,
  788. num_per,
  789. );
  790. // decrypt
  791. let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
  792. let mut dec_rescaled = PolyMatrixRaw::zero(&params, 1, 1);
  793. for z in 0..params.poly_len {
  794. dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
  795. }
  796. for z in 0..params.poly_len {
  797. // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
  798. assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
  799. }
  800. }
  801. #[test]
  802. fn fold_ciphertexts_is_correct() {
  803. let params = get_params();
  804. let mut seeded_rng = get_seeded_rng();
  805. let dim0 = 1 << params.db_dim_1;
  806. let num_per = 1 << params.db_dim_2;
  807. let scale_k = params.modulus / params.pt_modulus;
  808. let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
  809. let target_idx_num_per = target_idx % num_per;
  810. let mut client = Client::init(&params, &mut seeded_rng);
  811. _ = client.generate_keys();
  812. let mut v_reg = Vec::new();
  813. for i in 0..num_per {
  814. let val = if i == target_idx_num_per { scale_k } else { 0 };
  815. let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
  816. v_reg.push(client.encrypt_matrix_reg(&sigma));
  817. }
  818. let mut v_reg_raw = Vec::new();
  819. for i in 0..num_per {
  820. v_reg_raw.push(v_reg[i].raw());
  821. }
  822. let bits_per = get_bits_per(&params, params.t_gsw);
  823. let mut v_folding = Vec::new();
  824. for i in 0..params.db_dim_2 {
  825. let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
  826. let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  827. for j in 0..params.t_gsw {
  828. let value = (1u64 << (bits_per * j)) * bit;
  829. let sigma = PolyMatrixRaw::single_value(&params, value);
  830. let sigma_ntt = to_ntt_alloc(&sigma);
  831. let ct = client.encrypt_matrix_reg(&sigma_ntt);
  832. ct_gsw.copy_into(&ct, 0, 2 * j + 1);
  833. let prod = &to_ntt_alloc(client.get_sk_reg()) * &sigma_ntt;
  834. let ct = &client.encrypt_matrix_reg(&prod);
  835. ct_gsw.copy_into(&ct, 0, 2 * j);
  836. }
  837. v_folding.push(ct_gsw);
  838. }
  839. let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
  840. let mut v_folding_neg = Vec::new();
  841. let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
  842. for i in 0..params.db_dim_2 {
  843. invert(&mut ct_gsw_inv, &v_folding[i].raw());
  844. let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  845. add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
  846. v_folding_neg.push(ct_gsw_neg);
  847. }
  848. fold_ciphertexts(&params, &mut v_reg_raw, &v_folding, &v_folding_neg);
  849. // decrypt
  850. assert_eq!(
  851. dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k),
  852. 1
  853. );
  854. }
  855. fn full_protocol_is_correct_for_params(params: &Params) {
  856. let mut seeded_rng = get_seeded_rng();
  857. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  858. let mut client = Client::init(params, &mut seeded_rng);
  859. let public_params = client.generate_keys();
  860. let query = client.generate_query(target_idx);
  861. let (corr_item, db) = generate_random_db_and_get_item(params, target_idx);
  862. let response = process_query(params, &public_params, &query, db.as_slice());
  863. let result = client.decode_response(response.as_slice());
  864. let p_bits = log2_ceil(params.pt_modulus) as usize;
  865. let corr_result = corr_item.to_vec(p_bits, params.modp_words_per_chunk());
  866. assert_eq!(result.len(), corr_result.len());
  867. for z in 0..corr_result.len() {
  868. assert_eq!(result[z], corr_result[z], "at {:?}", z);
  869. }
  870. }
  871. fn full_protocol_is_correct_for_params_real_db(params: &Params) {
  872. let mut seeded_rng = get_seeded_rng();
  873. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  874. let mut client = Client::init(params, &mut seeded_rng);
  875. let public_params = client.generate_keys();
  876. let query = client.generate_query(target_idx);
  877. let mut file = File::open(TEST_PREPROCESSED_DB_PATH).unwrap();
  878. let db = load_preprocessed_db_from_file(params, &mut file);
  879. let response = process_query(params, &public_params, &query, db.as_slice());
  880. let result = client.decode_response(response.as_slice());
  881. let corr_result = vec![0x42, 0x5a, 0x68];
  882. for z in 0..corr_result.len() {
  883. assert_eq!(result[z], corr_result[z]);
  884. }
  885. }
  886. #[test]
  887. fn full_protocol_is_correct() {
  888. full_protocol_is_correct_for_params(&get_params());
  889. }
  890. #[test]
  891. fn larger_full_protocol_is_correct() {
  892. let cfg_expand = r#"
  893. {
  894. 'n': 2,
  895. 'nu_1': 10,
  896. 'nu_2': 6,
  897. 'p': 512,
  898. 'q2_bits': 21,
  899. 's_e': 85.83255142749422,
  900. 't_gsw': 10,
  901. 't_conv': 4,
  902. 't_exp_left': 16,
  903. 't_exp_right': 56,
  904. 'instances': 1,
  905. 'db_item_size': 9000 }
  906. "#;
  907. let cfg = cfg_expand;
  908. let cfg = cfg.replace("'", "\"");
  909. let params = params_from_json(&cfg);
  910. full_protocol_is_correct_for_params(&params);
  911. full_protocol_is_correct_for_params_real_db(&params);
  912. }
  913. // #[test]
  914. // fn full_protocol_is_correct_20_256() {
  915. // full_protocol_is_correct_for_params(&params_from_json(&CFG_20_256.replace("'", "\"")));
  916. // }
  917. // #[test]
  918. // fn full_protocol_is_correct_16_100000() {
  919. // full_protocol_is_correct_for_params(&params_from_json(&CFG_16_100000.replace("'", "\"")));
  920. // }
  921. #[test]
  922. #[ignore]
  923. fn full_protocol_is_correct_real_db_16_100000() {
  924. full_protocol_is_correct_for_params_real_db(&params_from_json(
  925. &CFG_16_100000.replace("'", "\""),
  926. ));
  927. }
  928. }