rangeproof.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. //! A module to transform range statements about `Scalar`s into
  2. //! statements about linear combinations of `Point`s.
  3. //!
  4. //! A range statement looks like `(a..b).contains(x-8)`, where `a` and
  5. //! `b` are expressions involving only _public_ `Scalar`s and constants
  6. //! and `x-8` is a private `Scalar`, possibly offset by a public
  7. //! `Scalar` or constant. At this time, none of the variables can be
  8. //! vector variables.
  9. //!
  10. //! As usual for Rust notation, the range `a..b` includes `a` but
  11. //! _excludes_ `b`. You can also write `a..=b` to include both
  12. //! endpoints. It is allowed for the range to "wrap around" 0, so
  13. //! that `L-50..100` is a valid range, and equivalent to `-50..100`,
  14. //! where `L` is the order of the group you are using.
  15. //!
  16. //! The size of the range (`b-a`) will be known at run time, but not
  17. //! necessarily at compile time. The size must fit in an [`i128`] and
  18. //! must be strictly greater than 1. Note that the range (and its size)
  19. //! are public, but the value you are stating is in the range will be
  20. //! private.
  21. use super::codegen::CodeGen;
  22. use super::pedersen::{recognize_linscalar, recognize_pubscalar, LinScalar};
  23. use super::sigma::combiners::*;
  24. use super::sigma::types::VarDict;
  25. use super::syntax::taggedvardict_to_vardict;
  26. use super::transform::paren_if_needed;
  27. use super::TaggedVarDict;
  28. use syn::{parse_quote, Expr, Result};
  29. /// A struct representing a normalized parsed range statement.
  30. ///
  31. /// Here, "normalized" means that the range is adjusted so that the
  32. /// lower bound is 0. This is accomplished by subtracting the stated
  33. /// lower bound from both the upper bound and the expression that is
  34. /// being asserting that it is in the range.
  35. #[derive(Clone, Debug, PartialEq, Eq)]
  36. struct RangeStatement {
  37. /// The upper bound of the range (exclusive). This must evaluate to
  38. /// a public Scalar.
  39. upper: Expr,
  40. /// The expression that is being asserted that it is in the range.
  41. /// This must be a [`LinScalar`]
  42. expr: LinScalar,
  43. }
  44. /// Subtract the Expr `lower` (with constant value `lowerval`, if
  45. /// present) from the Expr `expr` (with constant value `exprval`, if
  46. /// present). Return the resulting expression, as well as its constant
  47. /// value, if there is one. Do the subtraction numerically if possible,
  48. /// but otherwise symbolically.
  49. fn subtract_expr(
  50. expr: Option<&Expr>,
  51. exprval: Option<i128>,
  52. lower: &Expr,
  53. lowerval: Option<i128>,
  54. ) -> (Expr, Option<i128>) {
  55. // Note that if expr is None, then exprval is Some(0)
  56. if let (Some(ev), Some(lv)) = (exprval, lowerval) {
  57. if let Some(diffv) = ev.checked_sub(lv) {
  58. // We can do the subtraction numerically
  59. return (parse_quote! { #diffv }, Some(diffv));
  60. }
  61. }
  62. let paren_lower = paren_if_needed(lower.clone());
  63. // Return the difference symbolically
  64. (
  65. if let Some(e) = expr {
  66. parse_quote! { #e - #paren_lower }
  67. } else {
  68. parse_quote! { -#paren_lower }
  69. },
  70. None,
  71. )
  72. }
  73. /// Try to parse the given `Expr` as a range statement
  74. fn parse(vars: &TaggedVarDict, vardict: &VarDict, expr: &Expr) -> Option<RangeStatement> {
  75. // The expression needs to be of the form
  76. // (lower..upper).contains(expr)
  77. // The "top level" must be the method call ".contains"
  78. if let Expr::MethodCall(syn::ExprMethodCall {
  79. receiver,
  80. method,
  81. turbofish: None,
  82. args,
  83. ..
  84. }) = expr
  85. {
  86. if &method.to_string() != "contains" {
  87. // Wasn't ".contains"
  88. return None;
  89. }
  90. // Remove parens around the range, if present
  91. let mut range_expr = receiver.as_ref();
  92. if let Expr::Paren(syn::ExprParen {
  93. expr: parened_expr, ..
  94. }) = range_expr
  95. {
  96. range_expr = parened_expr;
  97. }
  98. // Parse the range
  99. if let Expr::Range(syn::ExprRange {
  100. start, limits, end, ..
  101. }) = range_expr
  102. {
  103. // The endpoints of the range need to be non-vector public
  104. // Scalar expressions
  105. // The first as_ref() turns &Option<Box<Expr>> into
  106. // Option<&Box<Expr>>. The ? removes the Option, and the
  107. // second as_ref() turns &Box<Expr> into &Expr.
  108. let lower = start.as_ref()?.as_ref().clone();
  109. let mut upper = end.as_ref()?.as_ref().clone();
  110. let Some((false, lowerval)) = recognize_pubscalar(vars, vardict, &lower) else {
  111. return None;
  112. };
  113. let Some((false, mut upperval)) = recognize_pubscalar(vars, vardict, &upper) else {
  114. return None;
  115. };
  116. let inclusive_upper = matches!(limits, syn::RangeLimits::Closed(_));
  117. // There needs to be exactly one argument of .contains()
  118. if args.len() != 1 {
  119. return None;
  120. }
  121. // The private expression needs to be a LinScalar
  122. let priv_expr = args.first().unwrap();
  123. let mut linscalar = recognize_linscalar(vars, vardict, priv_expr)?;
  124. // It is. See if the pub_scalar_expr in the LinScalar has a
  125. // constant value
  126. let linscalar_pubscalar_val = if let Some(ref pse) = linscalar.pub_scalar_expr {
  127. let Some((false, pubscalar_val)) = recognize_pubscalar(vars, vardict, pse) else {
  128. return None;
  129. };
  130. pubscalar_val
  131. } else {
  132. Some(0)
  133. };
  134. // We have a valid range statement. Normalize it by forcing
  135. // the upper bound to be exclusive, and the lower bound to
  136. // be 0.
  137. // If the range was inclusive of the upper bound (e.g.,
  138. // `0..=100`), add 1 to the upper bound to make it exclusive
  139. // (e.g., `0..101`).
  140. if inclusive_upper {
  141. // Add 1 to the upper bound, numerically if possible,
  142. // but otherwise symbolically
  143. let mut added_numerically = false;
  144. if let Some(uv) = upperval {
  145. if let Some(new_uv) = uv.checked_add(1) {
  146. upper = parse_quote! { #new_uv };
  147. upperval = Some(new_uv);
  148. added_numerically = true;
  149. }
  150. }
  151. if !added_numerically {
  152. upper = parse_quote! { #upper + 1 };
  153. upperval = None;
  154. }
  155. }
  156. // If the lower bound is not 0, subtract it from both the
  157. // upper bound and the pubscalar_expr in the LinScalar. Do
  158. // this numericaly if possibly, but otherwise symbolically.
  159. if lowerval != Some(0) {
  160. (upper, _) = subtract_expr(Some(&upper), upperval, &lower, lowerval);
  161. let pubscalar_expr;
  162. (pubscalar_expr, _) = subtract_expr(
  163. linscalar.pub_scalar_expr.as_ref(),
  164. linscalar_pubscalar_val,
  165. &lower,
  166. lowerval,
  167. );
  168. linscalar.pub_scalar_expr = Some(pubscalar_expr);
  169. }
  170. return Some(RangeStatement {
  171. upper,
  172. expr: linscalar,
  173. });
  174. }
  175. }
  176. None
  177. }
  178. /// Look for, and transform, range statements specified in the
  179. /// [`StatementTree`] into basic statements about linear combinations of
  180. /// `Point`s.
  181. pub fn transform(
  182. codegen: &mut CodeGen,
  183. st: &mut StatementTree,
  184. vars: &mut TaggedVarDict,
  185. ) -> Result<()> {
  186. // Make the VarDict version of the variable dictionary
  187. let vardict = taggedvardict_to_vardict(vars);
  188. // Gather mutable references to all Exprs in the leaves of the
  189. // StatementTree. Note that this ignores the combiner structure in
  190. // the StatementTree, but that's fine.
  191. let mut leaves = st.leaves_mut();
  192. // For each leaf expression, see if it looks like a range statement
  193. for leafexpr in leaves.iter_mut() {
  194. let is_range = parse(vars, &vardict, leafexpr);
  195. }
  196. Ok(())
  197. }
  198. #[cfg(test)]
  199. mod tests {
  200. use super::super::syntax::taggedvardict_from_strs;
  201. use super::*;
  202. fn parse_tester(vars: (&[&str], &[&str]), expr: Expr, expect: Option<RangeStatement>) {
  203. let taggedvardict = taggedvardict_from_strs(vars);
  204. let vardict = taggedvardict_to_vardict(&taggedvardict);
  205. let output = parse(&taggedvardict, &vardict, &expr);
  206. assert_eq!(output, expect);
  207. }
  208. #[test]
  209. fn parse_test() {
  210. let vars = (
  211. [
  212. "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
  213. ]
  214. .as_slice(),
  215. ["C", "cind A", "cind B"].as_slice(),
  216. );
  217. parse_tester(
  218. vars,
  219. parse_quote! {
  220. (0..100).contains(x)
  221. },
  222. Some(RangeStatement {
  223. upper: parse_quote! { 100 },
  224. expr: LinScalar {
  225. coeff: 1,
  226. pub_scalar_expr: None,
  227. id: parse_quote! {x},
  228. is_vec: false,
  229. },
  230. }),
  231. );
  232. parse_tester(
  233. vars,
  234. parse_quote! {
  235. (0..=100).contains(x)
  236. },
  237. Some(RangeStatement {
  238. upper: parse_quote! { 101i128 },
  239. expr: LinScalar {
  240. coeff: 1,
  241. pub_scalar_expr: None,
  242. id: parse_quote! {x},
  243. is_vec: false,
  244. },
  245. }),
  246. );
  247. parse_tester(
  248. vars,
  249. parse_quote! {
  250. (-12..100).contains(x)
  251. },
  252. Some(RangeStatement {
  253. upper: parse_quote! { 112i128 },
  254. expr: LinScalar {
  255. coeff: 1,
  256. pub_scalar_expr: Some(parse_quote! { 12i128 }),
  257. id: parse_quote! {x},
  258. is_vec: false,
  259. },
  260. }),
  261. );
  262. parse_tester(
  263. vars,
  264. parse_quote! {
  265. (-12..(1<<20)).contains(x)
  266. },
  267. Some(RangeStatement {
  268. upper: parse_quote! { 1048588i128 },
  269. expr: LinScalar {
  270. coeff: 1,
  271. pub_scalar_expr: Some(parse_quote! { 12i128 }),
  272. id: parse_quote! {x},
  273. is_vec: false,
  274. },
  275. }),
  276. );
  277. parse_tester(
  278. vars,
  279. parse_quote! {
  280. (12..(1<<20)).contains(x+7)
  281. },
  282. Some(RangeStatement {
  283. upper: parse_quote! { 1048564i128 },
  284. expr: LinScalar {
  285. coeff: 1,
  286. pub_scalar_expr: Some(parse_quote! { -5i128 }),
  287. id: parse_quote! {x},
  288. is_vec: false,
  289. },
  290. }),
  291. );
  292. parse_tester(
  293. vars,
  294. parse_quote! {
  295. (12..(1<<20)).contains(2*x+7)
  296. },
  297. Some(RangeStatement {
  298. upper: parse_quote! { 1048564i128 },
  299. expr: LinScalar {
  300. coeff: 2,
  301. pub_scalar_expr: Some(parse_quote! { -5i128 }),
  302. id: parse_quote! {x},
  303. is_vec: false,
  304. },
  305. }),
  306. );
  307. parse_tester(
  308. vars,
  309. parse_quote! {
  310. (-1..(((1<<126)-1)*2)).contains(x)
  311. },
  312. Some(RangeStatement {
  313. upper: parse_quote! { 170141183460469231731687303715884105727i128 },
  314. expr: LinScalar {
  315. coeff: 1,
  316. pub_scalar_expr: Some(parse_quote! { 1i128 }),
  317. id: parse_quote! {x},
  318. is_vec: false,
  319. },
  320. }),
  321. );
  322. parse_tester(
  323. vars,
  324. parse_quote! {
  325. (-2..(((1<<126)-1)*2)).contains(x)
  326. },
  327. Some(RangeStatement {
  328. upper: parse_quote! { (((1<<126)-1)*2)-(-2) },
  329. expr: LinScalar {
  330. coeff: 1,
  331. pub_scalar_expr: Some(parse_quote! { 2i128 }),
  332. id: parse_quote! {x},
  333. is_vec: false,
  334. },
  335. }),
  336. );
  337. parse_tester(
  338. vars,
  339. parse_quote! {
  340. (a*b..b+c*c+7).contains(3*x+c*(a+b+2))
  341. },
  342. Some(RangeStatement {
  343. upper: parse_quote! { b+c*c+7-(a*b) },
  344. expr: LinScalar {
  345. coeff: 3,
  346. pub_scalar_expr: Some(parse_quote! { c*(a+b+2i128)-(a*b) }),
  347. id: parse_quote! {x},
  348. is_vec: false,
  349. },
  350. }),
  351. );
  352. }
  353. }