123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- use crate::{arith::*, params::*};
- pub trait PolyMatrix<'a> {
- fn is_ntt(&self) -> bool;
- fn get_rows(&self) -> usize;
- fn get_cols(&self) -> usize;
- fn get_params(&self) -> &Params;
- fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
- fn as_slice(&self) -> &[u64];
- fn as_mut_slice(&mut self) -> &mut [u64];
- fn zero_out(&mut self) {
- for item in self.as_mut_slice() {
- *item = 0;
- }
- }
- fn get_poly(&self, row: usize, col: usize) -> &[u64] {
- let params = self.get_params();
- let start = (row * self.get_cols() + col) * params.poly_len;
- &self.as_slice()[start..start + params.poly_len]
- }
- fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
- let poly_len = self.get_params().poly_len;
- let start = (row * self.get_cols() + col) * poly_len;
- &mut self.as_mut_slice()[start..start + poly_len]
- }
- }
- pub struct PolyMatrixRaw<'a> {
- params: &'a Params,
- rows: usize,
- cols: usize,
- data: Vec<u64>,
- }
- pub struct PolyMatrixNTT<'a> {
- params: &'a Params,
- rows: usize,
- cols: usize,
- data: Vec<u64>,
- }
- impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
- fn is_ntt(&self) -> bool {
- false
- }
- fn get_rows(&self) -> usize {
- self.rows
- }
- fn get_cols(&self) -> usize {
- self.cols
- }
- fn get_params(&self) -> &Params {
- &self.params
- }
- fn as_slice(&self) -> &[u64] {
- self.data.as_slice()
- }
- fn as_mut_slice(&mut self) -> &mut [u64] {
- self.data.as_mut_slice()
- }
- fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
- let num_coeffs = rows * cols * params.poly_len;
- let data: Vec<u64> = vec![0; num_coeffs];
- PolyMatrixRaw {
- params,
- rows,
- cols,
- data,
- }
- }
- }
- impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
- fn is_ntt(&self) -> bool {
- true
- }
- fn get_rows(&self) -> usize {
- self.rows
- }
- fn get_cols(&self) -> usize {
- self.cols
- }
- fn get_params(&self) -> &Params {
- &self.params
- }
- fn as_slice(&self) -> &[u64] {
- self.data.as_slice()
- }
- fn as_mut_slice(&mut self) -> &mut [u64] {
- self.data.as_mut_slice()
- }
- fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
- let num_coeffs = rows * cols * params.poly_len * params.crt_count;
- let data: Vec<u64> = vec![0; num_coeffs];
- PolyMatrixNTT {
- params,
- rows,
- cols,
- data,
- }
- }
- }
- pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
- for c in 0..params.crt_count {
- for i in 0..params.poly_len {
- res[i] = multiply_modular(params, a[i], b[i], c);
- }
- }
- }
- pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
- for c in 0..params.crt_count {
- for i in 0..params.poly_len {
- res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
- }
- }
- }
- pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
- assert!(a.cols == b.rows);
- for i in 0..a.rows {
- for j in 0..b.cols {
- for z in 0..res.params.poly_len {
- res.get_poly_mut(i, j)[z] = 0;
- }
- for k in 0..a.cols {
- let params = res.params;
- let res_poly = res.get_poly_mut(i, j);
- let pol1 = a.get_poly(i, k);
- let pol2 = b.get_poly(k, j);
- multiply_add_poly(params, res_poly, pol1, pol2);
- }
- }
- }
- }
- #[cfg(test)]
- mod test {
- use super::*;
- fn get_params() -> Params {
- Params::init(2048, vec![7, 31])
- }
- fn assert_all_zero(a: &[u64]) {
- for i in a {
- assert_eq!(*i, 0);
- }
- }
- #[test]
- fn sets_all_zeros() {
- let params = get_params();
- let m1 = PolyMatrixNTT::zero(¶ms, 2, 1);
- assert_all_zero(m1.as_slice());
- }
- #[test]
- fn multiply_correctness() {
- let params = get_params();
- let m1 = PolyMatrixNTT::zero(¶ms, 2, 1);
- let m2 = PolyMatrixNTT::zero(¶ms, 3, 2);
- let mut m3 = PolyMatrixNTT::zero(¶ms, 3, 1);
- multiply(&mut m3, &m2, &m1);
- assert_all_zero(m3.as_slice());
- }
- }
|