12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- use std::primitive;
- use crate::{params::*, poly::*};
- pub fn get_bits_per(params: &Params, dim: usize) -> usize {
- let modulus_log2 = params.modulus_log2;
- if dim as u64 == modulus_log2 {
- return 1;
- }
- ((modulus_log2 as f64) / (dim as f64)).floor() as usize + 1
- }
- pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw {
- let mut g = PolyMatrixRaw::zero(params, rows, cols);
- let nx = g.rows;
- let m = g.cols;
- assert_eq!(m % nx, 0);
- let num_elems = m / nx;
- let params = g.params;
- let bits_per = get_bits_per(params, num_elems);
- for i in 0..nx {
- for j in 0..num_elems {
- if bits_per * j >= 64 {
- continue;
- }
- let poly = g.get_poly_mut(i, i + j * nx);
- poly[0] = 1u64 << (bits_per * j);
- }
- }
- g
- }
- pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
- let params = inp.params;
- let num_elems = mx / inp.rows;
- let bits_per = get_bits_per(params, num_elems);
- let mask = (1u64 << bits_per) - 1;
- let mut out = PolyMatrixRaw::zero(params, mx, inp.cols);
- for i in 0..inp.cols {
- for j in 0..inp.rows {
- for z in 0..params.poly_len {
- let val = inp.get_poly(j, i)[z];
- for k in 0..num_elems {
- let bit_offs = usize::min(k * bits_per, 64) as u64;
- let shifted = val.checked_shr(bit_offs as u32);
- let piece = match shifted {
- Some(x) => x & mask,
- None => 0,
- };
- out.get_poly_mut(j + k * inp.rows, i)[z] = piece;
- }
- }
- }
- }
- out
- }
- #[cfg(test)]
- mod test {
- use crate::util::get_test_params;
- use super::*;
- #[test]
- fn gadget_invert_is_correct() {
- let params = get_test_params();
- let mut mat = PolyMatrixRaw::zero(¶ms, 2, 1);
- mat.get_poly_mut(0, 0)[37] = 3;
- mat.get_poly_mut(1, 0)[37] = 6;
- let log_q = params.modulus_log2 as usize;
- let result = gadget_invert(2 * log_q, &mat);
- assert_eq!(result.get_poly(0, 0)[37], 1);
- assert_eq!(result.get_poly(2, 0)[37], 1);
- assert_eq!(result.get_poly(4, 0)[37], 0); // binary for '3'
- assert_eq!(result.get_poly(1, 0)[37], 0);
- assert_eq!(result.get_poly(3, 0)[37], 1);
- assert_eq!(result.get_poly(5, 0)[37], 1);
- assert_eq!(result.get_poly(7, 0)[37], 0); // binary for '6'
- }
- }
|