|
@@ -1,9 +1,53 @@
|
|
|
//! This module creates and manipulates trees of basic statements
|
|
|
//! combined with `AND`, `OR`, and `THRESH`.
|
|
|
|
|
|
+use super::types::*;
|
|
|
+use std::collections::HashMap;
|
|
|
use syn::parse::Result;
|
|
|
+use syn::visit::Visit;
|
|
|
use syn::Expr;
|
|
|
|
|
|
+/// For each [`Ident`](struct@syn::Ident) representing a private
|
|
|
+/// `Scalar` (as listed in a [`VarDict`]) that appears in an [`Expr`],
|
|
|
+/// call a given closure.
|
|
|
+struct PrivScalarMap<'a> {
|
|
|
+ /// The [`VarDict`] that maps variable names to their types
|
|
|
+ pub vars: &'a VarDict,
|
|
|
+
|
|
|
+ /// The closure that is called for each [`Ident`](struct@syn::Ident)
|
|
|
+ /// found in the [`Expr`] (provided in the call to
|
|
|
+ /// [`visit_expr`](PrivScalarMap::visit_expr)) that represents a
|
|
|
+ /// private `Scalar`
|
|
|
+ pub closure: &'a mut dyn FnMut(&syn::Ident) -> Result<()>,
|
|
|
+
|
|
|
+ /// The accumulated result. This will be the first
|
|
|
+ /// [`Err`](Result::Err) returned from the closure, or
|
|
|
+ /// [`Ok(())`](Result::Ok) if all calls to the closure succeeded.
|
|
|
+ pub result: Result<()>,
|
|
|
+}
|
|
|
+
|
|
|
+impl<'a> Visit<'a> for PrivScalarMap<'a> {
|
|
|
+ fn visit_path(&mut self, path: &'a syn::Path) {
|
|
|
+ // Whenever we see a `Path`, check first if it's just a bare
|
|
|
+ // `Ident`
|
|
|
+ let Some(id) = path.get_ident() else {
|
|
|
+ return;
|
|
|
+ };
|
|
|
+ // Then check if that `Ident` appears in the `VarDict`
|
|
|
+ let Some(vartype) = self.vars.get(&id.to_string()) else {
|
|
|
+ return;
|
|
|
+ };
|
|
|
+ // If so, and the `Ident` represents a private Scalar,
|
|
|
+ // call the closure if we haven't seen an `Err` returned from
|
|
|
+ // the closure yet.
|
|
|
+ if let AExprType::Scalar { is_pub: false, .. } = vartype {
|
|
|
+ if self.result.is_ok() {
|
|
|
+ self.result = (self.closure)(id);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
/// The statements in the ZKP form a tree. The leaves are basic
|
|
|
/// statements of various kinds; for example, equations or inequalities
|
|
|
/// about Scalars and Points. The interior nodes are combiners: `And`,
|
|
@@ -22,12 +66,13 @@ pub enum StatementTree {
|
|
|
}
|
|
|
|
|
|
impl StatementTree {
|
|
|
+ #[cfg(not(doctest))]
|
|
|
/// Parse an [`Expr`] (which may contain nested `AND`, `OR`, or
|
|
|
/// `THRESH`) into a [`StatementTree`]. For example, the
|
|
|
/// [`Expr`] obtained from:
|
|
|
/// ```
|
|
|
/// parse_quote! {
|
|
|
- /// AND(
|
|
|
+ /// AND (
|
|
|
/// C = c*B + r*A,
|
|
|
/// D = d*B + s*A,
|
|
|
/// OR (
|
|
@@ -123,6 +168,158 @@ impl StatementTree {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ #[cfg(not(doctest))]
|
|
|
+ /// Verify whether the [`StatementTree`] satisfies the disjunction
|
|
|
+ /// invariant.
|
|
|
+ ///
|
|
|
+ /// A _disjunction node_ is an [`Or`](StatementTree::Or) or
|
|
|
+ /// [`Thresh`](StatementTree::Thresh) node in the [`StatementTree`].
|
|
|
+ /// The _disjunction invariant_ is that a private variable (which is
|
|
|
+ /// necessarily a `Scalar` since there are no private `Point`
|
|
|
+ /// variables) that appears in the subtree rooted at a child of a
|
|
|
+ /// disjunction node cannot also appear outside of that subtree.
|
|
|
+ ///
|
|
|
+ /// For example, if all of the lowercase variables are private
|
|
|
+ /// `Scalar`s, the [`StatementTree`] created from:
|
|
|
+ ///
|
|
|
+ /// ```
|
|
|
+ /// AND (
|
|
|
+ /// C = c*B + r*A,
|
|
|
+ /// D = d*B + s*A,
|
|
|
+ /// OR (
|
|
|
+ /// AND (
|
|
|
+ /// C = c0*B + r0*A,
|
|
|
+ /// D = d0*B + s0*A,
|
|
|
+ /// c0 = d0,
|
|
|
+ /// ),
|
|
|
+ /// AND (
|
|
|
+ /// C = c1*B + r1*A,
|
|
|
+ /// D = d1*B + s1*A,
|
|
|
+ /// c1 = d1 + 1,
|
|
|
+ /// ),
|
|
|
+ /// )
|
|
|
+ /// )
|
|
|
+ /// ```
|
|
|
+ ///
|
|
|
+ /// satisfies the disjunction invariant, but
|
|
|
+ ///
|
|
|
+ /// ```
|
|
|
+ /// AND (
|
|
|
+ /// C = c*B + r*A,
|
|
|
+ /// D = d*B + s*A,
|
|
|
+ /// OR (
|
|
|
+ /// AND (
|
|
|
+ /// D = d0*B + s0*A,
|
|
|
+ /// c = d0,
|
|
|
+ /// ),
|
|
|
+ /// AND (
|
|
|
+ /// C = c1*B + r1*A,
|
|
|
+ /// D = d1*B + s1*A,
|
|
|
+ /// c1 = d1 + 1,
|
|
|
+ /// ),
|
|
|
+ /// )
|
|
|
+ /// )
|
|
|
+ /// ```
|
|
|
+ ///
|
|
|
+ /// does not, because `c` appears in the first child of the `OR` and
|
|
|
+ /// also outside of the `OR` entirely. Indeed, the reason to write
|
|
|
+ /// the first expression above rather than the more natural
|
|
|
+ ///
|
|
|
+ /// ```
|
|
|
+ /// AND (
|
|
|
+ /// C = c*B + r*A,
|
|
|
+ /// D = d*B + s*A,
|
|
|
+ /// OR (
|
|
|
+ /// c = d,
|
|
|
+ /// c = d + 1,
|
|
|
+ /// )
|
|
|
+ /// )
|
|
|
+ /// ```
|
|
|
+ ///
|
|
|
+ /// is exactly that the invariant must be satisfied.
|
|
|
+ ///
|
|
|
+ /// (In the future, it is possible we may provide a transformer that
|
|
|
+ /// will automatically convert [`StatementTree`]s to ones that
|
|
|
+ /// satisfy the invariant, but for now, the user of the macro must
|
|
|
+ /// manually write the statements in a form that satisfies the
|
|
|
+ /// disjunction invariant.
|
|
|
+ pub fn check_disjunction_invariant(&self, vars: &VarDict) -> Result<()> {
|
|
|
+ let mut disjunct_map: HashMap<String, usize> = HashMap::new();
|
|
|
+
|
|
|
+ // If the recursive call returns Err, return that Err.
|
|
|
+ // Otherwise, we don't care about the Ok(usize) returned, so
|
|
|
+ // just return Ok(())
|
|
|
+ self.check_disjunction_invariant_rec(vars, &mut disjunct_map, 0, 0)?;
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Internal recursive helper for
|
|
|
+ /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant).
|
|
|
+ ///
|
|
|
+ /// The `disjunct_map` is a [`HashMap`] that maps the names of
|
|
|
+ /// variables to an identifier of which child of a disjunction node
|
|
|
+ /// the variable appears in (or the root if none). In the case of
|
|
|
+ /// nested disjunction node, the closest one to the leaf is what
|
|
|
+ /// matters. Nodes are numbered in pre-order fashion, starting at 0
|
|
|
+ /// for the root, 1 for the first child of the root, 2 for the first
|
|
|
+ /// child of node 1, etc. `cur_node` is the node id of `self`, and
|
|
|
+ /// `cur_disjunct_child` is the node id of the closest child of a
|
|
|
+ /// disjunction node (or 0 for the root if none). Returns the next
|
|
|
+ /// node id to use in the preorder traversal.
|
|
|
+ fn check_disjunction_invariant_rec(
|
|
|
+ &self,
|
|
|
+ vars: &VarDict,
|
|
|
+ disjunct_map: &mut HashMap<String, usize>,
|
|
|
+ cur_node: usize,
|
|
|
+ cur_disjunct_child: usize,
|
|
|
+ ) -> Result<usize> {
|
|
|
+ let mut next_node = cur_node;
|
|
|
+ match self {
|
|
|
+ Self::And(v) => {
|
|
|
+ for st in v {
|
|
|
+ next_node = st.check_disjunction_invariant_rec(
|
|
|
+ vars,
|
|
|
+ disjunct_map,
|
|
|
+ next_node + 1,
|
|
|
+ cur_disjunct_child,
|
|
|
+ )?;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Self::Or(v) | Self::Thresh(_, v) => {
|
|
|
+ for st in v {
|
|
|
+ next_node = st.check_disjunction_invariant_rec(
|
|
|
+ vars,
|
|
|
+ disjunct_map,
|
|
|
+ next_node + 1,
|
|
|
+ next_node + 1,
|
|
|
+ )?;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Self::Leaf(e) => {
|
|
|
+ let mut psmap = PrivScalarMap {
|
|
|
+ vars,
|
|
|
+ closure: &mut |ident| {
|
|
|
+ let varname = ident.to_string();
|
|
|
+ if let Some(dis_id) = disjunct_map.get(&varname) {
|
|
|
+ if *dis_id != cur_disjunct_child {
|
|
|
+ return Err(syn::Error::new(
|
|
|
+ ident.span(),
|
|
|
+ "Disjunction invariant violation: a private variable cannot appear both inside and outside a single term of an OR or THRESH"));
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ disjunct_map.insert(varname, cur_disjunct_child);
|
|
|
+ }
|
|
|
+ Ok(())
|
|
|
+ },
|
|
|
+ result: Ok(()),
|
|
|
+ };
|
|
|
+ psmap.visit_expr(e);
|
|
|
+ psmap.result?;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Ok(next_node)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
@@ -269,4 +466,98 @@ mod test {
|
|
|
|
|
|
StatementTree::parse_andlist(&exprlist).unwrap();
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ // Test the disjunction invariant checker
|
|
|
+ fn disjunction_invariant_test() {
|
|
|
+ let vars: VarDict = vardict_from_strs(&[
|
|
|
+ ("c", "S"),
|
|
|
+ ("d", "S"),
|
|
|
+ ("c0", "S"),
|
|
|
+ ("c1", "S"),
|
|
|
+ ("d0", "S"),
|
|
|
+ ("d1", "S"),
|
|
|
+ ("A", "pP"),
|
|
|
+ ("B", "pP"),
|
|
|
+ ("C", "pP"),
|
|
|
+ ("D", "pP"),
|
|
|
+ ]);
|
|
|
+ // This one is OK
|
|
|
+ let st_ok = StatementTree::parse(&parse_quote! {
|
|
|
+ AND (
|
|
|
+ C = c*B + r*A,
|
|
|
+ D = d*B + s*A,
|
|
|
+ OR (
|
|
|
+ AND (
|
|
|
+ C = c0*B + r0*A,
|
|
|
+ D = d0*B + s0*A,
|
|
|
+ c0 = d0,
|
|
|
+ ),
|
|
|
+ AND (
|
|
|
+ C = c1*B + r1*A,
|
|
|
+ D = d1*B + s1*A,
|
|
|
+ c1 = d1 + 1,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ })
|
|
|
+ .unwrap();
|
|
|
+ // not OK: c0 appears in two branches of the OR
|
|
|
+ let st_nok1 = StatementTree::parse(&parse_quote! {
|
|
|
+ AND (
|
|
|
+ C = c*B + r*A,
|
|
|
+ D = d*B + s*A,
|
|
|
+ OR (
|
|
|
+ AND (
|
|
|
+ C = c0*B + r0*A,
|
|
|
+ D = d0*B + s0*A,
|
|
|
+ c0 = d0,
|
|
|
+ ),
|
|
|
+ AND (
|
|
|
+ C = c0*B + r0*A,
|
|
|
+ D = d1*B + s1*A,
|
|
|
+ c0 = d1 + 1,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ })
|
|
|
+ .unwrap();
|
|
|
+ // not OK: c appears in one branch of the OR and also outside
|
|
|
+ // the OR
|
|
|
+ let st_nok2 = StatementTree::parse(&parse_quote! {
|
|
|
+ AND (
|
|
|
+ C = c*B + r*A,
|
|
|
+ D = d*B + s*A,
|
|
|
+ OR (
|
|
|
+ AND (
|
|
|
+ D = d0*B + s0*A,
|
|
|
+ c = d0,
|
|
|
+ ),
|
|
|
+ AND (
|
|
|
+ C = c1*B + r1*A,
|
|
|
+ D = d1*B + s1*A,
|
|
|
+ c1 = d1 + 1,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ })
|
|
|
+ .unwrap();
|
|
|
+ // not OK: c and d appear in both branches of the OR, and also
|
|
|
+ // outside it
|
|
|
+ let st_nok3 = StatementTree::parse(&parse_quote! {
|
|
|
+ AND (
|
|
|
+ C = c*B + r*A,
|
|
|
+ D = d*B + s*A,
|
|
|
+ OR (
|
|
|
+ c = d,
|
|
|
+ c = d + 1,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ })
|
|
|
+ .unwrap();
|
|
|
+ st_ok.check_disjunction_invariant(&vars).unwrap();
|
|
|
+ st_nok1.check_disjunction_invariant(&vars).unwrap_err();
|
|
|
+ st_nok2.check_disjunction_invariant(&vars).unwrap_err();
|
|
|
+ st_nok3.check_disjunction_invariant(&vars).unwrap_err();
|
|
|
+ }
|
|
|
}
|