|
@@ -1,15 +1,22 @@
|
|
|
use std::arch::x86_64::*;
|
|
|
-use std::ops::Mul;
|
|
|
+use std::ops::{Add, Mul, Neg};
|
|
|
+use std::cell::RefCell;
|
|
|
+use rand::Rng;
|
|
|
+use rand::distributions::Standard;
|
|
|
|
|
|
-use crate::{arith::*, params::*, ntt::*, util::calc_index};
|
|
|
+use crate::{arith::*, params::*, ntt::*, util::*, discrete_gaussian::*};
|
|
|
+
|
|
|
+const SCRATCH_SPACE: usize = 8192;
|
|
|
+thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
|
|
|
|
|
|
pub trait PolyMatrix<'a> {
|
|
|
fn is_ntt(&self) -> bool;
|
|
|
fn get_rows(&self) -> usize;
|
|
|
fn get_cols(&self) -> usize;
|
|
|
fn get_params(&self) -> &Params;
|
|
|
+ fn num_words(&self) -> usize;
|
|
|
fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
|
|
|
- fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self;
|
|
|
+ fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
|
|
|
fn as_slice(&self) -> &[u64];
|
|
|
fn as_mut_slice(&mut self) -> &mut [u64];
|
|
|
fn zero_out(&mut self) {
|
|
@@ -18,12 +25,12 @@ pub trait PolyMatrix<'a> {
|
|
|
}
|
|
|
}
|
|
|
fn get_poly(&self, row: usize, col: usize) -> &[u64] {
|
|
|
- let num_words = self.get_params().num_words();
|
|
|
+ let num_words = self.num_words();
|
|
|
let start = (row * self.get_cols() + col) * num_words;
|
|
|
&self.as_slice()[start..start + num_words]
|
|
|
}
|
|
|
fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
|
|
|
- let num_words = self.get_params().num_words();
|
|
|
+ let num_words = self.num_words();
|
|
|
let start = (row * self.get_cols() + col) * num_words;
|
|
|
&mut self.as_mut_slice()[start..start + num_words]
|
|
|
}
|
|
@@ -40,6 +47,7 @@ pub trait PolyMatrix<'a> {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ fn pad_top(&self, pad_rows: usize) -> Self;
|
|
|
}
|
|
|
|
|
|
pub struct PolyMatrixRaw<'a> {
|
|
@@ -75,6 +83,9 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
|
|
|
fn as_mut_slice(&mut self) -> &mut [u64] {
|
|
|
self.data.as_mut_slice()
|
|
|
}
|
|
|
+ fn num_words(&self) -> usize {
|
|
|
+ self.params.poly_len
|
|
|
+ }
|
|
|
fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
|
|
|
let num_coeffs = rows * cols * params.poly_len;
|
|
|
let data: Vec<u64> = vec![0; num_coeffs];
|
|
@@ -85,18 +96,25 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
|
|
|
data,
|
|
|
}
|
|
|
}
|
|
|
- fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
|
|
|
+ fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
|
|
|
+ let rng = rand::thread_rng();
|
|
|
+ let mut iter = rng.sample_iter(&Standard);
|
|
|
let mut out = PolyMatrixRaw::zero(params, rows, cols);
|
|
|
for r in 0..rows {
|
|
|
for c in 0..cols {
|
|
|
for i in 0..params.poly_len {
|
|
|
- let val: u64 = rng.next().unwrap();
|
|
|
+ let val: u64 = iter.next().unwrap();
|
|
|
out.get_poly_mut(r, c)[i] = val % params.modulus;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
out
|
|
|
}
|
|
|
+ fn pad_top(&self, pad_rows: usize) -> Self {
|
|
|
+ let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
|
|
|
+ padded.copy_into(&self, pad_rows, 0);
|
|
|
+ padded
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
impl<'a> PolyMatrixRaw<'a> {
|
|
@@ -115,6 +133,16 @@ impl<'a> PolyMatrixRaw<'a> {
|
|
|
data,
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &mut DiscreteGaussian) -> Self {
|
|
|
+ let mut out = PolyMatrixRaw::zero(params, rows, cols);
|
|
|
+ dg.sample_matrix(&mut out);
|
|
|
+ out
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn ntt(&self) -> PolyMatrixNTT<'a> {
|
|
|
+ to_ntt_alloc(&self)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
|
|
@@ -136,6 +164,9 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
|
|
|
fn as_mut_slice(&mut self) -> &mut [u64] {
|
|
|
self.data.as_mut_slice()
|
|
|
}
|
|
|
+ fn num_words(&self) -> usize {
|
|
|
+ self.params.poly_len * self.params.crt_count
|
|
|
+ }
|
|
|
fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
|
|
|
let num_coeffs = rows * cols * params.poly_len * params.crt_count;
|
|
|
let data: Vec<u64> = vec![0; num_coeffs];
|
|
@@ -146,14 +177,16 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
|
|
|
data,
|
|
|
}
|
|
|
}
|
|
|
- fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
|
|
|
+ fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
|
|
|
+ let rng = rand::thread_rng();
|
|
|
+ let mut iter = rng.sample_iter(&Standard);
|
|
|
let mut out = PolyMatrixNTT::zero(params, rows, cols);
|
|
|
for r in 0..rows {
|
|
|
for c in 0..cols {
|
|
|
for i in 0..params.crt_count {
|
|
|
for j in 0..params.poly_len {
|
|
|
let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
|
|
|
- let val: u64 = rng.next().unwrap();
|
|
|
+ let val: u64 = iter.next().unwrap();
|
|
|
out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
|
|
|
}
|
|
|
}
|
|
@@ -161,6 +194,17 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
|
|
|
}
|
|
|
out
|
|
|
}
|
|
|
+ fn pad_top(&self, pad_rows: usize) -> Self {
|
|
|
+ let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
|
|
|
+ padded.copy_into(&self, pad_rows, 0);
|
|
|
+ padded
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> PolyMatrixNTT<'a> {
|
|
|
+ pub fn raw(&self) -> PolyMatrixRaw<'a> {
|
|
|
+ from_ntt_alloc(&self)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
@@ -179,6 +223,22 @@ pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64])
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
+ for c in 0..params.crt_count {
|
|
|
+ for i in 0..params.poly_len {
|
|
|
+ res[i] = add_modular(params, a[i], b[i], c);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
|
|
|
+ for c in 0..params.crt_count {
|
|
|
+ for i in 0..params.poly_len {
|
|
|
+ res[i] = invert_modular(params, a[i], c);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
for c in 0..params.crt_count {
|
|
|
for i in (0..params.poly_len).step_by(4) {
|
|
@@ -229,6 +289,8 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
|
|
|
#[cfg(target_feature = "avx2")]
|
|
|
pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
+ assert!(res.rows == a.rows);
|
|
|
+ assert!(res.cols == b.cols);
|
|
|
assert!(a.cols == b.rows);
|
|
|
|
|
|
let params = res.params;
|
|
@@ -248,58 +310,140 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
- assert_eq!(b.rows, 1);
|
|
|
- assert_eq!(b.cols, 1);
|
|
|
+pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
+ assert!(res.rows == a.rows);
|
|
|
+ assert!(res.cols == a.cols);
|
|
|
+ assert!(a.rows == b.rows);
|
|
|
+ assert!(a.cols == b.cols);
|
|
|
|
|
|
let params = res.params;
|
|
|
- let pol2 = b.get_poly(0, 0);
|
|
|
for i in 0..a.rows {
|
|
|
for j in 0..a.cols {
|
|
|
let res_poly = res.get_poly_mut(i, j);
|
|
|
let pol1 = a.get_poly(i, j);
|
|
|
+ let pol2 = b.get_poly(i, j);
|
|
|
+ add_poly(params, res_poly, pol1, pol2);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
|
|
|
+ assert!(res.rows == a.rows);
|
|
|
+ assert!(res.cols == a.cols);
|
|
|
+
|
|
|
+ let params = res.params;
|
|
|
+ for i in 0..a.rows {
|
|
|
+ for j in 0..a.cols {
|
|
|
+ let res_poly = res.get_poly_mut(i, j);
|
|
|
+ let pol1 = a.get_poly(i, j);
|
|
|
+ invert_poly(params, res_poly, pol1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
|
|
|
+ assert_eq!(a.cols, b.cols);
|
|
|
+ let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
|
|
|
+ c.copy_into(a, 0, 0);
|
|
|
+ c.copy_into(b, a.rows, 0);
|
|
|
+ c
|
|
|
+}
|
|
|
+
|
|
|
+pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
+ assert_eq!(a.rows, 1);
|
|
|
+ assert_eq!(a.cols, 1);
|
|
|
+
|
|
|
+ let params = res.params;
|
|
|
+ let pol2 = a.get_poly(0, 0);
|
|
|
+ for i in 0..b.rows {
|
|
|
+ for j in 0..b.cols {
|
|
|
+ let res_poly = res.get_poly_mut(i, j);
|
|
|
+ let pol1 = b.get_poly(i, j);
|
|
|
multiply_poly(params, res_poly, pol1, pol2);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn from_scalar_multiply<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
|
|
|
- let mut res = PolyMatrixNTT::zero(a.params, a.rows, a.cols);
|
|
|
- scalar_multiply(res, a, b);
|
|
|
+pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
|
|
|
+ let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
|
|
|
+ scalar_multiply(&mut res, a, b);
|
|
|
res
|
|
|
}
|
|
|
|
|
|
|
|
|
pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
|
|
|
+ let params = a.params;
|
|
|
for r in 0..a.rows {
|
|
|
for c in 0..a.cols {
|
|
|
- let pol_src = a.get_poly_mut(r, c);
|
|
|
- let pol_dst = b.get_poly_mut(r, c);
|
|
|
- for n in 0..a.params.crt_count {
|
|
|
- for z in 0..a.params.poly_len {
|
|
|
- pol_dst[n * a.params.poly_len + z] = pol_src[z] % a.params.moduli[n];
|
|
|
+ let pol_src = b.get_poly(r, c);
|
|
|
+ let pol_dst = a.get_poly_mut(r, c);
|
|
|
+ for n in 0..params.crt_count {
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ pol_dst[n * params.poly_len + z] = pol_src[z] % params.moduli[n];
|
|
|
}
|
|
|
}
|
|
|
- ntt_forward(a.params, pol_dst);
|
|
|
+ ntt_forward(params, pol_dst);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub fn from_ntt(params: &Params, res: &mut [u64]) {
|
|
|
- for c in 0..params.crt_count {
|
|
|
- for i in 0..params.poly_len {
|
|
|
- res[c*params.poly_len + i] %= params.moduli[c];
|
|
|
+pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
|
|
|
+ let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
|
|
|
+ to_ntt(&mut a, b);
|
|
|
+ a
|
|
|
+}
|
|
|
+
|
|
|
+pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) {
|
|
|
+ let params = a.params;
|
|
|
+ SCRATCH.with(|scratch_cell| {
|
|
|
+ let scratch_vec = &mut *scratch_cell.borrow_mut();
|
|
|
+ let scratch = scratch_vec.as_mut_slice();
|
|
|
+ for r in 0..a.rows {
|
|
|
+ for c in 0..a.cols {
|
|
|
+ let pol_src = b.get_poly(r, c);
|
|
|
+ let pol_dst = a.get_poly_mut(r, c);
|
|
|
+ scratch[0..pol_src.len()].copy_from_slice(pol_src);
|
|
|
+ ntt_inverse(params, scratch);
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ pol_dst[z] = params.crt_compose_2(scratch[z], scratch[params.poly_len + z]);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
+pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> {
|
|
|
+ let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols);
|
|
|
+ from_ntt(&mut a, b);
|
|
|
+ a
|
|
|
+}
|
|
|
|
|
|
-impl<'a> Mul for PolyMatrixNTT<'a> {
|
|
|
- type Output = Self;
|
|
|
+impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> {
|
|
|
+ type Output = PolyMatrixRaw<'a>;
|
|
|
+
|
|
|
+ fn neg(self) -> Self::Output {
|
|
|
+ let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols);
|
|
|
+ invert(&mut out, self);
|
|
|
+ out
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> {
|
|
|
+ type Output = PolyMatrixNTT<'a>;
|
|
|
|
|
|
fn mul(self, rhs: Self) -> Self::Output {
|
|
|
let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
|
|
|
- multiply(&mut out, &self, &rhs);
|
|
|
+ multiply(&mut out, self, rhs);
|
|
|
+ out
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> {
|
|
|
+ type Output = PolyMatrixNTT<'a>;
|
|
|
+
|
|
|
+ fn add(self, rhs: Self) -> Self::Output {
|
|
|
+ let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols);
|
|
|
+ add(&mut out, self, rhs);
|
|
|
out
|
|
|
}
|
|
|
}
|
|
@@ -309,7 +453,7 @@ mod test {
|
|
|
use super::*;
|
|
|
|
|
|
fn get_params() -> Params {
|
|
|
- Params::init(2048, &vec![268369921u64, 249561089u64], 2, 6.4)
|
|
|
+ get_test_params()
|
|
|
}
|
|
|
|
|
|
fn assert_all_zero(a: &[u64]) {
|
|
@@ -330,7 +474,21 @@ mod test {
|
|
|
let params = get_params();
|
|
|
let m1 = PolyMatrixNTT::zero(¶ms, 2, 1);
|
|
|
let m2 = PolyMatrixNTT::zero(¶ms, 3, 2);
|
|
|
- let m3 = m2 * m1;
|
|
|
+ let m3 = &m2 * &m1;
|
|
|
assert_all_zero(m3.as_slice());
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn full_multiply_correctness() {
|
|
|
+ let params = get_params();
|
|
|
+ let mut m1 = PolyMatrixRaw::zero(¶ms, 1, 1);
|
|
|
+ let mut m2 = PolyMatrixRaw::zero(¶ms, 1, 1);
|
|
|
+ m1.get_poly_mut(0, 0)[1] = 100;
|
|
|
+ m2.get_poly_mut(0, 0)[1] = 7;
|
|
|
+ let m1_ntt = to_ntt_alloc(&m1);
|
|
|
+ let m2_ntt = to_ntt_alloc(&m2);
|
|
|
+ let m3_ntt = &m1_ntt * &m2_ntt;
|
|
|
+ let m3 = from_ntt_alloc(&m3_ntt);
|
|
|
+ assert_eq!(m3.get_poly(0, 0)[2], 700);
|
|
|
+ }
|
|
|
}
|