poly.rs 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. use crate::{arith::*, params::*};
  2. pub trait PolyMatrix<'a> {
  3. fn is_ntt(&self) -> bool;
  4. fn get_rows(&self) -> usize;
  5. fn get_cols(&self) -> usize;
  6. fn get_params(&self) -> &Params;
  7. fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
  8. fn as_slice(&self) -> &[u64];
  9. fn as_mut_slice(&mut self) -> &mut [u64];
  10. fn zero_out(&mut self) {
  11. for item in self.as_mut_slice() {
  12. *item = 0;
  13. }
  14. }
  15. fn get_poly(&self, row: usize, col: usize) -> &[u64] {
  16. let params = self.get_params();
  17. let start = (row * self.get_cols() + col) * params.poly_len;
  18. &self.as_slice()[start..start + params.poly_len]
  19. }
  20. fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
  21. let poly_len = self.get_params().poly_len;
  22. let start = (row * self.get_cols() + col) * poly_len;
  23. &mut self.as_mut_slice()[start..start + poly_len]
  24. }
  25. }
  26. pub struct PolyMatrixRaw<'a> {
  27. params: &'a Params,
  28. rows: usize,
  29. cols: usize,
  30. data: Vec<u64>,
  31. }
  32. pub struct PolyMatrixNTT<'a> {
  33. params: &'a Params,
  34. rows: usize,
  35. cols: usize,
  36. data: Vec<u64>,
  37. }
  38. impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
  39. fn is_ntt(&self) -> bool {
  40. false
  41. }
  42. fn get_rows(&self) -> usize {
  43. self.rows
  44. }
  45. fn get_cols(&self) -> usize {
  46. self.cols
  47. }
  48. fn get_params(&self) -> &Params {
  49. &self.params
  50. }
  51. fn as_slice(&self) -> &[u64] {
  52. self.data.as_slice()
  53. }
  54. fn as_mut_slice(&mut self) -> &mut [u64] {
  55. self.data.as_mut_slice()
  56. }
  57. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
  58. let num_coeffs = rows * cols * params.poly_len;
  59. let data: Vec<u64> = vec![0; num_coeffs];
  60. PolyMatrixRaw {
  61. params,
  62. rows,
  63. cols,
  64. data,
  65. }
  66. }
  67. }
  68. impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
  69. fn is_ntt(&self) -> bool {
  70. true
  71. }
  72. fn get_rows(&self) -> usize {
  73. self.rows
  74. }
  75. fn get_cols(&self) -> usize {
  76. self.cols
  77. }
  78. fn get_params(&self) -> &Params {
  79. &self.params
  80. }
  81. fn as_slice(&self) -> &[u64] {
  82. self.data.as_slice()
  83. }
  84. fn as_mut_slice(&mut self) -> &mut [u64] {
  85. self.data.as_mut_slice()
  86. }
  87. fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
  88. let num_coeffs = rows * cols * params.poly_len * params.crt_count;
  89. let data: Vec<u64> = vec![0; num_coeffs];
  90. PolyMatrixNTT {
  91. params,
  92. rows,
  93. cols,
  94. data,
  95. }
  96. }
  97. }
  98. pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  99. for c in 0..params.crt_count {
  100. for i in 0..params.poly_len {
  101. res[i] = multiply_modular(params, a[i], b[i], c);
  102. }
  103. }
  104. }
  105. pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
  106. for c in 0..params.crt_count {
  107. for i in 0..params.poly_len {
  108. res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
  109. }
  110. }
  111. }
  112. pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
  113. assert!(a.cols == b.rows);
  114. for i in 0..a.rows {
  115. for j in 0..b.cols {
  116. for z in 0..res.params.poly_len {
  117. res.get_poly_mut(i, j)[z] = 0;
  118. }
  119. for k in 0..a.cols {
  120. let params = res.params;
  121. let res_poly = res.get_poly_mut(i, j);
  122. let pol1 = a.get_poly(i, k);
  123. let pol2 = b.get_poly(k, j);
  124. multiply_add_poly(params, res_poly, pol1, pol2);
  125. }
  126. }
  127. }
  128. }
  129. #[cfg(test)]
  130. mod test {
  131. use super::*;
  132. fn get_params() -> Params {
  133. Params::init(2048, vec![268369921u64, 249561089u64])
  134. }
  135. fn assert_all_zero(a: &[u64]) {
  136. for i in a {
  137. assert_eq!(*i, 0);
  138. }
  139. }
  140. #[test]
  141. fn sets_all_zeros() {
  142. let params = get_params();
  143. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  144. assert_all_zero(m1.as_slice());
  145. }
  146. #[test]
  147. fn multiply_correctness() {
  148. let params = get_params();
  149. let m1 = PolyMatrixNTT::zero(&params, 2, 1);
  150. let m2 = PolyMatrixNTT::zero(&params, 3, 2);
  151. let mut m3 = PolyMatrixNTT::zero(&params, 3, 1);
  152. multiply(&mut m3, &m2, &m1);
  153. assert_all_zero(m3.as_slice());
  154. }
  155. }