poly.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use std::ops::{Add, Mul, Neg};
  4. use std::cell::RefCell;
  5. use rand::Rng;
  6. use rand::distributions::Standard;
  7. use crate::{arith::*, params::*, ntt::*, util::*, discrete_gaussian::*};
  8. const SCRATCH_SPACE: usize = 8192;
  9. thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
  10. pub trait PolyMatrix<'a> {
  11. fn is_ntt(&self) -> bool;
  12. fn get_rows(&self) -> usize;
  13. fn get_cols(&self) -> usize;
  14. fn get_params(&self) -> &Params;
  15. fn num_words(&self) -> usize;
  16. fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
  17. fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
  18. fn as_slice(&self) -> &[u64];
  19. fn as_mut_slice(&mut self) -> &mut [u64];
  20. fn zero_out(&mut self) {
  21. for item in self.as_mut_slice() {
  22. *item = 0;
  23. }
  24. }
  25. fn get_poly(&self, row: usize, col: usize) -> &[u64] {
  26. let num_words = self.num_words();
  27. let start = (row * self.get_cols() + col) * num_words;
  28. &self.as_slice()[start..start + num_words]
  29. }
  30. fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
  31. let num_words = self.num_words();
  32. let start = (row * self.get_cols() + col) * num_words;
  33. &mut self.as_mut_slice()[start..start + num_words]
  34. }
  35. fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
  36. assert!(target_row < self.get_rows());
  37. assert!(target_col < self.get_cols());
  38. assert!(target_row + p.get_rows() <= self.get_rows());
  39. assert!(target_col + p.get_cols() <= self.get_cols());
  40. for r in 0..p.get_rows() {
  41. for c in 0..p.get_cols() {
  42. let pol_src = p.get_poly(r, c);
  43. let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
  44. pol_dst.copy_from_slice(pol_src);
  45. }
  46. }
  47. }
  48. fn pad_top(&self, pad_rows: usize) -> Self;
  49. }
  50. pub struct PolyMatrixRaw<'a> {
  51. pub params: &'a Params,
  52. pub rows: usize,
  53. pub cols: usize,
  54. pub data: Vec<u64>,
  55. }
  56. pub struct PolyMatrixNTT<'a> {
  57. pub params: &'a Params,
  58. pub rows: usize,
  59. pub cols: usize,
  60. pub data: Vec<u64>,
  61. }
  62. impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
  63. fn is_ntt(&self) -> bool {
  64. false
  65. }
  66. fn get_rows(&self) -> usize {
  67. self.rows
  68. }
  69. fn get_cols(&self) -> usize {
  70. self.cols
  71. }
  72. fn get_params(&self) -> &Params {
  73. &self.params
  74. }
  75. fn as_slice(&self) -> &[u64] {
  76. self.data.as_slice()
  77. }
  78. fn as_mut_slice(&mut self) -> &mut [u64] {
  79. self.data.as_mut_slice()
  80. }
  81. fn num_words(&self) -> usize {
  82. self.params.poly_len
  83. }
  84. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  85. let num_coeffs = rows * cols * params.poly_len;
  86. let data: Vec<u64> = vec![0; num_coeffs];
  87. PolyMatrixRaw {
  88. params,
  89. rows,
  90. cols,
  91. data,
  92. }
  93. }
  94. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  95. let rng = rand::thread_rng();
  96. let mut iter = rng.sample_iter(&Standard);
  97. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  98. for r in 0..rows {
  99. for c in 0..cols {
  100. for i in 0..params.poly_len {
  101. let val: u64 = iter.next().unwrap();
  102. out.get_poly_mut(r, c)[i] = val % params.modulus;
  103. }
  104. }
  105. }
  106. out
  107. }
  108. fn pad_top(&self, pad_rows: usize) -> Self {
  109. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  110. padded.copy_into(&self, pad_rows, 0);
  111. padded
  112. }
  113. }
  114. impl<'a> PolyMatrixRaw<'a> {
  115. pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  116. let num_coeffs = rows * cols * params.poly_len;
  117. let mut data: Vec<u64> = vec![0; num_coeffs];
  118. for r in 0..rows {
  119. let c = r;
  120. let idx = r * cols * params.poly_len + c * params.poly_len;
  121. data[idx] = 1;
  122. }
  123. PolyMatrixRaw {
  124. params,
  125. rows,
  126. cols,
  127. data,
  128. }
  129. }
  130. pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &mut DiscreteGaussian) -> Self {
  131. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  132. dg.sample_matrix(&mut out);
  133. out
  134. }
  135. pub fn ntt(&self) -> PolyMatrixNTT<'a> {
  136. to_ntt_alloc(&self)
  137. }
  138. pub fn to_vec(&self, modulus_bits: usize) -> Vec<u8> {
  139. let params = self.params;
  140. let sz_bits = self.rows * self.cols * params.poly_len * modulus_bits;
  141. let sz_bytes = f64::ceil((sz_bits as f64) / 8f64) as usize;
  142. let sz_bytes_roundup_16 = ((sz_bytes + 15) / 16) * 16;
  143. let mut data = vec![0u8; sz_bytes_roundup_16];
  144. let mut bit_offs = 0;
  145. for i in 0..self.rows * self.cols * params.poly_len {
  146. write_arbitrary_bits(data.as_mut_slice(), self.data[i], bit_offs, modulus_bits);
  147. bit_offs += modulus_bits;
  148. }
  149. data
  150. }
  151. pub fn single_value(params: &'a Params, value: u64) -> PolyMatrixRaw<'a> {
  152. let mut out = Self::zero(params, 1, 1);
  153. out.data[0] = value;
  154. out
  155. }
  156. }
  157. impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
  158. fn is_ntt(&self) -> bool {
  159. true
  160. }
  161. fn get_rows(&self) -> usize {
  162. self.rows
  163. }
  164. fn get_cols(&self) -> usize {
  165. self.cols
  166. }
  167. fn get_params(&self) -> &Params {
  168. &self.params
  169. }
  170. fn as_slice(&self) -> &[u64] {
  171. self.data.as_slice()
  172. }
  173. fn as_mut_slice(&mut self) -> &mut [u64] {
  174. self.data.as_mut_slice()
  175. }
  176. fn num_words(&self) -> usize {
  177. self.params.poly_len * self.params.crt_count
  178. }
  179. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
  180. let num_coeffs = rows * cols * params.poly_len * params.crt_count;
  181. let data: Vec<u64> = vec![0; num_coeffs];
  182. PolyMatrixNTT {
  183. params,
  184. rows,
  185. cols,
  186. data,
  187. }
  188. }
  189. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  190. let rng = rand::thread_rng();
  191. let mut iter = rng.sample_iter(&Standard);
  192. let mut out = PolyMatrixNTT::zero(params, rows, cols);
  193. for r in 0..rows {
  194. for c in 0..cols {
  195. for i in 0..params.crt_count {
  196. for j in 0..params.poly_len {
  197. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  198. let val: u64 = iter.next().unwrap();
  199. out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
  200. }
  201. }
  202. }
  203. }
  204. out
  205. }
  206. fn pad_top(&self, pad_rows: usize) -> Self {
  207. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  208. padded.copy_into(&self, pad_rows, 0);
  209. padded
  210. }
  211. }
  212. impl<'a> PolyMatrixNTT<'a> {
  213. pub fn raw(&self) -> PolyMatrixRaw<'a> {
  214. from_ntt_alloc(&self)
  215. }
  216. }
  217. pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  218. for c in 0..params.crt_count {
  219. for i in 0..params.poly_len {
  220. let idx = c * params.poly_len + i;
  221. res[idx] = multiply_modular(params, a[idx], b[idx], c);
  222. }
  223. }
  224. }
  225. pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  226. for c in 0..params.crt_count {
  227. for i in 0..params.poly_len {
  228. let idx = c * params.poly_len + i;
  229. res[idx] = multiply_add_modular(params, a[idx], b[idx], res[idx], c);
  230. }
  231. }
  232. }
  233. pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  234. for c in 0..params.crt_count {
  235. for i in 0..params.poly_len {
  236. let idx = c * params.poly_len + i;
  237. res[idx] = add_modular(params, a[idx], b[idx], c);
  238. }
  239. }
  240. }
  241. pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
  242. for i in 0..params.poly_len {
  243. res[i] = params.modulus - a[i];
  244. }
  245. }
  246. pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
  247. let poly_len = params.poly_len;
  248. for i in 0..poly_len {
  249. let num = (i * t) / poly_len;
  250. let rem = (i * t) % poly_len;
  251. if num % 2 == 0 {
  252. res[rem] = a[i];
  253. } else {
  254. res[rem] = params.modulus - a[i];
  255. }
  256. }
  257. }
  258. #[cfg(target_feature = "avx2")]
  259. pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  260. for c in 0..params.crt_count {
  261. for i in (0..params.poly_len).step_by(4) {
  262. unsafe {
  263. let p_x = &a[c*params.poly_len + i] as *const u64;
  264. let p_y = &b[c*params.poly_len + i] as *const u64;
  265. let p_z = &mut res[c*params.poly_len + i] as *mut u64;
  266. let x = _mm256_loadu_si256(p_x as *const __m256i);
  267. let y = _mm256_loadu_si256(p_y as *const __m256i);
  268. let z = _mm256_loadu_si256(p_z as *const __m256i);
  269. let product = _mm256_mul_epu32(x, y);
  270. let out = _mm256_add_epi64(z, product);
  271. _mm256_storeu_si256(p_z as *mut __m256i, out);
  272. }
  273. }
  274. }
  275. }
  276. pub fn modular_reduce(params: &Params, res: &mut [u64]) {
  277. for c in 0..params.crt_count {
  278. for i in 0..params.poly_len {
  279. res[c*params.poly_len + i] %= params.moduli[c];
  280. }
  281. }
  282. }
  283. #[cfg(not(target_feature = "avx2"))]
  284. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  285. assert!(res.rows == a.rows);
  286. assert!(res.cols == b.cols);
  287. assert!(a.cols == b.rows);
  288. let params = res.params;
  289. for i in 0..a.rows {
  290. for j in 0..b.cols {
  291. for z in 0..params.poly_len*params.crt_count {
  292. res.get_poly_mut(i, j)[z] = 0;
  293. }
  294. for k in 0..a.cols {
  295. let params = res.params;
  296. let res_poly = res.get_poly_mut(i, j);
  297. let pol1 = a.get_poly(i, k);
  298. let pol2 = b.get_poly(k, j);
  299. multiply_add_poly(params, res_poly, pol1, pol2);
  300. }
  301. }
  302. }
  303. }
  304. #[cfg(target_feature = "avx2")]
  305. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  306. assert!(res.rows == a.rows);
  307. assert!(res.cols == b.cols);
  308. assert!(a.cols == b.rows);
  309. let params = res.params;
  310. for i in 0..a.rows {
  311. for j in 0..b.cols {
  312. for z in 0..params.poly_len*params.crt_count {
  313. res.get_poly_mut(i, j)[z] = 0;
  314. }
  315. let res_poly = res.get_poly_mut(i, j);
  316. for k in 0..a.cols {
  317. let pol1 = a.get_poly(i, k);
  318. let pol2 = b.get_poly(k, j);
  319. multiply_add_poly_avx(params, res_poly, pol1, pol2);
  320. }
  321. modular_reduce(params, res_poly);
  322. }
  323. }
  324. }
  325. pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  326. assert!(res.rows == a.rows);
  327. assert!(res.cols == a.cols);
  328. assert!(a.rows == b.rows);
  329. assert!(a.cols == b.cols);
  330. let params = res.params;
  331. for i in 0..a.rows {
  332. for j in 0..a.cols {
  333. let res_poly = res.get_poly_mut(i, j);
  334. let pol1 = a.get_poly(i, j);
  335. let pol2 = b.get_poly(i, j);
  336. add_poly(params, res_poly, pol1, pol2);
  337. }
  338. }
  339. }
  340. pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
  341. assert!(res.rows == a.rows);
  342. assert!(res.cols == a.cols);
  343. let params = res.params;
  344. for i in 0..a.rows {
  345. for j in 0..a.cols {
  346. let res_poly = res.get_poly_mut(i, j);
  347. let pol1 = a.get_poly(i, j);
  348. invert_poly(params, res_poly, pol1);
  349. }
  350. }
  351. }
  352. pub fn automorph<'a>(res: &mut PolyMatrixRaw<'a>, a: &PolyMatrixRaw<'a>, t: usize) {
  353. assert!(res.rows == a.rows);
  354. assert!(res.cols == a.cols);
  355. let params = res.params;
  356. for i in 0..a.rows {
  357. for j in 0..a.cols {
  358. let res_poly = res.get_poly_mut(i, j);
  359. let pol1 = a.get_poly(i, j);
  360. automorph_poly(params, res_poly, pol1, t);
  361. }
  362. }
  363. }
  364. pub fn automorph_alloc<'a>(a: &PolyMatrixRaw<'a>, t: usize) -> PolyMatrixRaw<'a> {
  365. let mut res = PolyMatrixRaw::zero(a.params, a.rows, a.cols);
  366. automorph(&mut res, a, t);
  367. res
  368. }
  369. pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  370. assert_eq!(a.cols, b.cols);
  371. let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
  372. c.copy_into(a, 0, 0);
  373. c.copy_into(b, a.rows, 0);
  374. c
  375. }
  376. pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  377. assert_eq!(a.rows, 1);
  378. assert_eq!(a.cols, 1);
  379. let params = res.params;
  380. let pol2 = a.get_poly(0, 0);
  381. for i in 0..b.rows {
  382. for j in 0..b.cols {
  383. let res_poly = res.get_poly_mut(i, j);
  384. let pol1 = b.get_poly(i, j);
  385. multiply_poly(params, res_poly, pol1, pol2);
  386. }
  387. }
  388. }
  389. pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  390. let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  391. scalar_multiply(&mut res, a, b);
  392. res
  393. }
  394. pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
  395. let mut res = PolyMatrixRaw::zero(params, 1, 1);
  396. res.get_poly_mut(0, 0)[0] = val;
  397. res
  398. }
  399. pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
  400. let params = a.params;
  401. for r in 0..a.rows {
  402. for c in 0..a.cols {
  403. let pol_src = b.get_poly(r, c);
  404. let pol_dst = a.get_poly_mut(r, c);
  405. for n in 0..params.crt_count {
  406. for z in 0..params.poly_len {
  407. pol_dst[n * params.poly_len + z] = pol_src[z] % params.moduli[n];
  408. }
  409. }
  410. ntt_forward(params, pol_dst);
  411. }
  412. }
  413. }
  414. pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
  415. let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  416. to_ntt(&mut a, b);
  417. a
  418. }
  419. pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) {
  420. let params = a.params;
  421. SCRATCH.with(|scratch_cell| {
  422. let scratch_vec = &mut *scratch_cell.borrow_mut();
  423. let scratch = scratch_vec.as_mut_slice();
  424. for r in 0..a.rows {
  425. for c in 0..a.cols {
  426. let pol_src = b.get_poly(r, c);
  427. let pol_dst = a.get_poly_mut(r, c);
  428. scratch[0..pol_src.len()].copy_from_slice(pol_src);
  429. ntt_inverse(params, scratch);
  430. for z in 0..params.poly_len {
  431. pol_dst[z] = params.crt_compose(scratch, z);
  432. }
  433. }
  434. }
  435. });
  436. }
  437. pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> {
  438. let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols);
  439. from_ntt(&mut a, b);
  440. a
  441. }
  442. impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> {
  443. type Output = PolyMatrixRaw<'a>;
  444. fn neg(self) -> Self::Output {
  445. let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols);
  446. invert(&mut out, self);
  447. out
  448. }
  449. }
  450. impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> {
  451. type Output = PolyMatrixNTT<'a>;
  452. fn mul(self, rhs: Self) -> Self::Output {
  453. let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
  454. multiply(&mut out, self, rhs);
  455. out
  456. }
  457. }
  458. impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> {
  459. type Output = PolyMatrixNTT<'a>;
  460. fn add(self, rhs: Self) -> Self::Output {
  461. let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols);
  462. add(&mut out, self, rhs);
  463. out
  464. }
  465. }
  466. #[cfg(test)]
  467. mod test {
  468. use super::*;
  469. fn get_params() -> Params {
  470. get_test_params()
  471. }
  472. fn assert_all_zero(a: &[u64]) {
  473. for i in a {
  474. assert_eq!(*i, 0);
  475. }
  476. }
  477. #[test]
  478. fn sets_all_zeros() {
  479. let params = get_params();
  480. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  481. assert_all_zero(m1.as_slice());
  482. }
  483. #[test]
  484. fn multiply_correctness() {
  485. let params = get_params();
  486. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  487. let m2 = PolyMatrixNTT::zero(&params, 3, 2);
  488. let m3 = &m2 * &m1;
  489. assert_all_zero(m3.as_slice());
  490. }
  491. #[test]
  492. fn full_multiply_correctness() {
  493. let params = get_params();
  494. let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
  495. let mut m2 = PolyMatrixRaw::zero(&params, 1, 1);
  496. m1.get_poly_mut(0, 0)[1] = 100;
  497. m2.get_poly_mut(0, 0)[1] = 7;
  498. let m1_ntt = to_ntt_alloc(&m1);
  499. let m2_ntt = to_ntt_alloc(&m2);
  500. let m3_ntt = &m1_ntt * &m2_ntt;
  501. let m3 = from_ntt_alloc(&m3_ntt);
  502. assert_eq!(m3.get_poly(0, 0)[2], 700);
  503. }
  504. }