lagrange.rs 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. use curve25519_dalek::scalar::Scalar;
  2. // Versions that just compute coefficients; these are used if you know
  3. // all of your input points are correct
  4. // Compute the Lagrange coefficient for the value at the given x
  5. // coordinate, to interpolate the value at target_x. The x coordinates
  6. // in the coalition are allowed to include x itself, which will be
  7. // ignored.
  8. pub fn lagrange(coalition: &[u32], x: u32, target_x: u32) -> Scalar {
  9. let mut numer = Scalar::one();
  10. let mut denom = Scalar::one();
  11. let xscal = Scalar::from(x);
  12. let target_xscal = Scalar::from(target_x);
  13. for &c in coalition {
  14. if c != x {
  15. let cscal = Scalar::from(c);
  16. numer *= target_xscal - cscal;
  17. denom *= xscal - cscal;
  18. }
  19. }
  20. numer * denom.invert()
  21. }
  22. // Interpolate the given (x,y) coordinates at the target x value.
  23. // target_x must _not_ be in the x slice
  24. // We never actually use this function, but it's here for completeness
  25. #[allow(dead_code)]
  26. pub fn interpolate(x: &[u32], y: &[Scalar], target_x: u32) -> Scalar {
  27. assert!(x.len() == y.len());
  28. (0..x.len())
  29. .map(|i| lagrange(x, x[i], target_x) * y[i])
  30. .sum()
  31. }
  32. // Versions that compute the entire Lagrange polynomials; these are used
  33. // if need to _check_ that all of your input points are correct.
  34. // A ScalarPoly represents a polynomial whose coefficients are scalars.
  35. // The coeffs vector has length (deg+1), where deg is the degree of the
  36. // polynomial. coeffs[i] is the coefficient on x^i.
  37. #[derive(Clone, Debug)]
  38. pub struct ScalarPoly {
  39. pub coeffs: Vec<Scalar>,
  40. }
  41. impl ScalarPoly {
  42. pub fn zero() -> Self {
  43. Self { coeffs: vec![] }
  44. }
  45. pub fn one() -> Self {
  46. Self {
  47. coeffs: vec![Scalar::one()],
  48. }
  49. }
  50. pub fn rand(degree: usize) -> Self {
  51. let mut rng = rand::thread_rng();
  52. let mut coeffs: Vec<Scalar> = Vec::new();
  53. coeffs.resize_with(degree + 1, || Scalar::random(&mut rng));
  54. Self { coeffs }
  55. }
  56. // Evaluate the polynomial at the given point (using Horner's
  57. // method)
  58. pub fn eval(&self, x: &Scalar) -> Scalar {
  59. let mut res = Scalar::zero();
  60. for coeff in self.coeffs.iter().rev() {
  61. res *= x;
  62. res += coeff;
  63. }
  64. res
  65. }
  66. // Multiply self by the polynomial (x+a), for the given a
  67. pub fn mult_x_plus_a(&mut self, a: &Scalar) {
  68. // The length of coeffs is (deg+1), which is what we want the
  69. // new degree to be
  70. let newdeg = self.coeffs.len();
  71. if newdeg == 0 {
  72. // self is the zero polynomial, so it doesn't change with
  73. // the multiplication by x+a
  74. return;
  75. }
  76. let newcoeffs = (0..newdeg + 1)
  77. .map(|i| {
  78. if i == 0 {
  79. self.coeffs[i] * a
  80. } else if i == newdeg {
  81. self.coeffs[i - 1]
  82. } else {
  83. self.coeffs[i - 1] + self.coeffs[i] * a
  84. }
  85. })
  86. .collect();
  87. self.coeffs = newcoeffs;
  88. }
  89. // Multiply self by the constant c
  90. pub fn mult_scalar(&mut self, c: &Scalar) {
  91. for coeff in self.coeffs.iter_mut() {
  92. *coeff *= c;
  93. }
  94. }
  95. // Add another ScalarPoly to this one
  96. pub fn add(&mut self, other: &Self) {
  97. if other.coeffs.len() > self.coeffs.len() {
  98. self.coeffs.resize(other.coeffs.len(), Scalar::zero());
  99. }
  100. for i in 0..other.coeffs.len() {
  101. self.coeffs[i] += other.coeffs[i];
  102. }
  103. }
  104. }
  105. // Compute the Lagrange polynomial for the value at the given x
  106. // coordinate. The x coordinates in the coalition are allowed to
  107. // include x itself, which will be ignored.
  108. pub fn lagrange_poly(coalition: &[u32], x: u32) -> ScalarPoly {
  109. let mut numer = ScalarPoly::one();
  110. let mut denom = Scalar::one();
  111. let xscal = Scalar::from(x);
  112. for &c in coalition {
  113. if c != x {
  114. let cscal = Scalar::from(c);
  115. numer.mult_x_plus_a(&-cscal);
  116. denom *= xscal - cscal;
  117. }
  118. }
  119. numer.mult_scalar(&denom.invert());
  120. numer
  121. }
  122. // Compute the full set of Lagrange polynomials for the given coalition
  123. pub fn lagrange_polys(coalition: &[u32]) -> Vec<ScalarPoly> {
  124. coalition
  125. .iter()
  126. .map(|&x| lagrange_poly(coalition, x))
  127. .collect()
  128. }
  129. #[test]
  130. pub fn test_rand_poly() {
  131. let rpoly = ScalarPoly::rand(3);
  132. println!("randpoly = {:?}", rpoly);
  133. assert!(rpoly.coeffs.len() == 4);
  134. assert!(rpoly.coeffs[0] != Scalar::zero());
  135. assert!(rpoly.coeffs[0] != rpoly.coeffs[1]);
  136. }
  137. #[test]
  138. pub fn test_eval() {
  139. let mut poly = ScalarPoly::one();
  140. poly.mult_x_plus_a(&Scalar::from(2u32));
  141. poly.mult_x_plus_a(&Scalar::from(3u32));
  142. // poly should now be (x+2)*(x+3) = x^2 + 5x + 6
  143. assert!(poly.coeffs.len() == 3);
  144. assert!(poly.coeffs[0] == Scalar::from(6u32));
  145. assert!(poly.coeffs[1] == Scalar::from(5u32));
  146. assert!(poly.coeffs[2] == Scalar::from(1u32));
  147. let f0 = poly.eval(&Scalar::zero());
  148. let f2 = poly.eval(&Scalar::from(2u32));
  149. let f7 = poly.eval(&Scalar::from(7u32));
  150. println!("f0 = {:?}", f0);
  151. println!("f2 = {:?}", f2);
  152. println!("f7 = {:?}", f7);
  153. assert!(f0 == Scalar::from(6u32));
  154. assert!(f2 == Scalar::from(20u32));
  155. assert!(f7 == Scalar::from(90u32));
  156. }
  157. // Check that the sum of the given polys is just x^i
  158. #[cfg(test)]
  159. fn sum_polys_is_x_to_the_i(polys: &[ScalarPoly], i: usize) {
  160. let mut sum = ScalarPoly::zero();
  161. for p in polys.iter() {
  162. sum.add(p);
  163. }
  164. println!("sum = {:?}", sum);
  165. for j in 0..sum.coeffs.len() {
  166. assert!(
  167. sum.coeffs[j]
  168. == if i == j {
  169. Scalar::one()
  170. } else {
  171. Scalar::zero()
  172. }
  173. );
  174. }
  175. }
  176. #[test]
  177. pub fn test_lagrange_polys() {
  178. let coalition: Vec<u32> = vec![1, 2, 5, 8, 12, 14];
  179. let mut polys = lagrange_polys(&coalition);
  180. sum_polys_is_x_to_the_i(&polys, 0);
  181. for i in 0..coalition.len() {
  182. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  183. }
  184. sum_polys_is_x_to_the_i(&polys, 1);
  185. for i in 0..coalition.len() {
  186. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  187. }
  188. sum_polys_is_x_to_the_i(&polys, 2);
  189. }
  190. // Interpolate values at x=0 given the pre-computed Lagrange polynomials
  191. pub fn interpolate_polys_0(lag_polys: &[ScalarPoly], y: &[Scalar]) -> Scalar {
  192. assert!(lag_polys.len() == y.len());
  193. (0..y.len()).map(|i| lag_polys[i].coeffs[0] * y[i]).sum()
  194. }