client.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. use crate::{
  2. arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
  3. };
  4. use rand::Rng;
  5. use std::iter::once;
  6. fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
  7. for i in 0..a.rows * a.cols * a.params.poly_len {
  8. vec.extend_from_slice(&u64::to_ne_bytes(a.data[i]));
  9. }
  10. }
  11. fn serialize_vec_polymatrix(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
  12. for i in 0..a.len() {
  13. serialize_polymatrix(vec, &a[i]);
  14. }
  15. }
  16. pub struct PublicParameters<'a> {
  17. pub v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
  18. pub v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
  19. pub v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
  20. pub v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
  21. }
  22. impl<'a> PublicParameters<'a> {
  23. pub fn init(params: &'a Params) -> Self {
  24. if params.expand_queries {
  25. PublicParameters {
  26. v_packing: Vec::new(),
  27. v_expansion_left: Some(Vec::new()),
  28. v_expansion_right: Some(Vec::new()),
  29. v_conversion: Some(Vec::new()),
  30. }
  31. } else {
  32. PublicParameters {
  33. v_packing: Vec::new(),
  34. v_expansion_left: None,
  35. v_expansion_right: None,
  36. v_conversion: None,
  37. }
  38. }
  39. }
  40. fn from_ntt_alloc_vec(v: &Vec<PolyMatrixNTT<'a>>) -> Option<Vec<PolyMatrixRaw<'a>>> {
  41. Some(v.iter().map(from_ntt_alloc).collect())
  42. }
  43. fn from_ntt_alloc_opt_vec(
  44. v: &Option<Vec<PolyMatrixNTT<'a>>>,
  45. ) -> Option<Vec<PolyMatrixRaw<'a>>> {
  46. Some(v.as_ref()?.iter().map(from_ntt_alloc).collect())
  47. }
  48. pub fn to_raw(&self) -> Vec<Option<Vec<PolyMatrixRaw>>> {
  49. vec![
  50. Self::from_ntt_alloc_vec(&self.v_packing),
  51. Self::from_ntt_alloc_opt_vec(&self.v_expansion_left),
  52. Self::from_ntt_alloc_opt_vec(&self.v_expansion_right),
  53. Self::from_ntt_alloc_opt_vec(&self.v_conversion),
  54. ]
  55. }
  56. pub fn serialize(&self) -> Vec<u8> {
  57. let mut data = Vec::new();
  58. for v in self.to_raw().iter() {
  59. if v.is_some() {
  60. serialize_vec_polymatrix(&mut data, v.as_ref().unwrap());
  61. }
  62. }
  63. data
  64. }
  65. }
  66. pub struct Query<'a> {
  67. pub ct: Option<PolyMatrixRaw<'a>>,
  68. pub v_buf: Option<Vec<u64>>,
  69. pub v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
  70. }
  71. impl<'a> Query<'a> {
  72. pub fn empty() -> Self {
  73. Query {
  74. ct: None,
  75. v_ct: None,
  76. v_buf: None,
  77. }
  78. }
  79. pub fn serialize(&self) -> Vec<u8> {
  80. let mut data = Vec::new();
  81. if self.ct.is_some() {
  82. let ct = self.ct.as_ref().unwrap();
  83. serialize_polymatrix(&mut data, &ct);
  84. }
  85. if self.v_buf.is_some() {
  86. let v_buf = self.v_buf.as_ref().unwrap();
  87. data.extend(v_buf.iter().map(|x| x.to_ne_bytes()).flatten());
  88. }
  89. if self.v_ct.is_some() {
  90. let v_ct = self.v_ct.as_ref().unwrap();
  91. for x in v_ct {
  92. serialize_polymatrix(&mut data, x);
  93. }
  94. }
  95. data
  96. }
  97. }
  98. pub struct Client<'a, TRng: Rng> {
  99. params: &'a Params,
  100. sk_gsw: PolyMatrixRaw<'a>,
  101. pub sk_reg: PolyMatrixRaw<'a>,
  102. sk_gsw_full: PolyMatrixRaw<'a>,
  103. sk_reg_full: PolyMatrixRaw<'a>,
  104. dg: DiscreteGaussian<'a, TRng>,
  105. pub g: usize,
  106. pub stop_round: usize,
  107. }
  108. fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  109. assert_eq!(p.cols, 1);
  110. let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
  111. r.copy_into(p, 0, 0);
  112. r.copy_into(&PolyMatrixRaw::identity(p.params, p.rows, p.rows), 0, 1);
  113. r
  114. }
  115. fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
  116. Params::init(
  117. params.poly_len,
  118. moduli,
  119. params.noise_width,
  120. params.n,
  121. params.pt_modulus,
  122. params.q2_bits,
  123. params.t_conv,
  124. params.t_exp_left,
  125. params.t_exp_right,
  126. params.t_gsw,
  127. params.expand_queries,
  128. params.db_dim_1,
  129. params.db_dim_2,
  130. params.instances,
  131. params.db_item_size,
  132. )
  133. }
  134. impl<'a, TRng: Rng> Client<'a, TRng> {
  135. pub fn init(params: &'a Params, rng: &'a mut TRng) -> Self {
  136. let sk_gsw_dims = params.get_sk_gsw();
  137. let sk_reg_dims = params.get_sk_reg();
  138. let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
  139. let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
  140. let sk_gsw_full = matrix_with_identity(&sk_gsw);
  141. let sk_reg_full = matrix_with_identity(&sk_reg);
  142. let dg = DiscreteGaussian::init(params, rng);
  143. let further_dims = params.db_dim_2;
  144. let num_expanded = 1usize << params.db_dim_1;
  145. let num_bits_to_gen = params.t_gsw * further_dims + num_expanded;
  146. let g = log2_ceil_usize(num_bits_to_gen);
  147. let stop_round = log2_ceil_usize(params.t_gsw * further_dims);
  148. Self {
  149. params,
  150. sk_gsw,
  151. sk_reg,
  152. sk_gsw_full,
  153. sk_reg_full,
  154. dg,
  155. g,
  156. stop_round,
  157. }
  158. }
  159. pub fn get_rng(&mut self) -> &mut TRng {
  160. &mut self.dg.rng
  161. }
  162. fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
  163. let params = self.params;
  164. let n = params.n;
  165. let a = PolyMatrixRaw::random_rng(params, 1, m, self.get_rng());
  166. let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
  167. let a_inv = -&a;
  168. let b_p = &self.sk_gsw.ntt() * &a.ntt();
  169. let b = &e.ntt() + &b_p;
  170. let p = stack(&a_inv, &b.raw());
  171. p
  172. }
  173. fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
  174. let params = self.params;
  175. let a = PolyMatrixRaw::random_rng(params, 1, 1, self.get_rng());
  176. let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
  177. let b_p = &self.sk_reg.ntt() * &a.ntt();
  178. let b = &e.ntt() + &b_p;
  179. let mut p = PolyMatrixNTT::zero(params, 2, 1);
  180. p.copy_into(&(-&a).ntt(), 0, 0);
  181. p.copy_into(&b, 1, 0);
  182. p
  183. }
  184. fn get_fresh_reg_public_key(&mut self, m: usize) -> PolyMatrixNTT<'a> {
  185. let params = self.params;
  186. let mut p = PolyMatrixNTT::zero(params, 2, m);
  187. for i in 0..m {
  188. p.copy_into(&self.get_regev_sample(), 0, i);
  189. }
  190. p
  191. }
  192. fn encrypt_matrix_gsw(&mut self, ag: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  193. let mx = ag.cols;
  194. let p = self.get_fresh_gsw_public_key(mx);
  195. let res = &(p.ntt()) + &(ag.pad_top(1));
  196. res
  197. }
  198. pub fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  199. let m = a.cols;
  200. let p = self.get_fresh_reg_public_key(m);
  201. &p + &a.pad_top(1)
  202. }
  203. pub fn decrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  204. &self.sk_reg_full.ntt() * a
  205. }
  206. pub fn decrypt_matrix_gsw(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  207. &self.sk_gsw_full.ntt() * a
  208. }
  209. fn generate_expansion_params(
  210. &mut self,
  211. num_exp: usize,
  212. m_exp: usize,
  213. ) -> Vec<PolyMatrixNTT<'a>> {
  214. let params = self.params;
  215. let g_exp = build_gadget(params, 1, m_exp);
  216. let g_exp_ntt = g_exp.ntt();
  217. let mut res = Vec::new();
  218. for i in 0..num_exp {
  219. let t = (params.poly_len / (1 << i)) + 1;
  220. let tau_sk_reg = automorph_alloc(&self.sk_reg, t);
  221. let prod = &tau_sk_reg.ntt() * &g_exp_ntt;
  222. let w_exp_i = self.encrypt_matrix_reg(&prod);
  223. res.push(w_exp_i);
  224. }
  225. res
  226. }
  227. pub fn generate_keys(&mut self) -> PublicParameters<'a> {
  228. let params = self.params;
  229. self.dg.sample_matrix(&mut self.sk_gsw);
  230. self.dg.sample_matrix(&mut self.sk_reg);
  231. self.sk_gsw_full = matrix_with_identity(&self.sk_gsw);
  232. self.sk_reg_full = matrix_with_identity(&self.sk_reg);
  233. let sk_reg_ntt = to_ntt_alloc(&self.sk_reg);
  234. let m_conv = params.m_conv();
  235. let mut pp = PublicParameters::init(params);
  236. // Params for packing
  237. let gadget_conv = build_gadget(params, 1, m_conv);
  238. let gadget_conv_ntt = to_ntt_alloc(&gadget_conv);
  239. for i in 0..params.n {
  240. let scaled = scalar_multiply_alloc(&sk_reg_ntt, &gadget_conv_ntt);
  241. let mut ag = PolyMatrixNTT::zero(params, params.n, m_conv);
  242. ag.copy_into(&scaled, i, 0);
  243. let w = self.encrypt_matrix_gsw(&ag);
  244. pp.v_packing.push(w);
  245. }
  246. if params.expand_queries {
  247. // Params for expansion
  248. pp.v_expansion_left = Some(self.generate_expansion_params(self.g, params.t_exp_left));
  249. pp.v_expansion_right =
  250. Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
  251. // Params for converison
  252. let g_conv = build_gadget(params, 2, 2 * m_conv);
  253. let sk_reg_ntt = self.sk_reg.ntt();
  254. let sk_reg_squared_ntt = &sk_reg_ntt * &sk_reg_ntt;
  255. pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(
  256. params,
  257. 2,
  258. 2 * m_conv,
  259. ))));
  260. for i in 0..2 * m_conv {
  261. let sigma;
  262. if i % 2 == 0 {
  263. let val = g_conv.get_poly(0, i)[0];
  264. sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt();
  265. } else {
  266. let val = g_conv.get_poly(1, i)[0];
  267. sigma = &sk_reg_ntt * &single_poly(params, val).ntt();
  268. }
  269. let ct = self.encrypt_matrix_reg(&sigma);
  270. pp.v_conversion.as_mut().unwrap()[0].copy_into(&ct, 0, i);
  271. }
  272. }
  273. pp
  274. }
  275. pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
  276. let params = self.params;
  277. let further_dims = params.db_dim_2;
  278. let idx_dim0 = idx_target / (1 << further_dims);
  279. let idx_further = idx_target % (1 << further_dims);
  280. let scale_k = params.modulus / params.pt_modulus;
  281. let bits_per = get_bits_per(params, params.t_gsw);
  282. let mut query = Query::empty();
  283. if params.expand_queries {
  284. // pack query into single ciphertext
  285. let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
  286. sigma.data[2 * idx_dim0] = scale_k;
  287. for i in 0..further_dims as u64 {
  288. let bit: u64 = ((idx_further as u64) & (1 << i)) >> i;
  289. for j in 0..params.t_gsw {
  290. let val = (1u64 << (bits_per * j)) * bit;
  291. let idx = (i as usize) * params.t_gsw + (j as usize);
  292. sigma.data[2 * idx + 1] = val;
  293. }
  294. }
  295. let inv_2_g_first = invert_uint_mod(1 << self.g, params.modulus).unwrap();
  296. let inv_2_g_rest = invert_uint_mod(1 << (self.stop_round + 1), params.modulus).unwrap();
  297. for i in 0..params.poly_len / 2 {
  298. sigma.data[2 * i] =
  299. multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus);
  300. sigma.data[2 * i + 1] =
  301. multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus);
  302. }
  303. query.ct = Some(from_ntt_alloc(
  304. &self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)),
  305. ));
  306. } else {
  307. let num_expanded = 1 << params.db_dim_1;
  308. let mut sigma_v = Vec::<PolyMatrixNTT>::new();
  309. // generate regev ciphertexts
  310. let reg_cts_buf_words = num_expanded * 2 * params.poly_len;
  311. let mut reg_cts_buf = vec![0u64; reg_cts_buf_words];
  312. let mut reg_cts = Vec::<PolyMatrixNTT>::new();
  313. for i in 0..num_expanded {
  314. let value = ((i == idx_dim0) as u64) * scale_k;
  315. let sigma = PolyMatrixRaw::single_value(&params, value);
  316. reg_cts.push(self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)));
  317. }
  318. // reorient into server's preferred indexing
  319. reorient_reg_ciphertexts(self.params, reg_cts_buf.as_mut_slice(), &reg_cts);
  320. // generate GSW ciphertexts
  321. for i in 0..further_dims {
  322. let bit = ((idx_further as u64) & (1 << (i as u64))) >> (i as u64);
  323. let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
  324. for j in 0..params.t_gsw {
  325. let value = (1u64 << (bits_per * j)) * bit;
  326. let sigma = PolyMatrixRaw::single_value(&params, value);
  327. let sigma_ntt = to_ntt_alloc(&sigma);
  328. let ct = &self.encrypt_matrix_reg(&sigma_ntt);
  329. ct_gsw.copy_into(ct, 0, 2 * j + 1);
  330. let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt;
  331. let ct = &self.encrypt_matrix_reg(&prod);
  332. ct_gsw.copy_into(ct, 0, 2 * j);
  333. }
  334. sigma_v.push(ct_gsw);
  335. }
  336. query.v_buf = Some(reg_cts_buf);
  337. query.v_ct = Some(sigma_v.iter().map(|x| from_ntt_alloc(x)).collect());
  338. }
  339. query
  340. }
  341. pub fn decode_response(&self, data: &[u8]) -> Vec<u8> {
  342. /*
  343. 0. NTT over q2 the secret key
  344. 1. read first row in q2_bit chunks
  345. 2. read rest in q1_bit chunks
  346. 3. NTT over q2 the first row
  347. 4. Multiply the results of (0) and (3)
  348. 5. Divide and round correctly
  349. */
  350. let params = self.params;
  351. let p = params.pt_modulus;
  352. let p_bits = log2_ceil(params.pt_modulus);
  353. let q1 = 4 * params.pt_modulus;
  354. let q1_bits = log2_ceil(q1) as usize;
  355. let q2 = Q2_VALUES[params.q2_bits as usize];
  356. let q2_bits = params.q2_bits as usize;
  357. let q2_params = params_with_moduli(params, &vec![q2]);
  358. // this only needs to be done during keygen
  359. let mut sk_gsw_q2 = PolyMatrixRaw::zero(&q2_params, params.n, 1);
  360. for i in 0..params.poly_len * params.n {
  361. sk_gsw_q2.data[i] = recenter(self.sk_gsw.data[i], params.modulus, q2);
  362. }
  363. let mut sk_gsw_q2_ntt = PolyMatrixNTT::zero(&q2_params, params.n, 1);
  364. to_ntt(&mut sk_gsw_q2_ntt, &sk_gsw_q2);
  365. let mut result = PolyMatrixRaw::zero(&params, params.instances * params.n, params.n);
  366. let mut bit_offs = 0;
  367. for instance in 0..params.instances {
  368. // this must be done during decoding
  369. let mut first_row = PolyMatrixRaw::zero(&q2_params, 1, params.n);
  370. let mut rest_rows = PolyMatrixRaw::zero(&params, params.n, params.n);
  371. for i in 0..params.n * params.poly_len {
  372. first_row.data[i] = read_arbitrary_bits(data, bit_offs, q2_bits);
  373. bit_offs += q2_bits;
  374. }
  375. for i in 0..params.n * params.n * params.poly_len {
  376. rest_rows.data[i] = read_arbitrary_bits(data, bit_offs, q1_bits);
  377. bit_offs += q1_bits;
  378. }
  379. let mut first_row_q2 = PolyMatrixNTT::zero(&q2_params, 1, params.n);
  380. to_ntt(&mut first_row_q2, &first_row);
  381. let sk_prod = (&sk_gsw_q2_ntt * &first_row_q2).raw();
  382. let q1_i64 = q1 as i64;
  383. let q2_i64 = q2 as i64;
  384. let p_i128 = p as i128;
  385. for i in 0..params.n * params.n * params.poly_len {
  386. let mut val_first = sk_prod.data[i] as i64;
  387. if val_first >= q2_i64 / 2 {
  388. val_first -= q2_i64;
  389. }
  390. let mut val_rest = rest_rows.data[i] as i64;
  391. if val_rest >= q1_i64 / 2 {
  392. val_rest -= q1_i64;
  393. }
  394. let denom = (q2 * (q1 / p)) as i64;
  395. let mut r = val_first * q1_i64;
  396. r += val_rest * q2_i64;
  397. // divide r by q2, rounding
  398. let sign: i64 = if r >= 0 { 1 } else { -1 };
  399. let mut res = ((r + sign * (denom / 2)) as i128) / (denom as i128);
  400. res = (res + (denom as i128 / p_i128) * (p_i128) + 2 * (p_i128)) % (p_i128);
  401. let idx = instance * params.n * params.n * params.poly_len + i;
  402. result.data[idx] = res as u64;
  403. }
  404. }
  405. // println!("{:?}", result.data);
  406. let trials = params.n * params.n;
  407. let chunks = params.instances * trials;
  408. let bytes_per_chunk = f64::ceil(params.db_item_size as f64 / chunks as f64) as usize;
  409. let logp = log2(params.pt_modulus);
  410. let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
  411. println!("modp_words_per_chunk {:?}", modp_words_per_chunk);
  412. result.to_vec(p_bits as usize, modp_words_per_chunk)
  413. }
  414. }
  415. #[cfg(test)]
  416. mod test {
  417. use rand::thread_rng;
  418. use super::*;
  419. fn assert_first8(m: &[u64], gold: [u64; 8]) {
  420. let got: [u64; 8] = m[0..8].try_into().unwrap();
  421. assert_eq!(got, gold);
  422. }
  423. fn get_params() -> Params {
  424. get_short_keygen_params()
  425. }
  426. #[test]
  427. fn init_is_correct() {
  428. let params = get_params();
  429. let mut rng = thread_rng();
  430. let client = Client::init(&params, &mut rng);
  431. assert_eq!(client.stop_round, 5);
  432. assert_eq!(client.g, 10);
  433. assert_eq!(*client.params, params);
  434. }
  435. #[test]
  436. fn keygen_is_correct() {
  437. let params = get_params();
  438. let mut seeded_rng = get_static_seeded_rng();
  439. let mut client = Client::init(&params, &mut seeded_rng);
  440. let public_params = client.generate_keys();
  441. assert_first8(
  442. public_params.v_conversion.unwrap()[0].data.as_slice(),
  443. [
  444. 253586619, 247235120, 141892996, 163163429, 15531298, 200914775, 125109567,
  445. 75889562,
  446. ],
  447. );
  448. assert_first8(
  449. client.sk_gsw.data.as_slice(),
  450. [1, 5, 0, 3, 1, 3, 66974689739603967, 3],
  451. );
  452. }
  453. }