Explorar el Código

In prove(), check that all of the statements are actually true

This is currently accomplished by internally massaging the parse trees
of the provided statements to:
 - Change equality from = to ==
 - Change any identifier `id` to either `params.id` or
   `witness.id` depending on whether it is public or private

The result should then be a valid Rust expression, which is then
(for now) put into an assert!().  Perhaps later it will return an Err
instead.
Ian Goldberg hace 10 meses
padre
commit
6ceace1cf3
Se han modificado 2 ficheros con 82 adiciones y 2 borrados
  1. 1 1
      sigma_compiler_derive/Cargo.toml
  2. 81 1
      sigma_compiler_derive/src/lib.rs

+ 1 - 1
sigma_compiler_derive/Cargo.toml

@@ -7,7 +7,7 @@ edition = "2021"
 proc-macro = true
 
 [dependencies]
-syn = { version = "2.0", features = ["extra-traits"] }
+syn = { version = "2.0", features = ["extra-traits", "visit-mut", "full"] }
 quote = "1.0"
 
 [features]

+ 81 - 1
sigma_compiler_derive/src/lib.rs

@@ -1,8 +1,10 @@
 use proc_macro::TokenStream;
 use quote::{format_ident, quote, ToTokens};
+use std::collections::HashMap;
 use syn::parse::{Parse, ParseStream, Result};
 use syn::punctuated::Punctuated;
-use syn::{parenthesized, parse_macro_input, Expr, Ident, Token};
+use syn::visit_mut::{self, VisitMut};
+use syn::{parenthesized, parse_macro_input, parse_quote, Expr, Ident, Token};
 
 // Either an Ident or "vec(Ident)"
 #[derive(Debug)]
@@ -160,6 +162,78 @@ impl StructFieldList {
     }
 }
 
+/// An implementation of the
+/// [`VisitMut`](https://docs.rs/syn/latest/syn/visit_mut/trait.VisitMut.html)
+/// trait that massages the provided statements.  This massaging
+/// currently consists of:
+///   - Changing equality from = to ==
+///   - Changing any identifier `id` to either `params.id` or
+///     `witness.id` depending on whether it is public or private
+struct StatementFixup {
+    idmap: HashMap<String, Expr>,
+}
+
+impl StatementFixup {
+    pub fn new(spec: &SigmaCompSpec) -> Self {
+        let mut idmap: HashMap<String, Expr> = HashMap::new();
+
+        // For each public identifier id (Scalar, Point, or vectors of
+        // one of those), add to the map "id" -> params.id
+        for vid in spec
+            .pub_scalars
+            .iter()
+            .chain(spec.cind_points.iter())
+            .chain(spec.pub_points.iter())
+            .chain(spec.const_points.iter())
+        {
+            match vid {
+                VecIdent::Ident(id) | VecIdent::VecIdent(id) => {
+                    let idexpr: Expr = parse_quote! { params.#id };
+                    idmap.insert(id.to_string(), idexpr);
+                }
+            }
+        }
+
+        // For each private identifier id (Scalar, Point, or vectors of
+        // one of those), add to the map "id" -> witness.id
+        for vid in spec.rand_scalars.iter().chain(spec.priv_scalars.iter()) {
+            match vid {
+                VecIdent::Ident(id) | VecIdent::VecIdent(id) => {
+                    let idexpr: Expr = parse_quote! { witness.#id };
+                    idmap.insert(id.to_string(), idexpr);
+                }
+            }
+        }
+
+        Self { idmap }
+    }
+}
+
+impl VisitMut for StatementFixup {
+    fn visit_expr_mut(&mut self, node: &mut Expr) {
+        if let Expr::Assign(assn) = node {
+            *node = Expr::Binary(syn::ExprBinary {
+                attrs: assn.attrs.clone(),
+                left: assn.left.clone(),
+                op: syn::BinOp::Eq(Token![==](assn.eq_token.span)),
+                right: assn.right.clone(),
+            });
+        }
+        if let Expr::Path(expath) = node {
+            if let Some(id) = expath.path.get_ident() {
+                if let Some(expr) = self.idmap.get(&id.to_string()) {
+                    *node = expr.clone();
+                    // Don't recurse
+                    return;
+                }
+            }
+        }
+        // Unless we bailed out above, continue with the default
+        // traversal
+        visit_mut::visit_expr_mut(self, node);
+    }
+}
+
 fn sigma_compiler_impl(
     spec: &SigmaCompSpec,
     emit_prover: bool,
@@ -269,9 +343,15 @@ fn sigma_compiler_impl(
         } else {
             quote! {}
         };
+        let mut assert_statements = spec.statements.clone();
+        let mut statement_fixup = StatementFixup::new(&spec);
+        assert_statements
+            .iter_mut()
+            .for_each(|expr| statement_fixup.visit_expr_mut(expr));
         quote! {
             pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
                 #dumper
+                #(assert!(#assert_statements);)*
                 Ok(Vec::<u8>::default())
             }
         }