Browse Source

In prove(), check that all of the statements are actually true

This is currently accomplished by internally massaging the parse trees
of the provided statements to:
 - Change equality from = to ==
 - Change any identifier `id` to either `params.id` or
   `witness.id` depending on whether it is public or private

The result should then be a valid Rust expression, which is then
(for now) put into an assert!().  Perhaps later it will return an Err
instead.
Ian Goldberg 5 months ago
parent
commit
6ceace1cf3
2 changed files with 82 additions and 2 deletions
  1. 1 1
      sigma_compiler_derive/Cargo.toml
  2. 81 1
      sigma_compiler_derive/src/lib.rs

+ 1 - 1
sigma_compiler_derive/Cargo.toml

@@ -7,7 +7,7 @@ edition = "2021"
 proc-macro = true
 
 [dependencies]
-syn = { version = "2.0", features = ["extra-traits"] }
+syn = { version = "2.0", features = ["extra-traits", "visit-mut", "full"] }
 quote = "1.0"
 
 [features]

+ 81 - 1
sigma_compiler_derive/src/lib.rs

@@ -1,8 +1,10 @@
 use proc_macro::TokenStream;
 use quote::{format_ident, quote, ToTokens};
+use std::collections::HashMap;
 use syn::parse::{Parse, ParseStream, Result};
 use syn::punctuated::Punctuated;
-use syn::{parenthesized, parse_macro_input, Expr, Ident, Token};
+use syn::visit_mut::{self, VisitMut};
+use syn::{parenthesized, parse_macro_input, parse_quote, Expr, Ident, Token};
 
 // Either an Ident or "vec(Ident)"
 #[derive(Debug)]
@@ -160,6 +162,78 @@ impl StructFieldList {
     }
 }
 
+/// An implementation of the
+/// [`VisitMut`](https://docs.rs/syn/latest/syn/visit_mut/trait.VisitMut.html)
+/// trait that massages the provided statements.  This massaging
+/// currently consists of:
+///   - Changing equality from = to ==
+///   - Changing any identifier `id` to either `params.id` or
+///     `witness.id` depending on whether it is public or private
+struct StatementFixup {
+    idmap: HashMap<String, Expr>,
+}
+
+impl StatementFixup {
+    pub fn new(spec: &SigmaCompSpec) -> Self {
+        let mut idmap: HashMap<String, Expr> = HashMap::new();
+
+        // For each public identifier id (Scalar, Point, or vectors of
+        // one of those), add to the map "id" -> params.id
+        for vid in spec
+            .pub_scalars
+            .iter()
+            .chain(spec.cind_points.iter())
+            .chain(spec.pub_points.iter())
+            .chain(spec.const_points.iter())
+        {
+            match vid {
+                VecIdent::Ident(id) | VecIdent::VecIdent(id) => {
+                    let idexpr: Expr = parse_quote! { params.#id };
+                    idmap.insert(id.to_string(), idexpr);
+                }
+            }
+        }
+
+        // For each private identifier id (Scalar, Point, or vectors of
+        // one of those), add to the map "id" -> witness.id
+        for vid in spec.rand_scalars.iter().chain(spec.priv_scalars.iter()) {
+            match vid {
+                VecIdent::Ident(id) | VecIdent::VecIdent(id) => {
+                    let idexpr: Expr = parse_quote! { witness.#id };
+                    idmap.insert(id.to_string(), idexpr);
+                }
+            }
+        }
+
+        Self { idmap }
+    }
+}
+
+impl VisitMut for StatementFixup {
+    fn visit_expr_mut(&mut self, node: &mut Expr) {
+        if let Expr::Assign(assn) = node {
+            *node = Expr::Binary(syn::ExprBinary {
+                attrs: assn.attrs.clone(),
+                left: assn.left.clone(),
+                op: syn::BinOp::Eq(Token![==](assn.eq_token.span)),
+                right: assn.right.clone(),
+            });
+        }
+        if let Expr::Path(expath) = node {
+            if let Some(id) = expath.path.get_ident() {
+                if let Some(expr) = self.idmap.get(&id.to_string()) {
+                    *node = expr.clone();
+                    // Don't recurse
+                    return;
+                }
+            }
+        }
+        // Unless we bailed out above, continue with the default
+        // traversal
+        visit_mut::visit_expr_mut(self, node);
+    }
+}
+
 fn sigma_compiler_impl(
     spec: &SigmaCompSpec,
     emit_prover: bool,
@@ -269,9 +343,15 @@ fn sigma_compiler_impl(
         } else {
             quote! {}
         };
+        let mut assert_statements = spec.statements.clone();
+        let mut statement_fixup = StatementFixup::new(&spec);
+        assert_statements
+            .iter_mut()
+            .for_each(|expr| statement_fixup.visit_expr_mut(expr));
         quote! {
             pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
                 #dumper
+                #(assert!(#assert_statements);)*
                 Ok(Vec::<u8>::default())
             }
         }