client.rs 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. use std::collections::HashMap;
  2. use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*, arith::*};
  3. pub struct PublicParameters<'a> {
  4. v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
  5. v_expansion_left: Vec<PolyMatrixNTT<'a>>,
  6. v_expansion_right: Vec<PolyMatrixNTT<'a>>,
  7. conversion: PolyMatrixNTT<'a>, // V
  8. }
  9. impl<'a> PublicParameters<'a> {
  10. fn init(params: &'a Params) -> Self {
  11. PublicParameters {
  12. v_packing: Vec::new(),
  13. v_expansion_left: Vec::new(),
  14. v_expansion_right: Vec::new(),
  15. conversion: PolyMatrixNTT::zero(params, 2, 2 * params.m_conv())
  16. }
  17. }
  18. }
  19. pub struct Query<'a> {
  20. ct: PolyMatrixNTT<'a>,
  21. v_ct: Vec<PolyMatrixNTT<'a>>,
  22. }
  23. pub struct Client<'a> {
  24. params: &'a Params,
  25. sk_gsw: PolyMatrixRaw<'a>,
  26. sk_reg: PolyMatrixRaw<'a>,
  27. sk_gsw_full: PolyMatrixRaw<'a>,
  28. sk_reg_full: PolyMatrixRaw<'a>,
  29. dg: DiscreteGaussian,
  30. }
  31. fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  32. assert_eq!(p.cols, 1);
  33. let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
  34. r.copy_into(p, 0, 0);
  35. r.copy_into(&PolyMatrixRaw::identity(p.params, p.rows, p.rows), 0, 1);
  36. r
  37. }
  38. impl<'a> Client<'a> {
  39. pub fn init(params: &'a Params) -> Self {
  40. let sk_gsw_dims = params.get_sk_gsw();
  41. let sk_reg_dims = params.get_sk_reg();
  42. let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
  43. let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
  44. let sk_gsw_full = matrix_with_identity(&sk_gsw);
  45. let sk_reg_full = matrix_with_identity(&sk_reg);
  46. let dg = DiscreteGaussian::init(params);
  47. Self {
  48. params,
  49. sk_gsw,
  50. sk_reg,
  51. sk_gsw_full,
  52. sk_reg_full,
  53. dg,
  54. }
  55. }
  56. fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
  57. let params = self.params;
  58. let n = params.n;
  59. let a = PolyMatrixRaw::random(params, 1, m);
  60. let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
  61. let a_inv = -&a;
  62. let b_p = &self.sk_gsw.ntt() * &a.ntt();
  63. let b = &e.ntt() + &b_p;
  64. let p = stack(&a_inv, &b.raw());
  65. p
  66. }
  67. fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
  68. let params = self.params;
  69. let a = PolyMatrixRaw::random(params, 1, 1);
  70. let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
  71. let b_p = &self.sk_reg.ntt() * &a.ntt();
  72. let b = &e.ntt() + &b_p;
  73. let mut p = PolyMatrixNTT::zero(params, 2, 1);
  74. p.copy_into(&(-&a).ntt(), 0, 0);
  75. p.copy_into(&b, 1, 0);
  76. p
  77. }
  78. fn get_fresh_reg_public_key(&mut self, m: usize) -> PolyMatrixNTT<'a> {
  79. let params = self.params;
  80. let mut p = PolyMatrixNTT::zero(params, 2, m);
  81. for i in 0..m {
  82. p.copy_into(&self.get_regev_sample(), 0, i);
  83. }
  84. p
  85. }
  86. fn encrypt_matrix_gsw(&mut self, ag: PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  87. let mx = ag.cols;
  88. let p = self.get_fresh_gsw_public_key(mx);
  89. let res = &(p.ntt()) + &(ag.pad_top(1));
  90. res
  91. }
  92. fn encrypt_matrix_reg(&mut self, a: PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  93. let m = a.cols;
  94. let p = self.get_fresh_reg_public_key(m);
  95. &p + &a.pad_top(1)
  96. }
  97. fn generate_expansion_params(&mut self, num_exp: usize, m_exp: usize) -> Vec<PolyMatrixNTT<'a>> {
  98. let params = self.params;
  99. let g_exp = build_gadget(params, 1, m_exp);
  100. let g_exp_ntt = g_exp.ntt();
  101. let mut res = Vec::new();
  102. for i in 0..num_exp {
  103. let t = (params.poly_len / (1 << i)) + 1;
  104. let tau_sk_reg = automorph_alloc(&self.sk_reg, t);
  105. let prod = &tau_sk_reg.ntt() * &g_exp_ntt;
  106. let w_exp_i = self.encrypt_matrix_reg(prod);
  107. res.push(w_exp_i);
  108. }
  109. res
  110. }
  111. pub fn generate_keys(&mut self) -> PublicParameters {
  112. let params = self.params;
  113. self.dg.sample_matrix(&mut self.sk_gsw);
  114. self.dg.sample_matrix(&mut self.sk_reg);
  115. self.sk_gsw_full = matrix_with_identity(&self.sk_gsw);
  116. self.sk_reg_full = matrix_with_identity(&self.sk_reg);
  117. let sk_reg_ntt = to_ntt_alloc(&self.sk_reg);
  118. let m_conv = params.m_conv();
  119. let mut pp = PublicParameters::init(params);
  120. // Params for packing
  121. let gadget_conv = build_gadget(params, 1, m_conv);
  122. let gadget_conv_ntt = to_ntt_alloc(&gadget_conv);
  123. for i in 0..params.n {
  124. let scaled = scalar_multiply_alloc(&sk_reg_ntt, &gadget_conv_ntt);
  125. let mut ag = PolyMatrixNTT::zero(params, params.n, m_conv);
  126. ag.copy_into(&scaled, i, 0);
  127. let w = self.encrypt_matrix_gsw(ag);
  128. pp.v_packing.push(w);
  129. }
  130. if params.expand_queries {
  131. // Params for expansion
  132. let further_dims = 1usize << params.db_dim_2;
  133. let num_expanded = 1usize << params.db_dim_1;
  134. let num_bits_to_gen = params.t_gsw * further_dims + num_expanded;
  135. let g = log2(num_bits_to_gen as u64) as usize;
  136. let stop_round = log2((params.t_gsw * further_dims) as u64) as usize;
  137. pp.v_expansion_left = self.generate_expansion_params(g, params.t_exp_left);
  138. pp.v_expansion_right = self.generate_expansion_params(stop_round + 1, params.t_exp_right);
  139. // Params for converison
  140. let g_conv = build_gadget(params, 2, 2 * m_conv);
  141. let sk_reg_squared_ntt = &self.sk_reg.ntt() * &self.sk_reg.ntt();
  142. pp.conversion = PolyMatrixNTT::zero(params, 2, 2 * m_conv);
  143. for i in 0..2*m_conv {
  144. if i % 2 == 0 {
  145. let val = g_conv.get_poly(0, i)[0];
  146. let sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt();
  147. let ct = self.encrypt_matrix_reg(sigma);
  148. pp.conversion.copy_into(&ct, 0, i);
  149. }
  150. }
  151. }
  152. pp
  153. }
  154. // fn generate_query(&self) -> Query<'a> {
  155. // let params = self.params;
  156. // let mut query = Query { ct: PolyMatrixNTT::zero(params, 1, 1), v_ct: Vec::new() }
  157. // if params.expand_queries {
  158. // } else {
  159. // }
  160. // }
  161. }