lagrange.rs 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. use curve25519_dalek::scalar::Scalar;
  2. // Versions that just compute coefficients; these are used if you know
  3. // all of your input points are correct
  4. // Compute the Lagrange coefficient for the value at the given x
  5. // coordinate, to interpolate the value at target_x. The x coordinates
  6. // in the coalition are allowed to include x itself, which will be
  7. // ignored.
  8. pub fn lagrange(coalition: &[u32], x: u32, target_x: u32) -> Scalar {
  9. let mut numer = Scalar::one();
  10. let mut denom = Scalar::one();
  11. let xscal = Scalar::from(x);
  12. let target_xscal = Scalar::from(target_x);
  13. for &c in coalition {
  14. if c != x {
  15. let cscal = Scalar::from(c);
  16. numer *= target_xscal - cscal;
  17. denom *= xscal - cscal;
  18. }
  19. }
  20. numer * denom.invert()
  21. }
  22. // Interpolate the given (x,y) coordinates at the target x value.
  23. // target_x must _not_ be in the x slice
  24. pub fn interpolate(x: &[u32], y: &[Scalar], target_x: u32) -> Scalar {
  25. assert!(x.len() == y.len());
  26. let mut res = Scalar::zero();
  27. for i in 0..x.len() {
  28. let lag_coeff = lagrange(x, x[i], target_x);
  29. res += lag_coeff * y[i];
  30. }
  31. res
  32. }
  33. // Versions that compute the entire Lagrange polynomials; these are used
  34. // if need to _check_ that all of your input points are correct.
  35. // A ScalarPoly represents a polynomial whose coefficients are scalars.
  36. // The coeffs vector has length (deg+1), where deg is the degree of the
  37. // polynomial. coeffs[i] is the coefficient on x^i.
  38. #[derive(Clone, Debug)]
  39. pub struct ScalarPoly {
  40. pub coeffs: Vec<Scalar>,
  41. }
  42. impl ScalarPoly {
  43. pub fn zero() -> Self {
  44. Self { coeffs: vec![] }
  45. }
  46. pub fn one() -> Self {
  47. Self {
  48. coeffs: vec![Scalar::one()],
  49. }
  50. }
  51. pub fn rand(degree: usize) -> Self {
  52. let mut rng = rand::thread_rng();
  53. let mut coeffs: Vec<Scalar> = Vec::new();
  54. coeffs.resize_with(degree + 1, || Scalar::random(&mut rng));
  55. Self { coeffs }
  56. }
  57. // Evaluate the polynomial at the given point (using Horner's
  58. // method)
  59. pub fn eval(&self, x: &Scalar) -> Scalar {
  60. let mut res = Scalar::zero();
  61. for coeff in self.coeffs.iter().rev() {
  62. res *= x;
  63. res += coeff;
  64. }
  65. res
  66. }
  67. // Multiply self by the polynomial (x+a), for the given a
  68. pub fn mult_x_plus_a(&mut self, a: &Scalar) {
  69. // The length of coeffs is (deg+1), which is what we want the
  70. // new degree to be
  71. let newdeg = self.coeffs.len();
  72. if newdeg == 0 {
  73. // self is the zero polynomial, so it doesn't change with
  74. // the multiplication by x+a
  75. return;
  76. }
  77. let newcoeffs = (0..newdeg + 1)
  78. .map(|i| {
  79. if i == 0 {
  80. self.coeffs[i] * a
  81. } else if i == newdeg {
  82. self.coeffs[i - 1]
  83. } else {
  84. self.coeffs[i - 1] + self.coeffs[i] * a
  85. }
  86. })
  87. .collect();
  88. self.coeffs = newcoeffs;
  89. }
  90. // Multiply self by the constant c
  91. pub fn mult_scalar(&mut self, c: &Scalar) {
  92. for coeff in self.coeffs.iter_mut() {
  93. *coeff *= c;
  94. }
  95. }
  96. // Add another ScalarPoly to this one
  97. pub fn add(&mut self, other: &Self) {
  98. if other.coeffs.len() > self.coeffs.len() {
  99. self.coeffs.resize(other.coeffs.len(), Scalar::zero());
  100. }
  101. for i in 0..other.coeffs.len() {
  102. self.coeffs[i] += other.coeffs[i];
  103. }
  104. }
  105. }
  106. // Compute the Lagrange polynomial for the value at the given x
  107. // coordinate. The x coordinates in the coalition are allowed to
  108. // include x itself, which will be ignored.
  109. pub fn lagrange_poly(coalition: &[u32], x: u32) -> ScalarPoly {
  110. let mut numer = ScalarPoly::one();
  111. let mut denom = Scalar::one();
  112. let xscal = Scalar::from(x);
  113. for &c in coalition {
  114. if c != x {
  115. let cscal = Scalar::from(c);
  116. numer.mult_x_plus_a(&-cscal);
  117. denom *= xscal - cscal;
  118. }
  119. }
  120. numer.mult_scalar(&denom.invert());
  121. numer
  122. }
  123. // Compute the full set of Lagrange polynomials for the given coalition
  124. pub fn lagrange_polys(coalition: &[u32]) -> Vec<ScalarPoly> {
  125. coalition
  126. .iter()
  127. .map(|&x| lagrange_poly(coalition, x))
  128. .collect()
  129. }
  130. #[test]
  131. pub fn test_rand_poly() {
  132. let rpoly = ScalarPoly::rand(3);
  133. println!("randpoly = {:?}", rpoly);
  134. assert!(rpoly.coeffs.len() == 4);
  135. assert!(rpoly.coeffs[0] != Scalar::zero());
  136. assert!(rpoly.coeffs[0] != rpoly.coeffs[1]);
  137. }
  138. #[test]
  139. pub fn test_eval() {
  140. let mut poly = ScalarPoly::one();
  141. poly.mult_x_plus_a(&Scalar::from(2u32));
  142. poly.mult_x_plus_a(&Scalar::from(3u32));
  143. // poly should now be (x+2)*(x+3) = x^2 + 5x + 6
  144. assert!(poly.coeffs.len() == 3);
  145. assert!(poly.coeffs[0] == Scalar::from(6u32));
  146. assert!(poly.coeffs[1] == Scalar::from(5u32));
  147. assert!(poly.coeffs[2] == Scalar::from(1u32));
  148. let f0 = poly.eval(&Scalar::zero());
  149. let f2 = poly.eval(&Scalar::from(2u32));
  150. let f7 = poly.eval(&Scalar::from(7u32));
  151. println!("f0 = {:?}", f0);
  152. println!("f2 = {:?}", f2);
  153. println!("f7 = {:?}", f7);
  154. assert!(f0 == Scalar::from(6u32));
  155. assert!(f2 == Scalar::from(20u32));
  156. assert!(f7 == Scalar::from(90u32));
  157. }
  158. // Check that the sum of the given polys is just x^i
  159. #[cfg(test)]
  160. fn sum_polys_is_x_to_the_i(polys: &[ScalarPoly], i: usize) {
  161. let mut sum = ScalarPoly::zero();
  162. for p in polys.iter() {
  163. sum.add(p);
  164. }
  165. println!("sum = {:?}", sum);
  166. for j in 0..sum.coeffs.len() {
  167. assert!(
  168. sum.coeffs[j]
  169. == if i == j {
  170. Scalar::one()
  171. } else {
  172. Scalar::zero()
  173. }
  174. );
  175. }
  176. }
  177. #[test]
  178. pub fn test_lagrange_polys() {
  179. let coalition: Vec<u32> = vec![1, 2, 5, 8, 12, 14];
  180. let mut polys = lagrange_polys(&coalition);
  181. sum_polys_is_x_to_the_i(&polys, 0);
  182. for i in 0..coalition.len() {
  183. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  184. }
  185. sum_polys_is_x_to_the_i(&polys, 1);
  186. for i in 0..coalition.len() {
  187. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  188. }
  189. sum_polys_is_x_to_the_i(&polys, 2);
  190. }
  191. // Interpolate values at x=0 given the pre-computed Lagrange polynomials
  192. pub fn interpolate_polys_0(lag_polys: &[ScalarPoly], y: &[Scalar]) -> Scalar {
  193. assert!(lag_polys.len() == y.len());
  194. let mut res = Scalar::zero();
  195. for i in 0..y.len() {
  196. res += lag_polys[i].coeffs[0] * y[i];
  197. }
  198. res
  199. }