|
|
@@ -46,7 +46,7 @@ use std::collections::{HashSet, VecDeque};
|
|
|
use syn::spanned::Spanned;
|
|
|
use syn::visit::Visit;
|
|
|
use syn::visit_mut::{self, VisitMut};
|
|
|
-use syn::{parse_quote, Error, Expr, Ident, Result};
|
|
|
+use syn::{Error, Expr, Ident, Result};
|
|
|
|
|
|
/// Produce a [`HashSet`] of the private `Scalar`s appearing in the
|
|
|
/// provided [`Expr`], as specified in the provided [`TaggedVarDict`].
|
|
|
@@ -105,18 +105,18 @@ pub fn transform(
|
|
|
// Construct the VarDict corresponding to vars
|
|
|
let vardict = taggedvardict_to_vardict(vars);
|
|
|
|
|
|
- // Gather mutable references to all Exprs in the leaves of the
|
|
|
- // StatementTree. Note that this ignores the combiner structure in
|
|
|
- // the StatementTree, but that's fine.
|
|
|
- let mut leaves = st.leaves_mut();
|
|
|
+ let mut subs: VecDeque<(Ident, Expr, HashSet<String>)> = VecDeque::new();
|
|
|
+ let mut subs_vars: HashSet<String> = HashSet::new();
|
|
|
+
|
|
|
+ st.for_each_disjunction_branch(&mut |branch, path| {
|
|
|
+ // Are we in the root disjunction branch? (path is empty)
|
|
|
+ let in_root_disjunction_branch = path.is_empty();
|
|
|
|
|
|
// For each leaf expression, see if it looks like a substitution of
|
|
|
// a private Scalar
|
|
|
- let mut subs: VecDeque<(Ident, Expr, HashSet<String>)> = VecDeque::new();
|
|
|
- let mut subs_vars: HashSet<String> = HashSet::new();
|
|
|
- for leafexpr in leaves.iter_mut() {
|
|
|
+ branch.for_each_disjunction_branch_leaf(&mut |leaf| {
|
|
|
let mut is_subs = None;
|
|
|
- if let Expr::Assign(syn::ExprAssign { left, .. }) = *leafexpr {
|
|
|
+ if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign { left, .. })) = leaf {
|
|
|
if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
|
|
|
if let Some(id) = path.get_ident() {
|
|
|
let idstr = id.to_string();
|
|
|
@@ -133,15 +133,22 @@ pub fn transform(
|
|
|
// it to subs, replace it in the StatementTree with the
|
|
|
// constant true, and generate some code for `prove` to
|
|
|
// check the statement.
|
|
|
- let mut expr: Expr = parse_quote! { true };
|
|
|
- std::mem::swap(&mut expr, *leafexpr);
|
|
|
+ let old_leaf = std::mem::replace(leaf, StatementTree::leaf_true());
|
|
|
// This "if let" is guaranteed to succeed
|
|
|
- if let Expr::Assign(syn::ExprAssign { right, .. }) = expr {
|
|
|
+ if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign {
|
|
|
+ right, .. })) = old_leaf {
|
|
|
if let Ok((_, right_tokens)) = expr_type_tokens(&vardict, &right) {
|
|
|
let used_priv_scalars = priv_scalar_set(&right, vars);
|
|
|
if !subs_vars.insert(id.to_string()) {
|
|
|
return Err(Error::new(id.span(), "variable substituted multiple times"));
|
|
|
}
|
|
|
+ // Only if we're in the root disjunction branch,
|
|
|
+ // check whether the substituted Witness value
|
|
|
+ // actually equals the value it's being substituted
|
|
|
+ // for. We can't do this for substitutions in other
|
|
|
+ // disjunction branches, since it may not be true
|
|
|
+ // there.
|
|
|
+ if in_root_disjunction_branch {
|
|
|
codegen.prove_append(quote! {
|
|
|
// It's OK to have a test that observably fails
|
|
|
// for illegal inputs (but is constant time for
|
|
|
@@ -150,6 +157,7 @@ pub fn transform(
|
|
|
return Err(SigmaError::VerificationFailure);
|
|
|
}
|
|
|
});
|
|
|
+ }
|
|
|
let right = paren_if_needed(*right);
|
|
|
subs.push_back((id, right, used_priv_scalars));
|
|
|
} else {
|
|
|
@@ -164,7 +172,9 @@ pub fn transform(
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ Ok(())
|
|
|
}
|
|
|
+ )})?;
|
|
|
|
|
|
// Now apply each substitution to both the StatementTree and also
|
|
|
// the remaining substitutions
|
|
|
@@ -187,7 +197,7 @@ pub fn transform(
|
|
|
}
|
|
|
}
|
|
|
// Do the substitution on each leaf Expr in the StatementTree
|
|
|
- for leafexpr in leaves.iter_mut() {
|
|
|
+ for leafexpr in st.leaves_mut().iter_mut() {
|
|
|
do_substitution(leafexpr, &idstr, &expr);
|
|
|
}
|
|
|
// Remove the substituted variable from the TaggedVarDict
|
|
|
@@ -204,6 +214,7 @@ pub fn transform(
|
|
|
mod tests {
|
|
|
use super::super::syntax::taggedvardict_from_strs;
|
|
|
use super::*;
|
|
|
+ use syn::parse_quote;
|
|
|
|
|
|
fn substitution_tester(
|
|
|
vars: (&[&str], &[&str]),
|