| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- //! A module to look for, and apply, any statement involving the
- //! equality of _public_ `Scalar`s.
- //!
- //! Such a statement is of the form `a = 2*(c+1)` where `a` and `c` are
- //! public `Scalar`s. That is, it is a single variable name (which must
- //! be a public `Scalar`, as specified in the provided
- //! [`TaggedVarDict`]), an equal sign, and an [arithmetic expression]
- //! involving other public `Scalar` variables, constants, parens, and
- //! the operators `+`, `-`, and `*`.
- //!
- //! The statement is simply removed from the list of statements to be
- //! proven in the zero-knowledge sigma protocol, and code is emitted for
- //! the prover and verifier to each just check that the statement is
- //! satisfied.
- //!
- //! [arithmetic expression]: super::sigma::types::expr_type
- use super::codegen::CodeGen;
- use super::sigma::combiners::*;
- use super::sigma::types::{expr_type_tokens, AExprType};
- use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
- use super::transform::prune_statement_tree;
- use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
- use quote::quote;
- use syn::{parse_quote, Error, Expr, Result};
- /// Look for, and apply, all of the public scalar equality statements
- /// specified in leaves of the [`StatementTree`].
- #[allow(non_snake_case)] // so that Points can be capital letters
- pub fn transform(
- codegen: &mut CodeGen,
- st: &mut StatementTree,
- vars: &mut TaggedVarDict,
- ) -> Result<()> {
- // Construct the VarDict corresponding to vars
- let vardict = taggedvardict_to_vardict(vars);
- // A list of the computationally independent (non-vector) Points in
- // the macro input. There must be at least one of them in order to
- // handle public scalar equality statements inside disjunctions.
- let cind_points = collect_cind_points(vars);
- st.for_each_disjunction_branch(&mut |branch, path| {
- // Are we in the root disjunction branch? (path is empty)
- let in_root_disjunction_branch = path.is_empty();
- // For each leaf expression, see if it looks like a public Scalar
- // equality statement
- branch.for_each_disjunction_branch_leaf(&mut |leaf| {
- if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign { left, right, .. })) = leaf {
- if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
- if let Some(id) = path.get_ident() {
- let idstr = id.to_string();
- if let Some(TaggedIdent::Scalar(TaggedScalar {
- is_pub: true,
- is_vec: false,
- ..
- })) = vars.get(&idstr)
- {
- if let (
- AExprType::Scalar {
- is_pub: true,
- is_vec: false,
- ..
- },
- right_tokens,
- ) = expr_type_tokens(&vardict, right)?
- {
- // We found a public Scalar equality
- // statement.
- if in_root_disjunction_branch {
- // If we're in the root disjunction branch,
- // add code to both the prover and the
- // verifier to directly check the statement.
- codegen.prove_verify_append(quote! {
- if #id != #right_tokens {
- return Err(SigmaError::VerificationFailure);
- }
- });
- // Remove the statement from the
- // [`StatementTree`] by replacing it with
- // leaf_true (which will be pruned below).
- *leaf = StatementTree::leaf_true();
- } else {
- // If we're not in the root disjunction
- // branch, replace the statement
- // `left_id = right_side` with the
- // statement `left_id*A =
- // (right_side)*A` for a cind Point A.
- if cind_points.is_empty() {
- return Err(Error::new(
- proc_macro2::Span::call_site(),
- "At least one cind Point must be declared to support public Scalar equality statements inside disjunctions",
- ));
- }
- let cind_A = &cind_points[0];
- *leaf = StatementTree::Leaf(parse_quote! {
- #id * #cind_A = (#right) * #cind_A
- });
- }
- }
- }
- }
- }
- }
- Ok(())
- })
- })?;
- // Now prune the StatementTree
- prune_statement_tree(st);
- Ok(())
- }
|