poly.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. use std::arch::x86_64::*;
  2. use std::ops::{Add, Mul, Neg};
  3. use std::cell::RefCell;
  4. use rand::Rng;
  5. use rand::distributions::Standard;
  6. use crate::{arith::*, params::*, ntt::*, util::*, discrete_gaussian::*};
  7. const SCRATCH_SPACE: usize = 8192;
  8. thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
  9. pub trait PolyMatrix<'a> {
  10. fn is_ntt(&self) -> bool;
  11. fn get_rows(&self) -> usize;
  12. fn get_cols(&self) -> usize;
  13. fn get_params(&self) -> &Params;
  14. fn num_words(&self) -> usize;
  15. fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
  16. fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
  17. fn as_slice(&self) -> &[u64];
  18. fn as_mut_slice(&mut self) -> &mut [u64];
  19. fn zero_out(&mut self) {
  20. for item in self.as_mut_slice() {
  21. *item = 0;
  22. }
  23. }
  24. fn get_poly(&self, row: usize, col: usize) -> &[u64] {
  25. let num_words = self.num_words();
  26. let start = (row * self.get_cols() + col) * num_words;
  27. &self.as_slice()[start..start + num_words]
  28. }
  29. fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
  30. let num_words = self.num_words();
  31. let start = (row * self.get_cols() + col) * num_words;
  32. &mut self.as_mut_slice()[start..start + num_words]
  33. }
  34. fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
  35. assert!(target_row < self.get_rows());
  36. assert!(target_col < self.get_cols());
  37. assert!(target_row + p.get_rows() < self.get_rows());
  38. assert!(target_col + p.get_cols() < self.get_cols());
  39. for r in 0..p.get_rows() {
  40. for c in 0..p.get_cols() {
  41. let pol_src = p.get_poly(r, c);
  42. let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
  43. pol_dst.copy_from_slice(pol_src);
  44. }
  45. }
  46. }
  47. fn pad_top(&self, pad_rows: usize) -> Self;
  48. }
  49. pub struct PolyMatrixRaw<'a> {
  50. pub params: &'a Params,
  51. pub rows: usize,
  52. pub cols: usize,
  53. pub data: Vec<u64>,
  54. }
  55. pub struct PolyMatrixNTT<'a> {
  56. pub params: &'a Params,
  57. pub rows: usize,
  58. pub cols: usize,
  59. pub data: Vec<u64>,
  60. }
  61. impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
  62. fn is_ntt(&self) -> bool {
  63. false
  64. }
  65. fn get_rows(&self) -> usize {
  66. self.rows
  67. }
  68. fn get_cols(&self) -> usize {
  69. self.cols
  70. }
  71. fn get_params(&self) -> &Params {
  72. &self.params
  73. }
  74. fn as_slice(&self) -> &[u64] {
  75. self.data.as_slice()
  76. }
  77. fn as_mut_slice(&mut self) -> &mut [u64] {
  78. self.data.as_mut_slice()
  79. }
  80. fn num_words(&self) -> usize {
  81. self.params.poly_len
  82. }
  83. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  84. let num_coeffs = rows * cols * params.poly_len;
  85. let data: Vec<u64> = vec![0; num_coeffs];
  86. PolyMatrixRaw {
  87. params,
  88. rows,
  89. cols,
  90. data,
  91. }
  92. }
  93. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  94. let rng = rand::thread_rng();
  95. let mut iter = rng.sample_iter(&Standard);
  96. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  97. for r in 0..rows {
  98. for c in 0..cols {
  99. for i in 0..params.poly_len {
  100. let val: u64 = iter.next().unwrap();
  101. out.get_poly_mut(r, c)[i] = val % params.modulus;
  102. }
  103. }
  104. }
  105. out
  106. }
  107. fn pad_top(&self, pad_rows: usize) -> Self {
  108. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  109. padded.copy_into(&self, pad_rows, 0);
  110. padded
  111. }
  112. }
  113. impl<'a> PolyMatrixRaw<'a> {
  114. pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  115. let num_coeffs = rows * cols * params.poly_len;
  116. let mut data: Vec<u64> = vec![0; num_coeffs];
  117. for r in 0..rows {
  118. let c = r;
  119. let idx = r * cols * params.poly_len + c * params.poly_len;
  120. data[idx] = 1;
  121. }
  122. PolyMatrixRaw {
  123. params,
  124. rows,
  125. cols,
  126. data,
  127. }
  128. }
  129. pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &mut DiscreteGaussian) -> Self {
  130. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  131. dg.sample_matrix(&mut out);
  132. out
  133. }
  134. pub fn ntt(&self) -> PolyMatrixNTT<'a> {
  135. to_ntt_alloc(&self)
  136. }
  137. }
  138. impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
  139. fn is_ntt(&self) -> bool {
  140. true
  141. }
  142. fn get_rows(&self) -> usize {
  143. self.rows
  144. }
  145. fn get_cols(&self) -> usize {
  146. self.cols
  147. }
  148. fn get_params(&self) -> &Params {
  149. &self.params
  150. }
  151. fn as_slice(&self) -> &[u64] {
  152. self.data.as_slice()
  153. }
  154. fn as_mut_slice(&mut self) -> &mut [u64] {
  155. self.data.as_mut_slice()
  156. }
  157. fn num_words(&self) -> usize {
  158. self.params.poly_len * self.params.crt_count
  159. }
  160. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
  161. let num_coeffs = rows * cols * params.poly_len * params.crt_count;
  162. let data: Vec<u64> = vec![0; num_coeffs];
  163. PolyMatrixNTT {
  164. params,
  165. rows,
  166. cols,
  167. data,
  168. }
  169. }
  170. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  171. let rng = rand::thread_rng();
  172. let mut iter = rng.sample_iter(&Standard);
  173. let mut out = PolyMatrixNTT::zero(params, rows, cols);
  174. for r in 0..rows {
  175. for c in 0..cols {
  176. for i in 0..params.crt_count {
  177. for j in 0..params.poly_len {
  178. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  179. let val: u64 = iter.next().unwrap();
  180. out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
  181. }
  182. }
  183. }
  184. }
  185. out
  186. }
  187. fn pad_top(&self, pad_rows: usize) -> Self {
  188. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  189. padded.copy_into(&self, pad_rows, 0);
  190. padded
  191. }
  192. }
  193. impl<'a> PolyMatrixNTT<'a> {
  194. pub fn raw(&self) -> PolyMatrixRaw<'a> {
  195. from_ntt_alloc(&self)
  196. }
  197. }
  198. pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  199. for c in 0..params.crt_count {
  200. for i in 0..params.poly_len {
  201. res[i] = multiply_modular(params, a[i], b[i], c);
  202. }
  203. }
  204. }
  205. pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  206. for c in 0..params.crt_count {
  207. for i in 0..params.poly_len {
  208. res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
  209. }
  210. }
  211. }
  212. pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  213. for c in 0..params.crt_count {
  214. for i in 0..params.poly_len {
  215. res[i] = add_modular(params, a[i], b[i], c);
  216. }
  217. }
  218. }
  219. pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
  220. for c in 0..params.crt_count {
  221. for i in 0..params.poly_len {
  222. res[i] = invert_modular(params, a[i], c);
  223. }
  224. }
  225. }
  226. pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
  227. let poly_len = params.poly_len;
  228. for i in 0..poly_len {
  229. let num = (i * t) / poly_len;
  230. let rem = (i * t) % poly_len;
  231. if num % 2 == 0 {
  232. res[rem] = a[i];
  233. } else {
  234. res[rem] = params.modulus - a[i];
  235. }
  236. }
  237. }
  238. pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  239. for c in 0..params.crt_count {
  240. for i in (0..params.poly_len).step_by(4) {
  241. unsafe {
  242. let p_x = &a[c*params.poly_len + i] as *const u64;
  243. let p_y = &b[c*params.poly_len + i] as *const u64;
  244. let p_z = &mut res[c*params.poly_len + i] as *mut u64;
  245. let x = _mm256_loadu_si256(p_x as *const __m256i);
  246. let y = _mm256_loadu_si256(p_y as *const __m256i);
  247. let z = _mm256_loadu_si256(p_z as *const __m256i);
  248. let product = _mm256_mul_epu32(x, y);
  249. let out = _mm256_add_epi64(z, product);
  250. _mm256_storeu_si256(p_z as *mut __m256i, out);
  251. }
  252. }
  253. }
  254. }
  255. pub fn modular_reduce(params: &Params, res: &mut [u64]) {
  256. for c in 0..params.crt_count {
  257. for i in 0..params.poly_len {
  258. res[c*params.poly_len + i] %= params.moduli[c];
  259. }
  260. }
  261. }
  262. #[cfg(not(target_feature = "avx2"))]
  263. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  264. assert!(a.cols == b.rows);
  265. for i in 0..a.rows {
  266. for j in 0..b.cols {
  267. for z in 0..res.params.poly_len {
  268. res.get_poly_mut(i, j)[z] = 0;
  269. }
  270. for k in 0..a.cols {
  271. let params = res.params;
  272. let res_poly = res.get_poly_mut(i, j);
  273. let pol1 = a.get_poly(i, k);
  274. let pol2 = b.get_poly(k, j);
  275. multiply_add_poly(params, res_poly, pol1, pol2);
  276. }
  277. }
  278. }
  279. }
  280. #[cfg(target_feature = "avx2")]
  281. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  282. assert!(res.rows == a.rows);
  283. assert!(res.cols == b.cols);
  284. assert!(a.cols == b.rows);
  285. let params = res.params;
  286. for i in 0..a.rows {
  287. for j in 0..b.cols {
  288. for z in 0..res.params.poly_len {
  289. res.get_poly_mut(i, j)[z] = 0;
  290. }
  291. let res_poly = res.get_poly_mut(i, j);
  292. for k in 0..a.cols {
  293. let pol1 = a.get_poly(i, k);
  294. let pol2 = b.get_poly(k, j);
  295. multiply_add_poly_avx(params, res_poly, pol1, pol2);
  296. }
  297. modular_reduce(params, res_poly);
  298. }
  299. }
  300. }
  301. pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  302. assert!(res.rows == a.rows);
  303. assert!(res.cols == a.cols);
  304. assert!(a.rows == b.rows);
  305. assert!(a.cols == b.cols);
  306. let params = res.params;
  307. for i in 0..a.rows {
  308. for j in 0..a.cols {
  309. let res_poly = res.get_poly_mut(i, j);
  310. let pol1 = a.get_poly(i, j);
  311. let pol2 = b.get_poly(i, j);
  312. add_poly(params, res_poly, pol1, pol2);
  313. }
  314. }
  315. }
  316. pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
  317. assert!(res.rows == a.rows);
  318. assert!(res.cols == a.cols);
  319. let params = res.params;
  320. for i in 0..a.rows {
  321. for j in 0..a.cols {
  322. let res_poly = res.get_poly_mut(i, j);
  323. let pol1 = a.get_poly(i, j);
  324. invert_poly(params, res_poly, pol1);
  325. }
  326. }
  327. }
  328. pub fn automorph<'a>(res: &mut PolyMatrixRaw<'a>, a: &PolyMatrixRaw<'a>, t: usize) {
  329. assert!(res.rows == a.rows);
  330. assert!(res.cols == a.cols);
  331. let params = res.params;
  332. for i in 0..a.rows {
  333. for j in 0..a.cols {
  334. let res_poly = res.get_poly_mut(i, j);
  335. let pol1 = a.get_poly(i, j);
  336. automorph_poly(params, res_poly, pol1, t);
  337. }
  338. }
  339. }
  340. pub fn automorph_alloc<'a>(a: &PolyMatrixRaw<'a>, t: usize) -> PolyMatrixRaw<'a> {
  341. let mut res = PolyMatrixRaw::zero(a.params, a.rows, a.cols);
  342. automorph(&mut res, a, t);
  343. res
  344. }
  345. pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  346. assert_eq!(a.cols, b.cols);
  347. let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
  348. c.copy_into(a, 0, 0);
  349. c.copy_into(b, a.rows, 0);
  350. c
  351. }
  352. pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  353. assert_eq!(a.rows, 1);
  354. assert_eq!(a.cols, 1);
  355. let params = res.params;
  356. let pol2 = a.get_poly(0, 0);
  357. for i in 0..b.rows {
  358. for j in 0..b.cols {
  359. let res_poly = res.get_poly_mut(i, j);
  360. let pol1 = b.get_poly(i, j);
  361. multiply_poly(params, res_poly, pol1, pol2);
  362. }
  363. }
  364. }
  365. pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  366. let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  367. scalar_multiply(&mut res, a, b);
  368. res
  369. }
  370. pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
  371. let mut res = PolyMatrixRaw::zero(params, 1, 1);
  372. res.get_poly_mut(0, 0)[0] = val;
  373. res
  374. }
  375. pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
  376. let params = a.params;
  377. for r in 0..a.rows {
  378. for c in 0..a.cols {
  379. let pol_src = b.get_poly(r, c);
  380. let pol_dst = a.get_poly_mut(r, c);
  381. for n in 0..params.crt_count {
  382. for z in 0..params.poly_len {
  383. pol_dst[n * params.poly_len + z] = pol_src[z] % params.moduli[n];
  384. }
  385. }
  386. ntt_forward(params, pol_dst);
  387. }
  388. }
  389. }
  390. pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
  391. let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  392. to_ntt(&mut a, b);
  393. a
  394. }
  395. pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) {
  396. let params = a.params;
  397. SCRATCH.with(|scratch_cell| {
  398. let scratch_vec = &mut *scratch_cell.borrow_mut();
  399. let scratch = scratch_vec.as_mut_slice();
  400. for r in 0..a.rows {
  401. for c in 0..a.cols {
  402. let pol_src = b.get_poly(r, c);
  403. let pol_dst = a.get_poly_mut(r, c);
  404. scratch[0..pol_src.len()].copy_from_slice(pol_src);
  405. ntt_inverse(params, scratch);
  406. for z in 0..params.poly_len {
  407. pol_dst[z] = params.crt_compose_2(scratch[z], scratch[params.poly_len + z]);
  408. }
  409. }
  410. }
  411. });
  412. }
  413. pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> {
  414. let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols);
  415. from_ntt(&mut a, b);
  416. a
  417. }
  418. impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> {
  419. type Output = PolyMatrixRaw<'a>;
  420. fn neg(self) -> Self::Output {
  421. let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols);
  422. invert(&mut out, self);
  423. out
  424. }
  425. }
  426. impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> {
  427. type Output = PolyMatrixNTT<'a>;
  428. fn mul(self, rhs: Self) -> Self::Output {
  429. let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
  430. multiply(&mut out, self, rhs);
  431. out
  432. }
  433. }
  434. impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> {
  435. type Output = PolyMatrixNTT<'a>;
  436. fn add(self, rhs: Self) -> Self::Output {
  437. let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols);
  438. add(&mut out, self, rhs);
  439. out
  440. }
  441. }
  442. #[cfg(test)]
  443. mod test {
  444. use super::*;
  445. fn get_params() -> Params {
  446. get_test_params()
  447. }
  448. fn assert_all_zero(a: &[u64]) {
  449. for i in a {
  450. assert_eq!(*i, 0);
  451. }
  452. }
  453. #[test]
  454. fn sets_all_zeros() {
  455. let params = get_params();
  456. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  457. assert_all_zero(m1.as_slice());
  458. }
  459. #[test]
  460. fn multiply_correctness() {
  461. let params = get_params();
  462. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  463. let m2 = PolyMatrixNTT::zero(&params, 3, 2);
  464. let m3 = &m2 * &m1;
  465. assert_all_zero(m3.as_slice());
  466. }
  467. #[test]
  468. fn full_multiply_correctness() {
  469. let params = get_params();
  470. let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
  471. let mut m2 = PolyMatrixRaw::zero(&params, 1, 1);
  472. m1.get_poly_mut(0, 0)[1] = 100;
  473. m2.get_poly_mut(0, 0)[1] = 7;
  474. let m1_ntt = to_ntt_alloc(&m1);
  475. let m2_ntt = to_ntt_alloc(&m2);
  476. let m3_ntt = &m1_ntt * &m2_ntt;
  477. let m3 = from_ntt_alloc(&m3_ntt);
  478. assert_eq!(m3.get_poly(0, 0)[2], 700);
  479. }
  480. }