lagrange.rs 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. use curve25519_dalek::scalar::Scalar;
  2. // Versions that just compute coefficients; these are used if you know
  3. // all of your input points are correct
  4. // Compute the Lagrange coefficient for the value at the given x
  5. // coordinate, to interpolate the value at target_x. The x coordinates
  6. // in the coalition are allowed to include x itself, which will be
  7. // ignored.
  8. pub fn lagrange(coalition: &[u32], x: u32, target_x: u32) -> Scalar {
  9. let mut numer = Scalar::one();
  10. let mut denom = Scalar::one();
  11. let xscal = Scalar::from(x);
  12. let target_xscal = Scalar::from(target_x);
  13. for &c in coalition {
  14. if c != x {
  15. let cscal = Scalar::from(c);
  16. numer *= target_xscal - cscal;
  17. denom *= xscal - cscal;
  18. }
  19. }
  20. numer * denom.invert()
  21. }
  22. // Interpolate the given (x,y) coordinates at the target x value.
  23. // target_x must _not_ be in the x slice
  24. pub fn interpolate(x: &[u32], y: &[Scalar], target_x: u32) -> Scalar {
  25. assert!(x.len() == y.len());
  26. (0..x.len())
  27. .map(|i| lagrange(x, x[i], target_x) * y[i])
  28. .sum()
  29. }
  30. // Versions that compute the entire Lagrange polynomials; these are used
  31. // if need to _check_ that all of your input points are correct.
  32. // A ScalarPoly represents a polynomial whose coefficients are scalars.
  33. // The coeffs vector has length (deg+1), where deg is the degree of the
  34. // polynomial. coeffs[i] is the coefficient on x^i.
  35. #[derive(Clone, Debug)]
  36. pub struct ScalarPoly {
  37. pub coeffs: Vec<Scalar>,
  38. }
  39. impl ScalarPoly {
  40. pub fn zero() -> Self {
  41. Self { coeffs: vec![] }
  42. }
  43. pub fn one() -> Self {
  44. Self {
  45. coeffs: vec![Scalar::one()],
  46. }
  47. }
  48. pub fn rand(degree: usize) -> Self {
  49. let mut rng = rand::thread_rng();
  50. let mut coeffs: Vec<Scalar> = Vec::new();
  51. coeffs.resize_with(degree + 1, || Scalar::random(&mut rng));
  52. Self { coeffs }
  53. }
  54. // Evaluate the polynomial at the given point (using Horner's
  55. // method)
  56. pub fn eval(&self, x: &Scalar) -> Scalar {
  57. let mut res = Scalar::zero();
  58. for coeff in self.coeffs.iter().rev() {
  59. res *= x;
  60. res += coeff;
  61. }
  62. res
  63. }
  64. // Multiply self by the polynomial (x+a), for the given a
  65. pub fn mult_x_plus_a(&mut self, a: &Scalar) {
  66. // The length of coeffs is (deg+1), which is what we want the
  67. // new degree to be
  68. let newdeg = self.coeffs.len();
  69. if newdeg == 0 {
  70. // self is the zero polynomial, so it doesn't change with
  71. // the multiplication by x+a
  72. return;
  73. }
  74. let newcoeffs = (0..newdeg + 1)
  75. .map(|i| {
  76. if i == 0 {
  77. self.coeffs[i] * a
  78. } else if i == newdeg {
  79. self.coeffs[i - 1]
  80. } else {
  81. self.coeffs[i - 1] + self.coeffs[i] * a
  82. }
  83. })
  84. .collect();
  85. self.coeffs = newcoeffs;
  86. }
  87. // Multiply self by the constant c
  88. pub fn mult_scalar(&mut self, c: &Scalar) {
  89. for coeff in self.coeffs.iter_mut() {
  90. *coeff *= c;
  91. }
  92. }
  93. // Add another ScalarPoly to this one
  94. pub fn add(&mut self, other: &Self) {
  95. if other.coeffs.len() > self.coeffs.len() {
  96. self.coeffs.resize(other.coeffs.len(), Scalar::zero());
  97. }
  98. for i in 0..other.coeffs.len() {
  99. self.coeffs[i] += other.coeffs[i];
  100. }
  101. }
  102. }
  103. // Compute the Lagrange polynomial for the value at the given x
  104. // coordinate. The x coordinates in the coalition are allowed to
  105. // include x itself, which will be ignored.
  106. pub fn lagrange_poly(coalition: &[u32], x: u32) -> ScalarPoly {
  107. let mut numer = ScalarPoly::one();
  108. let mut denom = Scalar::one();
  109. let xscal = Scalar::from(x);
  110. for &c in coalition {
  111. if c != x {
  112. let cscal = Scalar::from(c);
  113. numer.mult_x_plus_a(&-cscal);
  114. denom *= xscal - cscal;
  115. }
  116. }
  117. numer.mult_scalar(&denom.invert());
  118. numer
  119. }
  120. // Compute the full set of Lagrange polynomials for the given coalition
  121. pub fn lagrange_polys(coalition: &[u32]) -> Vec<ScalarPoly> {
  122. coalition
  123. .iter()
  124. .map(|&x| lagrange_poly(coalition, x))
  125. .collect()
  126. }
  127. #[test]
  128. pub fn test_rand_poly() {
  129. let rpoly = ScalarPoly::rand(3);
  130. println!("randpoly = {:?}", rpoly);
  131. assert!(rpoly.coeffs.len() == 4);
  132. assert!(rpoly.coeffs[0] != Scalar::zero());
  133. assert!(rpoly.coeffs[0] != rpoly.coeffs[1]);
  134. }
  135. #[test]
  136. pub fn test_eval() {
  137. let mut poly = ScalarPoly::one();
  138. poly.mult_x_plus_a(&Scalar::from(2u32));
  139. poly.mult_x_plus_a(&Scalar::from(3u32));
  140. // poly should now be (x+2)*(x+3) = x^2 + 5x + 6
  141. assert!(poly.coeffs.len() == 3);
  142. assert!(poly.coeffs[0] == Scalar::from(6u32));
  143. assert!(poly.coeffs[1] == Scalar::from(5u32));
  144. assert!(poly.coeffs[2] == Scalar::from(1u32));
  145. let f0 = poly.eval(&Scalar::zero());
  146. let f2 = poly.eval(&Scalar::from(2u32));
  147. let f7 = poly.eval(&Scalar::from(7u32));
  148. println!("f0 = {:?}", f0);
  149. println!("f2 = {:?}", f2);
  150. println!("f7 = {:?}", f7);
  151. assert!(f0 == Scalar::from(6u32));
  152. assert!(f2 == Scalar::from(20u32));
  153. assert!(f7 == Scalar::from(90u32));
  154. }
  155. // Check that the sum of the given polys is just x^i
  156. #[cfg(test)]
  157. fn sum_polys_is_x_to_the_i(polys: &[ScalarPoly], i: usize) {
  158. let mut sum = ScalarPoly::zero();
  159. for p in polys.iter() {
  160. sum.add(p);
  161. }
  162. println!("sum = {:?}", sum);
  163. for j in 0..sum.coeffs.len() {
  164. assert!(
  165. sum.coeffs[j]
  166. == if i == j {
  167. Scalar::one()
  168. } else {
  169. Scalar::zero()
  170. }
  171. );
  172. }
  173. }
  174. #[test]
  175. pub fn test_lagrange_polys() {
  176. let coalition: Vec<u32> = vec![1, 2, 5, 8, 12, 14];
  177. let mut polys = lagrange_polys(&coalition);
  178. sum_polys_is_x_to_the_i(&polys, 0);
  179. for i in 0..coalition.len() {
  180. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  181. }
  182. sum_polys_is_x_to_the_i(&polys, 1);
  183. for i in 0..coalition.len() {
  184. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  185. }
  186. sum_polys_is_x_to_the_i(&polys, 2);
  187. }
  188. // Interpolate values at x=0 given the pre-computed Lagrange polynomials
  189. pub fn interpolate_polys_0(lag_polys: &[ScalarPoly], y: &[Scalar]) -> Scalar {
  190. assert!(lag_polys.len() == y.len());
  191. (0..y.len()).map(|i| lag_polys[i].coeffs[0] * y[i]).sum()
  192. }