lagrange.rs 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. use curve25519_dalek::scalar::Scalar;
  2. // Versions that just compute coefficients; these are used if you know
  3. // all of your input points are correct
  4. // Compute the Lagrange coefficient for the value at the given x
  5. // coordinate, to interpolate the value at target_x. The x coordinates
  6. // in the coalition are allowed to include x itself, which will be
  7. // ignored.
  8. pub fn lagrange(coalition: &[u32], x: u32, target_x: u32) -> Scalar {
  9. let mut numer = Scalar::one();
  10. let mut denom = Scalar::one();
  11. let xscal = Scalar::from(x);
  12. let target_xscal = Scalar::from(target_x);
  13. for &c in coalition {
  14. if c != x {
  15. let cscal = Scalar::from(c);
  16. numer *= target_xscal - cscal;
  17. denom *= xscal - cscal;
  18. }
  19. }
  20. numer * denom.invert()
  21. }
  22. // Interpolate the given (x,y) coordinates at the target x value.
  23. // target_x must _not_ be in the x slice
  24. pub fn interpolate(x: &[u32], y: &[Scalar], target_x: u32) -> Scalar {
  25. assert!(x.len() == y.len());
  26. let mut res = Scalar::zero();
  27. for i in 0..x.len() {
  28. let lag_coeff = lagrange(x, x[i], target_x);
  29. res += lag_coeff * y[i];
  30. }
  31. res
  32. }
  33. // Versions that compute the entire Lagrange polynomials; these are used
  34. // if need to _check_ that all of your input points are correct.
  35. // A ScalarPoly represents a polynomial whose coefficients are scalars.
  36. // The coeffs vector has length (deg+1), where deg is the degree of the
  37. // polynomial. coeffs[i] is the coefficient on x^i.
  38. #[derive(Clone, Debug)]
  39. pub struct ScalarPoly {
  40. pub coeffs: Vec<Scalar>,
  41. }
  42. impl ScalarPoly {
  43. pub fn zero() -> Self {
  44. Self { coeffs: vec![] }
  45. }
  46. pub fn one() -> Self {
  47. Self {
  48. coeffs: vec![Scalar::one()],
  49. }
  50. }
  51. // Multiply self by the polynomial (x+a), for the given a
  52. pub fn mult_x_plus_a(&mut self, a: &Scalar) {
  53. // The length of coeffs is (deg+1), which is what we want the
  54. // new degree to be
  55. let newdeg = self.coeffs.len();
  56. if newdeg == 0 {
  57. // self is the zero polynomial, so it doesn't change with
  58. // the multiplication by x+a
  59. return;
  60. }
  61. let newcoeffs = (0..newdeg + 1)
  62. .map(|i| {
  63. if i == 0 {
  64. self.coeffs[i] * a
  65. } else if i == newdeg {
  66. self.coeffs[i - 1]
  67. } else {
  68. self.coeffs[i - 1] + self.coeffs[i] * a
  69. }
  70. })
  71. .collect();
  72. self.coeffs = newcoeffs;
  73. }
  74. // Multiply self by the constant c
  75. pub fn mult_scalar(&mut self, c: &Scalar) {
  76. for coeff in self.coeffs.iter_mut() {
  77. *coeff *= c;
  78. }
  79. }
  80. // Add another ScalarPoly to this one
  81. pub fn add(&mut self, other: &Self) {
  82. if other.coeffs.len() > self.coeffs.len() {
  83. self.coeffs.resize(other.coeffs.len(), Scalar::zero());
  84. }
  85. for i in 0..other.coeffs.len() {
  86. self.coeffs[i] += other.coeffs[i];
  87. }
  88. }
  89. }
  90. // Compute the Lagrange polynomial for the value at the given x
  91. // coordinate. The x coordinates in the coalition are allowed to
  92. // include x itself, which will be ignored.
  93. pub fn lagrange_poly(coalition: &[u32], x: u32) -> ScalarPoly {
  94. let mut numer = ScalarPoly::one();
  95. let mut denom = Scalar::one();
  96. let xscal = Scalar::from(x);
  97. for &c in coalition {
  98. if c != x {
  99. let cscal = Scalar::from(c);
  100. numer.mult_x_plus_a(&-cscal);
  101. denom *= xscal - cscal;
  102. }
  103. }
  104. numer.mult_scalar(&denom.invert());
  105. numer
  106. }
  107. // Compute the full set of Lagrange polynomials for the given coalition
  108. pub fn lagrange_polys(coalition: &[u32]) -> Vec<ScalarPoly> {
  109. coalition
  110. .iter()
  111. .map(|&x| lagrange_poly(coalition, x))
  112. .collect()
  113. }
  114. // Check that the sum of the given polys is just x^i
  115. #[cfg(test)]
  116. fn sum_polys_is_x_to_the_i(polys: &Vec<ScalarPoly>, i: usize) {
  117. let mut sum = ScalarPoly::zero();
  118. for p in polys.iter() {
  119. sum.add(p);
  120. }
  121. println!("sum = {:?}", sum);
  122. for j in 0..sum.coeffs.len() {
  123. assert!(
  124. sum.coeffs[j]
  125. == if i == j {
  126. Scalar::one()
  127. } else {
  128. Scalar::zero()
  129. }
  130. );
  131. }
  132. }
  133. #[test]
  134. pub fn test_lagrange_polys() {
  135. let coalition: Vec<u32> = vec![1, 2, 5, 8, 12, 14];
  136. let mut polys = lagrange_polys(&coalition);
  137. sum_polys_is_x_to_the_i(&polys, 0);
  138. for i in 0..coalition.len() {
  139. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  140. }
  141. sum_polys_is_x_to_the_i(&polys, 1);
  142. for i in 0..coalition.len() {
  143. polys[i].mult_scalar(&Scalar::from(coalition[i]));
  144. }
  145. sum_polys_is_x_to_the_i(&polys, 2);
  146. }