|
|
@@ -18,14 +18,15 @@
|
|
|
use super::codegen::CodeGen;
|
|
|
use super::sigma::combiners::*;
|
|
|
use super::sigma::types::{expr_type_tokens, AExprType};
|
|
|
-use super::syntax::taggedvardict_to_vardict;
|
|
|
+use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
|
|
|
use super::transform::prune_statement_tree;
|
|
|
use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
|
|
|
use quote::quote;
|
|
|
-use syn::{parse_quote, Expr, Result};
|
|
|
+use syn::{parse_quote, Error, Expr, Result};
|
|
|
|
|
|
/// Look for, and apply, all of the public scalar equality statements
|
|
|
/// specified in leaves of the [`StatementTree`].
|
|
|
+#[allow(non_snake_case)] // so that Points can be capital letters
|
|
|
pub fn transform(
|
|
|
codegen: &mut CodeGen,
|
|
|
st: &mut StatementTree,
|
|
|
@@ -34,15 +35,19 @@ pub fn transform(
|
|
|
// Construct the VarDict corresponding to vars
|
|
|
let vardict = taggedvardict_to_vardict(vars);
|
|
|
|
|
|
- // Gather mutable references to all Exprs in the leaves of the
|
|
|
- // StatementTree. Note that this ignores the combiner structure in
|
|
|
- // the StatementTree, but that's fine.
|
|
|
- let mut leaves = st.leaves_mut();
|
|
|
+ // A list of the computationally independent (non-vector) Points in
|
|
|
+ // the macro input. There must be at least one of them in order to
|
|
|
+ // handle public scalar equality statements inside disjunctions.
|
|
|
+ let cind_points = collect_cind_points(vars);
|
|
|
+
|
|
|
+ st.for_each_disjunction_branch(&mut |branch, path| {
|
|
|
+ // Are we in the root disjunction branch? (path is empty)
|
|
|
+ let in_root_disjunction_branch = path.is_empty();
|
|
|
|
|
|
// For each leaf expression, see if it looks like a public Scalar
|
|
|
// equality statement
|
|
|
- for leafexpr in leaves.iter_mut() {
|
|
|
- if let Expr::Assign(syn::ExprAssign { left, right, .. }) = *leafexpr {
|
|
|
+ branch.for_each_disjunction_branch_leaf(&mut |leaf| {
|
|
|
+ if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign { left, right, .. })) = leaf {
|
|
|
if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
|
|
|
if let Some(id) = path.get_ident() {
|
|
|
let idstr = id.to_string();
|
|
|
@@ -62,8 +67,11 @@ pub fn transform(
|
|
|
) = expr_type_tokens(&vardict, right)?
|
|
|
{
|
|
|
// We found a public Scalar equality
|
|
|
- // statement. Add code to both the prover
|
|
|
- // and the verifier to check the statement.
|
|
|
+ // statement.
|
|
|
+ if in_root_disjunction_branch {
|
|
|
+ // If we're in the root disjunction branch,
|
|
|
+ // add code to both the prover and the
|
|
|
+ // verifier to directly check the statement.
|
|
|
codegen.prove_verify_append(quote! {
|
|
|
if #id != #right_tokens {
|
|
|
return Err(SigmaError::VerificationFailure);
|
|
|
@@ -73,14 +81,33 @@ pub fn transform(
|
|
|
// Remove the statement from the
|
|
|
// [`StatementTree`] by replacing it with
|
|
|
// leaf_true (which will be pruned below).
|
|
|
- let mut expr: Expr = parse_quote! { true };
|
|
|
- std::mem::swap(&mut expr, *leafexpr);
|
|
|
+ *leaf = StatementTree::leaf_true();
|
|
|
+ } else {
|
|
|
+ // If we're not in the root disjunction
|
|
|
+ // branch, replace the statement
|
|
|
+ // `left_id = right_side` with the
|
|
|
+ // statement `left_id*A =
|
|
|
+ // (right_side)*A` for a cind Point A.
|
|
|
+ if cind_points.is_empty() {
|
|
|
+ return Err(Error::new(
|
|
|
+ proc_macro2::Span::call_site(),
|
|
|
+ "At least one cind Point must be declared to support public Scalar equality statements inside disjunctions",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+ let cind_A = &cind_points[0];
|
|
|
+
|
|
|
+ *leaf = StatementTree::Leaf(parse_quote! {
|
|
|
+ #id * #cind_A = (#right) * #cind_A
|
|
|
+ });
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
+ Ok(())
|
|
|
+ })
|
|
|
+ })?;
|
|
|
|
|
|
// Now prune the StatementTree
|
|
|
prune_statement_tree(st);
|