notequals.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. //! A module to transform not-equals statements about a private
  2. //! `Scalar` into statements about linear combinations of `Point`s.
  3. //!
  4. //! A non-equals statement looks like `x - 8 != a`, where `x` is a
  5. //! private `Scalar`, `x-8` optionally offsets that private `Scalar` by
  6. //! a public `Scalar` or constant, and `a` is a public `Scalar` or
  7. //! constant (or an [arithmetic expression] that evaluates to a public
  8. //! `Scalar`).
  9. //!
  10. //! [arithmetic expression]: super::sigma::types::expr_type
  11. use super::codegen::CodeGen;
  12. use super::pedersen::{
  13. convert_commitment, convert_randomness, recognize_linscalar, recognize_pedersen_assignment,
  14. recognize_pubscalar, unique_random_scalars, LinScalar, PedersenAssignment,
  15. };
  16. use super::sigma::combiners::*;
  17. use super::sigma::types::{expr_type_tokens, VarDict};
  18. use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
  19. use super::transform::paren_if_needed;
  20. use super::TaggedVarDict;
  21. use quote::{format_ident, quote};
  22. use std::collections::HashMap;
  23. use syn::{parse_quote, Error, Expr, Ident, Result};
  24. /// Subtract the Expr `subexpr` (with constant value `subval`, if
  25. /// present) from the `LinScalar` `linscalar`. Return the resulting
  26. /// `LinScalar`.
  27. fn subtract_expr(linscalar: LinScalar, subexpr: &Expr, subval: Option<i128>) -> LinScalar {
  28. if subval != Some(0) {
  29. let paren_sub = paren_if_needed(subexpr.clone());
  30. if let Some(expr) = linscalar.pub_scalar_expr {
  31. return LinScalar {
  32. pub_scalar_expr: Some(parse_quote! {
  33. #expr - #paren_sub
  34. }),
  35. ..linscalar
  36. };
  37. } else {
  38. return LinScalar {
  39. pub_scalar_expr: Some(parse_quote! {
  40. -#paren_sub
  41. }),
  42. ..linscalar
  43. };
  44. }
  45. }
  46. linscalar
  47. }
  48. /// Try to parse the given `Expr` as a not-equals statement. The
  49. /// resulting `LinScalar` is the left side minus the right side.
  50. fn parse(vars: &TaggedVarDict, vardict: &VarDict, expr: &Expr) -> Option<LinScalar> {
  51. let Expr::Binary(syn::ExprBinary {
  52. left,
  53. op: syn::BinOp::Ne(_),
  54. right,
  55. ..
  56. }) = expr
  57. else {
  58. return None;
  59. };
  60. let linscalar = recognize_linscalar(vars, vardict, left)?;
  61. let (subexpr_is_vec, subval) = recognize_pubscalar(vars, vardict, right)?;
  62. // We don't support vector variables
  63. if linscalar.is_vec || subexpr_is_vec {
  64. return None;
  65. }
  66. Some(subtract_expr(linscalar, right, subval))
  67. }
  68. /// Look for, and transform, not-equals statements specified in the
  69. /// [`StatementTree`] into basic statements about linear combinations of
  70. /// `Point`s.
  71. #[allow(non_snake_case)] // so that Points can be capital letters
  72. pub fn transform(
  73. codegen: &mut CodeGen,
  74. st: &mut StatementTree,
  75. vars: &mut TaggedVarDict,
  76. ) -> Result<()> {
  77. // Make the VarDict version of the variable dictionary
  78. let mut vardict = taggedvardict_to_vardict(vars);
  79. // A HashSet of the unique random Scalars in the macro input
  80. let mut randoms = unique_random_scalars(vars, st);
  81. // Gather mutable references to all of the leaves of the
  82. // StatementTree. Note that this ignores the combiner structure in
  83. // the StatementTree, but that's fine.
  84. let mut leaves = st.leaves_st_mut();
  85. // A list of the computationally independent (non-vector) Points in
  86. // the macro input. There must be at least two of them in order to
  87. // handle not-equals statements, so that we can make Pedersen
  88. // commitments.
  89. let cind_points = collect_cind_points(vars);
  90. // Find any statements that look like Pedersen commitments in the
  91. // StatementTree, and make a HashMap mapping the committed private
  92. // variable to the parsed commitment.
  93. let pedersens: HashMap<Ident, PedersenAssignment> = leaves
  94. .iter()
  95. .filter_map(|leaf| {
  96. // See if we recognize this leaf expression as a
  97. // PedersenAssignment, and if so, make a pair mapping its
  98. // variable to the PedersenAssignment. (The "collect()"
  99. // will turn the list of pairs into a HashMap.)
  100. if let StatementTree::Leaf(leafexpr) = leaf {
  101. recognize_pedersen_assignment(vars, &randoms, &vardict, leafexpr)
  102. .map(|ped_assign| (ped_assign.var(), ped_assign))
  103. } else {
  104. None
  105. }
  106. })
  107. .collect();
  108. // Count how many not-equals statements we've seen
  109. let mut neq_stmt_index = 0usize;
  110. // The generated variable name for the rng
  111. let rng_var = codegen.gen_ident(&format_ident!("rng"));
  112. for leaf in leaves.iter_mut() {
  113. // For each leaf expression, see if it looks like a not-equals statement
  114. let StatementTree::Leaf(leafexpr) = leaf else {
  115. continue;
  116. };
  117. let Some(neq_linscalar) = parse(vars, &vardict, leafexpr) else {
  118. continue;
  119. };
  120. neq_stmt_index += 1;
  121. // We will transform the not-equals statement into a list of
  122. // basic linear combination statements that will be ANDed
  123. // together to replace the not-equals statement in the
  124. // StatementTree. This vector holds the list of basic
  125. // statements.
  126. let mut basic_statements: Vec<Expr> = Vec::new();
  127. // We'll need a Pedersen commitment to the variable in the
  128. // not-equals statement. See if there already is one.
  129. let neq_id = &neq_linscalar.id;
  130. let ped_assign = if let Some(ped_assign) = pedersens.get(neq_id) {
  131. ped_assign.clone()
  132. } else {
  133. // We'll need to create a new one. First find two
  134. // computationally independent Points.
  135. if cind_points.len() < 2 {
  136. return Err(Error::new(
  137. proc_macro2::Span::call_site(),
  138. "At least two cind Points must be declared to support not-equals statements",
  139. ));
  140. }
  141. let cind_A = &cind_points[0];
  142. let cind_B = &cind_points[1];
  143. // Create new variables for the Pedersen commitment and its
  144. // random Scalar.
  145. let commitment_var = codegen.gen_point(
  146. vars,
  147. &format_ident!("neq{}_{}_genC", neq_stmt_index, neq_id),
  148. false, // is_vec
  149. true, // send_to_verifier
  150. );
  151. let rand_var = codegen.gen_scalar(
  152. vars,
  153. &format_ident!("neq{}_{}_genr", neq_stmt_index, neq_id),
  154. true, // is_rand
  155. false, // is_vec
  156. );
  157. // Update vardict and randoms with the new vars
  158. vardict = taggedvardict_to_vardict(vars);
  159. randoms.insert(rand_var.to_string());
  160. let ped_assign_expr: Expr = parse_quote! {
  161. #commitment_var = #neq_id * #cind_A + #rand_var * #cind_B
  162. };
  163. let ped_assign =
  164. recognize_pedersen_assignment(vars, &randoms, &vardict, &ped_assign_expr).unwrap();
  165. codegen.prove_append(quote! {
  166. let #rand_var = Scalar::random(#rng_var);
  167. let #ped_assign_expr;
  168. });
  169. basic_statements.push(ped_assign_expr);
  170. ped_assign
  171. };
  172. // At this point, we have a Pedersen commitment for some linear
  173. // function of neq_id (given by
  174. // ped_assign.pedersen.var_term.coeff), using some linear
  175. // function of rand_var (given by
  176. // ped_assign.pedersen.rand_term.coeff) as the randomness. But
  177. // what we need is a Pedersen commitment for a possibly
  178. // different linear function of neq_id (given by
  179. // neq_linscalar). So we output runtime code for both the
  180. // prover and the verifier that converts the commitment, and
  181. // code for just the prover that converts the randomness.
  182. // Make a new runtime variable to hold the converted commitment
  183. let commitment_var = codegen.gen_point(
  184. vars,
  185. &format_ident!("neq{}_{}_C", neq_stmt_index, neq_id),
  186. false, // is_vec
  187. false, // send_to_verifier
  188. );
  189. let rand_var = codegen.gen_ident(&format_ident!("neq{}_{}_r", neq_stmt_index, neq_id));
  190. // Update vardict and randoms with the new vars
  191. vardict = taggedvardict_to_vardict(vars);
  192. randoms.insert(rand_var.to_string());
  193. codegen.prove_verify_append(convert_commitment(
  194. &commitment_var,
  195. &ped_assign,
  196. &neq_linscalar,
  197. &vardict,
  198. )?);
  199. codegen.prove_append(convert_randomness(
  200. &rand_var,
  201. &ped_assign,
  202. &neq_linscalar,
  203. &vardict,
  204. )?);
  205. // Now commitment_var is a Pedersen commitment to the LinScalar
  206. // we want to prove is not 0, using the randomness rand_var.
  207. // That is, commitment_var = L(x)*A + rand_var*B, where L(x) is
  208. // a linear function of x, and we want to show that L(x) != 0.
  209. // So we compute j = L(x).invert(), and s = -rand_var*j as new
  210. // private Scalars, and show that A = j*commitment_var + s*B.
  211. let Lx_var = codegen.gen_ident(&format_ident!("neq{}_{}_var", neq_stmt_index, neq_id));
  212. let Lx_code = expr_type_tokens(&vardict, &neq_linscalar.to_expr())?.1;
  213. let j_var = codegen.gen_scalar(
  214. vars,
  215. &format_ident!("neq{}_{}_j", neq_stmt_index, neq_id),
  216. false, // is_rand
  217. false, // is_vec
  218. );
  219. let s_var = codegen.gen_scalar(
  220. vars,
  221. &format_ident!("neq{}_{}_s", neq_stmt_index, neq_id),
  222. false, // is_rand
  223. false, // is_vec
  224. );
  225. // Update vardict with the new vars
  226. vardict = taggedvardict_to_vardict(vars);
  227. // The generators used in the Pedersen commitment
  228. let commit_generator = &ped_assign.pedersen.var_term.id;
  229. let rand_generator = &ped_assign.pedersen.rand_term.id;
  230. // The prover code
  231. codegen.prove_append(quote! {
  232. let #Lx_var = #Lx_code;
  233. let #j_var = <Scalar as Field>::invert(&#Lx_var)
  234. .into_option()
  235. .ok_or(SigmaError::VerificationFailure)?;
  236. let #s_var = -#rand_var * #j_var;
  237. });
  238. basic_statements.push(parse_quote! {
  239. #commit_generator = #j_var * #commitment_var
  240. + #s_var * #rand_generator
  241. });
  242. // Now replace the not-equals statement with an And of the
  243. // basic_statements
  244. let neq_st = StatementTree::And(
  245. basic_statements
  246. .into_iter()
  247. .map(StatementTree::Leaf)
  248. .collect(),
  249. );
  250. **leaf = neq_st;
  251. }
  252. Ok(())
  253. }
  254. #[cfg(test)]
  255. mod tests {
  256. use super::super::syntax::taggedvardict_from_strs;
  257. use super::*;
  258. fn parse_tester(vars: (&[&str], &[&str]), expr: Expr, expect: Option<LinScalar>) {
  259. let taggedvardict = taggedvardict_from_strs(vars);
  260. let vardict = taggedvardict_to_vardict(&taggedvardict);
  261. let output = parse(&taggedvardict, &vardict, &expr);
  262. assert_eq!(output, expect);
  263. }
  264. #[test]
  265. fn parse_test() {
  266. let vars = (
  267. [
  268. "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
  269. ]
  270. .as_slice(),
  271. ["C", "cind A", "cind B"].as_slice(),
  272. );
  273. parse_tester(
  274. vars,
  275. parse_quote! {
  276. x != 0
  277. },
  278. Some(LinScalar {
  279. coeff: 1,
  280. pub_scalar_expr: None,
  281. id: parse_quote! {x},
  282. is_vec: false,
  283. }),
  284. );
  285. parse_tester(
  286. vars,
  287. parse_quote! {
  288. x != 5
  289. },
  290. Some(LinScalar {
  291. coeff: 1,
  292. pub_scalar_expr: Some(parse_quote! {-5}),
  293. id: parse_quote! {x},
  294. is_vec: false,
  295. }),
  296. );
  297. parse_tester(
  298. vars,
  299. parse_quote! {
  300. 2*x != 5
  301. },
  302. Some(LinScalar {
  303. coeff: 2,
  304. pub_scalar_expr: Some(parse_quote! {-5}),
  305. id: parse_quote! {x},
  306. is_vec: false,
  307. }),
  308. );
  309. parse_tester(
  310. vars,
  311. parse_quote! {
  312. 2*x + 12 != 5
  313. },
  314. Some(LinScalar {
  315. coeff: 2,
  316. pub_scalar_expr: Some(parse_quote! {12i128-5}),
  317. id: parse_quote! {x},
  318. is_vec: false,
  319. }),
  320. );
  321. parse_tester(
  322. vars,
  323. parse_quote! {
  324. 2*x + a*a != 0
  325. },
  326. Some(LinScalar {
  327. coeff: 2,
  328. pub_scalar_expr: Some(parse_quote! {a*a}),
  329. id: parse_quote! {x},
  330. is_vec: false,
  331. }),
  332. );
  333. parse_tester(
  334. vars,
  335. parse_quote! {
  336. 2*x + a*a != b*c + c
  337. },
  338. Some(LinScalar {
  339. coeff: 2,
  340. pub_scalar_expr: Some(parse_quote! {a*a-(b*c+c)}),
  341. id: parse_quote! {x},
  342. is_vec: false,
  343. }),
  344. );
  345. }
  346. }