poly.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. use core::num;
  2. #[cfg(target_feature = "avx2")]
  3. use std::arch::x86_64::*;
  4. use std::ops::{Add, Mul, Neg};
  5. use std::cell::RefCell;
  6. use rand::Rng;
  7. use rand::distributions::Standard;
  8. use crate::{arith::*, params::*, ntt::*, util::*, discrete_gaussian::*};
  9. const SCRATCH_SPACE: usize = 8192;
  10. thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
  11. pub trait PolyMatrix<'a> {
  12. fn is_ntt(&self) -> bool;
  13. fn get_rows(&self) -> usize;
  14. fn get_cols(&self) -> usize;
  15. fn get_params(&self) -> &Params;
  16. fn num_words(&self) -> usize;
  17. fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
  18. fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
  19. fn as_slice(&self) -> &[u64];
  20. fn as_mut_slice(&mut self) -> &mut [u64];
  21. fn zero_out(&mut self) {
  22. for item in self.as_mut_slice() {
  23. *item = 0;
  24. }
  25. }
  26. fn get_poly(&self, row: usize, col: usize) -> &[u64] {
  27. let num_words = self.num_words();
  28. let start = (row * self.get_cols() + col) * num_words;
  29. &self.as_slice()[start..start + num_words]
  30. }
  31. fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
  32. let num_words = self.num_words();
  33. let start = (row * self.get_cols() + col) * num_words;
  34. &mut self.as_mut_slice()[start..start + num_words]
  35. }
  36. fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
  37. assert!(target_row < self.get_rows());
  38. assert!(target_col < self.get_cols());
  39. assert!(target_row + p.get_rows() <= self.get_rows());
  40. assert!(target_col + p.get_cols() <= self.get_cols());
  41. for r in 0..p.get_rows() {
  42. for c in 0..p.get_cols() {
  43. let pol_src = p.get_poly(r, c);
  44. let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
  45. pol_dst.copy_from_slice(pol_src);
  46. }
  47. }
  48. }
  49. fn pad_top(&self, pad_rows: usize) -> Self;
  50. }
  51. pub struct PolyMatrixRaw<'a> {
  52. pub params: &'a Params,
  53. pub rows: usize,
  54. pub cols: usize,
  55. pub data: Vec<u64>,
  56. }
  57. pub struct PolyMatrixNTT<'a> {
  58. pub params: &'a Params,
  59. pub rows: usize,
  60. pub cols: usize,
  61. pub data: Vec<u64>,
  62. }
  63. impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
  64. fn is_ntt(&self) -> bool {
  65. false
  66. }
  67. fn get_rows(&self) -> usize {
  68. self.rows
  69. }
  70. fn get_cols(&self) -> usize {
  71. self.cols
  72. }
  73. fn get_params(&self) -> &Params {
  74. &self.params
  75. }
  76. fn as_slice(&self) -> &[u64] {
  77. self.data.as_slice()
  78. }
  79. fn as_mut_slice(&mut self) -> &mut [u64] {
  80. self.data.as_mut_slice()
  81. }
  82. fn num_words(&self) -> usize {
  83. self.params.poly_len
  84. }
  85. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  86. let num_coeffs = rows * cols * params.poly_len;
  87. let data: Vec<u64> = vec![0; num_coeffs];
  88. PolyMatrixRaw {
  89. params,
  90. rows,
  91. cols,
  92. data,
  93. }
  94. }
  95. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  96. let rng = rand::thread_rng();
  97. let mut iter = rng.sample_iter(&Standard);
  98. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  99. for r in 0..rows {
  100. for c in 0..cols {
  101. for i in 0..params.poly_len {
  102. let val: u64 = iter.next().unwrap();
  103. out.get_poly_mut(r, c)[i] = val % params.modulus;
  104. }
  105. }
  106. }
  107. out
  108. }
  109. fn pad_top(&self, pad_rows: usize) -> Self {
  110. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  111. padded.copy_into(&self, pad_rows, 0);
  112. padded
  113. }
  114. }
  115. impl<'a> PolyMatrixRaw<'a> {
  116. pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  117. let num_coeffs = rows * cols * params.poly_len;
  118. let mut data: Vec<u64> = vec![0; num_coeffs];
  119. for r in 0..rows {
  120. let c = r;
  121. let idx = r * cols * params.poly_len + c * params.poly_len;
  122. data[idx] = 1;
  123. }
  124. PolyMatrixRaw {
  125. params,
  126. rows,
  127. cols,
  128. data,
  129. }
  130. }
  131. pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &mut DiscreteGaussian) -> Self {
  132. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  133. dg.sample_matrix(&mut out);
  134. out
  135. }
  136. pub fn ntt(&self) -> PolyMatrixNTT<'a> {
  137. to_ntt_alloc(&self)
  138. }
  139. pub fn to_vec(&self, modulus_bits: usize, num_coeffs: usize) -> Vec<u8> {
  140. let sz_bits = self.rows * self.cols * num_coeffs * modulus_bits;
  141. let sz_bytes = f64::ceil((sz_bits as f64) / 8f64) as usize + 32;
  142. let sz_bytes_roundup_16 = ((sz_bytes + 15) / 16) * 16;
  143. let mut data = vec![0u8; sz_bytes_roundup_16];
  144. let mut bit_offs = 0;
  145. for r in 0..self.rows {
  146. for c in 0..self.cols {
  147. for z in 0..num_coeffs {
  148. write_arbitrary_bits(data.as_mut_slice(), self.get_poly(r,c)[z], bit_offs, modulus_bits);
  149. bit_offs += modulus_bits;
  150. }
  151. // round bit_offs down to nearest byte boundary
  152. bit_offs = (bit_offs / 8) * 8
  153. }
  154. }
  155. data
  156. }
  157. pub fn single_value(params: &'a Params, value: u64) -> PolyMatrixRaw<'a> {
  158. let mut out = Self::zero(params, 1, 1);
  159. out.data[0] = value;
  160. out
  161. }
  162. }
  163. impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
  164. fn is_ntt(&self) -> bool {
  165. true
  166. }
  167. fn get_rows(&self) -> usize {
  168. self.rows
  169. }
  170. fn get_cols(&self) -> usize {
  171. self.cols
  172. }
  173. fn get_params(&self) -> &Params {
  174. &self.params
  175. }
  176. fn as_slice(&self) -> &[u64] {
  177. self.data.as_slice()
  178. }
  179. fn as_mut_slice(&mut self) -> &mut [u64] {
  180. self.data.as_mut_slice()
  181. }
  182. fn num_words(&self) -> usize {
  183. self.params.poly_len * self.params.crt_count
  184. }
  185. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
  186. let num_coeffs = rows * cols * params.poly_len * params.crt_count;
  187. let data: Vec<u64> = vec![0; num_coeffs];
  188. PolyMatrixNTT {
  189. params,
  190. rows,
  191. cols,
  192. data,
  193. }
  194. }
  195. fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
  196. let rng = rand::thread_rng();
  197. let mut iter = rng.sample_iter(&Standard);
  198. let mut out = PolyMatrixNTT::zero(params, rows, cols);
  199. for r in 0..rows {
  200. for c in 0..cols {
  201. for i in 0..params.crt_count {
  202. for j in 0..params.poly_len {
  203. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  204. let val: u64 = iter.next().unwrap();
  205. out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
  206. }
  207. }
  208. }
  209. }
  210. out
  211. }
  212. fn pad_top(&self, pad_rows: usize) -> Self {
  213. let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
  214. padded.copy_into(&self, pad_rows, 0);
  215. padded
  216. }
  217. }
  218. impl<'a> PolyMatrixNTT<'a> {
  219. pub fn raw(&self) -> PolyMatrixRaw<'a> {
  220. from_ntt_alloc(&self)
  221. }
  222. }
  223. pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  224. for c in 0..params.crt_count {
  225. for i in 0..params.poly_len {
  226. let idx = c * params.poly_len + i;
  227. res[idx] = multiply_modular(params, a[idx], b[idx], c);
  228. }
  229. }
  230. }
  231. pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  232. for c in 0..params.crt_count {
  233. for i in 0..params.poly_len {
  234. let idx = c * params.poly_len + i;
  235. res[idx] = multiply_add_modular(params, a[idx], b[idx], res[idx], c);
  236. }
  237. }
  238. }
  239. pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  240. for c in 0..params.crt_count {
  241. for i in 0..params.poly_len {
  242. let idx = c * params.poly_len + i;
  243. res[idx] = add_modular(params, a[idx], b[idx], c);
  244. }
  245. }
  246. }
  247. pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
  248. for i in 0..params.poly_len {
  249. res[i] = params.modulus - a[i];
  250. }
  251. }
  252. pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
  253. let poly_len = params.poly_len;
  254. for i in 0..poly_len {
  255. let num = (i * t) / poly_len;
  256. let rem = (i * t) % poly_len;
  257. if num % 2 == 0 {
  258. res[rem] = a[i];
  259. } else {
  260. res[rem] = params.modulus - a[i];
  261. }
  262. }
  263. }
  264. #[cfg(target_feature = "avx2")]
  265. pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  266. for c in 0..params.crt_count {
  267. for i in (0..params.poly_len).step_by(4) {
  268. unsafe {
  269. let p_x = &a[c*params.poly_len + i] as *const u64;
  270. let p_y = &b[c*params.poly_len + i] as *const u64;
  271. let p_z = &mut res[c*params.poly_len + i] as *mut u64;
  272. let x = _mm256_loadu_si256(p_x as *const __m256i);
  273. let y = _mm256_loadu_si256(p_y as *const __m256i);
  274. let z = _mm256_loadu_si256(p_z as *const __m256i);
  275. let product = _mm256_mul_epu32(x, y);
  276. let out = _mm256_add_epi64(z, product);
  277. _mm256_storeu_si256(p_z as *mut __m256i, out);
  278. }
  279. }
  280. }
  281. }
  282. pub fn modular_reduce(params: &Params, res: &mut [u64]) {
  283. for c in 0..params.crt_count {
  284. for i in 0..params.poly_len {
  285. res[c*params.poly_len + i] %= params.moduli[c];
  286. }
  287. }
  288. }
  289. #[cfg(not(target_feature = "avx2"))]
  290. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  291. assert!(res.rows == a.rows);
  292. assert!(res.cols == b.cols);
  293. assert!(a.cols == b.rows);
  294. let params = res.params;
  295. for i in 0..a.rows {
  296. for j in 0..b.cols {
  297. for z in 0..params.poly_len*params.crt_count {
  298. res.get_poly_mut(i, j)[z] = 0;
  299. }
  300. for k in 0..a.cols {
  301. let params = res.params;
  302. let res_poly = res.get_poly_mut(i, j);
  303. let pol1 = a.get_poly(i, k);
  304. let pol2 = b.get_poly(k, j);
  305. multiply_add_poly(params, res_poly, pol1, pol2);
  306. }
  307. }
  308. }
  309. }
  310. #[cfg(target_feature = "avx2")]
  311. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  312. assert!(res.rows == a.rows);
  313. assert!(res.cols == b.cols);
  314. assert!(a.cols == b.rows);
  315. let params = res.params;
  316. for i in 0..a.rows {
  317. for j in 0..b.cols {
  318. for z in 0..params.poly_len*params.crt_count {
  319. res.get_poly_mut(i, j)[z] = 0;
  320. }
  321. let res_poly = res.get_poly_mut(i, j);
  322. for k in 0..a.cols {
  323. let pol1 = a.get_poly(i, k);
  324. let pol2 = b.get_poly(k, j);
  325. multiply_add_poly_avx(params, res_poly, pol1, pol2);
  326. }
  327. modular_reduce(params, res_poly);
  328. }
  329. }
  330. }
  331. pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  332. assert!(res.rows == a.rows);
  333. assert!(res.cols == a.cols);
  334. assert!(a.rows == b.rows);
  335. assert!(a.cols == b.cols);
  336. let params = res.params;
  337. for i in 0..a.rows {
  338. for j in 0..a.cols {
  339. let res_poly = res.get_poly_mut(i, j);
  340. let pol1 = a.get_poly(i, j);
  341. let pol2 = b.get_poly(i, j);
  342. add_poly(params, res_poly, pol1, pol2);
  343. }
  344. }
  345. }
  346. pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
  347. assert!(res.rows == a.rows);
  348. assert!(res.cols == a.cols);
  349. let params = res.params;
  350. for i in 0..a.rows {
  351. for j in 0..a.cols {
  352. let res_poly = res.get_poly_mut(i, j);
  353. let pol1 = a.get_poly(i, j);
  354. invert_poly(params, res_poly, pol1);
  355. }
  356. }
  357. }
  358. pub fn automorph<'a>(res: &mut PolyMatrixRaw<'a>, a: &PolyMatrixRaw<'a>, t: usize) {
  359. assert!(res.rows == a.rows);
  360. assert!(res.cols == a.cols);
  361. let params = res.params;
  362. for i in 0..a.rows {
  363. for j in 0..a.cols {
  364. let res_poly = res.get_poly_mut(i, j);
  365. let pol1 = a.get_poly(i, j);
  366. automorph_poly(params, res_poly, pol1, t);
  367. }
  368. }
  369. }
  370. pub fn automorph_alloc<'a>(a: &PolyMatrixRaw<'a>, t: usize) -> PolyMatrixRaw<'a> {
  371. let mut res = PolyMatrixRaw::zero(a.params, a.rows, a.cols);
  372. automorph(&mut res, a, t);
  373. res
  374. }
  375. pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
  376. assert_eq!(a.cols, b.cols);
  377. let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
  378. c.copy_into(a, 0, 0);
  379. c.copy_into(b, a.rows, 0);
  380. c
  381. }
  382. pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  383. assert_eq!(a.rows, 1);
  384. assert_eq!(a.cols, 1);
  385. let params = res.params;
  386. let pol2 = a.get_poly(0, 0);
  387. for i in 0..b.rows {
  388. for j in 0..b.cols {
  389. let res_poly = res.get_poly_mut(i, j);
  390. let pol1 = b.get_poly(i, j);
  391. multiply_poly(params, res_poly, pol1, pol2);
  392. }
  393. }
  394. }
  395. pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  396. let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  397. scalar_multiply(&mut res, a, b);
  398. res
  399. }
  400. pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
  401. let mut res = PolyMatrixRaw::zero(params, 1, 1);
  402. res.get_poly_mut(0, 0)[0] = val;
  403. res
  404. }
  405. pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
  406. let params = a.params;
  407. for r in 0..a.rows {
  408. for c in 0..a.cols {
  409. let pol_src = b.get_poly(r, c);
  410. let pol_dst = a.get_poly_mut(r, c);
  411. for n in 0..params.crt_count {
  412. for z in 0..params.poly_len {
  413. pol_dst[n * params.poly_len + z] = pol_src[z] % params.moduli[n];
  414. }
  415. }
  416. ntt_forward(params, pol_dst);
  417. }
  418. }
  419. }
  420. pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
  421. let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
  422. to_ntt(&mut a, b);
  423. a
  424. }
  425. pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) {
  426. let params = a.params;
  427. SCRATCH.with(|scratch_cell| {
  428. let scratch_vec = &mut *scratch_cell.borrow_mut();
  429. let scratch = scratch_vec.as_mut_slice();
  430. for r in 0..a.rows {
  431. for c in 0..a.cols {
  432. let pol_src = b.get_poly(r, c);
  433. let pol_dst = a.get_poly_mut(r, c);
  434. scratch[0..pol_src.len()].copy_from_slice(pol_src);
  435. ntt_inverse(params, scratch);
  436. for z in 0..params.poly_len {
  437. pol_dst[z] = params.crt_compose(scratch, z);
  438. }
  439. }
  440. }
  441. });
  442. }
  443. pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> {
  444. let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols);
  445. from_ntt(&mut a, b);
  446. a
  447. }
  448. impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> {
  449. type Output = PolyMatrixRaw<'a>;
  450. fn neg(self) -> Self::Output {
  451. let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols);
  452. invert(&mut out, self);
  453. out
  454. }
  455. }
  456. impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> {
  457. type Output = PolyMatrixNTT<'a>;
  458. fn mul(self, rhs: Self) -> Self::Output {
  459. let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
  460. multiply(&mut out, self, rhs);
  461. out
  462. }
  463. }
  464. impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> {
  465. type Output = PolyMatrixNTT<'a>;
  466. fn add(self, rhs: Self) -> Self::Output {
  467. let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols);
  468. add(&mut out, self, rhs);
  469. out
  470. }
  471. }
  472. #[cfg(test)]
  473. mod test {
  474. use super::*;
  475. fn get_params() -> Params {
  476. get_test_params()
  477. }
  478. fn assert_all_zero(a: &[u64]) {
  479. for i in a {
  480. assert_eq!(*i, 0);
  481. }
  482. }
  483. #[test]
  484. fn sets_all_zeros() {
  485. let params = get_params();
  486. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  487. assert_all_zero(m1.as_slice());
  488. }
  489. #[test]
  490. fn multiply_correctness() {
  491. let params = get_params();
  492. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  493. let m2 = PolyMatrixNTT::zero(&params, 3, 2);
  494. let m3 = &m2 * &m1;
  495. assert_all_zero(m3.as_slice());
  496. }
  497. #[test]
  498. fn full_multiply_correctness() {
  499. let params = get_params();
  500. let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
  501. let mut m2 = PolyMatrixRaw::zero(&params, 1, 1);
  502. m1.get_poly_mut(0, 0)[1] = 100;
  503. m2.get_poly_mut(0, 0)[1] = 7;
  504. let m1_ntt = to_ntt_alloc(&m1);
  505. let m2_ntt = to_ntt_alloc(&m2);
  506. let m3_ntt = &m1_ntt * &m2_ntt;
  507. let m3 = from_ntt_alloc(&m3_ntt);
  508. assert_eq!(m3.get_poly(0, 0)[2], 700);
  509. }
  510. #[test]
  511. fn to_vec_correctness() {
  512. let params = get_params();
  513. let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
  514. for i in 0..params.poly_len {
  515. m1.data[i] = 1;
  516. }
  517. let modulus_bits = 9;
  518. let v = m1.to_vec(modulus_bits, params.poly_len);
  519. for i in 0..v.len() {
  520. println!("{:?}", v[i]);
  521. }
  522. let mut bit_offs = 0;
  523. for i in 0..params.poly_len {
  524. let val = read_arbitrary_bits(v.as_slice(), bit_offs, modulus_bits);
  525. assert_eq!(m1.data[i], val);
  526. bit_offs += modulus_bits;
  527. }
  528. }
  529. }