poly.rs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use rand::distributions::Standard;
  4. use rand::Rng;
  5. use std::cell::RefCell;
  6. use std::ops::{Add, Mul, Neg};
  7. use crate::{aligned_memory::*, arith::*, discrete_gaussian::*, ntt::*, params::*, util::*};
  8. const SCRATCH_SPACE: usize = 8192;
  9. thread_local!(static SCRATCH: RefCell<AlignedMemory64> = RefCell::new(AlignedMemory64::new(SCRATCH_SPACE)));
  10. pub trait PolyMatrix<'a> {
  11. fn is_ntt(&self) -> bool;
  12. fn get_rows(&self) -> usize;
  13. fn get_cols(&self) -> usize;
  14. fn get_params(&self) -> &Params;
  15. fn num_words(&self) -> usize;
  16. fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
  17. fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
  18. fn random_rng<T: Rng>(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self;
  19. fn as_slice(&self) -> &[u64];
  20. fn as_mut_slice(&mut self) -> &mut [u64];
  21. fn zero_out(&mut self) {
  22. for item in self.as_mut_slice() {
  23. *item = 0;
  24. }
  25. }
  26. fn get_poly(&self, row: usize, col: usize) -> &[u64] {
  27. let num_words = self.num_words();
  28. let start = (row * self.get_cols() + col) * num_words;
  29. &self.as_slice()[start..start + num_words]
  30. }
  31. fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
  32. let num_words = self.num_words();
  33. let start = (row * self.get_cols() + col) * num_words;
  34. &mut self.as_mut_slice()[start..start + num_words]
  35. }
  36. fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
  37. assert!(target_row < self.get_rows());
  38. assert!(target_col < self.get_cols());
  39. assert!(target_row + p.get_rows() <= self.get_rows());
  40. assert!(target_col + p.get_cols() <= self.get_cols());
  41. for r in 0..p.get_rows() {
  42. for c in 0..p.get_cols() {
  43. let pol_src = p.get_poly(r, c);
  44. let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
  45. pol_dst.copy_from_slice(pol_src);
  46. }
  47. }
  48. }
  49. fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self;
  50. fn pad_top(&self, pad_rows: usize) -> Self;
  51. }
  52. pub struct PolyMatrixRaw<'a> {
  53. pub params: &'a Params,
  54. pub rows: usize,
  55. pub cols: usize,
  56. pub data: AlignedMemory64,
  57. }
  58. pub struct PolyMatrixNTT<'a> {
  59. pub params: &'a Params,
  60. pub rows: usize,
  61. pub cols: usize,
  62. pub data: AlignedMemory64,
  63. }
  64. impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
  65. fn is_ntt(&self) -> bool {
  66. false
  67. }
  68. fn get_rows(&self) -> usize {
  69. self.rows
  70. }
  71. fn get_cols(&self) -> usize {
  72. self.cols
  73. }
  74. fn get_params(&self) -> &Params {
  75. &self.params
  76. }
  77. fn as_slice(&self) -> &[u64] {
  78. self.data.as_slice()
  79. }
  80. fn as_mut_slice(&mut self) -> &mut [u64] {
  81. self.data.as_mut_slice()
  82. }
  83. fn num_words(&self) -> usize {
  84. self.params.poly_len
  85. }
  86. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  87. let num_coeffs = rows * cols * params.poly_len;
  88. let data = AlignedMemory64::new(num_coeffs);
  89. PolyMatrixRaw {
  90. params,
  91. rows,
  92. cols,
  93. data,
  94. }
  95. }
  96. fn random_rng<T: Rng>(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self {
  97. let mut iter = rng.sample_iter(&Standard);
  98. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  99. for r in 0..rows {
  100. for c in 0..cols {
  101. for i in 0..params.poly_len {
  102. let val: u64 = iter.next().unwrap();
  103. out.get_poly_mut(r, c)[i] = val % params.modulus;
  104. }
  105. }
  106. }
  107. out
  108. }
  109. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  110. let mut rng = rand::thread_rng();
  111. Self::random_rng(params, rows, cols, &mut rng)
  112. }
  113. fn pad_top(&self, pad_rows: usize) -> Self {
  114. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  115. padded.copy_into(&self, pad_rows, 0);
  116. padded
  117. }
  118. fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self {
  119. let mut m = Self::zero(self.params, rows, cols);
  120. assert!(target_row < self.rows);
  121. assert!(target_col < self.cols);
  122. assert!(target_row + rows <= self.rows);
  123. assert!(target_col + cols <= self.cols);
  124. for r in 0..rows {
  125. for c in 0..cols {
  126. let pol_src = self.get_poly(target_row + r, target_col + c);
  127. let pol_dst = m.get_poly_mut(r, c);
  128. pol_dst.copy_from_slice(pol_src);
  129. }
  130. }
  131. m
  132. }
  133. }
  134. impl<'a> Clone for PolyMatrixRaw<'a> {
  135. fn clone(&self) -> Self {
  136. let mut data_clone = AlignedMemory64::new(self.data.len());
  137. data_clone
  138. .as_mut_slice()
  139. .copy_from_slice(self.data.as_slice());
  140. PolyMatrixRaw {
  141. params: self.params,
  142. rows: self.rows,
  143. cols: self.cols,
  144. data: data_clone,
  145. }
  146. }
  147. }
  148. impl<'a> PolyMatrixRaw<'a> {
  149. pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  150. let num_coeffs = rows * cols * params.poly_len;
  151. let mut data = AlignedMemory::new(num_coeffs);
  152. for r in 0..rows {
  153. let c = r;
  154. let idx = r * cols * params.poly_len + c * params.poly_len;
  155. data[idx] = 1;
  156. }
  157. PolyMatrixRaw {
  158. params,
  159. rows,
  160. cols,
  161. data,
  162. }
  163. }
  164. pub fn noise<T: Rng + Send>(
  165. params: &'a Params,
  166. rows: usize,
  167. cols: usize,
  168. dg: &DiscreteGaussian<T>,
  169. ) -> Self {
  170. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  171. dg.sample_matrix(&mut out);
  172. out
  173. }
  174. pub fn ntt(&self) -> PolyMatrixNTT<'a> {
  175. to_ntt_alloc(&self)
  176. }
  177. pub fn reduce_mod(&mut self, modulus: u64) {
  178. for r in 0..self.rows {
  179. for c in 0..self.cols {
  180. for z in 0..self.params.poly_len {
  181. self.get_poly_mut(r, c)[z] %= modulus;
  182. }
  183. }
  184. }
  185. }
  186. pub fn apply_func<F: Fn(u64) -> u64>(&mut self, func: F) {
  187. for r in 0..self.rows {
  188. for c in 0..self.cols {
  189. let pol_mut = self.get_poly_mut(r, c);
  190. for el in pol_mut {
  191. *el = func(*el);
  192. }
  193. }
  194. }
  195. }
  196. pub fn to_vec(&self, modulus_bits: usize, num_coeffs: usize) -> Vec<u8> {
  197. let sz_bits = self.rows * self.cols * num_coeffs * modulus_bits;
  198. let sz_bytes = f64::ceil((sz_bits as f64) / 8f64) as usize + 32;
  199. let sz_bytes_roundup_16 = ((sz_bytes + 15) / 16) * 16;
  200. let mut data = vec![0u8; sz_bytes_roundup_16];
  201. let mut bit_offs = 0;
  202. for r in 0..self.rows {
  203. for c in 0..self.cols {
  204. for z in 0..num_coeffs {
  205. write_arbitrary_bits(
  206. data.as_mut_slice(),
  207. self.get_poly(r, c)[z],
  208. bit_offs,
  209. modulus_bits,
  210. );
  211. bit_offs += modulus_bits;
  212. }
  213. // round bit_offs down to nearest byte boundary
  214. bit_offs = (bit_offs / 8) * 8
  215. }
  216. }
  217. data
  218. }
  219. pub fn single_value(params: &'a Params, value: u64) -> PolyMatrixRaw<'a> {
  220. let mut out = Self::zero(params, 1, 1);
  221. out.data[0] = value;
  222. out
  223. }
  224. }
  225. impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
  226. fn is_ntt(&self) -> bool {
  227. true
  228. }
  229. fn get_rows(&self) -> usize {
  230. self.rows
  231. }
  232. fn get_cols(&self) -> usize {
  233. self.cols
  234. }
  235. fn get_params(&self) -> &Params {
  236. &self.params
  237. }
  238. fn as_slice(&self) -> &[u64] {
  239. self.data.as_slice()
  240. }
  241. fn as_mut_slice(&mut self) -> &mut [u64] {
  242. self.data.as_mut_slice()
  243. }
  244. fn num_words(&self) -> usize {
  245. self.params.poly_len * self.params.crt_count
  246. }
  247. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
  248. let num_coeffs = rows * cols * params.poly_len * params.crt_count;
  249. let data = AlignedMemory::new(num_coeffs);
  250. PolyMatrixNTT {
  251. params,
  252. rows,
  253. cols,
  254. data,
  255. }
  256. }
  257. fn random_rng<T: Rng>(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self {
  258. let mut iter = rng.sample_iter(&Standard);
  259. let mut out = PolyMatrixNTT::zero(params, rows, cols);
  260. for r in 0..rows {
  261. for c in 0..cols {
  262. for i in 0..params.crt_count {
  263. for j in 0..params.poly_len {
  264. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  265. let val: u64 = iter.next().unwrap();
  266. out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
  267. }
  268. }
  269. }
  270. }
  271. out
  272. }
  273. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  274. let mut rng = rand::thread_rng();
  275. Self::random_rng(params, rows, cols, &mut rng)
  276. }
  277. fn pad_top(&self, pad_rows: usize) -> Self {
  278. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  279. padded.copy_into(&self, pad_rows, 0);
  280. padded
  281. }
  282. fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self {
  283. let mut m = Self::zero(self.params, rows, cols);
  284. assert!(target_row < self.rows);
  285. assert!(target_col < self.cols);
  286. assert!(target_row + rows <= self.rows);
  287. assert!(target_col + cols <= self.cols);
  288. for r in 0..rows {
  289. for c in 0..cols {
  290. let pol_src = self.get_poly(target_row + r, target_col + c);
  291. let pol_dst = m.get_poly_mut(r, c);
  292. pol_dst.copy_from_slice(pol_src);
  293. }
  294. }
  295. m
  296. }
  297. }
  298. impl<'a> Clone for PolyMatrixNTT<'a> {
  299. fn clone(&self) -> Self {
  300. let mut data_clone = AlignedMemory64::new(self.data.len());
  301. data_clone
  302. .as_mut_slice()
  303. .copy_from_slice(self.data.as_slice());
  304. PolyMatrixNTT {
  305. params: self.params,
  306. rows: self.rows,
  307. cols: self.cols,
  308. data: data_clone,
  309. }
  310. }
  311. }
  312. impl<'a> PolyMatrixNTT<'a> {
  313. pub fn raw(&self) -> PolyMatrixRaw<'a> {
  314. from_ntt_alloc(&self)
  315. }
  316. }
  317. pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  318. for c in 0..params.crt_count {
  319. for i in 0..params.poly_len {
  320. let idx = c * params.poly_len + i;
  321. res[idx] = multiply_modular(params, a[idx], b[idx], c);
  322. }
  323. }
  324. }
  325. pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  326. for c in 0..params.crt_count {
  327. for i in 0..params.poly_len {
  328. let idx = c * params.poly_len + i;
  329. res[idx] = multiply_add_modular(params, a[idx], b[idx], res[idx], c);
  330. }
  331. }
  332. }
  333. pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  334. for c in 0..params.crt_count {
  335. for i in 0..params.poly_len {
  336. let idx = c * params.poly_len + i;
  337. res[idx] = add_modular(params, a[idx], b[idx], c);
  338. }
  339. }
  340. }
  341. pub fn add_poly_into(params: &Params, res: &mut [u64], a: &[u64]) {
  342. for c in 0..params.crt_count {
  343. for i in 0..params.poly_len {
  344. let idx = c * params.poly_len + i;
  345. res[idx] = add_modular(params, res[idx], a[idx], c);
  346. }
  347. }
  348. }
  349. pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
  350. for i in 0..params.poly_len {
  351. res[i] = params.modulus - a[i];
  352. }
  353. }
  354. pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
  355. let poly_len = params.poly_len;
  356. for i in 0..poly_len {
  357. let num = (i * t) / poly_len;
  358. let rem = (i * t) % poly_len;
  359. if num % 2 == 0 {
  360. res[rem] = a[i];
  361. } else {
  362. res[rem] = params.modulus - a[i];
  363. }
  364. }
  365. }
  366. #[cfg(target_feature = "avx2")]
  367. pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  368. for c in 0..params.crt_count {
  369. for i in (0..params.poly_len).step_by(4) {
  370. unsafe {
  371. let p_x = &a[c * params.poly_len + i] as *const u64;
  372. let p_y = &b[c * params.poly_len + i] as *const u64;
  373. let p_z = &mut res[c * params.poly_len + i] as *mut u64;
  374. let x = _mm256_load_si256(p_x as *const __m256i);
  375. let y = _mm256_load_si256(p_y as *const __m256i);
  376. let z = _mm256_load_si256(p_z as *const __m256i);
  377. let product = _mm256_mul_epu32(x, y);
  378. let out = _mm256_add_epi64(z, product);
  379. _mm256_store_si256(p_z as *mut __m256i, out);
  380. }
  381. }
  382. }
  383. }
  384. pub fn modular_reduce(params: &Params, res: &mut [u64]) {
  385. for c in 0..params.crt_count {
  386. for i in 0..params.poly_len {
  387. let idx = c * params.poly_len + i;
  388. res[idx] = barrett_coeff_u64(params, res[idx], c);
  389. }
  390. }
  391. }
  392. #[cfg(not(target_feature = "avx2"))]
  393. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  394. assert!(res.rows == a.rows);
  395. assert!(res.cols == b.cols);
  396. assert!(a.cols == b.rows);
  397. let params = res.params;
  398. for i in 0..a.rows {
  399. for j in 0..b.cols {
  400. for z in 0..params.poly_len * params.crt_count {
  401. res.get_poly_mut(i, j)[z] = 0;
  402. }
  403. for k in 0..a.cols {
  404. let params = res.params;
  405. let res_poly = res.get_poly_mut(i, j);
  406. let pol1 = a.get_poly(i, k);
  407. let pol2 = b.get_poly(k, j);
  408. multiply_add_poly(params, res_poly, pol1, pol2);
  409. }
  410. }
  411. }
  412. }
  413. #[cfg(target_feature = "avx2")]
  414. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  415. assert_eq!(res.rows, a.rows);
  416. assert_eq!(res.cols, b.cols);
  417. assert_eq!(a.cols, b.rows);
  418. let params = res.params;
  419. for i in 0..a.rows {
  420. for j in 0..b.cols {
  421. for z in 0..params.poly_len * params.crt_count {
  422. res.get_poly_mut(i, j)[z] = 0;
  423. }
  424. let res_poly = res.get_poly_mut(i, j);
  425. for k in 0..a.cols {
  426. let pol1 = a.get_poly(i, k);
  427. let pol2 = b.get_poly(k, j);
  428. multiply_add_poly_avx(params, res_poly, pol1, pol2);
  429. }
  430. modular_reduce(params, res_poly);
  431. }
  432. }
  433. }
  434. pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  435. assert!(res.rows == a.rows);
  436. assert!(res.cols == a.cols);
  437. assert!(a.rows == b.rows);
  438. assert!(a.cols == b.cols);
  439. let params = res.params;
  440. for i in 0..a.rows {
  441. for j in 0..a.cols {
  442. let res_poly = res.get_poly_mut(i, j);
  443. let pol1 = a.get_poly(i, j);
  444. let pol2 = b.get_poly(i, j);
  445. add_poly(params, res_poly, pol1, pol2);
  446. }
  447. }
  448. }
  449. pub fn add_into(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT) {
  450. assert!(res.rows == a.rows);
  451. assert!(res.cols == a.cols);
  452. let params = res.params;
  453. for i in 0..res.rows {
  454. for j in 0..res.cols {
  455. let res_poly = res.get_poly_mut(i, j);
  456. let pol2 = a.get_poly(i, j);
  457. add_poly_into(params, res_poly, pol2);
  458. }
  459. }
  460. }
  461. pub fn add_into_at(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, t_row: usize, t_col: usize) {
  462. let params = res.params;
  463. for i in 0..a.rows {
  464. for j in 0..a.cols {
  465. let res_poly = res.get_poly_mut(t_row + i, t_col + j);
  466. let pol2 = a.get_poly(i, j);
  467. add_poly_into(params, res_poly, pol2);
  468. }
  469. }
  470. }
  471. pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
  472. assert!(res.rows == a.rows);
  473. assert!(res.cols == a.cols);
  474. let params = res.params;
  475. for i in 0..a.rows {
  476. for j in 0..a.cols {
  477. let res_poly = res.get_poly_mut(i, j);
  478. let pol1 = a.get_poly(i, j);
  479. invert_poly(params, res_poly, pol1);
  480. }
  481. }
  482. }
  483. pub fn automorph<'a>(res: &mut PolyMatrixRaw<'a>, a: &PolyMatrixRaw<'a>, t: usize) {
  484. assert!(res.rows == a.rows);
  485. assert!(res.cols == a.cols);
  486. let params = res.params;
  487. for i in 0..a.rows {
  488. for j in 0..a.cols {
  489. let res_poly = res.get_poly_mut(i, j);
  490. let pol1 = a.get_poly(i, j);
  491. automorph_poly(params, res_poly, pol1, t);
  492. }
  493. }
  494. }
  495. pub fn automorph_alloc<'a>(a: &PolyMatrixRaw<'a>, t: usize) -> PolyMatrixRaw<'a> {
  496. let mut res = PolyMatrixRaw::zero(a.params, a.rows, a.cols);
  497. automorph(&mut res, a, t);
  498. res
  499. }
  500. pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  501. assert_eq!(a.cols, b.cols);
  502. let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
  503. c.copy_into(a, 0, 0);
  504. c.copy_into(b, a.rows, 0);
  505. c
  506. }
  507. pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  508. assert_eq!(a.rows, 1);
  509. assert_eq!(a.cols, 1);
  510. let params = res.params;
  511. let pol2 = a.get_poly(0, 0);
  512. for i in 0..b.rows {
  513. for j in 0..b.cols {
  514. let res_poly = res.get_poly_mut(i, j);
  515. let pol1 = b.get_poly(i, j);
  516. multiply_poly(params, res_poly, pol1, pol2);
  517. }
  518. }
  519. }
  520. pub fn scalar_multiply_alloc<'a>(
  521. a: &PolyMatrixNTT<'a>,
  522. b: &PolyMatrixNTT<'a>,
  523. ) -> PolyMatrixNTT<'a> {
  524. let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  525. scalar_multiply(&mut res, a, b);
  526. res
  527. }
  528. pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
  529. let mut res = PolyMatrixRaw::zero(params, 1, 1);
  530. res.get_poly_mut(0, 0)[0] = val;
  531. res
  532. }
  533. fn reduce_copy(params: &Params, out: &mut [u64], inp: &[u64]) {
  534. for n in 0..params.crt_count {
  535. for z in 0..params.poly_len {
  536. out[n * params.poly_len + z] = barrett_coeff_u64(params, inp[z], n);
  537. }
  538. }
  539. }
  540. pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
  541. let params = a.params;
  542. for r in 0..a.rows {
  543. for c in 0..a.cols {
  544. let pol_src = b.get_poly(r, c);
  545. let pol_dst = a.get_poly_mut(r, c);
  546. reduce_copy(params, pol_dst, pol_src);
  547. ntt_forward(params, pol_dst);
  548. }
  549. }
  550. }
  551. pub fn to_ntt_no_reduce(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
  552. let params = a.params;
  553. for r in 0..a.rows {
  554. for c in 0..a.cols {
  555. let pol_src = b.get_poly(r, c);
  556. let pol_dst = a.get_poly_mut(r, c);
  557. for n in 0..params.crt_count {
  558. let idx = n * params.poly_len;
  559. pol_dst[idx..idx + params.poly_len].copy_from_slice(pol_src);
  560. }
  561. ntt_forward(params, pol_dst);
  562. }
  563. }
  564. }
  565. pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
  566. let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  567. to_ntt(&mut a, b);
  568. a
  569. }
  570. pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) {
  571. let params = a.params;
  572. SCRATCH.with(|scratch_cell| {
  573. let scratch_vec = &mut *scratch_cell.borrow_mut();
  574. let scratch = scratch_vec.as_mut_slice();
  575. for r in 0..a.rows {
  576. for c in 0..a.cols {
  577. let pol_src = b.get_poly(r, c);
  578. let pol_dst = a.get_poly_mut(r, c);
  579. scratch[0..pol_src.len()].copy_from_slice(pol_src);
  580. ntt_inverse(params, scratch);
  581. for z in 0..params.poly_len {
  582. pol_dst[z] = params.crt_compose(scratch, z);
  583. }
  584. }
  585. }
  586. });
  587. }
  588. pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> {
  589. let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols);
  590. from_ntt(&mut a, b);
  591. a
  592. }
  593. impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> {
  594. type Output = PolyMatrixRaw<'a>;
  595. fn neg(self) -> Self::Output {
  596. let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols);
  597. invert(&mut out, self);
  598. out
  599. }
  600. }
  601. impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> {
  602. type Output = PolyMatrixNTT<'a>;
  603. fn mul(self, rhs: Self) -> Self::Output {
  604. let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
  605. multiply(&mut out, self, rhs);
  606. out
  607. }
  608. }
  609. impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> {
  610. type Output = PolyMatrixNTT<'a>;
  611. fn add(self, rhs: Self) -> Self::Output {
  612. let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols);
  613. add(&mut out, self, rhs);
  614. out
  615. }
  616. }
  617. #[cfg(test)]
  618. mod test {
  619. use super::*;
  620. fn get_params() -> Params {
  621. get_test_params()
  622. }
  623. fn assert_all_zero(a: &[u64]) {
  624. for i in a {
  625. assert_eq!(*i, 0);
  626. }
  627. }
  628. #[test]
  629. fn sets_all_zeros() {
  630. let params = get_params();
  631. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  632. assert_all_zero(m1.as_slice());
  633. }
  634. #[test]
  635. fn multiply_correctness() {
  636. let params = get_params();
  637. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  638. let m2 = PolyMatrixNTT::zero(&params, 3, 2);
  639. let m3 = &m2 * &m1;
  640. assert_all_zero(m3.as_slice());
  641. }
  642. #[test]
  643. fn full_multiply_correctness() {
  644. let params = get_params();
  645. let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
  646. let mut m2 = PolyMatrixRaw::zero(&params, 1, 1);
  647. m1.get_poly_mut(0, 0)[1] = 100;
  648. m2.get_poly_mut(0, 0)[1] = 7;
  649. let m1_ntt = to_ntt_alloc(&m1);
  650. let m2_ntt = to_ntt_alloc(&m2);
  651. let m3_ntt = &m1_ntt * &m2_ntt;
  652. let m3 = from_ntt_alloc(&m3_ntt);
  653. assert_eq!(m3.get_poly(0, 0)[2], 700);
  654. }
  655. #[test]
  656. fn to_vec_correctness() {
  657. let params = get_params();
  658. let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
  659. for i in 0..params.poly_len {
  660. m1.data[i] = 1;
  661. }
  662. let modulus_bits = 9;
  663. let v = m1.to_vec(modulus_bits, params.poly_len);
  664. for i in 0..v.len() {
  665. println!("{:?}", v[i]);
  666. }
  667. let mut bit_offs = 0;
  668. for i in 0..params.poly_len {
  669. let val = read_arbitrary_bits(v.as_slice(), bit_offs, modulus_bits);
  670. assert_eq!(m1.data[i], val);
  671. bit_offs += modulus_bits;
  672. }
  673. }
  674. }