Browse Source

Enforce the disjunction invariant

A _disjunction node_ is an Or or Thresh node in the StatementTree.  The
_disjunction invariant_ is that a private variable (which is necessarily
a Scalar since there are no private Point variables) that appears in the
subtree rooted at a child of a disjunction node cannot also appear
outside of that subtree.
Ian Goldberg 4 months ago
parent
commit
d7ba2a9350

+ 1 - 1
sigma_compiler_core/Cargo.toml

@@ -4,7 +4,7 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
-syn = { version = "2.0", features = ["extra-traits", "visit-mut", "full"] }
+syn = { version = "2.0", features = ["extra-traits", "visit", "visit-mut", "full"] }
 quote = "1.0"
 proc-macro2 = "1.0"
 

+ 292 - 1
sigma_compiler_core/src/sigma/combiners.rs

@@ -1,9 +1,53 @@
 //! This module creates and manipulates trees of basic statements
 //! combined with `AND`, `OR`, and `THRESH`.
 
+use super::types::*;
+use std::collections::HashMap;
 use syn::parse::Result;
+use syn::visit::Visit;
 use syn::Expr;
 
+/// For each [`Ident`](struct@syn::Ident) representing a private
+/// `Scalar` (as listed in a [`VarDict`]) that appears in an [`Expr`],
+/// call a given closure.
+struct PrivScalarMap<'a> {
+    /// The [`VarDict`] that maps variable names to their types
+    pub vars: &'a VarDict,
+
+    /// The closure that is called for each [`Ident`](struct@syn::Ident)
+    /// found in the [`Expr`] (provided in the call to
+    /// [`visit_expr`](PrivScalarMap::visit_expr)) that represents a
+    /// private `Scalar`
+    pub closure: &'a mut dyn FnMut(&syn::Ident) -> Result<()>,
+
+    /// The accumulated result.  This will be the first
+    /// [`Err`](Result::Err) returned from the closure, or
+    /// [`Ok(())`](Result::Ok) if all calls to the closure succeeded.
+    pub result: Result<()>,
+}
+
+impl<'a> Visit<'a> for PrivScalarMap<'a> {
+    fn visit_path(&mut self, path: &'a syn::Path) {
+        // Whenever we see a `Path`, check first if it's just a bare
+        // `Ident`
+        let Some(id) = path.get_ident() else {
+            return;
+        };
+        // Then check if that `Ident` appears in the `VarDict`
+        let Some(vartype) = self.vars.get(&id.to_string()) else {
+            return;
+        };
+        // If so, and the `Ident` represents a private Scalar,
+        // call the closure if we haven't seen an `Err` returned from
+        // the closure yet.
+        if let AExprType::Scalar { is_pub: false, .. } = vartype {
+            if self.result.is_ok() {
+                self.result = (self.closure)(id);
+            }
+        }
+    }
+}
+
 /// The statements in the ZKP form a tree.  The leaves are basic
 /// statements of various kinds; for example, equations or inequalities
 /// about Scalars and Points.  The interior nodes are combiners: `And`,
@@ -22,12 +66,13 @@ pub enum StatementTree {
 }
 
 impl StatementTree {
+    #[cfg(not(doctest))]
     /// Parse an [`Expr`] (which may contain nested `AND`, `OR`, or
     /// `THRESH`) into a [`StatementTree`].  For example, the
     /// [`Expr`] obtained from:
     /// ```
     /// parse_quote! {
-    ///    AND(
+    ///    AND (
     ///        C = c*B + r*A,
     ///        D = d*B + s*A,
     ///        OR (
@@ -123,6 +168,158 @@ impl StatementTree {
             }
         }
     }
+
+    #[cfg(not(doctest))]
+    /// Verify whether the [`StatementTree`] satisfies the disjunction
+    /// invariant.
+    ///
+    /// A _disjunction node_ is an [`Or`](StatementTree::Or) or
+    /// [`Thresh`](StatementTree::Thresh) node in the [`StatementTree`].
+    /// The _disjunction invariant_ is that a private variable (which is
+    /// necessarily a `Scalar` since there are no private `Point`
+    /// variables) that appears in the subtree rooted at a child of a
+    /// disjunction node cannot also appear outside of that subtree.
+    ///
+    /// For example, if all of the lowercase variables are private
+    /// `Scalar`s, the [`StatementTree`] created from:
+    ///
+    /// ```
+    ///    AND (
+    ///        C = c*B + r*A,
+    ///        D = d*B + s*A,
+    ///        OR (
+    ///            AND (
+    ///                C = c0*B + r0*A,
+    ///                D = d0*B + s0*A,
+    ///                c0 = d0,
+    ///            ),
+    ///            AND (
+    ///                C = c1*B + r1*A,
+    ///                D = d1*B + s1*A,
+    ///                c1 = d1 + 1,
+    ///            ),
+    ///        )
+    ///    )
+    /// ```
+    ///
+    /// satisfies the disjunction invariant, but
+    ///
+    /// ```
+    ///    AND (
+    ///        C = c*B + r*A,
+    ///        D = d*B + s*A,
+    ///        OR (
+    ///            AND (
+    ///                D = d0*B + s0*A,
+    ///                c = d0,
+    ///            ),
+    ///            AND (
+    ///                C = c1*B + r1*A,
+    ///                D = d1*B + s1*A,
+    ///                c1 = d1 + 1,
+    ///            ),
+    ///        )
+    ///    )
+    /// ```
+    ///
+    /// does not, because `c` appears in the first child of the `OR` and
+    /// also outside of the `OR` entirely.  Indeed, the reason to write
+    /// the first expression above rather than the more natural
+    ///
+    /// ```
+    ///    AND (
+    ///        C = c*B + r*A,
+    ///        D = d*B + s*A,
+    ///        OR (
+    ///            c = d,
+    ///            c = d + 1,
+    ///        )
+    ///    )
+    /// ```
+    ///
+    /// is exactly that the invariant must be satisfied.
+    ///
+    /// (In the future, it is possible we may provide a transformer that
+    /// will automatically convert [`StatementTree`]s to ones that
+    /// satisfy the invariant, but for now, the user of the macro must
+    /// manually write the statements in a form that satisfies the
+    /// disjunction invariant.
+    pub fn check_disjunction_invariant(&self, vars: &VarDict) -> Result<()> {
+        let mut disjunct_map: HashMap<String, usize> = HashMap::new();
+
+        // If the recursive call returns Err, return that Err.
+        // Otherwise, we don't care about the Ok(usize) returned, so
+        // just return Ok(())
+        self.check_disjunction_invariant_rec(vars, &mut disjunct_map, 0, 0)?;
+        Ok(())
+    }
+
+    /// Internal recursive helper for
+    /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant).
+    ///
+    /// The `disjunct_map` is a [`HashMap`] that maps the names of
+    /// variables to an identifier of which child of a disjunction node
+    /// the variable appears in (or the root if none).  In the case of
+    /// nested disjunction node, the closest one to the leaf is what
+    /// matters.  Nodes are numbered in pre-order fashion, starting at 0
+    /// for the root, 1 for the first child of the root, 2 for the first
+    /// child of node 1, etc.  `cur_node` is the node id of `self`, and
+    /// `cur_disjunct_child` is the node id of the closest child of a
+    /// disjunction node (or 0 for the root if none).  Returns the next
+    /// node id to use in the preorder traversal.
+    fn check_disjunction_invariant_rec(
+        &self,
+        vars: &VarDict,
+        disjunct_map: &mut HashMap<String, usize>,
+        cur_node: usize,
+        cur_disjunct_child: usize,
+    ) -> Result<usize> {
+        let mut next_node = cur_node;
+        match self {
+            Self::And(v) => {
+                for st in v {
+                    next_node = st.check_disjunction_invariant_rec(
+                        vars,
+                        disjunct_map,
+                        next_node + 1,
+                        cur_disjunct_child,
+                    )?;
+                }
+            }
+            Self::Or(v) | Self::Thresh(_, v) => {
+                for st in v {
+                    next_node = st.check_disjunction_invariant_rec(
+                        vars,
+                        disjunct_map,
+                        next_node + 1,
+                        next_node + 1,
+                    )?;
+                }
+            }
+            Self::Leaf(e) => {
+                let mut psmap = PrivScalarMap {
+                    vars,
+                    closure: &mut |ident| {
+                        let varname = ident.to_string();
+                        if let Some(dis_id) = disjunct_map.get(&varname) {
+                            if *dis_id != cur_disjunct_child {
+                                return Err(syn::Error::new(
+                                    ident.span(),
+                                    "Disjunction invariant violation: a private variable cannot appear both inside and outside a single term of an OR or THRESH"));
+                            }
+                        } else {
+                            disjunct_map.insert(varname, cur_disjunct_child);
+                        }
+                        Ok(())
+                    },
+                    result: Ok(()),
+                };
+                psmap.visit_expr(e);
+                psmap.result?;
+            }
+        }
+        Ok(next_node)
+    }
 }
 
 #[cfg(test)]
@@ -269,4 +466,98 @@ mod test {
 
         StatementTree::parse_andlist(&exprlist).unwrap();
     }
+
+    #[test]
+    // Test the disjunction invariant checker
+    fn disjunction_invariant_test() {
+        let vars: VarDict = vardict_from_strs(&[
+            ("c", "S"),
+            ("d", "S"),
+            ("c0", "S"),
+            ("c1", "S"),
+            ("d0", "S"),
+            ("d1", "S"),
+            ("A", "pP"),
+            ("B", "pP"),
+            ("C", "pP"),
+            ("D", "pP"),
+        ]);
+        // This one is OK
+        let st_ok = StatementTree::parse(&parse_quote! {
+           AND (
+               C = c*B + r*A,
+               D = d*B + s*A,
+               OR (
+                   AND (
+                       C = c0*B + r0*A,
+                       D = d0*B + s0*A,
+                       c0 = d0,
+                   ),
+                   AND (
+                       C = c1*B + r1*A,
+                       D = d1*B + s1*A,
+                       c1 = d1 + 1,
+                   ),
+               )
+           )
+        })
+        .unwrap();
+        // not OK: c0 appears in two branches of the OR
+        let st_nok1 = StatementTree::parse(&parse_quote! {
+           AND (
+               C = c*B + r*A,
+               D = d*B + s*A,
+               OR (
+                   AND (
+                       C = c0*B + r0*A,
+                       D = d0*B + s0*A,
+                       c0 = d0,
+                   ),
+                   AND (
+                       C = c0*B + r0*A,
+                       D = d1*B + s1*A,
+                       c0 = d1 + 1,
+                   ),
+               )
+           )
+        })
+        .unwrap();
+        // not OK: c appears in one branch of the OR and also outside
+        // the OR
+        let st_nok2 = StatementTree::parse(&parse_quote! {
+           AND (
+               C = c*B + r*A,
+               D = d*B + s*A,
+               OR (
+                   AND (
+                       D = d0*B + s0*A,
+                       c = d0,
+                   ),
+                   AND (
+                       C = c1*B + r1*A,
+                       D = d1*B + s1*A,
+                       c1 = d1 + 1,
+                   ),
+               )
+           )
+        })
+        .unwrap();
+        // not OK: c and d appear in both branches of the OR, and also
+        // outside it
+        let st_nok3 = StatementTree::parse(&parse_quote! {
+           AND (
+               C = c*B + r*A,
+               D = d*B + s*A,
+               OR (
+                   c = d,
+                   c = d + 1,
+               )
+           )
+        })
+        .unwrap();
+        st_ok.check_disjunction_invariant(&vars).unwrap();
+        st_nok1.check_disjunction_invariant(&vars).unwrap_err();
+        st_nok2.check_disjunction_invariant(&vars).unwrap_err();
+        st_nok3.check_disjunction_invariant(&vars).unwrap_err();
+    }
 }

+ 2 - 0
sigma_compiler_core/src/syntax.rs

@@ -226,6 +226,8 @@ impl Parse for SigmaCompSpec {
             input.parse_terminated(Expr::parse, Token![,])?;
         let statementlist: Vec<Expr> = statementpunc.into_iter().collect();
         let statements = StatementTree::parse_andlist(&statementlist)?;
+        let vardict = taggedvardict_to_vardict(&vars);
+        statements.check_disjunction_invariant(&vardict)?;
 
         Ok(SigmaCompSpec {
             proto_name,