gadget.rs 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. use std::primitive;
  2. use crate::{params::*, poly::*};
  3. pub fn get_bits_per(params: &Params, dim: usize) -> usize {
  4. let modulus_log2 = params.modulus_log2;
  5. if dim as u64 == modulus_log2 {
  6. return 1;
  7. }
  8. ((modulus_log2 as f64) / (dim as f64)).floor() as usize + 1
  9. }
  10. pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw {
  11. let mut g = PolyMatrixRaw::zero(params, rows, cols);
  12. let nx = g.rows;
  13. let m = g.cols;
  14. assert_eq!(m % nx, 0);
  15. let num_elems = m / nx;
  16. let params = g.params;
  17. let bits_per = get_bits_per(params, num_elems);
  18. for i in 0..nx {
  19. for j in 0..num_elems {
  20. if bits_per * j >= 64 {
  21. continue;
  22. }
  23. let poly = g.get_poly_mut(i, i + j * nx);
  24. poly[0] = 1u64 << (bits_per * j);
  25. }
  26. }
  27. g
  28. }
  29. pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  30. let params = inp.params;
  31. let num_elems = mx / inp.rows;
  32. let bits_per = get_bits_per(params, num_elems);
  33. let mask = (1u64 << bits_per) - 1;
  34. let mut out = PolyMatrixRaw::zero(params, mx, inp.cols);
  35. for i in 0..inp.cols {
  36. for j in 0..inp.rows {
  37. for z in 0..params.poly_len {
  38. let val = inp.get_poly(j, i)[z];
  39. for k in 0..num_elems {
  40. let bit_offs = usize::min(k * bits_per, 64) as u64;
  41. let shifted = val.checked_shr(bit_offs as u32);
  42. let piece = match shifted {
  43. Some(x) => x & mask,
  44. None => 0,
  45. };
  46. out.get_poly_mut(j + k * inp.rows, i)[z] = piece;
  47. }
  48. }
  49. }
  50. }
  51. out
  52. }
  53. #[cfg(test)]
  54. mod test {
  55. use crate::util::get_test_params;
  56. use super::*;
  57. #[test]
  58. fn gadget_invert_is_correct() {
  59. let params = get_test_params();
  60. let mut mat = PolyMatrixRaw::zero(&params, 2, 1);
  61. mat.get_poly_mut(0, 0)[37] = 3;
  62. mat.get_poly_mut(1, 0)[37] = 6;
  63. let log_q = params.modulus_log2 as usize;
  64. let result = gadget_invert(2 * log_q, &mat);
  65. assert_eq!(result.get_poly(0, 0)[37], 1);
  66. assert_eq!(result.get_poly(2, 0)[37], 1);
  67. assert_eq!(result.get_poly(4, 0)[37], 0); // binary for '3'
  68. assert_eq!(result.get_poly(1, 0)[37], 0);
  69. assert_eq!(result.get_poly(3, 0)[37], 1);
  70. assert_eq!(result.get_poly(5, 0)[37], 1);
  71. assert_eq!(result.get_poly(7, 0)[37], 0); // binary for '6'
  72. }
  73. }