Browse Source

Look for, and apply, all of the substitutions specified in leaves of a StatementTree

Ian Goldberg 3 months ago
parent
commit
2a45909392
1 changed files with 390 additions and 1 deletions
  1. 390 1
      sigma_compiler_core/src/transform.rs

+ 390 - 1
sigma_compiler_core/src/transform.rs

@@ -4,7 +4,12 @@
 //! [disjunction invariant]: StatementTree::check_disjunction_invariant
 
 use super::sigma::combiners::*;
-use syn::{parse_quote, Expr};
+use super::syntax::taggedvardict_to_vardict;
+use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
+use std::collections::{HashSet, VecDeque};
+use syn::visit::Visit;
+use syn::visit_mut::{self, VisitMut};
+use syn::{parse_quote, Error, Expr, Ident, Result};
 
 /// Produce a [`StatementTree`] that represents the constant `true`
 fn leaf_true() -> StatementTree {
@@ -109,9 +114,182 @@ pub fn prune_statement_tree(st: &mut StatementTree) {
     }
 }
 
+/// Produce a [`HashSet`] of the private `Scalar`s appearing in the
+/// provided [`Expr`], as specified in the provided [`TaggedVarDict`].
+fn priv_scalar_set(e: &Expr, taggedvardict: &TaggedVarDict) -> HashSet<String> {
+    let mut set: HashSet<String> = HashSet::new();
+    let vardict = taggedvardict_to_vardict(taggedvardict);
+    let mut priv_map = PrivScalarMap {
+        vars: &vardict,
+        closure: &mut |ident| {
+            set.insert(ident.to_string());
+            Ok(())
+        },
+        result: Ok(()),
+    };
+    priv_map.visit_expr(e);
+    set
+}
+
+/// Add parentheses around an [`Expr`] (which represents an [arithmetic
+/// expression]) if needed.
+///
+/// The parentheses are needed if the [`Expr`] would parse as multiple
+/// tokens.  For example, `a+b` turns into `(a+b)`, but `c`
+/// remains `c` and `(a+b)` remains `(a+b)`.
+///
+/// [arithmetic expression]: super::sigma::types::expr_type
+fn paren_if_needed(expr: Expr) -> Expr {
+    match expr {
+        Expr::Unary(_) | Expr::Binary(_) => parse_quote! { (#expr) },
+        _ => expr,
+    }
+}
+
+/// Apply a single substitution on an [`Expr`].
+///
+/// Replace all instances of the [`struct@Ident`] given by the string
+/// `idstr` in `expr` with a copy of `replacement`.
+fn do_substitution<'a>(expr: &mut Expr, idstr: &'a str, replacement: &'a Expr) {
+    struct Subs<'a> {
+        idstr: &'a str,
+        replacement: &'a Expr,
+    }
+
+    impl<'a> VisitMut for Subs<'a> {
+        fn visit_expr_mut(&mut self, node: &mut Expr) {
+            if let Expr::Path(expath) = node {
+                if let Some(id) = expath.path.get_ident() {
+                    if id.to_string().as_str() == self.idstr {
+                        *node = self.replacement.clone();
+                        return;
+                    }
+                }
+            }
+            // Unless we bailed out above, continue with the default
+            // traversal
+            visit_mut::visit_expr_mut(self, node);
+        }
+    }
+
+    let mut subs = Subs { idstr, replacement };
+    subs.visit_expr_mut(expr);
+}
+
+/// Look for, and apply, all of the _substitutions_ specified in leaves
+/// of the [`StatementTree`].
+///
+/// A _substitution_ is a statement of the form `a = b` or `b = 2*(c + 1)`.
+/// That is, it is a single variable name (which must be a private
+/// `Scalar`, as specified in the provided [`TaggedVarDict`]), an equal
+/// sign, and an [arithmetic expression] involving other `Scalar`
+/// variables, constants, parens, and the operators `+`, `-`, and `*`.
+///
+/// Applying a substitution means replacing the variable to the left of
+/// the `=` with the expression on the right of the `=` everywhere it
+/// appears in the [`StatementTree`].  Any given variable may only be
+/// substituted once in a [`StatementTree`].
+///
+/// The expression on the right must not contain the variable on the
+/// left, either directly or after other substitutions.  For example,
+/// `a = a + b` is not allowed, nor is the combination of substitutions
+/// `a = b + 1, b = c + 2, c = 2*a`.
+///
+/// After a substitution is applied, the substituted variable will no
+/// longer appear anywhere in the [`StatementTree`], and will be removed
+/// from the [`TaggedVarDict`].  The leaves of the [`StatementTree`]
+/// containing the substitution statements themselves will be turned
+/// into the constant `true` and then pruned using
+/// [`prune_statement_tree`].
+///
+/// It is the case that if the [disjunction invariant] is satisfied
+/// before this function is called (and the caller must ensure that it
+/// is), then it will be satisfied after the substitutions are applied,
+/// and then also after the [`StatementTree`] is pruned.
+///
+/// [arithmetic expression]: super::sigma::types::expr_type
+/// [disjunction invariant]: StatementTree::check_disjunction_invariant
+pub fn apply_substitutions(st: &mut StatementTree, vars: &mut TaggedVarDict) -> Result<()> {
+    // Gather mutable references to all Exprs in the leaves of the
+    // StatementTree.  Note that this ignores the combiner structure in
+    // the StatementTree, but that's fine.
+    let mut leaves = st.leaves_mut();
+
+    // For each leaf expression, see if it looks like a substitution of
+    // a private Scalar
+    let mut subs: VecDeque<(Ident, Expr, HashSet<String>)> = VecDeque::new();
+    let mut subs_vars: HashSet<String> = HashSet::new();
+    for leafexpr in leaves.iter_mut() {
+        let mut is_subs = None;
+        if let Expr::Assign(syn::ExprAssign { left, .. }) = *leafexpr {
+            if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
+                if let Some(id) = path.get_ident() {
+                    let idstr = id.to_string();
+                    if let Some(TaggedIdent::Scalar(TaggedScalar { is_pub: false, .. })) =
+                        vars.get(&idstr)
+                    {
+                        is_subs = Some(id.clone());
+                    }
+                }
+            }
+        }
+        if let Some(id) = is_subs {
+            // If this leaf is a substitution of a private Scalar, add
+            // it to subs and replace it in the StatementTree with the
+            // constant true.
+            let mut expr: Expr = parse_quote! { true };
+            std::mem::swap(&mut expr, *leafexpr);
+            // This "if let" is guaranteed to succeed
+            if let Expr::Assign(syn::ExprAssign { right, .. }) = expr {
+                let used_priv_scalars = priv_scalar_set(&right, vars);
+                if !subs_vars.insert(id.to_string()) {
+                    return Err(Error::new(id.span(), "variable substituted multiple times"));
+                }
+                let right = paren_if_needed(*right);
+                subs.push_back((id, right, used_priv_scalars));
+            }
+        }
+    }
+
+    // Now apply each substitution to both the StatementTree and also
+    // the remaining substitutions
+    while !subs.is_empty() {
+        let (id, expr, priv_vars) = subs.pop_front().unwrap();
+        let idstr = id.to_string();
+        if priv_vars.contains(&idstr) {
+            return Err(Error::new(
+                id.span(),
+                "variable appears in its own substitution",
+            ));
+        }
+        // Do the substitution on each remaining substitution in the
+        // list
+        for (_sid, sexpr, spriv_vars) in subs.iter_mut() {
+            if spriv_vars.contains(&idstr) {
+                do_substitution(sexpr, &idstr, &expr);
+                spriv_vars.remove(&idstr);
+                spriv_vars.extend(priv_vars.clone().into_iter());
+            }
+        }
+        // Do the substitution on each leaf Expr in the StatementTree
+        for leafexpr in leaves.iter_mut() {
+            do_substitution(leafexpr, &idstr, &expr);
+        }
+        // Remove the substituted variable from the TaggedVarDict
+        vars.remove(&idstr);
+    }
+
+    // Now prune the StatementTree
+    prune_statement_tree(st);
+
+    Ok(())
+}
+
 #[cfg(test)]
 mod tests {
+    use super::super::TaggedPoint;
     use super::*;
+    use std::collections::HashMap;
 
     #[test]
     fn leaf_true_test() {
@@ -237,4 +415,215 @@ mod tests {
             },
         );
     }
+
+    fn taggedvardict_from_strs((scalar_strs, point_strs): (&[&str], &[&str])) -> TaggedVarDict {
+        let mut vars = HashMap::new();
+
+        for scalar in scalar_strs {
+            let ts: TaggedScalar = syn::parse_str(scalar).unwrap();
+            vars.insert(ts.id.to_string(), TaggedIdent::Scalar(ts));
+        }
+        for point in point_strs {
+            let tp: TaggedPoint = syn::parse_str(point).unwrap();
+            vars.insert(tp.id.to_string(), TaggedIdent::Point(tp));
+        }
+        vars
+    }
+
+    fn substitution_tester(
+        vars: (&[&str], &[&str]),
+        e: Expr,
+        subbed_vars: (&[&str], &[&str]),
+        subbed_e: Expr,
+    ) -> Result<()> {
+        let mut taggedvardict = taggedvardict_from_strs(vars);
+        let mut st = StatementTree::parse(&e).unwrap();
+        apply_substitutions(&mut st, &mut taggedvardict)?;
+        let subbed_taggedvardict = taggedvardict_from_strs(subbed_vars);
+        let subbed_st = StatementTree::parse(&subbed_e).unwrap();
+        assert_eq!(st, subbed_st);
+        assert_eq!(taggedvardict, subbed_taggedvardict);
+        Ok(())
+    }
+
+    #[test]
+    fn apply_substitutions_test() {
+        let vars_a = (["a", "b", "pub c"].as_slice(), ["A", "B", "C"].as_slice());
+
+        // No substitutions (left side of = is a Point, not a Scalar)
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                A = b*B + c*C
+            },
+            vars_a,
+            parse_quote! {
+                A = b*B + c*C
+            },
+        )
+        .unwrap();
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    B = a*A + c*C,
+                )
+            },
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    B = a*A + c*C,
+                )
+            },
+        )
+        .unwrap();
+
+        // No substitutions (the left side of the = is public, not
+        // private)
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    c = a,
+                )
+            },
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    c = a,
+                )
+            },
+        )
+        .unwrap();
+
+        // Error: same variable substituted more than once
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    a = c,
+                    a = b,
+                )
+            },
+            vars_a,
+            parse_quote! { true },
+        )
+        .unwrap_err();
+
+        // Error: same variable substituted more than once
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    a = c,
+                    a = b,
+                )
+            },
+            vars_a,
+            parse_quote! { true },
+        )
+        .unwrap_err();
+
+        // Error: variable appears in its own substitution (directly)
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    a = 2*a + 1,
+                )
+            },
+            vars_a,
+            parse_quote! { true },
+        )
+        .unwrap_err();
+
+        // Error: variable appears in its own substitution (indirectly)
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    a = 2*b + 1,
+                    b = a + 4,
+                )
+            },
+            vars_a,
+            parse_quote! { true },
+        )
+        .unwrap_err();
+
+        // Successful substitutions
+
+        let vars_nob = (["a", "pub c"].as_slice(), ["A", "B", "C"].as_slice());
+        substitution_tester(
+            vars_a,
+            parse_quote! {
+                AND (
+                    A = b*B + c*C,
+                    b = c,
+                )
+            },
+            vars_nob,
+            parse_quote! { A = c*B + c*C },
+        )
+        .unwrap();
+
+        let vars_cd = (
+            [
+                "c", "d", "r", "s", "c0", "d0", "r0", "s0", "c1", "d1", "r1", "s1",
+            ]
+            .as_slice(),
+            ["A", "B", "C", "D"].as_slice(),
+        );
+        let vars_cd_noc01 = (
+            ["c", "d", "r", "s", "d0", "r0", "s0", "d1", "r1", "s1"].as_slice(),
+            ["A", "B", "C", "D"].as_slice(),
+        );
+        substitution_tester(
+            vars_cd,
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
+                    OR (
+                        AND (
+                            C = c0*B + r0*A,
+                            D = d0*B + s0*A,
+                            c0 = d0,
+                        ),
+                        AND (
+                            C = c1*B + r1*A,
+                            D = d1*B + s1*A,
+                            c1 = d1 + 1,
+                        ),
+                     )
+                )
+            },
+            vars_cd_noc01,
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
+                    OR (
+                        AND (
+                            C = d0*B + r0*A,
+                            D = d0*B + s0*A,
+                        ),
+                        AND (
+                            C = (d1+1)*B + r1*A,
+                            D = d1*B + s1*A,
+                        ),
+                     )
+                )
+            },
+        )
+        .unwrap();
+    }
 }