Browse Source

Start making and using the StatementTree

The statements in the ZKP form a tree.  The leaves are basic
statements of various kinds; for example, equations or inequalities
about Scalars and Points.  The interior nodes are combiners: `And`,
`Or`, or `Thresh` (with a given constant threshold).  A leaf is true
if the basic statement it contains is true.  An `And` node is true
if all of its children are true.  An `Or` node is true if at least
one of its children is true.  A `Thresh` node (with threshold `k`) is
true if at least `k` of its children are true.
Ian Goldberg 4 months ago
parent
commit
9525945aa9

+ 43 - 0
sigma_compiler_core/src/combiners.rs

@@ -0,0 +1,43 @@
+use syn::parse::Result;
+use syn::Expr;
+
+/// The statements in the ZKP form a tree.  The leaves are basic
+/// statements of various kinds; for example, equations or inequalities
+/// about Scalars and Points.  The interior nodes are combiners: `And`,
+/// `Or`, or `Thresh` (with a given constant threshold).  A leaf is true
+/// if the basic statement it contains is true.  An `And` node is true
+/// if all of its children are true.  An `Or` node is true if at least
+/// one of its children is true.  A `Thresh` node (with threshold `k`) is
+/// true if at least `k` of its children are true.
+
+#[derive(Clone, Debug)]
+pub enum StatementTree {
+    Leaf(Expr),
+    And(Vec<StatementTree>),
+    Or(Vec<StatementTree>),
+    Thresh(usize, Vec<StatementTree>),
+}
+
+impl StatementTree {
+    pub fn parse(expr: &Expr) -> Result<Self> {
+        Ok(StatementTree::Leaf(expr.clone()))
+    }
+
+    pub fn parse_andlist(exprlist: &[Expr]) -> Result<Self> {
+        let children: Result<Vec<StatementTree>> =
+            exprlist.iter().map(|e| Self::parse(e)).collect();
+        Ok(StatementTree::And(children?))
+    }
+
+    pub fn leaves_mut(&mut self) -> Vec<&mut Expr> {
+        match self {
+            StatementTree::Leaf(ref mut e) => vec![e],
+            StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
+                v.iter_mut().fold(Vec::<&mut Expr>::new(), |mut b, st| {
+                    b.extend(st.leaves_mut());
+                    b
+                })
+            }
+        }
+    }
+}

+ 7 - 3
sigma_compiler_core/src/lib.rs

@@ -4,6 +4,7 @@ use std::collections::HashMap;
 use syn::visit_mut::{self, VisitMut};
 use syn::visit_mut::{self, VisitMut};
 use syn::{parse_quote, Expr, Ident, Token};
 use syn::{parse_quote, Expr, Ident, Token};
 
 
+mod combiners;
 mod syntax;
 mod syntax;
 
 
 pub use syntax::{SigmaCompSpec, TaggedIdent, TaggedPoint, TaggedScalar, VarDict};
 pub use syntax::{SigmaCompSpec, TaggedIdent, TaggedPoint, TaggedScalar, VarDict};
@@ -245,11 +246,14 @@ pub fn sigma_compiler_core(
         } else {
         } else {
             quote! {}
             quote! {}
         };
         };
-        let mut assert_statements = spec.statements.clone();
+        let mut assert_statementtree = spec.statements.clone();
         let mut statement_fixup = StatementFixup::new(spec);
         let mut statement_fixup = StatementFixup::new(spec);
-        assert_statements
-            .iter_mut()
+        assert_statementtree
+            .leaves_mut()
+            .into_iter()
             .for_each(|expr| statement_fixup.visit_expr_mut(expr));
             .for_each(|expr| statement_fixup.visit_expr_mut(expr));
+        let assert_statements = assert_statementtree.leaves_mut();
+
         quote! {
         quote! {
             pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
             pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
                 #dumper
                 #dumper

+ 4 - 2
sigma_compiler_core/src/syntax.rs

@@ -1,3 +1,4 @@
+use super::combiners::StatementTree;
 use quote::format_ident;
 use quote::format_ident;
 use std::collections::HashMap;
 use std::collections::HashMap;
 use syn::ext::IdentExt;
 use syn::ext::IdentExt;
@@ -129,7 +130,7 @@ pub struct SigmaCompSpec {
     pub proto_name: Ident,
     pub proto_name: Ident,
     pub group_name: Ident,
     pub group_name: Ident,
     pub vars: VarDict,
     pub vars: VarDict,
-    pub statements: Vec<Expr>,
+    pub statements: StatementTree,
 }
 }
 
 
 // T is TaggedScalar or TaggedPoint
 // T is TaggedScalar or TaggedPoint
@@ -174,7 +175,8 @@ impl Parse for SigmaCompSpec {
 
 
         let statementpunc: Punctuated<Expr, Token![,]> =
         let statementpunc: Punctuated<Expr, Token![,]> =
             input.parse_terminated(Expr::parse, Token![,])?;
             input.parse_terminated(Expr::parse, Token![,])?;
-        let statements: Vec<Expr> = statementpunc.into_iter().collect();
+        let statementlist: Vec<Expr> = statementpunc.into_iter().collect();
+        let statements = StatementTree::parse_andlist(&statementlist)?;
 
 
         Ok(SigmaCompSpec {
         Ok(SigmaCompSpec {
             proto_name,
             proto_name,