Samir Menon 2 years ago
parent
commit
7b3d65eaf0
8 changed files with 472 additions and 0 deletions
  1. 75 0
      Cargo.lock
  2. 1 0
      Cargo.toml
  3. 50 0
      src/arith.rs
  4. 17 0
      src/main.rs
  5. 30 0
      src/ntt.rs
  6. 96 0
      src/number_theory.rs
  7. 34 0
      src/params.rs
  8. 169 0
      src/poly.rs

+ 75 - 0
Cargo.lock

@@ -0,0 +1,75 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "cfg-if"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+
+[[package]]
+name = "getrandom"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "418d37c8b1d42553c93648be529cb70f920d3baf8ef469b74b9638df426e0b4c"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi",
+]
+
+[[package]]
+name = "libc"
+version = "0.2.119"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1bf2e165bb3457c8e098ea76f3e3bc9db55f87aa90d52d0e6be741470916aaa4"
+
+[[package]]
+name = "ppv-lite86"
+version = "0.2.16"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
+
+[[package]]
+name = "rand"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
+dependencies = [
+ "libc",
+ "rand_chacha",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+dependencies = [
+ "ppv-lite86",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7"
+dependencies = [
+ "getrandom",
+]
+
+[[package]]
+name = "spiral-rs"
+version = "0.1.0"
+dependencies = [
+ "rand",
+]
+
+[[package]]
+name = "wasi"
+version = "0.10.2+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"

+ 1 - 0
Cargo.toml

@@ -6,3 +6,4 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
+rand = "0.8.5"

+ 50 - 0
src/arith.rs

@@ -0,0 +1,50 @@
+use crate::params::*;
+use std::mem;
+
+pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
+    (((a as u128) * (b as u128)) % (modulus as u128)) as u64
+}
+
+pub const fn log2(a: u64) -> u64 {
+    std::mem::size_of::<u64>() as u64 * 8 - a.leading_zeros() as u64 - 1
+}
+
+pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
+    (a * b) % params.moduli[c]
+}
+
+pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
+    (a * b + x) % params.moduli[c]
+}
+
+fn swap(a: u64, b: u64) -> (u64, u64) {
+    (b, a)
+}
+
+pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
+    if exponent == 0 {
+        return 1;
+    }
+
+    if exponent == 1 {
+        return operand;
+    }
+
+    let mut power = operand;
+    let mut product = 0u64;
+    let mut intermediate = 0u64;
+
+    loop {
+        if (exponent & 1) == 1 {
+            product = multiply_uint_mod(power, intermediate, modulus);
+            mem::swap(&mut product, &mut intermediate);
+        }
+        exponent >>= 1;
+        if exponent == 0 {
+            break;
+        }
+        product = multiply_uint_mod(power, power, modulus);
+        mem::swap(&mut product, &mut power);
+    }
+    intermediate
+}

+ 17 - 0
src/main.rs

@@ -1,3 +1,20 @@
+mod arith;
+mod ntt;
+mod number_theory;
+mod params;
+mod poly;
+
+use crate::params::*;
+use crate::poly::*;
+
 fn main() {
     println!("Hello, world!");
+    let params = Params::init(2048, vec![7, 31]);
+    let m1 = poly::PolyMatrixNTT::zero(&params, 2, 1);
+    println!("{}", m1.is_ntt());
+    let m2 = poly::PolyMatrixNTT::zero(&params, 3, 2);
+    let mut m3 = poly::PolyMatrixNTT::zero(&params, 3, 1);
+    println!("{}", m1.is_ntt());
+    multiply(&mut m3, &m2, &m1);
+    println!("{}", m3.is_ntt());
 }

+ 30 - 0
src/ntt.rs

@@ -0,0 +1,30 @@
+use std::usize;
+
+use crate::{number_theory::*, params::*, poly::*};
+use rand::Rng;
+
+pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
+    let mut v: Vec<Vec<Vec<u64>>> = Vec::new();
+    for coeff_mod in 0..moduli.len() {
+        let modulus = moduli[coeff_mod];
+        let root = get_minimal_primitive_root(2 * poly_len as u64, modulus).unwrap();
+        let inv_root = invert_uint_mod(root, modulus);
+    }
+    v
+}
+
+pub fn ntt_forward(params: Params, out: &mut PolyMatrixRaw, inp: &PolyMatrixRaw) {
+    for coeff_mod in 0..params.crt_count {
+        let mut n = params.poly_len;
+
+        for mm in 0..params.poly_len_log2 {
+            let m = 1 << mm;
+            let t = n >> (mm + 1);
+
+            for i in 0..m {
+                let w = params.get_ntt_forward_table(coeff_mod);
+                let wprime = params.get_ntt_forward_prime_table(coeff_mod);
+            }
+        }
+    }
+}

+ 96 - 0
src/number_theory.rs

@@ -0,0 +1,96 @@
+use crate::arith::*;
+use rand::Rng;
+
+const ATTEMPT_MAX: usize = 100;
+
+pub fn is_primitive_root(root: u64, degree: u64, modulus: u64) -> bool {
+    if root == 0 {
+        return false;
+    }
+
+    exponentiate_uint_mod(root, degree >> 1, modulus) == modulus - 1
+}
+
+pub fn get_primitive_root(degree: u64, modulus: u64) -> Option<u64> {
+    assert!(modulus > 1);
+    assert!(degree >= 2);
+    let size_entire_group = degree - 1;
+    let size_quotient_group = size_entire_group / degree;
+    if size_entire_group - size_quotient_group * degree != 0 {
+        return None;
+    }
+
+    let mut root = 0u64;
+    for trial in 0..ATTEMPT_MAX {
+        let mut rng = rand::thread_rng();
+        let r1: u64 = rng.gen();
+        let r2: u64 = rng.gen();
+        let r3 = ((r1 << 32) | r2) % modulus;
+        root = exponentiate_uint_mod(r3, size_quotient_group, modulus);
+        if is_primitive_root(root, degree, modulus) {
+            break;
+        }
+        if trial != ATTEMPT_MAX - 1 {
+            return None;
+        }
+    }
+
+    Some(root)
+}
+
+pub fn get_minimal_primitive_root(degree: u64, modulus: u64) -> Option<u64> {
+    let mut root = get_primitive_root(degree, modulus)?;
+    let generator_sq = multiply_uint_mod(root, root, modulus);
+    let mut current_generator = root;
+
+    for _ in 0..degree {
+        if current_generator < root {
+            root = current_generator;
+        }
+
+        current_generator = multiply_uint_mod(current_generator, generator_sq, modulus);
+    }
+
+    Some(current_generator)
+}
+
+pub fn extended_gcd(mut x: u64, mut y: u64) -> (u64, i64, i64) {
+    assert!(x != 0);
+    assert!(y != 0);
+
+    let mut prev_a = 1;
+    let mut a = 0;
+    let mut prev_b = 0;
+    let mut b = 1;
+
+    while y != 0 {
+        let q: i64 = (x / y) as i64;
+        let mut temp = (x % y) as i64;
+        x = y;
+        y = temp as u64;
+
+        temp = a;
+        a = prev_a - (q * a);
+        prev_a = temp;
+
+        temp = b;
+        b = prev_b - (q * b);
+        prev_b = temp;
+    }
+
+    (x, prev_a, prev_b)
+}
+
+pub fn invert_uint_mod(value: u64, modulus: u64) -> Option<u64> {
+    if value == 0 {
+        return None;
+    }
+    let gcd_tuple = extended_gcd(value, modulus);
+    if gcd_tuple.0 != 1 {
+        return None;
+    } else if gcd_tuple.1 < 0 {
+        return Some(gcd_tuple.1 as u64 + modulus);
+    } else {
+        return Some(gcd_tuple.1 as u64);
+    }
+}

+ 34 - 0
src/params.rs

@@ -0,0 +1,34 @@
+use crate::{arith::*, ntt::*};
+
+pub struct Params {
+    pub poly_len: usize,
+    pub poly_len_log2: usize,
+    pub ntt_tables: Vec<Vec<Vec<u64>>>,
+    pub crt_count: usize,
+    pub moduli: Vec<u64>,
+}
+
+impl Params {
+    pub fn num_words(&self) -> usize {
+        self.poly_len * self.crt_count
+    }
+    pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
+        self.ntt_tables[i][0].as_slice()
+    }
+    pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
+        self.ntt_tables[i][1].as_slice()
+    }
+
+    pub fn init(poly_len: usize, moduli: Vec<u64>) -> Self {
+        let poly_len_log2 = log2(poly_len as u64) as usize;
+        let crt_count = moduli.len();
+        let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
+        Self {
+            poly_len,
+            poly_len_log2,
+            ntt_tables,
+            crt_count,
+            moduli,
+        }
+    }
+}

+ 169 - 0
src/poly.rs

@@ -0,0 +1,169 @@
+use crate::{arith::*, params::*};
+
+pub trait PolyMatrix<'a> {
+    fn is_ntt(&self) -> bool;
+    fn get_rows(&self) -> usize;
+    fn get_cols(&self) -> usize;
+    fn get_params(&self) -> &Params;
+    fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
+    fn as_slice(&self) -> &[u64];
+    fn as_mut_slice(&mut self) -> &mut [u64];
+    fn zero_out(&mut self) {
+        for item in self.as_mut_slice() {
+            *item = 0;
+        }
+    }
+    fn get_poly(&self, row: usize, col: usize) -> &[u64] {
+        let params = self.get_params();
+        let start = (row * self.get_cols() + col) * params.poly_len;
+        &self.as_slice()[start..start + params.poly_len]
+    }
+    fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
+        let poly_len = self.get_params().poly_len;
+        let start = (row * self.get_cols() + col) * poly_len;
+        &mut self.as_mut_slice()[start..start + poly_len]
+    }
+}
+
+pub struct PolyMatrixRaw<'a> {
+    params: &'a Params,
+    rows: usize,
+    cols: usize,
+    data: Vec<u64>,
+}
+
+pub struct PolyMatrixNTT<'a> {
+    params: &'a Params,
+    rows: usize,
+    cols: usize,
+    data: Vec<u64>,
+}
+
+impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
+    fn is_ntt(&self) -> bool {
+        false
+    }
+    fn get_rows(&self) -> usize {
+        self.rows
+    }
+    fn get_cols(&self) -> usize {
+        self.cols
+    }
+    fn get_params(&self) -> &Params {
+        &self.params
+    }
+    fn as_slice(&self) -> &[u64] {
+        self.data.as_slice()
+    }
+    fn as_mut_slice(&mut self) -> &mut [u64] {
+        self.data.as_mut_slice()
+    }
+    fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
+        let num_coeffs = rows * cols * params.poly_len;
+        let data: Vec<u64> = vec![0; num_coeffs];
+        PolyMatrixRaw {
+            params,
+            rows,
+            cols,
+            data,
+        }
+    }
+}
+
+impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
+    fn is_ntt(&self) -> bool {
+        true
+    }
+    fn get_rows(&self) -> usize {
+        self.rows
+    }
+    fn get_cols(&self) -> usize {
+        self.cols
+    }
+    fn get_params(&self) -> &Params {
+        &self.params
+    }
+    fn as_slice(&self) -> &[u64] {
+        self.data.as_slice()
+    }
+    fn as_mut_slice(&mut self) -> &mut [u64] {
+        self.data.as_mut_slice()
+    }
+    fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
+        let num_coeffs = rows * cols * params.poly_len * params.crt_count;
+        let data: Vec<u64> = vec![0; num_coeffs];
+        PolyMatrixNTT {
+            params,
+            rows,
+            cols,
+            data,
+        }
+    }
+}
+
+pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            res[i] = multiply_modular(params, a[i], b[i], c);
+        }
+    }
+}
+
+pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
+        }
+    }
+}
+
+pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert!(a.cols == b.rows);
+
+    for i in 0..a.rows {
+        for j in 0..b.cols {
+            for z in 0..res.params.poly_len {
+                res.get_poly_mut(i, j)[z] = 0;
+            }
+            for k in 0..a.cols {
+                let params = res.params;
+                let res_poly = res.get_poly_mut(i, j);
+                let pol1 = a.get_poly(i, k);
+                let pol2 = b.get_poly(k, j);
+                multiply_add_poly(params, res_poly, pol1, pol2);
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    fn get_params() -> Params {
+        Params::init(2048, vec![7, 31])
+    }
+
+    fn assert_all_zero(a: &[u64]) {
+        for i in a {
+            assert_eq!(*i, 0);
+        }
+    }
+
+    #[test]
+    fn sets_all_zeros() {
+        let params = get_params();
+        let m1 = PolyMatrixNTT::zero(&params, 2, 1);
+        assert_all_zero(m1.as_slice());
+    }
+
+    #[test]
+    fn multiply_correctness() {
+        let params = get_params();
+        let m1 = PolyMatrixNTT::zero(&params, 2, 1);
+        let m2 = PolyMatrixNTT::zero(&params, 3, 2);
+        let mut m3 = PolyMatrixNTT::zero(&params, 3, 1);
+        multiply(&mut m3, &m2, &m1);
+        assert_all_zero(m3.as_slice());
+    }
+}