poly.rs 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. use std::arch::x86_64::*;
  2. use std::ops::Mul;
  3. use crate::{arith::*, params::*, ntt::*, util::calc_index};
  4. pub trait PolyMatrix<'a> {
  5. fn is_ntt(&self) -> bool;
  6. fn get_rows(&self) -> usize;
  7. fn get_cols(&self) -> usize;
  8. fn get_params(&self) -> &Params;
  9. fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
  10. fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self;
  11. fn as_slice(&self) -> &[u64];
  12. fn as_mut_slice(&mut self) -> &mut [u64];
  13. fn zero_out(&mut self) {
  14. for item in self.as_mut_slice() {
  15. *item = 0;
  16. }
  17. }
  18. fn get_poly(&self, row: usize, col: usize) -> &[u64] {
  19. let num_words = self.get_params().num_words();
  20. let start = (row * self.get_cols() + col) * num_words;
  21. &self.as_slice()[start..start + num_words]
  22. }
  23. fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
  24. let num_words = self.get_params().num_words();
  25. let start = (row * self.get_cols() + col) * num_words;
  26. &mut self.as_mut_slice()[start..start + num_words]
  27. }
  28. fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
  29. assert!(target_row < self.get_rows());
  30. assert!(target_col < self.get_cols());
  31. assert!(target_row + p.get_rows() < self.get_rows());
  32. assert!(target_col + p.get_cols() < self.get_cols());
  33. for r in 0..p.get_rows() {
  34. for c in 0..p.get_cols() {
  35. let pol_src = p.get_poly(r, c);
  36. let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
  37. pol_dst.copy_from_slice(pol_src);
  38. }
  39. }
  40. }
  41. }
  42. pub struct PolyMatrixRaw<'a> {
  43. pub params: &'a Params,
  44. pub rows: usize,
  45. pub cols: usize,
  46. pub data: Vec<u64>,
  47. }
  48. pub struct PolyMatrixNTT<'a> {
  49. pub params: &'a Params,
  50. pub rows: usize,
  51. pub cols: usize,
  52. pub data: Vec<u64>,
  53. }
  54. impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
  55. fn is_ntt(&self) -> bool {
  56. false
  57. }
  58. fn get_rows(&self) -> usize {
  59. self.rows
  60. }
  61. fn get_cols(&self) -> usize {
  62. self.cols
  63. }
  64. fn get_params(&self) -> &Params {
  65. &self.params
  66. }
  67. fn as_slice(&self) -> &[u64] {
  68. self.data.as_slice()
  69. }
  70. fn as_mut_slice(&mut self) -> &mut [u64] {
  71. self.data.as_mut_slice()
  72. }
  73. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  74. let num_coeffs = rows * cols * params.poly_len;
  75. let data: Vec<u64> = vec![0; num_coeffs];
  76. PolyMatrixRaw {
  77. params,
  78. rows,
  79. cols,
  80. data,
  81. }
  82. }
  83. fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
  84. let mut out = PolyMatrixRaw::zero(params, rows, cols);
  85. for r in 0..rows {
  86. for c in 0..cols {
  87. for i in 0..params.poly_len {
  88. let val: u64 = rng.next().unwrap();
  89. out.get_poly_mut(r, c)[i] = val % params.modulus;
  90. }
  91. }
  92. }
  93. out
  94. }
  95. }
  96. impl<'a> PolyMatrixRaw<'a> {
  97. pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  98. let num_coeffs = rows * cols * params.poly_len;
  99. let mut data: Vec<u64> = vec![0; num_coeffs];
  100. for r in 0..rows {
  101. let c = r;
  102. let idx = r * cols * params.poly_len + c * params.poly_len;
  103. data[idx] = 1;
  104. }
  105. PolyMatrixRaw {
  106. params,
  107. rows,
  108. cols,
  109. data,
  110. }
  111. }
  112. }
  113. impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
  114. fn is_ntt(&self) -> bool {
  115. true
  116. }
  117. fn get_rows(&self) -> usize {
  118. self.rows
  119. }
  120. fn get_cols(&self) -> usize {
  121. self.cols
  122. }
  123. fn get_params(&self) -> &Params {
  124. &self.params
  125. }
  126. fn as_slice(&self) -> &[u64] {
  127. self.data.as_slice()
  128. }
  129. fn as_mut_slice(&mut self) -> &mut [u64] {
  130. self.data.as_mut_slice()
  131. }
  132. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
  133. let num_coeffs = rows * cols * params.poly_len * params.crt_count;
  134. let data: Vec<u64> = vec![0; num_coeffs];
  135. PolyMatrixNTT {
  136. params,
  137. rows,
  138. cols,
  139. data,
  140. }
  141. }
  142. fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
  143. let mut out = PolyMatrixNTT::zero(params, rows, cols);
  144. for r in 0..rows {
  145. for c in 0..cols {
  146. for i in 0..params.crt_count {
  147. for j in 0..params.poly_len {
  148. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  149. let val: u64 = rng.next().unwrap();
  150. out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
  151. }
  152. }
  153. }
  154. }
  155. out
  156. }
  157. }
  158. pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  159. for c in 0..params.crt_count {
  160. for i in 0..params.poly_len {
  161. res[i] = multiply_modular(params, a[i], b[i], c);
  162. }
  163. }
  164. }
  165. pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  166. for c in 0..params.crt_count {
  167. for i in 0..params.poly_len {
  168. res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
  169. }
  170. }
  171. }
  172. pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  173. for c in 0..params.crt_count {
  174. for i in (0..params.poly_len).step_by(4) {
  175. unsafe {
  176. let p_x = &a[c*params.poly_len + i] as *const u64;
  177. let p_y = &b[c*params.poly_len + i] as *const u64;
  178. let p_z = &mut res[c*params.poly_len + i] as *mut u64;
  179. let x = _mm256_loadu_si256(p_x as *const __m256i);
  180. let y = _mm256_loadu_si256(p_y as *const __m256i);
  181. let z = _mm256_loadu_si256(p_z as *const __m256i);
  182. let product = _mm256_mul_epu32(x, y);
  183. let out = _mm256_add_epi64(z, product);
  184. _mm256_storeu_si256(p_z as *mut __m256i, out);
  185. }
  186. }
  187. }
  188. }
  189. pub fn modular_reduce(params: &Params, res: &mut [u64]) {
  190. for c in 0..params.crt_count {
  191. for i in 0..params.poly_len {
  192. res[c*params.poly_len + i] %= params.moduli[c];
  193. }
  194. }
  195. }
  196. #[cfg(not(target_feature = "avx2"))]
  197. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  198. assert!(a.cols == b.rows);
  199. for i in 0..a.rows {
  200. for j in 0..b.cols {
  201. for z in 0..res.params.poly_len {
  202. res.get_poly_mut(i, j)[z] = 0;
  203. }
  204. for k in 0..a.cols {
  205. let params = res.params;
  206. let res_poly = res.get_poly_mut(i, j);
  207. let pol1 = a.get_poly(i, k);
  208. let pol2 = b.get_poly(k, j);
  209. multiply_add_poly(params, res_poly, pol1, pol2);
  210. }
  211. }
  212. }
  213. }
  214. #[cfg(target_feature = "avx2")]
  215. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  216. assert!(a.cols == b.rows);
  217. let params = res.params;
  218. for i in 0..a.rows {
  219. for j in 0..b.cols {
  220. for z in 0..res.params.poly_len {
  221. res.get_poly_mut(i, j)[z] = 0;
  222. }
  223. let res_poly = res.get_poly_mut(i, j);
  224. for k in 0..a.cols {
  225. let pol1 = a.get_poly(i, k);
  226. let pol2 = b.get_poly(k, j);
  227. multiply_add_poly_avx(params, res_poly, pol1, pol2);
  228. }
  229. modular_reduce(params, res_poly);
  230. }
  231. }
  232. }
  233. pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  234. assert_eq!(b.rows, 1);
  235. assert_eq!(b.cols, 1);
  236. let params = res.params;
  237. let pol2 = b.get_poly(0, 0);
  238. for i in 0..a.rows {
  239. for j in 0..a.cols {
  240. let res_poly = res.get_poly_mut(i, j);
  241. let pol1 = a.get_poly(i, j);
  242. multiply_poly(params, res_poly, pol1, pol2);
  243. }
  244. }
  245. }
  246. pub fn from_scalar_multiply<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
  247. let mut res = PolyMatrixNTT::zero(a.params, a.rows, a.cols);
  248. scalar_multiply(res, a, b);
  249. res
  250. }
  251. pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
  252. for r in 0..a.rows {
  253. for c in 0..a.cols {
  254. let pol_src = a.get_poly_mut(r, c);
  255. let pol_dst = b.get_poly_mut(r, c);
  256. for n in 0..a.params.crt_count {
  257. for z in 0..a.params.poly_len {
  258. pol_dst[n * a.params.poly_len + z] = pol_src[z] % a.params.moduli[n];
  259. }
  260. }
  261. ntt_forward(a.params, pol_dst);
  262. }
  263. }
  264. }
  265. pub fn from_ntt(params: &Params, res: &mut [u64]) {
  266. for c in 0..params.crt_count {
  267. for i in 0..params.poly_len {
  268. res[c*params.poly_len + i] %= params.moduli[c];
  269. }
  270. }
  271. }
  272. impl<'a> Mul for PolyMatrixNTT<'a> {
  273. type Output = Self;
  274. fn mul(self, rhs: Self) -> Self::Output {
  275. let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
  276. multiply(&mut out, &self, &rhs);
  277. out
  278. }
  279. }
  280. #[cfg(test)]
  281. mod test {
  282. use super::*;
  283. fn get_params() -> Params {
  284. Params::init(2048, &vec![268369921u64, 249561089u64], 2, 6.4)
  285. }
  286. fn assert_all_zero(a: &[u64]) {
  287. for i in a {
  288. assert_eq!(*i, 0);
  289. }
  290. }
  291. #[test]
  292. fn sets_all_zeros() {
  293. let params = get_params();
  294. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  295. assert_all_zero(m1.as_slice());
  296. }
  297. #[test]
  298. fn multiply_correctness() {
  299. let params = get_params();
  300. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  301. let m2 = PolyMatrixNTT::zero(&params, 3, 2);
  302. let m3 = m2 * m1;
  303. assert_all_zero(m3.as_slice());
  304. }
  305. }