Browse Source

add barrett reduction

Samir Menon 2 years ago
parent
commit
153e33844b

+ 8 - 8
client/src/lib.rs

@@ -6,14 +6,14 @@ use wasm_bindgen::prelude::*;
 const UUID_V4_LEN: usize = 36;
 
 // console_log! macro
-#[wasm_bindgen]
-extern "C" {
-    #[wasm_bindgen(js_namespace = console)]
-    fn log(s: &str);
-}
-macro_rules! console_log {
-    ($($t:tt)*) => (log(&format_args!($($t)*).to_string()))
-}
+// #[wasm_bindgen]
+// extern "C" {
+//     #[wasm_bindgen(js_namespace = console)]
+//     fn log(s: &str);
+// }
+// macro_rules! console_log {
+//     ($($t:tt)*) => (log(&format_args!($($t)*).to_string()))
+// }
 
 // Container class for a static lifetime Client
 // Avoids a lifetime in the return signature of bound Rust functions

+ 1 - 1
spiral-rs/.cargo/config.toml

@@ -1,3 +1,3 @@
 [build]
-target = "x86_64-unknown-linux-gnu"
+# target = "x86_64-unknown-linux-gnu"
 rustflags = ["-C", "target-feature=+avx2"]

+ 5 - 1
spiral-rs/Cargo.toml

@@ -10,7 +10,7 @@ reqwest = { version = "0.11", features = ["blocking"] }
 serde_json = "1.0"
 
 [dev-dependencies]
-criterion = "0.3"
+criterion = { version = "0.3", features = ["html_reports"] }
 pprof = { version = "0.4", features = ["flamegraph", "criterion"] }
 
 [[bench]]
@@ -21,6 +21,10 @@ harness = false
 name = "server"
 harness = false
 
+[[bench]]
+name = "poly"
+harness = false
+
 [profile.release]
 lto = "fat"
 codegen-units = 1

+ 4 - 3
spiral-rs/benches/ntt.rs

@@ -1,11 +1,12 @@
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 use rand::Rng;
+use spiral_rs::aligned_memory::*;
 use spiral_rs::ntt::*;
 use spiral_rs::util::*;
 
 fn criterion_benchmark(c: &mut Criterion) {
     let params = get_test_params();
-    let mut v1 = vec![0; params.crt_count * params.poly_len];
+    let mut v1 = AlignedMemory64::new(params.crt_count * params.poly_len);
     let mut rng = rand::thread_rng();
     for i in 0..params.crt_count {
         for j in 0..params.poly_len {
@@ -15,10 +16,10 @@ fn criterion_benchmark(c: &mut Criterion) {
         }
     }
     c.bench_function("nttf 2048", |b| {
-        b.iter(|| ntt_forward(black_box(&params), black_box(&mut v1)))
+        b.iter(|| ntt_forward(black_box(&params), black_box(v1.as_mut_slice())))
     });
     c.bench_function("ntti 2048", |b| {
-        b.iter(|| ntt_inverse(black_box(&params), black_box(&mut v1)))
+        b.iter(|| ntt_inverse(black_box(&params), black_box(v1.as_mut_slice())))
     });
 }
 

+ 19 - 5
spiral-rs/benches/poly.rs

@@ -4,11 +4,25 @@ use spiral_rs::util::*;
 
 fn criterion_benchmark(c: &mut Criterion) {
     let params = get_test_params();
-    let m1 = PolyMatrixNTT::random(&params, 2, 1);
-    let m2 = PolyMatrixNTT::random(&params, 3, 2);
-    let mut m3 = PolyMatrixNTT::zero(&params, 2, 2);
-    c.bench_function("nttf 2048", |b| {
-        b.iter(|| multiply(black_box(&mut m3), black_box(&m1), black_box(&m2)))
+    let mut m1 = PolyMatrixRaw::random(&params, 10, 10);
+    let mut m2 = PolyMatrixNTT::random(&params, 10, 10);
+    let m3 = PolyMatrixNTT::random(&params, 10, 10);
+    let mut m4 = PolyMatrixNTT::random(&params, 10, 10);
+
+    // c.bench_function("nttf_noreduce 2048", |b| {
+    //     b.iter(|| to_ntt_no_reduce(black_box(&mut m2), black_box(&m1)))
+    // });
+
+    c.bench_function("multiply", |b| {
+        b.iter(|| multiply(black_box(&mut m4), black_box(&m2), black_box(&m3)))
+    });
+
+    c.bench_function("nttf_full 2048", |b| {
+        b.iter(|| to_ntt(black_box(&mut m2), black_box(&m1)))
+    });
+
+    c.bench_function("ntti_full 2048", |b| {
+        b.iter(|| from_ntt(black_box(&mut m1), black_box(&m2)))
     });
 }
 

+ 1 - 1
spiral-rs/benches/server.rs

@@ -32,7 +32,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     let v_w_right = public_params.v_expansion_right.unwrap();
 
     // note: the benchmark on AVX2 is 545ms for the c++ impl
-    group.bench_function("coeff exp", |b| {
+    group.bench_function("coeff_exp", |b| {
         b.iter(|| {
             coefficient_expansion(
                 black_box(&mut v),

+ 24 - 16
spiral-rs/src/aligned_memory.rs

@@ -1,5 +1,9 @@
-use std::{alloc::{alloc_zeroed, dealloc, Layout}, slice::{from_raw_parts, from_raw_parts_mut}, ops::{Index, IndexMut}, mem::size_of};
-
+use std::{
+    alloc::{alloc_zeroed, dealloc, Layout},
+    mem::size_of,
+    ops::{Index, IndexMut},
+    slice::{from_raw_parts, from_raw_parts_mut},
+};
 
 const ALIGN_SIMD: usize = 64; // enough to support AVX-512
 pub type AlignedMemory64 = AlignedMemory<ALIGN_SIMD>;
@@ -7,10 +11,10 @@ pub type AlignedMemory64 = AlignedMemory<ALIGN_SIMD>;
 pub struct AlignedMemory<const ALIGN: usize> {
     p: *mut u64,
     sz_u64: usize,
-    layout: Layout
+    layout: Layout,
 }
 
-impl<const ALIGN: usize> AlignedMemory<{ALIGN}> {
+impl<const ALIGN: usize> AlignedMemory<{ ALIGN }> {
     pub fn new(sz_u64: usize) -> Self {
         let sz_bytes = sz_u64 * size_of::<u64>();
         let layout = Layout::from_size_align(sz_bytes, ALIGN).unwrap();
@@ -23,20 +27,24 @@ impl<const ALIGN: usize> AlignedMemory<{ALIGN}> {
         Self {
             p: ptr as *mut u64,
             sz_u64,
-            layout
+            layout,
         }
     }
 
     pub fn as_slice(&self) -> &[u64] {
-        unsafe {
-            from_raw_parts(self.p, self.sz_u64)
-        }
+        unsafe { from_raw_parts(self.p, self.sz_u64) }
     }
 
     pub fn as_mut_slice(&mut self) -> &mut [u64] {
-        unsafe {
-            from_raw_parts_mut(self.p, self.sz_u64)
-        }
+        unsafe { from_raw_parts_mut(self.p, self.sz_u64) }
+    }
+
+    pub unsafe fn as_ptr(&self) -> *const u64 {
+        self.p
+    }
+
+    pub unsafe fn as_mut_ptr(&mut self) -> *mut u64 {
+        self.p
     }
 
     pub fn len(&self) -> usize {
@@ -44,7 +52,7 @@ impl<const ALIGN: usize> AlignedMemory<{ALIGN}> {
     }
 }
 
-impl<const ALIGN: usize> Drop for AlignedMemory<{ALIGN}> {
+impl<const ALIGN: usize> Drop for AlignedMemory<{ ALIGN }> {
     fn drop(&mut self) {
         unsafe {
             dealloc(self.p as *mut u8, self.layout);
@@ -52,7 +60,7 @@ impl<const ALIGN: usize> Drop for AlignedMemory<{ALIGN}> {
     }
 }
 
-impl<const ALIGN: usize> Index<usize> for AlignedMemory<{ALIGN}> {
+impl<const ALIGN: usize> Index<usize> for AlignedMemory<{ ALIGN }> {
     type Output = u64;
 
     fn index(&self, index: usize) -> &Self::Output {
@@ -60,16 +68,16 @@ impl<const ALIGN: usize> Index<usize> for AlignedMemory<{ALIGN}> {
     }
 }
 
-impl<const ALIGN: usize> IndexMut<usize> for AlignedMemory<{ALIGN}> {
+impl<const ALIGN: usize> IndexMut<usize> for AlignedMemory<{ ALIGN }> {
     fn index_mut(&mut self, index: usize) -> &mut Self::Output {
         &mut self.as_mut_slice()[index]
     }
 }
 
-impl<const ALIGN: usize> Clone for AlignedMemory<{ALIGN}> {
+impl<const ALIGN: usize> Clone for AlignedMemory<{ ALIGN }> {
     fn clone(&self) -> Self {
         let mut out = Self::new(self.sz_u64);
         out.as_mut_slice().copy_from_slice(self.as_slice());
         out
     }
-}
+}

+ 377 - 4
spiral-rs/src/arith.rs

@@ -1,5 +1,6 @@
 use crate::params::*;
 use std::mem;
+use std::slice;
 
 pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
     (((a as u128) * (b as u128)) % (modulus as u128)) as u64
@@ -18,15 +19,15 @@ pub fn log2_ceil_usize(a: usize) -> usize {
 }
 
 pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
-    (a * b) % params.moduli[c]
+    barrett_coeff_u64(params, a * b, c)
 }
 
 pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
-    (a * b + x) % params.moduli[c]
+    barrett_coeff_u64(params, a * b + x, c)
 }
 
 pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
-    (a + b) % params.moduli[c]
+    barrett_coeff_u64(params, a + b, c)
 }
 
 pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
@@ -34,7 +35,7 @@ pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
 }
 
 pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
-    (x) % params.moduli[c]
+    barrett_coeff_u64(params, x, c)
 }
 
 pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
@@ -102,12 +103,384 @@ pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
     a_val as u64
 }
 
+pub fn get_barrett_crs(modulus: u64) -> (u64, u64) {
+    let numerator = [0, 0, 1];
+    let (_, quotient) = divide_uint192_inplace(numerator, modulus);
+
+    (quotient[0], quotient[1])
+}
+
+pub fn get_barrett(moduli: &[u64]) -> ([u64; MAX_MODULI], [u64; MAX_MODULI]) {
+    let mut cr0 = [0u64; MAX_MODULI];
+    let mut cr1 = [0u64; MAX_MODULI];
+    for i in 0..moduli.len() {
+        (cr0[i], cr1[i]) = get_barrett_crs(moduli[i]);
+    }
+    (cr0, cr1)
+}
+
+pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
+    let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
+
+    // Barrett subtraction
+    let res = input - tmp * modulus;
+
+    // One more subtraction is enough
+    if res >= modulus {
+        res - modulus
+    } else {
+        res
+    }
+}
+
+pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 {
+    barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n])
+}
+
+fn split(x: u128) -> (u64, u64) {
+    let lo = x & ((1u128 << 64) - 1);
+    let hi = x >> 64;
+    (lo as u64, hi as u64)
+}
+
+fn mul_u128(a: u64, b: u64) -> (u64, u64) {
+    let prod = (a as u128) * (b as u128);
+    split(prod)
+}
+
+fn add_u64(op1: u64, op2: u64, out: &mut u64) -> u64 {
+    match op1.checked_add(op2) {
+        Some(x) => {
+            *out = x;
+            0
+        }
+        None => 1,
+    }
+}
+
+fn barrett_raw_u128(val: u128, cr0: u64, cr1: u64, modulus: u64) -> u64 {
+    let (zx, zy) = split(val);
+
+    let mut tmp1 = 0;
+    let mut tmp3;
+    let mut carry;
+    let (_, prody) = mul_u128(zx, cr0);
+    carry = prody;
+    let (mut tmp2x, mut tmp2y) = mul_u128(zx, cr1);
+    tmp3 = tmp2y + add_u64(tmp2x, carry, &mut tmp1);
+    (tmp2x, tmp2y) = mul_u128(zy, cr0);
+    carry = tmp2y + add_u64(tmp1, tmp2x, &mut tmp1);
+    tmp1 = zy * cr1 + tmp3 + carry;
+    tmp3 = zx.wrapping_sub(tmp1.wrapping_mul(modulus));
+
+    tmp3
+
+    // uint64_t zx = val & (((__uint128_t)1 << 64) - 1);
+    // uint64_t zy = val >> 64;
+
+    // uint64_t tmp1, tmp3, carry;
+    // ulonglong2_h prod = umul64wide(zx, const_ratio_0);
+    // carry = prod.y;
+    // ulonglong2_h tmp2 = umul64wide(zx, const_ratio_1);
+    // tmp3 = tmp2.y + cpu_add_u64(tmp2.x, carry, &tmp1);
+    // tmp2 = umul64wide(zy, const_ratio_0);
+    // carry = tmp2.y + cpu_add_u64(tmp1, tmp2.x, &tmp1);
+    // tmp1 = zy * const_ratio_1 + tmp3 + carry;
+    // tmp3 = zx - tmp1 * modulus;
+
+    // return tmp3;
+}
+
+fn barrett_reduction_u128_raw(modulus: u64, cr0: u64, cr1: u64, val: u128) -> u64 {
+    let mut reduced_val = barrett_raw_u128(val, cr0, cr1, modulus);
+    reduced_val -= (modulus) * ((reduced_val >= modulus) as u64);
+    reduced_val
+}
+
+pub fn barrett_reduction_u128(params: &Params, val: u128) -> u64 {
+    let modulus = params.modulus;
+    let cr0 = params.barrett_cr_0_modulus;
+    let cr1 = params.barrett_cr_1_modulus;
+    barrett_reduction_u128_raw(modulus, cr0, cr1, val)
+}
+
+// Following code is ported from SEAL (github.com/microsoft/SEAL)
+
+pub fn get_significant_bit_count(val: &[u64]) -> usize {
+    for i in (0..val.len()).rev() {
+        for j in (0..64).rev() {
+            if (val[i] & (1u64 << j)) != 0 {
+                return i * 64 + j + 1;
+            }
+        }
+    }
+    0
+}
+
+fn divide_round_up(num: usize, denom: usize) -> usize {
+    (num + (denom - 1)) / denom
+}
+
+const BITS_PER_U64: usize = u64::BITS as usize;
+
+fn left_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
+    let mut result = [0u64; 3];
+    if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
+        result[2] = operand[0];
+        result[1] = 0;
+        result[0] = 0;
+    } else if (shift_amount & BITS_PER_U64) != 0 {
+        result[2] = operand[1];
+        result[1] = operand[0];
+        result[0] = 0;
+    } else {
+        result[2] = operand[2];
+        result[1] = operand[1];
+        result[0] = operand[0];
+    }
+
+    let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
+
+    if bit_shift_amount != 0 {
+        let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
+
+        result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount);
+        result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount);
+        result[0] = result[0] << bit_shift_amount;
+    }
+
+    result
+}
+
+fn right_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
+    let mut result = [0u64; 3];
+
+    if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
+        result[0] = operand[2];
+        result[1] = 0;
+        result[2] = 0;
+    } else if (shift_amount & BITS_PER_U64) != 0 {
+        result[0] = operand[1];
+        result[1] = operand[2];
+        result[2] = 0;
+    } else {
+        result[2] = operand[2];
+        result[1] = operand[1];
+        result[0] = operand[0];
+    }
+
+    let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
+
+    if bit_shift_amount != 0 {
+        let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
+
+        result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount);
+        result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount);
+        result[2] = result[2] >> bit_shift_amount;
+    }
+
+    result
+}
+
+fn add_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
+    *result = operand1.wrapping_add(operand2);
+    (*result < operand1) as u8
+}
+
+fn add_uint64_carry(operand1: u64, operand2: u64, carry: u8, result: &mut u64) -> u8 {
+    let operand1 = operand1.wrapping_add(operand2);
+    *result = operand1.wrapping_add(carry as u64);
+    ((operand1 < operand2) || (!operand1 < (carry as u64))) as u8
+}
+
+fn sub_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
+    *result = operand1.wrapping_sub(operand2);
+    (operand2 > operand1) as u8
+}
+
+fn sub_uint64_borrow(operand1: u64, operand2: u64, borrow: u8, result: &mut u64) -> u8 {
+    let diff = operand1.wrapping_sub(operand2);
+    *result = diff.wrapping_sub((borrow != 0) as u64);
+    ((diff > operand1) || (diff < (borrow as u64))) as u8
+}
+
+pub fn sub_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
+    let mut borrow = sub_uint64(operand1[0], operand2[0], &mut result[0]);
+
+    for i in 0..uint64_count - 1 {
+        let mut temp_result = 0u64;
+        borrow = sub_uint64_borrow(operand1[1 + i], operand2[1 + i], borrow, &mut temp_result);
+        result[1 + i] = temp_result;
+    }
+
+    borrow
+}
+
+pub fn add_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
+    let mut carry = add_uint64(operand1[0], operand2[0], &mut result[0]);
+
+    for i in 0..uint64_count - 1 {
+        let mut temp_result = 0u64;
+        carry = add_uint64_carry(operand1[1 + i], operand2[1 + i], carry, &mut temp_result);
+        result[1 + i] = temp_result;
+    }
+
+    carry
+}
+
+pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u64; 3], [u64; 3]) {
+    let mut numerator_bits = get_significant_bit_count(&numerator);
+    let mut denominator_bits = get_significant_bit_count(slice::from_ref(&denominator));
+
+    let mut quotient = [0u64; 3];
+
+    if numerator_bits < denominator_bits {
+        return (numerator, quotient);
+    }
+
+    let uint64_count = divide_round_up(numerator_bits, BITS_PER_U64);
+
+    if uint64_count == 1 {
+        quotient[0] = numerator[0] / denominator;
+        numerator[0] -= quotient[0] * denominator;
+        return (numerator, quotient);
+    }
+
+    let mut shifted_denominator = [0u64; 3];
+    shifted_denominator[0] = denominator;
+
+    let mut difference = [0u64; 3];
+
+    let denominator_shift = numerator_bits - denominator_bits;
+
+    let shifted_denominator = left_shift_uint192(shifted_denominator, denominator_shift);
+    denominator_bits += denominator_shift;
+
+    let mut remaining_shifts = denominator_shift;
+    while numerator_bits == denominator_bits {
+        if (sub_uint(
+            &numerator,
+            &shifted_denominator,
+            uint64_count,
+            &mut difference,
+        )) != 0
+        {
+            if remaining_shifts == 0 {
+                break;
+            }
+
+            add_uint(
+                &difference.clone(),
+                &numerator,
+                uint64_count,
+                &mut difference,
+            );
+
+            quotient = left_shift_uint192(quotient, 1);
+            remaining_shifts -= 1;
+        }
+
+        quotient[0] |= 1;
+
+        numerator_bits = get_significant_bit_count(&difference);
+        let mut numerator_shift = denominator_bits - numerator_bits;
+        if numerator_shift > remaining_shifts {
+            numerator_shift = remaining_shifts;
+        }
+
+        if numerator_bits > 0 {
+            numerator = left_shift_uint192(difference, numerator_shift);
+            numerator_bits += numerator_shift;
+        } else {
+            for w in 0..uint64_count {
+                numerator[w] = 0;
+            }
+        }
+
+        quotient = left_shift_uint192(quotient, numerator_shift);
+        remaining_shifts -= numerator_shift;
+    }
+
+    if numerator_bits > 0 {
+        numerator = right_shift_uint192(numerator, denominator_shift);
+    }
+
+    (numerator, quotient)
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
+    use crate::util::get_seeded_rng;
+    use rand::Rng;
+
+    fn combine(lo: u64, hi: u64) -> u128 {
+        (lo as u128) & ((hi as u128) << 64)
+    }
 
     #[test]
     fn div2_uint_mod_correct() {
         assert_eq!(div2_uint_mod(3, 7), 5);
     }
+
+    #[test]
+    fn divide_uint192_inplace_correct() {
+        assert_eq!(
+            divide_uint192_inplace([35, 0, 0], 7),
+            ([0, 0, 0], [5, 0, 0])
+        );
+        assert_eq!(
+            divide_uint192_inplace([0x10101010, 0x2B2B2B2B, 0xF1F1F1F1], 0x1000),
+            (
+                [0x10, 0, 0],
+                [0xB2B0000000010101, 0x1F1000000002B2B2, 0xF1F1F]
+            )
+        );
+    }
+
+    #[test]
+    fn get_barrett_crs_correct() {
+        assert_eq!(
+            get_barrett_crs(268369921u64),
+            (16144578669088582089u64, 68736257792u64)
+        );
+        assert_eq!(
+            get_barrett_crs(249561089u64),
+            (10966983149909726427u64, 73916747789u64)
+        );
+        assert_eq!(
+            get_barrett_crs(66974689739603969u64),
+            (7906011006380390721u64, 275u64)
+        );
+    }
+
+    #[test]
+    fn barrett_reduction_u128_raw_correct() {
+        let modulus = 66974689739603969u64;
+        let modulus_u128 = modulus as u128;
+        let exec = |val| {
+            barrett_reduction_u128_raw(66974689739603969u64, 7906011006380390721u64, 275u64, val)
+        };
+        assert_eq!(exec(modulus_u128), 0);
+        assert_eq!(exec(modulus_u128 + 1), 1);
+        assert_eq!(exec(modulus_u128 * 7 + 5), 5);
+
+        let mut rng = get_seeded_rng();
+        for _ in 0..100 {
+            let val = combine(rng.gen(), rng.gen());
+            assert_eq!(exec(val), (val % modulus_u128) as u64);
+        }
+    }
+
+    #[test]
+    fn barrett_raw_u64_correct() {
+        let modulus = 66974689739603969u64;
+        let cr1 = 275u64;
+
+        let mut rng = get_seeded_rng();
+        for _ in 0..100 {
+            let val = rng.gen();
+            assert_eq!(barrett_raw_u64(val, cr1, modulus), val % modulus);
+        }
+    }
 }

+ 1 - 1
spiral-rs/src/client.rs

@@ -1,7 +1,7 @@
 use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
-use rand::{Rng};
+use rand::Rng;
 use std::iter::once;
 
 fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {

+ 2 - 2
spiral-rs/src/gadget.rs

@@ -31,7 +31,7 @@ pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw
     g
 }
 
-pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>, rdim: usize)  {
+pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>, rdim: usize) {
     assert_eq!(out.cols, inp.cols);
 
     let params = inp.params;
@@ -59,7 +59,7 @@ pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'
     }
 }
 
-pub fn gadget_invert<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>)  {
+pub fn gadget_invert<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>) {
     gadget_invert_rdim(out, inp, inp.rows);
 }
 

+ 1 - 1
spiral-rs/src/lib.rs

@@ -1,8 +1,8 @@
+pub mod aligned_memory;
 pub mod arith;
 pub mod discrete_gaussian;
 pub mod number_theory;
 pub mod util;
-pub mod aligned_memory;
 
 pub mod gadget;
 pub mod ntt;

+ 1 - 1
spiral-rs/src/ntt.rs

@@ -367,7 +367,7 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
 #[cfg(test)]
 mod test {
     use super::*;
-    use crate::{util::*, aligned_memory::AlignedMemory64};
+    use crate::{aligned_memory::AlignedMemory64, util::*};
     use rand::Rng;
 
     fn get_params() -> Params {

+ 31 - 11
spiral-rs/src/params.rs

@@ -1,5 +1,7 @@
 use crate::{arith::*, ntt::*, number_theory::*, poly::*};
 
+pub const MAX_MODULI: usize = 4;
+
 pub static Q2_VALUES: [u64; 37] = [
     0,
     0,
@@ -48,7 +50,13 @@ pub struct Params {
     pub scratch: Vec<u64>,
 
     pub crt_count: usize,
-    pub moduli: Vec<u64>,
+    pub barrett_cr_0: [u64; MAX_MODULI],
+    pub barrett_cr_1: [u64; MAX_MODULI],
+    pub barrett_cr_0_modulus: u64,
+    pub barrett_cr_1_modulus: u64,
+    pub mod0_inv_mod1: u64,
+    pub mod1_inv_mod0: u64,
+    pub moduli: [u64; MAX_MODULI],
     pub modulus: u64,
     pub modulus_log2: u64,
     pub noise_width: f64,
@@ -112,13 +120,10 @@ impl Params {
     pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
         assert_eq!(self.crt_count, 2);
 
-        let a = self.moduli[0];
-        let b = self.moduli[1];
-        let a_inv_mod_b = invert_uint_mod(a, b).unwrap();
-        let b_inv_mod_a = invert_uint_mod(b, a).unwrap();
-        let mut val = (x as u128) * (b_inv_mod_a as u128) * (b as u128);
-        val += (y as u128) * (a_inv_mod_b as u128) * (a as u128);
-        (val % (self.modulus as u128)) as u64 // FIXME: use barrett
+        let mut val = (x as u128) * (self.mod1_inv_mod0 as u128);
+        val += (y as u128) * (self.mod0_inv_mod1 as u128);
+
+        barrett_reduction_u128(self, val)
     }
 
     pub fn crt_compose(&self, a: &[u64], idx: usize) -> u64 {
@@ -131,7 +136,7 @@ impl Params {
 
     pub fn init(
         poly_len: usize,
-        moduli: &Vec<u64>,
+        moduli: &[u64],
         noise_width: f64,
         n: usize,
         pt_modulus: u64,
@@ -148,20 +153,35 @@ impl Params {
     ) -> Self {
         let poly_len_log2 = log2(poly_len as u64) as usize;
         let crt_count = moduli.len();
-        let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
+        assert!(crt_count <= MAX_MODULI);
+        let mut moduli_array = [0; MAX_MODULI];
+        for i in 0..crt_count {
+            moduli_array[i] = moduli[i];
+        }
+        let ntt_tables = build_ntt_tables(poly_len, moduli);
         let scratch = vec![0u64; crt_count * poly_len];
         let mut modulus = 1;
         for m in moduli {
             modulus *= m;
         }
         let modulus_log2 = log2_ceil(modulus);
+        let (barrett_cr_0, barrett_cr_1) = get_barrett(moduli);
+        let (barrett_cr_0_modulus, barrett_cr_1_modulus) = get_barrett_crs(modulus);
+        let mod0_inv_mod1 = moduli[0] * invert_uint_mod(moduli[0], moduli[1]).unwrap();
+        let mod1_inv_mod0 = moduli[1] * invert_uint_mod(moduli[1], moduli[0]).unwrap();
         Self {
             poly_len,
             poly_len_log2,
             ntt_tables,
             scratch,
             crt_count,
-            moduli: moduli.clone(),
+            barrett_cr_0,
+            barrett_cr_1,
+            barrett_cr_0_modulus,
+            barrett_cr_1_modulus,
+            mod0_inv_mod1,
+            mod1_inv_mod0,
+            moduli: moduli_array,
             modulus,
             modulus_log2,
             noise_width,

+ 12 - 7
spiral-rs/src/poly.rs

@@ -6,7 +6,7 @@ use rand::Rng;
 use std::cell::RefCell;
 use std::ops::{Add, Mul, Neg};
 
-use crate::{arith::*, discrete_gaussian::*, ntt::*, params::*, util::*, aligned_memory::*};
+use crate::{aligned_memory::*, arith::*, discrete_gaussian::*, ntt::*, params::*, util::*};
 
 const SCRATCH_SPACE: usize = 8192;
 thread_local!(static SCRATCH: RefCell<AlignedMemory64> = RefCell::new(AlignedMemory64::new(SCRATCH_SPACE)));
@@ -355,7 +355,8 @@ pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u
 pub fn modular_reduce(params: &Params, res: &mut [u64]) {
     for c in 0..params.crt_count {
         for i in 0..params.poly_len {
-            res[c * params.poly_len + i] %= params.moduli[c];
+            let idx = c * params.poly_len + i;
+            res[idx] = barrett_coeff_u64(params, res[idx], c);
         }
     }
 }
@@ -495,17 +496,21 @@ pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
     res
 }
 
+fn reduce_copy(params: &Params, out: &mut [u64], inp: &[u64]) {
+    for n in 0..params.crt_count {
+        for z in 0..params.poly_len {
+            out[n * params.poly_len + z] = barrett_coeff_u64(params, inp[z], n);
+        }
+    }
+}
+
 pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
     let params = a.params;
     for r in 0..a.rows {
         for c in 0..a.cols {
             let pol_src = b.get_poly(r, c);
             let pol_dst = a.get_poly_mut(r, c);
-            for n in 0..params.crt_count {
-                for z in 0..params.poly_len {
-                    pol_dst[n * params.poly_len + z] = pol_src[z] % params.moduli[n];
-                }
-            }
+            reduce_copy(params, pol_dst, pol_src);
             ntt_forward(params, pol_dst);
         }
     }

+ 18 - 18
spiral-rs/src/server.rs

@@ -1,4 +1,4 @@
-use crate::arith;
+use crate::arith::*;
 use crate::gadget::*;
 use crate::params::*;
 use crate::poly::*;
@@ -41,33 +41,33 @@ pub fn coefficient_expansion(
             }
 
             let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
-                0 => (&v_w_left[r], params.t_exp_left, &mut ginv_ct_left, &mut ginv_ct_left_ntt),
-                1 | _ => (&v_w_right[r], params.t_exp_right, &mut ginv_ct_right, &mut ginv_ct_right_ntt),
+                0 => (
+                    &v_w_left[r],
+                    params.t_exp_left,
+                    &mut ginv_ct_left,
+                    &mut ginv_ct_left_ntt,
+                ),
+                1 | _ => (
+                    &v_w_right[r],
+                    params.t_exp_right,
+                    &mut ginv_ct_right,
+                    &mut ginv_ct_right_ntt,
+                ),
             };
-            // let (w, gadget_dim) = match i % 2 {
-            //     0 => (&v_w_left[r], params.t_exp_left),
-            //     1 | _ => (&v_w_right[r], params.t_exp_right),
-            // };
-
 
             if i < num_in {
                 let (src, dest) = v.split_at_mut(num_in);
                 scalar_multiply(&mut dest[i], neg1, &src[i]);
             }
 
-            // let ct = from_ntt_alloc(&v[i]);
-            // let ct_auto = automorph_alloc(&ct, t);
-            // let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
-            // let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
-            // let ginv_ct = gadget_invert_alloc(gadget_dim, &ct_auto_0);
-            // let ginv_ct_ntt = ginv_ct.ntt();
-            // let w_times_ginv_ct = w * &ginv_ct_ntt;
-
             from_ntt(&mut ct, &v[i]);
             automorph(&mut ct_auto, &ct, t);
             gadget_invert_rdim(gi_ct, &ct_auto, 1);
             to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
-            ct_auto_1.data.as_mut_slice().copy_from_slice(ct_auto.get_poly(1, 0));
+            ct_auto_1
+                .data
+                .as_mut_slice()
+                .copy_from_slice(ct_auto.get_poly(1, 0));
             to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
             multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
 
@@ -78,7 +78,7 @@ pub fn coefficient_expansion(
                         let sum = v[i].data[idx]
                             + w_times_ginv_ct.data[idx]
                             + j * ct_auto_1_ntt.data[n * poly_len + z];
-                        v[i].data[idx] = arith::modular_reduce(params, sum, n);
+                        v[i].data[idx] = barrett_coeff_u64(params, sum, n);
                         idx += 1;
                     }
                 }

+ 7 - 1
spiral-rs/src/util.rs

@@ -90,7 +90,13 @@ pub const fn get_empty_params() -> Params {
         ntt_tables: Vec::new(),
         scratch: Vec::new(),
         crt_count: 0,
-        moduli: Vec::new(),
+        barrett_cr_0_modulus: 0,
+        barrett_cr_1_modulus: 0,
+        barrett_cr_0: [0u64; MAX_MODULI],
+        barrett_cr_1: [0u64; MAX_MODULI],
+        mod0_inv_mod1: 0,
+        mod1_inv_mod0: 0,
+        moduli: [0u64; MAX_MODULI],
         modulus: 0,
         modulus_log2: 0,
         noise_width: 0f64,