Browse Source

working ntt

Samir Menon 2 years ago
parent
commit
62283c0892
11 changed files with 931 additions and 27 deletions
  1. 45 0
      .vscode/launch.json
  2. 584 0
      Cargo.lock
  3. 8 1
      Cargo.toml
  4. 23 0
      benches/ntt.rs
  5. 35 3
      src/arith.rs
  6. 6 0
      src/lib.rs
  7. 3 8
      src/main.rs
  8. 208 11
      src/ntt.rs
  9. 4 4
      src/number_theory.rs
  10. 6 0
      src/params.rs
  11. 9 0
      src/util.rs

+ 45 - 0
.vscode/launch.json

@@ -0,0 +1,45 @@
+{
+    // Use IntelliSense to learn about possible attributes.
+    // Hover to view descriptions of existing attributes.
+    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "type": "lldb",
+            "request": "launch",
+            "name": "Debug executable 'spiral-rs'",
+            "cargo": {
+                "args": [
+                    "build",
+                    "--bin=spiral-rs",
+                    "--package=spiral-rs"
+                ],
+                "filter": {
+                    "name": "spiral-rs",
+                    "kind": "bin"
+                }
+            },
+            "args": [],
+            "cwd": "${workspaceFolder}"
+        },
+        {
+            "type": "lldb",
+            "request": "launch",
+            "name": "Debug unit tests in executable 'spiral-rs'",
+            "cargo": {
+                "args": [
+                    "test",
+                    "--no-run",
+                    "--bin=spiral-rs",
+                    "--package=spiral-rs"
+                ],
+                "filter": {
+                    "name": "spiral-rs",
+                    "kind": "bin"
+                }
+            },
+            "args": [],
+            "cwd": "${workspaceFolder}"
+        }
+    ]
+}

+ 584 - 0
Cargo.lock

@@ -2,12 +2,181 @@
 # It is not intended for manual editing.
 version = 3
 
+[[package]]
+name = "atty"
+version = "0.2.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
+dependencies = [
+ "hermit-abi",
+ "libc",
+ "winapi",
+]
+
+[[package]]
+name = "autocfg"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
+
+[[package]]
+name = "bitflags"
+version = "1.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
+
+[[package]]
+name = "bstr"
+version = "0.2.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223"
+dependencies = [
+ "lazy_static",
+ "memchr",
+ "regex-automata",
+ "serde",
+]
+
+[[package]]
+name = "bumpalo"
+version = "3.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899"
+
+[[package]]
+name = "cast"
+version = "0.2.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4c24dab4283a142afa2fdca129b80ad2c6284e073930f964c3a1293c225ee39a"
+dependencies = [
+ "rustc_version",
+]
+
 [[package]]
 name = "cfg-if"
 version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
 
+[[package]]
+name = "clap"
+version = "2.34.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
+dependencies = [
+ "bitflags",
+ "textwrap",
+ "unicode-width",
+]
+
+[[package]]
+name = "criterion"
+version = "0.3.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1604dafd25fba2fe2d5895a9da139f8dc9b319a5fe5354ca137cbbce4e178d10"
+dependencies = [
+ "atty",
+ "cast",
+ "clap",
+ "criterion-plot",
+ "csv",
+ "itertools",
+ "lazy_static",
+ "num-traits",
+ "oorandom",
+ "plotters",
+ "rayon",
+ "regex",
+ "serde",
+ "serde_cbor",
+ "serde_derive",
+ "serde_json",
+ "tinytemplate",
+ "walkdir",
+]
+
+[[package]]
+name = "criterion-plot"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d00996de9f2f7559f7f4dc286073197f83e92256a59ed395f9aac01fe717da57"
+dependencies = [
+ "cast",
+ "itertools",
+]
+
+[[package]]
+name = "crossbeam-channel"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e54ea8bc3fb1ee042f5aace6e3c6e025d3874866da222930f70ce62aceba0bfa"
+dependencies = [
+ "cfg-if",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
+dependencies = [
+ "cfg-if",
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c00d6d2ea26e8b151d99093005cb442fb9a37aeaca582a03ec70946f49ab5ed9"
+dependencies = [
+ "cfg-if",
+ "crossbeam-utils",
+ "lazy_static",
+ "memoffset",
+ "scopeguard",
+]
+
+[[package]]
+name = "crossbeam-utils"
+version = "0.8.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b5e5bed1f1c269533fa816a0a5492b3545209a205ca1a54842be180eb63a16a6"
+dependencies = [
+ "cfg-if",
+ "lazy_static",
+]
+
+[[package]]
+name = "csv"
+version = "1.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1"
+dependencies = [
+ "bstr",
+ "csv-core",
+ "itoa 0.4.8",
+ "ryu",
+ "serde",
+]
+
+[[package]]
+name = "csv-core"
+version = "0.1.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90"
+dependencies = [
+ "memchr",
+]
+
+[[package]]
+name = "either"
+version = "1.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
+
 [[package]]
 name = "getrandom"
 version = "0.2.4"
@@ -19,18 +188,164 @@ dependencies = [
  "wasi",
 ]
 
+[[package]]
+name = "half"
+version = "1.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
+
+[[package]]
+name = "hermit-abi"
+version = "0.1.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
+dependencies = [
+ "libc",
+]
+
+[[package]]
+name = "itertools"
+version = "0.10.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3"
+dependencies = [
+ "either",
+]
+
+[[package]]
+name = "itoa"
+version = "0.4.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4"
+
+[[package]]
+name = "itoa"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
+
+[[package]]
+name = "js-sys"
+version = "0.3.56"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04"
+dependencies = [
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "lazy_static"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
+
 [[package]]
 name = "libc"
 version = "0.2.119"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1bf2e165bb3457c8e098ea76f3e3bc9db55f87aa90d52d0e6be741470916aaa4"
 
+[[package]]
+name = "log"
+version = "0.4.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
+dependencies = [
+ "cfg-if",
+]
+
+[[package]]
+name = "memchr"
+version = "2.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
+
+[[package]]
+name = "memoffset"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "num_cpus"
+version = "1.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
+dependencies = [
+ "hermit-abi",
+ "libc",
+]
+
+[[package]]
+name = "oorandom"
+version = "11.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
+
+[[package]]
+name = "plotters"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "32a3fd9ec30b9749ce28cd91f255d569591cdf937fe280c312143e3c4bad6f2a"
+dependencies = [
+ "num-traits",
+ "plotters-backend",
+ "plotters-svg",
+ "wasm-bindgen",
+ "web-sys",
+]
+
+[[package]]
+name = "plotters-backend"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d88417318da0eaf0fdcdb51a0ee6c3bed624333bff8f946733049380be67ac1c"
+
+[[package]]
+name = "plotters-svg"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "521fa9638fa597e1dc53e9412a4f9cefb01187ee1f7413076f9e6749e2885ba9"
+dependencies = [
+ "plotters-backend",
+]
+
 [[package]]
 name = "ppv-lite86"
 version = "0.2.16"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
 
+[[package]]
+name = "proc-macro2"
+version = "1.0.36"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029"
+dependencies = [
+ "unicode-xid",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "864d3e96a899863136fc6e99f3d7cae289dafe43bf2c5ac19b70df7210c0a145"
+dependencies = [
+ "proc-macro2",
+]
+
 [[package]]
 name = "rand"
 version = "0.8.5"
@@ -61,15 +376,284 @@ dependencies = [
  "getrandom",
 ]
 
+[[package]]
+name = "rayon"
+version = "1.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90"
+dependencies = [
+ "autocfg",
+ "crossbeam-deque",
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e"
+dependencies = [
+ "crossbeam-channel",
+ "crossbeam-deque",
+ "crossbeam-utils",
+ "lazy_static",
+ "num_cpus",
+]
+
+[[package]]
+name = "regex"
+version = "1.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
+dependencies = [
+ "regex-syntax",
+]
+
+[[package]]
+name = "regex-automata"
+version = "0.1.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
+
+[[package]]
+name = "regex-syntax"
+version = "0.6.25"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
+
+[[package]]
+name = "rustc_version"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
+dependencies = [
+ "semver",
+]
+
+[[package]]
+name = "ryu"
+version = "1.0.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
+
+[[package]]
+name = "same-file"
+version = "1.0.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
+dependencies = [
+ "winapi-util",
+]
+
+[[package]]
+name = "scopeguard"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
+
+[[package]]
+name = "semver"
+version = "1.0.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a4a3381e03edd24287172047536f20cabde766e2cd3e65e6b00fb3af51c4f38d"
+
+[[package]]
+name = "serde"
+version = "1.0.136"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789"
+
+[[package]]
+name = "serde_cbor"
+version = "0.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
+dependencies = [
+ "half",
+ "serde",
+]
+
+[[package]]
+name = "serde_derive"
+version = "1.0.136"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "serde_json"
+version = "1.0.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95"
+dependencies = [
+ "itoa 1.0.1",
+ "ryu",
+ "serde",
+]
+
 [[package]]
 name = "spiral-rs"
 version = "0.1.0"
 dependencies = [
+ "criterion",
  "rand",
 ]
 
+[[package]]
+name = "syn"
+version = "1.0.86"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8a65b3f4ffa0092e9887669db0eae07941f023991ab58ea44da8fe8e2d511c6b"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-xid",
+]
+
+[[package]]
+name = "textwrap"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
+dependencies = [
+ "unicode-width",
+]
+
+[[package]]
+name = "tinytemplate"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
+[[package]]
+name = "unicode-width"
+version = "0.1.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973"
+
+[[package]]
+name = "unicode-xid"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
+
+[[package]]
+name = "walkdir"
+version = "2.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
+dependencies = [
+ "same-file",
+ "winapi",
+ "winapi-util",
+]
+
 [[package]]
 name = "wasi"
 version = "0.10.2+wasi-snapshot-preview1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
+
+[[package]]
+name = "wasm-bindgen"
+version = "0.2.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06"
+dependencies = [
+ "cfg-if",
+ "wasm-bindgen-macro",
+]
+
+[[package]]
+name = "wasm-bindgen-backend"
+version = "0.2.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca"
+dependencies = [
+ "bumpalo",
+ "lazy_static",
+ "log",
+ "proc-macro2",
+ "quote",
+ "syn",
+ "wasm-bindgen-shared",
+]
+
+[[package]]
+name = "wasm-bindgen-macro"
+version = "0.2.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01"
+dependencies = [
+ "quote",
+ "wasm-bindgen-macro-support",
+]
+
+[[package]]
+name = "wasm-bindgen-macro-support"
+version = "0.2.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+ "wasm-bindgen-backend",
+ "wasm-bindgen-shared",
+]
+
+[[package]]
+name = "wasm-bindgen-shared"
+version = "0.2.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2"
+
+[[package]]
+name = "web-sys"
+version = "0.3.56"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb"
+dependencies = [
+ "js-sys",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-util"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
+dependencies = [
+ "winapi",
+]
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"

+ 8 - 1
Cargo.toml

@@ -6,4 +6,11 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
-rand = "0.8.5"
+rand = "0.8.5"
+
+[dev-dependencies]
+criterion = "0.3"
+
+[[bench]]
+name = "ntt"
+harness = false

+ 23 - 0
benches/ntt.rs

@@ -0,0 +1,23 @@
+use spiral_rs::ntt::*;
+use spiral_rs::params::*;
+use spiral_rs::util::*;
+use rand::Rng;
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let params = Params::init(2048, vec![268369921u64, 249561089u64]);
+    let mut v1 = vec![0; params.crt_count * params.poly_len];
+    let mut rng = rand::thread_rng();
+    for i in 0..params.crt_count {
+        for j in 0..params.poly_len {
+            let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
+            let val: u64 = rng.gen();
+            v1[idx] = val % params.moduli[i];
+        }
+    }
+    c.bench_function("nttf 2048", |b| b.iter(|| ntt_forward(black_box(&params), black_box(&mut v1))));
+    c.bench_function("ntti 2048", |b| b.iter(|| ntt_inverse(black_box(&params), black_box(&mut v1))));
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);

+ 35 - 3
src/arith.rs

@@ -31,11 +31,11 @@ pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u
     }
 
     let mut power = operand;
-    let mut product = 0u64;
-    let mut intermediate = 0u64;
+    let mut product;
+    let mut intermediate = 1u64;
 
     loop {
-        if (exponent & 1) == 1 {
+        if (exponent % 2) == 1 {
             product = multiply_uint_mod(power, intermediate, modulus);
             mem::swap(&mut product, &mut intermediate);
         }
@@ -48,3 +48,35 @@ pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u
     }
     intermediate
 }
+
+pub fn reverse_bits(x: u64, bit_count: usize) -> u64 {
+    if bit_count == 0 {
+        return 0;
+    }
+
+    let r = x.reverse_bits();
+    r >> (mem::size_of::<u64>() * 8 - bit_count)
+}
+
+pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 {
+    if operand & 1 == 1 {
+        let res = operand.overflowing_add(modulus);
+        if res.1 {
+            return (res.0 >> 1) | (1u64 << 63);
+        } else {
+            return res.0 >> 1;
+        }
+    } else {
+        return operand >> 1;
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn div2_uint_mod_correct() {
+        assert_eq!(div2_uint_mod(3, 7), 5);
+    }
+}

+ 6 - 0
src/lib.rs

@@ -0,0 +1,6 @@
+pub mod arith;
+pub mod ntt;
+pub mod number_theory;
+pub mod params;
+pub mod poly;
+pub mod util;

+ 3 - 8
src/main.rs

@@ -1,11 +1,6 @@
-mod arith;
-mod ntt;
-mod number_theory;
-mod params;
-mod poly;
-
-use crate::params::*;
-use crate::poly::*;
+use spiral_rs::poly::*;
+use spiral_rs::params::*;
+use spiral_rs::*;
 
 fn main() {
     println!("Hello, world!");

+ 208 - 11
src/ntt.rs

@@ -1,30 +1,227 @@
-use std::usize;
+use crate::{
+    arith::*,
+    number_theory::*,
+    params::*,
+    poly::*,
+    util::*,
+};
 
-use crate::{number_theory::*, params::*, poly::*};
-use rand::Rng;
+pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec<u64> {
+    let poly_len = 1usize << poly_len_log2;
+    let mut root_powers = vec![0u64; poly_len];
+    let mut power = root;
+    for i in 1..poly_len {
+        let idx = reverse_bits(i as u64, poly_len_log2) as usize;
+        root_powers[idx] = power;
+        power = multiply_uint_mod(power, root, modulus);
+    }
+    root_powers[0] = 1;
+    root_powers
+}
+
+pub fn scale_powers_u64(modulus: u64, poly_len: usize, inp: &[u64]) -> Vec<u64> {
+    let mut scaled_powers = vec![0; poly_len];
+    for i in 0..poly_len {
+        let wide_val = (inp[i] as u128) << 64u128;
+        let quotient = wide_val / (modulus as u128);
+        scaled_powers[i] = quotient as u64;
+    }
+    scaled_powers
+}
+
+pub fn scale_powers_u32(modulus: u32, poly_len: usize, inp: &[u64]) -> Vec<u64> {
+    let mut scaled_powers = vec![0; poly_len];
+    for i in 0..poly_len {
+        let wide_val = inp[i] << 32;
+        let quotient = wide_val / (modulus as u64);
+        scaled_powers[i] = (quotient as u32) as u64;
+    }
+    scaled_powers
+}
 
 pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
-    let mut v: Vec<Vec<Vec<u64>>> = Vec::new();
+    let poly_len_log2 = log2(poly_len as u64) as usize;
+    let mut output: Vec<Vec<Vec<u64>>> = vec![Vec::new(); moduli.len()];
     for coeff_mod in 0..moduli.len() {
         let modulus = moduli[coeff_mod];
+        let modulus_as_u32 = modulus.try_into().unwrap();
         let root = get_minimal_primitive_root(2 * poly_len as u64, modulus).unwrap();
-        let inv_root = invert_uint_mod(root, modulus);
+        let inv_root = invert_uint_mod(root, modulus).unwrap();
+
+        let root_powers = powers_of_primitive_root(root, modulus, poly_len_log2);
+        let scaled_root_powers = scale_powers_u32(modulus_as_u32, poly_len, root_powers.as_slice());
+        let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2);        
+        for i in 0..poly_len {
+            inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus);
+        }
+        let mut scaled_inv_root_powers =
+            scale_powers_u32(modulus_as_u32, poly_len, inv_root_powers.as_slice());
+
+        output[coeff_mod] = vec![
+            root_powers,
+            scaled_root_powers,
+            inv_root_powers,
+            scaled_inv_root_powers,
+        ];
     }
-    v
+    output
 }
 
-pub fn ntt_forward(params: Params, out: &mut PolyMatrixRaw, inp: &PolyMatrixRaw) {
+pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
+    let log_n = params.poly_len_log2;
+    let n = 1 << log_n;
+
     for coeff_mod in 0..params.crt_count {
-        let mut n = params.poly_len;
+        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+
+        let forward_table = params.get_ntt_forward_table(coeff_mod);
+        let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
+        let modulus_small = params.moduli[coeff_mod] as u32;
+        let two_times_modulus_small: u32 = 2 * modulus_small;
 
-        for mm in 0..params.poly_len_log2 {
+        for mm in 0..log_n {
             let m = 1 << mm;
             let t = n >> (mm + 1);
 
+            let mut it = operand.chunks_exact_mut(2 * t);
+
             for i in 0..m {
-                let w = params.get_ntt_forward_table(coeff_mod);
-                let wprime = params.get_ntt_forward_prime_table(coeff_mod);
+                let w = forward_table[m+i];
+                let w_prime = forward_table_prime[m+i];
+                
+                let op = it.next().unwrap();
+
+                for j in 0..t {
+                    let x: u32 = op[j] as u32;
+                    let y: u32 = op[t + j] as u32;
+
+                    let currX: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
+
+                    let Q: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
+
+                    let new_Q = w * (y as u64) - Q * (modulus_small as u64);
+
+                    op[j] = currX as u64 + new_Q;
+                    op[t + j] = currX as u64 +  ((two_times_modulus_small as u64) - new_Q);
+                }
+            }
+        }
+
+        for i in 0..n {
+            operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64) * two_times_modulus_small as u64;
+            operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64;
+        }
+    }
+}
+
+pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
+    for coeff_mod in 0..params.crt_count {
+        let mut n = params.poly_len;
+
+        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+
+        let inverse_table = params.get_ntt_inverse_table(coeff_mod);
+        let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
+        let modulus = params.moduli[coeff_mod];
+        let two_times_modulus: u64 = 2 * modulus;
+
+        for mm in (0..params.poly_len_log2).rev() {
+            let h = 1 << mm;
+            let t = n >> (mm + 1);
+
+            for i in 0..h {
+                let w = inverse_table[h+i];
+                let w_prime = inverse_table_prime[h+i];
+
+                for j in 0..t {
+                    let x = operand[2 * i * t + j];
+                    let y = operand[2 * i * t + t + j];
+                    
+                    let T = two_times_modulus - y + x;
+                    let currU = x + y - (two_times_modulus * (((x << 1) >= T) as u64));
+
+                    let resX= (currU + (modulus * ((T & 1) as u64))) >> 1;
+                    let H = (T * w_prime) >> 32;
+
+                    operand[2 * i * t + j] = resX;
+                    operand[2 * i * t + t + j] = w * T - H * modulus;
+                }
             }
         }
+
+        for i in 0..n {
+            operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
+            operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
+        }
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use rand::Rng;
+
+    fn get_params() -> Params {
+        Params::init(2048, vec![268369921u64, 249561089u64])
+    }
+
+    const REF_VAL: u64 = 519370102;
+
+    #[test]
+    fn build_ntt_tables_correct() {
+        let moduli = [268369921u64, 249561089u64];
+        let poly_len = 2048usize;
+        let res = build_ntt_tables(poly_len, moduli.as_slice());
+        assert_eq!(res.len(), 2);
+        assert_eq!(res[0].len(), 4);
+        assert_eq!(res[0][0].len(), poly_len);
+        assert_eq!(res[0][2][0], 134184961u64);
+        assert_eq!(res[0][2][1], 96647580u64);
+        let mut x1 = 0u64;
+        for i in 0..res.len() {
+            for j in 0..res[0].len() {
+                for k in 0..res[0][0].len() {
+                    x1 ^= res[i][j][k];
+                }
+            }
+        }
+        assert_eq!(x1, REF_VAL);
+    }
+
+    #[test]
+    fn ntt_forward_correct() {
+        let params = get_params();
+        let mut v1 = vec![0; 2*2048];
+        v1[0] = 100;
+        v1[2048] = 100;
+        ntt_forward(&params, v1.as_mut_slice());
+        assert_eq!(v1[50], 100);
+        assert_eq!(v1[2048 + 50], 100);
+    }
+
+    #[test]
+    fn ntt_correct() {
+        let params = get_params();
+        let mut v1 = vec![0; params.crt_count * params.poly_len];
+        let mut rng = rand::thread_rng();
+        for i in 0..params.crt_count {
+            for j in 0..params.poly_len {
+                let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
+                let val: u64 = rng.gen();
+                v1[idx] = val % params.moduli[i];
+            }
+        }
+        let mut v2 = v1.clone();
+        ntt_forward(&params, v2.as_mut_slice());
+        ntt_inverse(&params, v2.as_mut_slice());
+        for i in 0..params.crt_count*params.poly_len {
+            assert_eq!(v1[i], v2[i]);
+        }
+    }
+
+    #[test]
+    fn calc_index_correct() {
+        assert_eq!(calc_index(&[2, 3, 4], &[10, 10, 100]), 2304);
+        assert_eq!(calc_index(&[2, 3, 4], &[3, 5, 7]), 95);
+    }
+}

+ 4 - 4
src/number_theory.rs

@@ -14,7 +14,7 @@ pub fn is_primitive_root(root: u64, degree: u64, modulus: u64) -> bool {
 pub fn get_primitive_root(degree: u64, modulus: u64) -> Option<u64> {
     assert!(modulus > 1);
     assert!(degree >= 2);
-    let size_entire_group = degree - 1;
+    let size_entire_group = modulus - 1;
     let size_quotient_group = size_entire_group / degree;
     if size_entire_group - size_quotient_group * degree != 0 {
         return None;
@@ -30,7 +30,7 @@ pub fn get_primitive_root(degree: u64, modulus: u64) -> Option<u64> {
         if is_primitive_root(root, degree, modulus) {
             break;
         }
-        if trial != ATTEMPT_MAX - 1 {
+        if trial == ATTEMPT_MAX - 1 {
             return None;
         }
     }
@@ -51,7 +51,7 @@ pub fn get_minimal_primitive_root(degree: u64, modulus: u64) -> Option<u64> {
         current_generator = multiply_uint_mod(current_generator, generator_sq, modulus);
     }
 
-    Some(current_generator)
+    Some(root)
 }
 
 pub fn extended_gcd(mut x: u64, mut y: u64) -> (u64, i64, i64) {
@@ -89,7 +89,7 @@ pub fn invert_uint_mod(value: u64, modulus: u64) -> Option<u64> {
     if gcd_tuple.0 != 1 {
         return None;
     } else if gcd_tuple.1 < 0 {
-        return Some(gcd_tuple.1 as u64 + modulus);
+        return Some((gcd_tuple.1 as u64).overflowing_add(modulus).0);
     } else {
         return Some(gcd_tuple.1 as u64);
     }

+ 6 - 0
src/params.rs

@@ -18,6 +18,12 @@ impl Params {
     pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] {
         self.ntt_tables[i][1].as_slice()
     }
+    pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] {
+        self.ntt_tables[i][2].as_slice()
+    }
+    pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] {
+        self.ntt_tables[i][3].as_slice()
+    }
 
     pub fn init(poly_len: usize, moduli: Vec<u64>) -> Self {
         let poly_len_log2 = log2(poly_len as u64) as usize;

+ 9 - 0
src/util.rs

@@ -0,0 +1,9 @@
+pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
+    let mut idx = 0usize;
+    let mut prod = 1usize;
+    for i in (0..indices.len()).rev() {
+        idx += indices[i] * prod;
+        prod *= lengths[i];
+    }
+    idx
+}