Browse Source

Only emit code to check consistency of substitutions if we're in the root disjunction branch

For example, if the statement contains "y = x + 1" _not_ inside an "Or",
then both substitute "y" with "x+1" in the StatementTree, and also emit
code that returns VerificationFailure if the Witness contains x and y
that does not satisfy y = x + 1.

On the other hand, if the statement contains "y = x + 1" _inside_ an
"Or", then do the substitution, but do not emit the checking code, since
it may not be the case that that part of the StatementTree is intended
to be true.

This commit is not formatted so as to minimize the diff.  The next
commit will run "cargo fmt" with no code changes.
Ian Goldberg 3 months ago
parent
commit
0c7e495a76
1 changed files with 24 additions and 13 deletions
  1. 24 13
      sigma_compiler_core/src/substitution.rs

+ 24 - 13
sigma_compiler_core/src/substitution.rs

@@ -46,7 +46,7 @@ use std::collections::{HashSet, VecDeque};
 use syn::spanned::Spanned;
 use syn::visit::Visit;
 use syn::visit_mut::{self, VisitMut};
-use syn::{parse_quote, Error, Expr, Ident, Result};
+use syn::{Error, Expr, Ident, Result};
 
 /// Produce a [`HashSet`] of the private `Scalar`s appearing in the
 /// provided [`Expr`], as specified in the provided [`TaggedVarDict`].
@@ -105,18 +105,18 @@ pub fn transform(
     // Construct the VarDict corresponding to vars
     let vardict = taggedvardict_to_vardict(vars);
 
-    // Gather mutable references to all Exprs in the leaves of the
-    // StatementTree.  Note that this ignores the combiner structure in
-    // the StatementTree, but that's fine.
-    let mut leaves = st.leaves_mut();
+    let mut subs: VecDeque<(Ident, Expr, HashSet<String>)> = VecDeque::new();
+    let mut subs_vars: HashSet<String> = HashSet::new();
+
+    st.for_each_disjunction_branch(&mut |branch, path| {
+        // Are we in the root disjunction branch?  (path is empty)
+        let in_root_disjunction_branch = path.is_empty();
 
     // For each leaf expression, see if it looks like a substitution of
     // a private Scalar
-    let mut subs: VecDeque<(Ident, Expr, HashSet<String>)> = VecDeque::new();
-    let mut subs_vars: HashSet<String> = HashSet::new();
-    for leafexpr in leaves.iter_mut() {
+        branch.for_each_disjunction_branch_leaf(&mut |leaf| {
         let mut is_subs = None;
-        if let Expr::Assign(syn::ExprAssign { left, .. }) = *leafexpr {
+        if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign { left, .. })) = leaf {
             if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
                 if let Some(id) = path.get_ident() {
                     let idstr = id.to_string();
@@ -133,15 +133,22 @@ pub fn transform(
             // it to subs, replace it in the StatementTree with the
             // constant true, and generate some code for `prove` to
             // check the statement.
-            let mut expr: Expr = parse_quote! { true };
-            std::mem::swap(&mut expr, *leafexpr);
+            let old_leaf = std::mem::replace(leaf, StatementTree::leaf_true());
             // This "if let" is guaranteed to succeed
-            if let Expr::Assign(syn::ExprAssign { right, .. }) = expr {
+            if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign {
+            right, .. })) = old_leaf {
                 if let Ok((_, right_tokens)) = expr_type_tokens(&vardict, &right) {
                     let used_priv_scalars = priv_scalar_set(&right, vars);
                     if !subs_vars.insert(id.to_string()) {
                         return Err(Error::new(id.span(), "variable substituted multiple times"));
                     }
+                    // Only if we're in the root disjunction branch,
+                    // check whether the substituted Witness value
+                    // actually equals the value it's being substituted
+                    // for.  We can't do this for substitutions in other
+                    // disjunction branches, since it may not be true
+                    // there.
+                    if in_root_disjunction_branch {
                     codegen.prove_append(quote! {
                         // It's OK to have a test that observably fails
                         // for illegal inputs (but is constant time for
@@ -150,6 +157,7 @@ pub fn transform(
                             return Err(SigmaError::VerificationFailure);
                         }
                     });
+                    }
                     let right = paren_if_needed(*right);
                     subs.push_back((id, right, used_priv_scalars));
                 } else {
@@ -164,7 +172,9 @@ pub fn transform(
                 }
             }
         }
+    Ok(())
     }
+    )})?;
 
     // Now apply each substitution to both the StatementTree and also
     // the remaining substitutions
@@ -187,7 +197,7 @@ pub fn transform(
             }
         }
         // Do the substitution on each leaf Expr in the StatementTree
-        for leafexpr in leaves.iter_mut() {
+        for leafexpr in st.leaves_mut().iter_mut() {
             do_substitution(leafexpr, &idstr, &expr);
         }
         // Remove the substituted variable from the TaggedVarDict
@@ -204,6 +214,7 @@ pub fn transform(
 mod tests {
     use super::super::syntax::taggedvardict_from_strs;
     use super::*;
+    use syn::parse_quote;
 
     fn substitution_tester(
         vars: (&[&str], &[&str]),