Browse Source

add alignment

Samir Menon 2 years ago
parent
commit
94b40205af

+ 1 - 0
client/Cargo.toml

@@ -12,6 +12,7 @@ default = []
 
 [dependencies]
 spiral-rs = { path = "../spiral-rs" }
+rand = { version = "0.8.5" }
 wasm-bindgen = "0.2.74"
 
 # The `console_error_panic_hook` crate provides better debugging of panics by

+ 7 - 4
client/src/lib.rs

@@ -1,5 +1,5 @@
 mod utils;
-
+use rand::{rngs::ThreadRng, thread_rng};
 use spiral_rs::{client::*, discrete_gaussian::*, params::*, util::*};
 use wasm_bindgen::prelude::*;
 
@@ -19,17 +19,19 @@ macro_rules! console_log {
 // Avoids a lifetime in the return signature of bound Rust functions
 #[wasm_bindgen]
 pub struct WrappedClient {
-    client: Client<'static>,
+    client: Client<'static, ThreadRng>,
 }
 
 // Unsafe global with a static lifetime
 // Accessed unsafely only once, at load / setup
 static mut PARAMS: Params = get_empty_params();
+static mut RNG: Option<ThreadRng> = None;
 
 // Very simply test to ensure random generation is not obviously biased.
 fn dg_seems_okay() {
     let params = get_test_params();
-    let mut dg = DiscreteGaussian::init(&params);
+    let mut rng = thread_rng();
+    let mut dg = DiscreteGaussian::init(&params, &mut rng);
     let mut v = Vec::new();
     let trials = 10000;
     let mut sum = 0;
@@ -72,7 +74,8 @@ pub fn initialize(json_params: Option<String>) -> WrappedClient {
     // this minimal unsafe operation is need to initialize state
     unsafe {
         PARAMS = params_from_json(&cfg);
-        client = Client::init(&PARAMS);
+        RNG = Some(thread_rng());
+        client = Client::init(&PARAMS, RNG.as_mut().unwrap());
     }
 
     WrappedClient { client }

+ 2 - 2
spiral-rs/.cargo/config.toml

@@ -1,3 +1,3 @@
 [build]
-# target = "x86_64-unknown-linux-gnu"
-# rustflags = ["-C", "target-feature=+avx2"]
+target = "x86_64-unknown-linux-gnu"
+rustflags = ["-C", "target-feature=+avx2"]

+ 311 - 4
spiral-rs/Cargo.lock

@@ -2,6 +2,41 @@
 # It is not intended for manual editing.
 version = 3
 
+[[package]]
+name = "addr2line"
+version = "0.17.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b"
+dependencies = [
+ "gimli",
+]
+
+[[package]]
+name = "adler"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
+
+[[package]]
+name = "ahash"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
+dependencies = [
+ "getrandom",
+ "once_cell",
+ "version_check",
+]
+
+[[package]]
+name = "arrayvec"
+version = "0.4.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cd9fd44efafa8690358b7408d253adf110036b88f55672a933f01d616ad9b1b9"
+dependencies = [
+ "nodrop",
+]
+
 [[package]]
 name = "atty"
 version = "0.2.14"
@@ -19,6 +54,21 @@ version = "1.1.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
 
+[[package]]
+name = "backtrace"
+version = "0.3.64"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f"
+dependencies = [
+ "addr2line",
+ "cc",
+ "cfg-if",
+ "libc",
+ "miniz_oxide",
+ "object",
+ "rustc-demangle",
+]
+
 [[package]]
 name = "base64"
 version = "0.13.0"
@@ -27,9 +77,9 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
 
 [[package]]
 name = "bitflags"
-version = "1.3.2"
+version = "1.2.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
+checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
 
 [[package]]
 name = "bstr"
@@ -49,6 +99,12 @@ version = "3.9.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899"
 
+[[package]]
+name = "bytemuck"
+version = "1.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cdead85bdec19c194affaeeb670c0e41fe23de31459efd1c174d049269cf02cc"
+
 [[package]]
 name = "bytes"
 version = "1.1.0"
@@ -103,6 +159,15 @@ version = "0.8.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
 
+[[package]]
+name = "cpp_demangle"
+version = "0.3.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eeaa953eaad386a53111e47172c2fedba671e5684c8dd601a5f474f4f118710f"
+dependencies = [
+ "cfg-if",
+]
+
 [[package]]
 name = "criterion"
 version = "0.3.5"
@@ -205,6 +270,15 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "debugid"
+version = "0.7.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d6ee87af31d84ef885378aebca32be3d682b0e0dc119d5b4860a2c5bb5046730"
+dependencies = [
+ "uuid",
+]
+
 [[package]]
 name = "either"
 version = "1.6.1"
@@ -321,6 +395,12 @@ dependencies = [
  "wasm-bindgen",
 ]
 
+[[package]]
+name = "gimli"
+version = "0.26.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4"
+
 [[package]]
 name = "h2"
 version = "0.3.12"
@@ -453,6 +533,24 @@ dependencies = [
  "hashbrown",
 ]
 
+[[package]]
+name = "inferno"
+version = "0.10.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "de3886428c6400486522cf44b8626e7b94ad794c14390290f2a274dcf728a58f"
+dependencies = [
+ "ahash",
+ "atty",
+ "indexmap",
+ "itoa 1.0.1",
+ "lazy_static",
+ "log",
+ "num-format",
+ "quick-xml",
+ "rgb",
+ "str_stack",
+]
+
 [[package]]
 name = "instant"
 version = "0.1.12"
@@ -510,6 +608,16 @@ version = "0.2.122"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259"
 
+[[package]]
+name = "lock_api"
+version = "0.4.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53"
+dependencies = [
+ "autocfg",
+ "scopeguard",
+]
+
 [[package]]
 name = "log"
 version = "0.4.14"
@@ -531,6 +639,15 @@ version = "2.4.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
 
+[[package]]
+name = "memmap2"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "057a3db23999c867821a7a59feb06a578fcb03685e983dff90daf9e7d24ac08f"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "memoffset"
 version = "0.6.5"
@@ -546,6 +663,16 @@ version = "0.3.16"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
 
+[[package]]
+name = "miniz_oxide"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b"
+dependencies = [
+ "adler",
+ "autocfg",
+]
+
 [[package]]
 name = "mio"
 version = "0.8.2"
@@ -587,6 +714,25 @@ dependencies = [
  "tempfile",
 ]
 
+[[package]]
+name = "nix"
+version = "0.20.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f5e06129fb611568ef4e868c14b326274959aa70ff7776e9d55323531c374945"
+dependencies = [
+ "bitflags",
+ "cc",
+ "cfg-if",
+ "libc",
+ "memoffset",
+]
+
+[[package]]
+name = "nodrop"
+version = "0.1.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
+
 [[package]]
 name = "ntapi"
 version = "0.3.7"
@@ -596,6 +742,16 @@ dependencies = [
  "winapi",
 ]
 
+[[package]]
+name = "num-format"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bafe4179722c2894288ee77a9f044f02811c86af699344c498b0840c698a2465"
+dependencies = [
+ "arrayvec",
+ "itoa 0.4.8",
+]
+
 [[package]]
 name = "num-traits"
 version = "0.2.14"
@@ -615,6 +771,15 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "object"
+version = "0.27.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9"
+dependencies = [
+ "memchr",
+]
+
 [[package]]
 name = "once_cell"
 version = "1.10.0"
@@ -660,6 +825,31 @@ dependencies = [
  "vcpkg",
 ]
 
+[[package]]
+name = "parking_lot"
+version = "0.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
+dependencies = [
+ "instant",
+ "lock_api",
+ "parking_lot_core",
+]
+
+[[package]]
+name = "parking_lot_core"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216"
+dependencies = [
+ "cfg-if",
+ "instant",
+ "libc",
+ "redox_syscall",
+ "smallvec",
+ "winapi",
+]
+
 [[package]]
 name = "percent-encoding"
 version = "2.1.0"
@@ -712,6 +902,25 @@ dependencies = [
  "plotters-backend",
 ]
 
+[[package]]
+name = "pprof"
+version = "0.4.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d78fcdebc1569625891b4fefed7ece660af53082529d03d9c6e8d01b3880ab92"
+dependencies = [
+ "backtrace",
+ "criterion",
+ "inferno",
+ "lazy_static",
+ "libc",
+ "log",
+ "nix",
+ "parking_lot",
+ "symbolic-demangle",
+ "tempfile",
+ "thiserror",
+]
+
 [[package]]
 name = "ppv-lite86"
 version = "0.2.16"
@@ -727,6 +936,15 @@ dependencies = [
  "unicode-xid",
 ]
 
+[[package]]
+name = "quick-xml"
+version = "0.22.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8533f14c8382aaad0d592c812ac3b826162128b65662331e1127b45c3d18536b"
+dependencies = [
+ "memchr",
+]
+
 [[package]]
 name = "quote"
 version = "1.0.15"
@@ -866,6 +1084,21 @@ dependencies = [
  "winreg",
 ]
 
+[[package]]
+name = "rgb"
+version = "0.8.32"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e74fdc210d8f24a7dbfedc13b04ba5764f5232754ccebfdf5fff1bad791ccbc6"
+dependencies = [
+ "bytemuck",
+]
+
+[[package]]
+name = "rustc-demangle"
+version = "0.1.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342"
+
 [[package]]
 name = "rustc_version"
 version = "0.4.0"
@@ -908,9 +1141,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
 
 [[package]]
 name = "security-framework"
-version = "2.6.1"
+version = "2.3.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2dc14f172faf8a0194a3aded622712b0de276821addc574fa54fc0a1167e10dc"
+checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467"
 dependencies = [
  "bitflags",
  "core-foundation",
@@ -991,6 +1224,12 @@ version = "0.4.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5"
 
+[[package]]
+name = "smallvec"
+version = "1.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
+
 [[package]]
 name = "socket2"
 version = "0.4.4"
@@ -1007,11 +1246,47 @@ version = "0.1.0"
 dependencies = [
  "criterion",
  "getrandom",
+ "pprof",
  "rand",
  "reqwest",
  "serde_json",
 ]
 
+[[package]]
+name = "stable_deref_trait"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
+
+[[package]]
+name = "str_stack"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb"
+
+[[package]]
+name = "symbolic-common"
+version = "8.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac6aac7b803adc9ee75344af7681969f76d4b38e4723c6eaacf3b28f5f1d87ff"
+dependencies = [
+ "debugid",
+ "memmap2",
+ "stable_deref_trait",
+ "uuid",
+]
+
+[[package]]
+name = "symbolic-demangle"
+version = "8.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8143ea5aa546f86c64f9b9aafdd14223ffad4ecd2d58575c63c21335909c99a7"
+dependencies = [
+ "cpp_demangle",
+ "rustc-demangle",
+ "symbolic-common",
+]
+
 [[package]]
 name = "syn"
 version = "1.0.86"
@@ -1046,6 +1321,26 @@ dependencies = [
  "unicode-width",
 ]
 
+[[package]]
+name = "thiserror"
+version = "1.0.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
+dependencies = [
+ "thiserror-impl",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "1.0.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
 [[package]]
 name = "tinytemplate"
 version = "1.2.1"
@@ -1182,12 +1477,24 @@ dependencies = [
  "percent-encoding",
 ]
 
+[[package]]
+name = "uuid"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
+
 [[package]]
 name = "vcpkg"
 version = "0.2.15"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
 
+[[package]]
+name = "version_check"
+version = "0.9.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
+
 [[package]]
 name = "walkdir"
 version = "2.3.2"

+ 1 - 0
spiral-rs/Cargo.toml

@@ -11,6 +11,7 @@ serde_json = "1.0"
 
 [dev-dependencies]
 criterion = "0.3"
+pprof = { version = "0.4", features = ["flamegraph", "criterion"] }
 
 [[bench]]
 name = "ntt"

+ 11 - 3
spiral-rs/benches/server.rs

@@ -1,4 +1,6 @@
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use pprof::criterion::{Output, PProfProfiler};
+
 use spiral_rs::client::*;
 use spiral_rs::poly::*;
 use spiral_rs::server::*;
@@ -9,9 +11,9 @@ fn criterion_benchmark(c: &mut Criterion) {
     let mut group = c.benchmark_group("sample-size");
     group
         .sample_size(10)
-        .measurement_time(Duration::from_secs(10));
+        .measurement_time(Duration::from_secs(30));
 
-    let params = get_short_keygen_params();
+    let params = get_expansion_testing_params();
     let v_neg1 = params.get_v_neg1();
     let mut seeded_rng = get_seeded_rng();
     let mut client = Client::init(&params, &mut seeded_rng);
@@ -29,6 +31,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     let v_w_left = public_params.v_expansion_left.unwrap();
     let v_w_right = public_params.v_expansion_right.unwrap();
 
+    // note: the benchmark on AVX2 is 545ms for the c++ impl
     group.bench_function("coeff exp", |b| {
         b.iter(|| {
             coefficient_expansion(
@@ -46,5 +49,10 @@ fn criterion_benchmark(c: &mut Criterion) {
     group.finish();
 }
 
-criterion_group!(benches, criterion_benchmark);
+// criterion_group!(benches, criterion_benchmark);
+criterion_group! {
+    name = benches;
+    config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
+    targets = criterion_benchmark
+}
 criterion_main!(benches);

+ 75 - 0
spiral-rs/src/aligned_memory.rs

@@ -0,0 +1,75 @@
+use std::{alloc::{alloc_zeroed, dealloc, Layout}, slice::{from_raw_parts, from_raw_parts_mut}, ops::{Index, IndexMut}, mem::size_of};
+
+
+const ALIGN_SIMD: usize = 64; // enough to support AVX-512
+pub type AlignedMemory64 = AlignedMemory<ALIGN_SIMD>;
+
+pub struct AlignedMemory<const ALIGN: usize> {
+    p: *mut u64,
+    sz_u64: usize,
+    layout: Layout
+}
+
+impl<const ALIGN: usize> AlignedMemory<{ALIGN}> {
+    pub fn new(sz_u64: usize) -> Self {
+        let sz_bytes = sz_u64 * size_of::<u64>();
+        let layout = Layout::from_size_align(sz_bytes, ALIGN).unwrap();
+
+        let ptr;
+        unsafe {
+            ptr = alloc_zeroed(layout);
+        }
+
+        Self {
+            p: ptr as *mut u64,
+            sz_u64,
+            layout
+        }
+    }
+
+    pub fn as_slice(&self) -> &[u64] {
+        unsafe {
+            from_raw_parts(self.p, self.sz_u64)
+        }
+    }
+
+    pub fn as_mut_slice(&mut self) -> &mut [u64] {
+        unsafe {
+            from_raw_parts_mut(self.p, self.sz_u64)
+        }
+    }
+
+    pub fn len(&self) -> usize {
+        self.sz_u64
+    }
+}
+
+impl<const ALIGN: usize> Drop for AlignedMemory<{ALIGN}> {
+    fn drop(&mut self) {
+        unsafe {
+            dealloc(self.p as *mut u8, self.layout);
+        }
+    }
+}
+
+impl<const ALIGN: usize> Index<usize> for AlignedMemory<{ALIGN}> {
+    type Output = u64;
+
+    fn index(&self, index: usize) -> &Self::Output {
+        &self.as_slice()[index]
+    }
+}
+
+impl<const ALIGN: usize> IndexMut<usize> for AlignedMemory<{ALIGN}> {
+    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
+        &mut self.as_mut_slice()[index]
+    }
+}
+
+impl<const ALIGN: usize> Clone for AlignedMemory<{ALIGN}> {
+    fn clone(&self) -> Self {
+        let mut out = Self::new(self.sz_u64);
+        out.as_mut_slice().copy_from_slice(self.as_slice());
+        out
+    }
+}

+ 6 - 5
spiral-rs/src/client.rs

@@ -1,8 +1,7 @@
 use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
-use rand::rngs::StdRng;
-use rand::{thread_rng, Rng};
+use rand::{Rng};
 use std::iter::once;
 
 fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
@@ -500,6 +499,8 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
 
 #[cfg(test)]
 mod test {
+    use rand::thread_rng;
+
     use super::*;
 
     fn assert_first8(m: &[u64], gold: [u64; 8]) {
@@ -517,7 +518,7 @@ mod test {
         let mut rng = thread_rng();
         let client = Client::init(&params, &mut rng);
 
-        assert_eq!(client.stop_round, 6);
+        assert_eq!(client.stop_round, 5);
         assert_eq!(client.g, 10);
         assert_eq!(*client.params, params);
     }
@@ -531,7 +532,7 @@ mod test {
         let public_params = client.generate_keys();
 
         assert_first8(
-            &public_params.v_conversion.unwrap()[0].data,
+            public_params.v_conversion.unwrap()[0].data.as_slice(),
             [
                 253586619, 247235120, 141892996, 163163429, 15531298, 200914775, 125109567,
                 75889562,
@@ -539,7 +540,7 @@ mod test {
         );
 
         assert_first8(
-            &client.sk_gsw.data,
+            client.sk_gsw.data.as_slice(),
             [1, 5, 0, 3, 1, 3, 66974689739603967, 3],
         );
     }

+ 2 - 1
spiral-rs/src/discrete_gaussian.rs

@@ -1,7 +1,6 @@
 use rand::distributions::WeightedIndex;
 use rand::prelude::Distribution;
 use rand::Rng;
-use rand::{rngs::ThreadRng, thread_rng};
 
 use crate::params::*;
 use crate::poly::*;
@@ -53,6 +52,8 @@ impl<'a, T: Rng> DiscreteGaussian<'a, T> {
 
 #[cfg(test)]
 mod test {
+    use rand::thread_rng;
+
     use super::*;
     use crate::util::*;
 

+ 17 - 9
spiral-rs/src/gadget.rs

@@ -1,5 +1,3 @@
-use std::primitive;
-
 use crate::{params::*, poly::*};
 
 pub fn get_bits_per(params: &Params, dim: usize) -> usize {
@@ -33,16 +31,17 @@ pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw
     g
 }
 
-pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
-    let params = inp.params;
+pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>, rdim: usize)  {
+    assert_eq!(out.cols, inp.cols);
 
-    let num_elems = mx / inp.rows;
+    let params = inp.params;
+    let mx = out.rows;
+    let num_elems = mx / rdim;
     let bits_per = get_bits_per(params, num_elems);
     let mask = (1u64 << bits_per) - 1;
 
-    let mut out = PolyMatrixRaw::zero(params, mx, inp.cols);
     for i in 0..inp.cols {
-        for j in 0..inp.rows {
+        for j in 0..rdim {
             for z in 0..params.poly_len {
                 let val = inp.get_poly(j, i)[z];
                 for k in 0..num_elems {
@@ -53,11 +52,20 @@ pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a
                         None => 0,
                     };
 
-                    out.get_poly_mut(j + k * inp.rows, i)[z] = piece;
+                    out.get_poly_mut(j + k * rdim, i)[z] = piece;
                 }
             }
         }
     }
+}
+
+pub fn gadget_invert<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>)  {
+    gadget_invert_rdim(out, inp, inp.rows);
+}
+
+pub fn gadget_invert_alloc<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
+    let mut out = PolyMatrixRaw::zero(inp.params, mx, inp.cols);
+    gadget_invert(&mut out, inp);
     out
 }
 
@@ -74,7 +82,7 @@ mod test {
         mat.get_poly_mut(0, 0)[37] = 3;
         mat.get_poly_mut(1, 0)[37] = 6;
         let log_q = params.modulus_log2 as usize;
-        let result = gadget_invert(2 * log_q, &mat);
+        let result = gadget_invert_alloc(2 * log_q, &mat);
 
         assert_eq!(result.get_poly(0, 0)[37], 1);
         assert_eq!(result.get_poly(2, 0)[37], 1);

+ 1 - 0
spiral-rs/src/lib.rs

@@ -2,6 +2,7 @@ pub mod arith;
 pub mod discrete_gaussian;
 pub mod number_theory;
 pub mod util;
+pub mod aligned_memory;
 
 pub mod gadget;
 pub mod ntt;

+ 35 - 14
spiral-rs/src/ntt.rs

@@ -156,8 +156,8 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
                             // Use AVX2 here
                             let p_x = &mut op[j] as *mut u64;
                             let p_y = &mut op[j + t] as *mut u64;
-                            let x = _mm256_loadu_si256(p_x as *const __m256i);
-                            let y = _mm256_loadu_si256(p_y as *const __m256i);
+                            let x = _mm256_load_si256(p_x as *const __m256i);
+                            let y = _mm256_load_si256(p_y as *const __m256i);
 
                             let cmp_val = _mm256_set1_epi64x(two_times_modulus_small as i64);
                             let gt_mask = _mm256_cmpgt_epi64(x, cmp_val);
@@ -181,8 +181,8 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
                             let q_final_inverted = _mm256_sub_epi64(cmp_val, q_final);
                             let new_y = _mm256_add_epi64(curr_x, q_final_inverted);
 
-                            _mm256_storeu_si256(p_x as *mut __m256i, new_x);
-                            _mm256_storeu_si256(p_y as *mut __m256i, new_y);
+                            _mm256_store_si256(p_x as *mut __m256i, new_x);
+                            _mm256_store_si256(p_y as *mut __m256i, new_y);
                         }
                     }
                 }
@@ -194,7 +194,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
                 let p_x = &mut operand[i] as *mut u64;
 
                 let cmp_val1 = _mm256_set1_epi64x(two_times_modulus_small as i64);
-                let mut x = _mm256_loadu_si256(p_x as *const __m256i);
+                let mut x = _mm256_load_si256(p_x as *const __m256i);
                 let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
                 let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
                 x = _mm256_sub_epi64(x, to_subtract);
@@ -203,7 +203,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
                 gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
                 to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
                 x = _mm256_sub_epi64(x, to_subtract);
-                _mm256_storeu_si256(p_x as *mut __m256i, x);
+                _mm256_store_si256(p_x as *mut __m256i, x);
             }
         }
     }
@@ -301,8 +301,8 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
                             // Use AVX2 here
                             let p_x = &mut op[j] as *mut u64;
                             let p_y = &mut op[j + t] as *mut u64;
-                            let x = _mm256_loadu_si256(p_x as *const __m256i);
-                            let y = _mm256_loadu_si256(p_y as *const __m256i);
+                            let x = _mm256_load_si256(p_x as *const __m256i);
+                            let y = _mm256_load_si256(p_y as *const __m256i);
 
                             let modulus_vec = _mm256_set1_epi64x(modulus as i64);
                             let two_times_modulus_vec =
@@ -331,8 +331,8 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
                             let h_tmp_times_modulus = _mm256_mul_epu32(h_tmp, modulus_vec);
                             let new_y = _mm256_sub_epi64(w_times_t_tmp, h_tmp_times_modulus);
 
-                            _mm256_storeu_si256(p_x as *mut __m256i, new_x);
-                            _mm256_storeu_si256(p_y as *mut __m256i, new_y);
+                            _mm256_store_si256(p_x as *mut __m256i, new_x);
+                            _mm256_store_si256(p_y as *mut __m256i, new_y);
                         }
                     }
                 }
@@ -343,13 +343,31 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
             operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
             operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
         }
+
+        // for i in (0..n).step_by(4) {
+        //     unsafe {
+        //         let p_x = &mut operand[i] as *mut u64;
+
+        //         let cmp_val1 = _mm256_set1_epi64x(two_times_modulus as i64);
+        //         let mut x = _mm256_load_si256(p_x as *const __m256i);
+        //         let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
+        //         let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
+        //         x = _mm256_sub_epi64(x, to_subtract);
+
+        //         let cmp_val2 = _mm256_set1_epi64x(modulus as i64);
+        //         gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
+        //         to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
+        //         x = _mm256_sub_epi64(x, to_subtract);
+        //         _mm256_store_si256(p_x as *mut __m256i, x);
+        //     }
+        // }
     }
 }
 
 #[cfg(test)]
 mod test {
     use super::*;
-    use crate::util::*;
+    use crate::{util::*, aligned_memory::AlignedMemory64};
     use rand::Rng;
 
     fn get_params() -> Params {
@@ -382,7 +400,7 @@ mod test {
     #[test]
     fn ntt_forward_correct() {
         let params = get_params();
-        let mut v1 = vec![0; 2 * 2048];
+        let mut v1 = AlignedMemory64::new(2 * 2048);
         v1[0] = 100;
         v1[2048] = 100;
         ntt_forward(&params, v1.as_mut_slice());
@@ -393,7 +411,10 @@ mod test {
     #[test]
     fn ntt_inverse_correct() {
         let params = get_params();
-        let mut v1 = vec![100; 2 * 2048];
+        let mut v1 = AlignedMemory64::new(2 * 2048);
+        for i in 0..v1.len() {
+            v1[i] = 100;
+        }
         ntt_inverse(&params, v1.as_mut_slice());
         assert_eq!(v1[0], 100);
         assert_eq!(v1[2048], 100);
@@ -404,7 +425,7 @@ mod test {
     #[test]
     fn ntt_correct() {
         let params = get_params();
-        let mut v1 = vec![0; params.crt_count * params.poly_len];
+        let mut v1 = AlignedMemory64::new(params.crt_count * params.poly_len);
         let mut rng = rand::thread_rng();
         for i in 0..params.crt_count {
             for j in 0..params.poly_len {

+ 26 - 11
spiral-rs/src/poly.rs

@@ -6,10 +6,10 @@ use rand::Rng;
 use std::cell::RefCell;
 use std::ops::{Add, Mul, Neg};
 
-use crate::{arith::*, discrete_gaussian::*, ntt::*, params::*, util::*};
+use crate::{arith::*, discrete_gaussian::*, ntt::*, params::*, util::*, aligned_memory::*};
 
 const SCRATCH_SPACE: usize = 8192;
-thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
+thread_local!(static SCRATCH: RefCell<AlignedMemory64> = RefCell::new(AlignedMemory64::new(SCRATCH_SPACE)));
 
 pub trait PolyMatrix<'a> {
     fn is_ntt(&self) -> bool;
@@ -59,14 +59,14 @@ pub struct PolyMatrixRaw<'a> {
     pub params: &'a Params,
     pub rows: usize,
     pub cols: usize,
-    pub data: Vec<u64>,
+    pub data: AlignedMemory64,
 }
 
 pub struct PolyMatrixNTT<'a> {
     pub params: &'a Params,
     pub rows: usize,
     pub cols: usize,
-    pub data: Vec<u64>,
+    pub data: AlignedMemory64,
 }
 
 impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
@@ -93,7 +93,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
     }
     fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
         let num_coeffs = rows * cols * params.poly_len;
-        let data: Vec<u64> = vec![0; num_coeffs];
+        let data = AlignedMemory64::new(num_coeffs);
         PolyMatrixRaw {
             params,
             rows,
@@ -143,7 +143,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
 impl<'a> PolyMatrixRaw<'a> {
     pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
         let num_coeffs = rows * cols * params.poly_len;
-        let mut data: Vec<u64> = vec![0; num_coeffs];
+        let mut data = AlignedMemory::new(num_coeffs);
         for r in 0..rows {
             let c = r;
             let idx = r * cols * params.poly_len + c * params.poly_len;
@@ -227,7 +227,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
     }
     fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
         let num_coeffs = rows * cols * params.poly_len * params.crt_count;
-        let data: Vec<u64> = vec![0; num_coeffs];
+        let data = AlignedMemory::new(num_coeffs);
         PolyMatrixNTT {
             params,
             rows,
@@ -339,14 +339,14 @@ pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u
                 let p_x = &a[c * params.poly_len + i] as *const u64;
                 let p_y = &b[c * params.poly_len + i] as *const u64;
                 let p_z = &mut res[c * params.poly_len + i] as *mut u64;
-                let x = _mm256_loadu_si256(p_x as *const __m256i);
-                let y = _mm256_loadu_si256(p_y as *const __m256i);
-                let z = _mm256_loadu_si256(p_z as *const __m256i);
+                let x = _mm256_load_si256(p_x as *const __m256i);
+                let y = _mm256_load_si256(p_y as *const __m256i);
+                let z = _mm256_load_si256(p_z as *const __m256i);
 
                 let product = _mm256_mul_epu32(x, y);
                 let out = _mm256_add_epi64(z, product);
 
-                _mm256_storeu_si256(p_z as *mut __m256i, out);
+                _mm256_store_si256(p_z as *mut __m256i, out);
             }
         }
     }
@@ -511,6 +511,21 @@ pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
     }
 }
 
+pub fn to_ntt_no_reduce(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
+    let params = a.params;
+    for r in 0..a.rows {
+        for c in 0..a.cols {
+            let pol_src = b.get_poly(r, c);
+            let pol_dst = a.get_poly_mut(r, c);
+            for n in 0..params.crt_count {
+                let idx = n * params.poly_len;
+                pol_dst[idx..idx + params.poly_len].copy_from_slice(pol_src);
+            }
+            ntt_forward(params, pol_dst);
+        }
+    }
+}
+
 pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
     let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
     to_ntt(&mut a, b);

+ 35 - 12
spiral-rs/src/server.rs

@@ -1,5 +1,5 @@
 use crate::arith;
-use crate::gadget::gadget_invert;
+use crate::gadget::*;
 use crate::params::*;
 use crate::poly::*;
 
@@ -15,6 +15,16 @@ pub fn coefficient_expansion(
 ) {
     let poly_len = params.poly_len;
 
+    let mut ct = PolyMatrixRaw::zero(params, 2, 1);
+    let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
+    let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
+    let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
+    let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
+    let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
+    let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
+    let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
+    let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
+
     for r in 0..g {
         let num_in = 1 << r;
         let num_out = 2 * num_in;
@@ -30,23 +40,36 @@ pub fn coefficient_expansion(
                 continue;
             }
 
-            let (w, gadget_dim) = match i % 2 {
-                0 => (&v_w_left[r], params.t_exp_left),
-                1 | _ => (&v_w_right[r], params.t_exp_right),
+            let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
+                0 => (&v_w_left[r], params.t_exp_left, &mut ginv_ct_left, &mut ginv_ct_left_ntt),
+                1 | _ => (&v_w_right[r], params.t_exp_right, &mut ginv_ct_right, &mut ginv_ct_right_ntt),
             };
+            // let (w, gadget_dim) = match i % 2 {
+            //     0 => (&v_w_left[r], params.t_exp_left),
+            //     1 | _ => (&v_w_right[r], params.t_exp_right),
+            // };
+
 
             if i < num_in {
                 let (src, dest) = v.split_at_mut(num_in);
                 scalar_multiply(&mut dest[i], neg1, &src[i]);
             }
 
-            let ct = from_ntt_alloc(&v[i]);
-            let ct_auto = automorph_alloc(&ct, t);
-            let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
-            let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
-            let ginv_ct = gadget_invert(gadget_dim, &ct_auto_0);
-            let ginv_ct_ntt = ginv_ct.ntt();
-            let w_times_ginv_ct = w * &ginv_ct_ntt;
+            // let ct = from_ntt_alloc(&v[i]);
+            // let ct_auto = automorph_alloc(&ct, t);
+            // let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
+            // let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
+            // let ginv_ct = gadget_invert_alloc(gadget_dim, &ct_auto_0);
+            // let ginv_ct_ntt = ginv_ct.ntt();
+            // let w_times_ginv_ct = w * &ginv_ct_ntt;
+
+            from_ntt(&mut ct, &v[i]);
+            automorph(&mut ct_auto, &ct, t);
+            gadget_invert_rdim(gi_ct, &ct_auto, 1);
+            to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
+            ct_auto_1.data.as_mut_slice().copy_from_slice(ct_auto.get_poly(1, 0));
+            to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
+            multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
 
             let mut idx = 0;
             for j in 0..2 {
@@ -71,7 +94,7 @@ mod test {
     use super::*;
 
     fn get_params() -> Params {
-        get_short_keygen_params()
+        get_expansion_testing_params()
     }
 
     #[test]

+ 20 - 0
spiral-rs/src/util.rs

@@ -52,6 +52,26 @@ pub fn get_short_keygen_params() -> Params {
     )
 }
 
+pub fn get_expansion_testing_params() -> Params {
+    let cfg = r#"
+        {'n': 2,
+        'nu_1': 9,
+        'nu_2': 6,
+        'p': 256,
+        'q_prime_bits': 20,
+        's_e': 87.62938774292914,
+        't_GSW': 8,
+        't_conv': 4,
+        't_exp': 8,
+        't_exp_right': 56,
+        'instances': 1,
+        'db_item_size': 256 }
+    "#;
+    let cfg = cfg.replace("'", "\"");
+    let b = params_from_json(&cfg);
+    b
+}
+
 pub fn get_seed() -> [u8; 32] {
     [
         1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,