server.rs 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. use crate::arith::*;
  2. use crate::gadget::*;
  3. use crate::params::*;
  4. use crate::poly::*;
  5. pub fn coefficient_expansion(
  6. v: &mut Vec<PolyMatrixNTT>,
  7. g: usize,
  8. stopround: usize,
  9. params: &Params,
  10. v_w_left: &Vec<PolyMatrixNTT>,
  11. v_w_right: &Vec<PolyMatrixNTT>,
  12. v_neg1: &Vec<PolyMatrixNTT>,
  13. max_bits_to_gen_right: usize,
  14. ) {
  15. let poly_len = params.poly_len;
  16. let mut ct = PolyMatrixRaw::zero(params, 2, 1);
  17. let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
  18. let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
  19. let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
  20. let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
  21. let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
  22. let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
  23. let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
  24. let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
  25. for r in 0..g {
  26. let num_in = 1 << r;
  27. let num_out = 2 * num_in;
  28. let t = (poly_len / (1 << r)) + 1;
  29. let neg1 = &v_neg1[r];
  30. for i in 0..num_out {
  31. if stopround > 0 && i % 2 == 1 && r > stopround
  32. || (r == stopround && i / 2 >= max_bits_to_gen_right)
  33. {
  34. continue;
  35. }
  36. let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
  37. 0 => (
  38. &v_w_left[r],
  39. params.t_exp_left,
  40. &mut ginv_ct_left,
  41. &mut ginv_ct_left_ntt,
  42. ),
  43. 1 | _ => (
  44. &v_w_right[r],
  45. params.t_exp_right,
  46. &mut ginv_ct_right,
  47. &mut ginv_ct_right_ntt,
  48. ),
  49. };
  50. if i < num_in {
  51. let (src, dest) = v.split_at_mut(num_in);
  52. scalar_multiply(&mut dest[i], neg1, &src[i]);
  53. }
  54. from_ntt(&mut ct, &v[i]);
  55. automorph(&mut ct_auto, &ct, t);
  56. gadget_invert_rdim(gi_ct, &ct_auto, 1);
  57. to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
  58. ct_auto_1
  59. .data
  60. .as_mut_slice()
  61. .copy_from_slice(ct_auto.get_poly(1, 0));
  62. to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
  63. multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
  64. let mut idx = 0;
  65. for j in 0..2 {
  66. for n in 0..params.crt_count {
  67. for z in 0..poly_len {
  68. let sum = v[i].data[idx]
  69. + w_times_ginv_ct.data[idx]
  70. + j * ct_auto_1_ntt.data[n * poly_len + z];
  71. v[i].data[idx] = barrett_coeff_u64(params, sum, n);
  72. idx += 1;
  73. }
  74. }
  75. }
  76. }
  77. }
  78. }
  79. #[cfg(test)]
  80. mod test {
  81. use crate::{client::*, util::*};
  82. use super::*;
  83. fn get_params() -> Params {
  84. get_expansion_testing_params()
  85. }
  86. #[test]
  87. fn coefficient_expansion_is_correct() {
  88. let params = get_params();
  89. let v_neg1 = params.get_v_neg1();
  90. let mut seeded_rng = get_seeded_rng();
  91. let mut client = Client::init(&params, &mut seeded_rng);
  92. let public_params = client.generate_keys();
  93. let mut v = Vec::new();
  94. for _ in 0..params.poly_len {
  95. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  96. }
  97. let scale_k = params.modulus / params.pt_modulus;
  98. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  99. sigma.data[7] = scale_k;
  100. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  101. let v_w_left = public_params.v_expansion_left.unwrap();
  102. let v_w_right = public_params.v_expansion_right.unwrap();
  103. coefficient_expansion(
  104. &mut v,
  105. client.g,
  106. client.stop_round,
  107. &params,
  108. &v_w_left,
  109. &v_w_right,
  110. &v_neg1,
  111. params.t_gsw * params.db_dim_2,
  112. );
  113. }
  114. }