server.rs 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. use crate::arith;
  2. use crate::gadget::gadget_invert;
  3. use crate::params::*;
  4. use crate::poly::*;
  5. pub fn coefficient_expansion(
  6. v: &mut Vec<PolyMatrixNTT>,
  7. g: usize,
  8. stopround: usize,
  9. params: &Params,
  10. v_w_left: &Vec<PolyMatrixNTT>,
  11. v_w_right: &Vec<PolyMatrixNTT>,
  12. v_neg1: &Vec<PolyMatrixNTT>,
  13. max_bits_to_gen_right: usize,
  14. ) {
  15. let poly_len = params.poly_len;
  16. for r in 0..g {
  17. let num_in = 1 << r;
  18. let num_out = 2 * num_in;
  19. let t = (poly_len / (1 << r)) + 1;
  20. let neg1 = &v_neg1[r];
  21. for i in 0..num_out {
  22. if stopround > 0 && i % 2 == 1 && r > stopround
  23. || (r == stopround && i / 2 >= max_bits_to_gen_right)
  24. {
  25. continue;
  26. }
  27. let (w, gadget_dim) = match i % 2 {
  28. 0 => (&v_w_left[r], params.t_exp_left),
  29. 1 | _ => (&v_w_right[r], params.t_exp_right),
  30. };
  31. if i < num_in {
  32. let (src, dest) = v.split_at_mut(num_in);
  33. scalar_multiply(&mut dest[i], neg1, &src[i]);
  34. }
  35. let ct = from_ntt_alloc(&v[i]);
  36. let ct_auto = automorph_alloc(&ct, t);
  37. let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
  38. let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
  39. let ginv_ct = gadget_invert(gadget_dim, &ct_auto_0);
  40. let ginv_ct_ntt = ginv_ct.ntt();
  41. let w_times_ginv_ct = w * &ginv_ct_ntt;
  42. let mut idx = 0;
  43. for j in 0..2 {
  44. for n in 0..params.crt_count {
  45. for z in 0..poly_len {
  46. let sum = v[i].data[idx]
  47. + w_times_ginv_ct.data[idx]
  48. + j * ct_auto_1_ntt.data[n * poly_len + z];
  49. v[i].data[idx] = arith::modular_reduce(params, sum, n);
  50. idx += 1;
  51. }
  52. }
  53. }
  54. }
  55. }
  56. }
  57. #[cfg(test)]
  58. mod test {
  59. use crate::{client::*, util::*};
  60. use super::*;
  61. fn get_params() -> Params {
  62. get_short_keygen_params()
  63. }
  64. #[test]
  65. fn coefficient_expansion_is_correct() {
  66. let params = get_params();
  67. let v_neg1 = params.get_v_neg1();
  68. let mut seeded_rng = get_seeded_rng();
  69. let mut client = Client::init(&params, &mut seeded_rng);
  70. let public_params = client.generate_keys();
  71. let mut v = Vec::new();
  72. for _ in 0..params.poly_len {
  73. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  74. }
  75. let scale_k = params.modulus / params.pt_modulus;
  76. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  77. sigma.data[7] = scale_k;
  78. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  79. let v_w_left = public_params.v_expansion_left.unwrap();
  80. let v_w_right = public_params.v_expansion_right.unwrap();
  81. coefficient_expansion(
  82. &mut v,
  83. client.g,
  84. client.stop_round,
  85. &params,
  86. &v_w_left,
  87. &v_w_right,
  88. &v_neg1,
  89. params.t_gsw * params.db_dim_2,
  90. );
  91. }
  92. }