pubscalareq.rs 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. //! A module to look for, and apply, any statement involving the
  2. //! equality of _public_ `Scalar`s.
  3. //!
  4. //! Such a statement is of the form `a = 2*(c+1)` where `a` and `c` are
  5. //! public `Scalar`s. That is, it is a single variable name (which must
  6. //! be a public `Scalar`, as specified in the provided
  7. //! [`TaggedVarDict`]), an equal sign, and an [arithmetic expression]
  8. //! involving other public `Scalar` variables, constants, parens, and
  9. //! the operators `+`, `-`, and `*`.
  10. //!
  11. //! The statement is simply removed from the list of statements to be
  12. //! proven in the zero-knowledge sigma protocol, and code is emitted for
  13. //! the prover and verifier to each just check that the statement is
  14. //! satisfied.
  15. //!
  16. //! [arithmetic expression]: super::sigma::types::expr_type
  17. use super::codegen::CodeGen;
  18. use super::sigma::combiners::*;
  19. use super::sigma::types::{expr_type_tokens, AExprType};
  20. use super::syntax::taggedvardict_to_vardict;
  21. use super::transform::prune_statement_tree;
  22. use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
  23. use quote::quote;
  24. use syn::{parse_quote, Expr, Result};
  25. /// Look for, and apply, all of the public scalar equality statements
  26. /// specified in leaves of the [`StatementTree`].
  27. pub fn transform(
  28. codegen: &mut CodeGen,
  29. st: &mut StatementTree,
  30. vars: &mut TaggedVarDict,
  31. ) -> Result<()> {
  32. // Construct the VarDict corresponding to vars
  33. let vardict = taggedvardict_to_vardict(vars);
  34. // Gather mutable references to all Exprs in the leaves of the
  35. // StatementTree. Note that this ignores the combiner structure in
  36. // the StatementTree, but that's fine.
  37. let mut leaves = st.leaves_mut();
  38. // For each leaf expression, see if it looks like a public Scalar
  39. // equality statement
  40. for leafexpr in leaves.iter_mut() {
  41. if let Expr::Assign(syn::ExprAssign { left, right, .. }) = *leafexpr {
  42. if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
  43. if let Some(id) = path.get_ident() {
  44. let idstr = id.to_string();
  45. if let Some(TaggedIdent::Scalar(TaggedScalar {
  46. is_pub: true,
  47. is_vec: false,
  48. ..
  49. })) = vars.get(&idstr)
  50. {
  51. if let (
  52. AExprType::Scalar {
  53. is_pub: true,
  54. is_vec: false,
  55. ..
  56. },
  57. right_tokens,
  58. ) = expr_type_tokens(&vardict, right)?
  59. {
  60. // We found a public Scalar equality
  61. // statement. Add code to both the prover
  62. // and the verifier to check the statement.
  63. codegen.prove_append(quote! {
  64. if #id != #right_tokens {
  65. return Err(SigmaError::VerificationFailure);
  66. }
  67. });
  68. codegen.verify_append(quote! {
  69. if #id != #right_tokens {
  70. return Err(SigmaError::VerificationFailure);
  71. }
  72. });
  73. // Remove the statement from the
  74. // [`StatementTree`] by replacing it with
  75. // leaf_true (which will be pruned below).
  76. let mut expr: Expr = parse_quote! { true };
  77. std::mem::swap(&mut expr, *leafexpr);
  78. }
  79. }
  80. }
  81. }
  82. }
  83. }
  84. // Now prune the StatementTree
  85. prune_statement_tree(st);
  86. Ok(())
  87. }