gadget.rs 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. use crate::{params::*, poly::*};
  2. pub fn get_bits_per(params: &Params, dim: usize) -> usize {
  3. let modulus_log2 = params.modulus_log2;
  4. if dim as u64 == modulus_log2 {
  5. return 1;
  6. }
  7. ((modulus_log2 as f64) / (dim as f64)).floor() as usize + 1
  8. }
  9. pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw {
  10. let mut g = PolyMatrixRaw::zero(params, rows, cols);
  11. let nx = g.rows;
  12. let m = g.cols;
  13. assert_eq!(m % nx, 0);
  14. let num_elems = m / nx;
  15. let params = g.params;
  16. let bits_per = get_bits_per(params, num_elems);
  17. for i in 0..nx {
  18. for j in 0..num_elems {
  19. if bits_per * j >= 64 {
  20. continue;
  21. }
  22. let poly = g.get_poly_mut(i, i + j * nx);
  23. poly[0] = 1u64 << (bits_per * j);
  24. }
  25. }
  26. g
  27. }
  28. pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>, rdim: usize) {
  29. assert_eq!(out.cols, inp.cols);
  30. let params = inp.params;
  31. let mx = out.rows;
  32. let num_elems = mx / rdim;
  33. let bits_per = get_bits_per(params, num_elems);
  34. let mask = (1u64 << bits_per) - 1;
  35. for i in 0..inp.cols {
  36. for j in 0..rdim {
  37. for z in 0..params.poly_len {
  38. let val = inp.get_poly(j, i)[z];
  39. for k in 0..num_elems {
  40. let bit_offs = usize::min(k * bits_per, 64) as u64;
  41. let shifted = val.checked_shr(bit_offs as u32);
  42. let piece = match shifted {
  43. Some(x) => x & mask,
  44. None => 0,
  45. };
  46. out.get_poly_mut(j + k * rdim, i)[z] = piece;
  47. }
  48. }
  49. }
  50. }
  51. }
  52. pub fn gadget_invert<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>) {
  53. gadget_invert_rdim(out, inp, inp.rows);
  54. }
  55. pub fn gadget_invert_alloc<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  56. let mut out = PolyMatrixRaw::zero(inp.params, mx, inp.cols);
  57. gadget_invert(&mut out, inp);
  58. out
  59. }
  60. #[cfg(test)]
  61. mod test {
  62. use crate::util::get_test_params;
  63. use super::*;
  64. #[test]
  65. fn gadget_invert_is_correct() {
  66. let params = get_test_params();
  67. let mut mat = PolyMatrixRaw::zero(&params, 2, 1);
  68. mat.get_poly_mut(0, 0)[37] = 3;
  69. mat.get_poly_mut(1, 0)[37] = 6;
  70. let log_q = params.modulus_log2 as usize;
  71. let result = gadget_invert_alloc(2 * log_q, &mat);
  72. assert_eq!(result.get_poly(0, 0)[37], 1);
  73. assert_eq!(result.get_poly(2, 0)[37], 1);
  74. assert_eq!(result.get_poly(4, 0)[37], 0); // binary for '3'
  75. assert_eq!(result.get_poly(1, 0)[37], 0);
  76. assert_eq!(result.get_poly(3, 0)[37], 1);
  77. assert_eq!(result.get_poly(5, 0)[37], 1);
  78. assert_eq!(result.get_poly(7, 0)[37], 0); // binary for '6'
  79. }
  80. }