소스 검색

Code generation for substitution statements

The generated code currently does an assert! that the `Params` and
`Witness` supplied to it do in fact satisfy the statements being
substituted.  It should change to return an Err instead.
Ian Goldberg 3 달 전
부모
커밋
e7f804289e
3개의 변경된 파일85개의 추가작업 그리고 53개의 파일을 삭제
  1. 65 47
      sigma_compiler_core/src/codegen.rs
  2. 1 1
      sigma_compiler_core/src/lib.rs
  3. 19 5
      sigma_compiler_core/src/transform.rs

+ 65 - 47
sigma_compiler_core/src/codegen.rs

@@ -4,9 +4,10 @@
 use super::syntax::*;
 use proc_macro2::TokenStream;
 use quote::{quote, ToTokens};
-use std::collections::HashMap;
+#[cfg(test)]
+use syn::parse_quote;
 use syn::visit_mut::{self, VisitMut};
-use syn::{parse_quote, Expr, Ident, Token};
+use syn::{Expr, Ident, Token};
 
 // Names and types of fields that might end up in a generated struct
 enum StructField {
@@ -78,6 +79,24 @@ impl StructFieldList {
         });
         quote! { #(#decls)* }
     }
+    /// Output a ToTokens of the list of fields
+    pub fn field_list(&self) -> impl ToTokens {
+        let field_ids = self.fields.iter().map(|f| match f {
+            StructField::Scalar(id) => quote! {
+                #id,
+            },
+            StructField::VecScalar(id) => quote! {
+                #id,
+            },
+            StructField::Point(id) => quote! {
+                #id,
+            },
+            StructField::VecPoint(id) => quote! {
+                #id,
+            },
+        });
+        quote! { #(#field_ids)* }
+    }
 }
 
 /// An implementation of the
@@ -86,35 +105,7 @@ impl StructFieldList {
 ///
 /// This massaging currently consists of:
 ///   - Changing equality from = to ==
-///   - Changing any identifier `id` to either `params.id` or
-///     `witness.id` depending on whether it is public or private
-struct StatementFixup {
-    idmap: HashMap<String, Expr>,
-}
-
-impl StatementFixup {
-    pub fn new(spec: &SigmaCompSpec) -> Self {
-        let mut idmap: HashMap<String, Expr> = HashMap::new();
-
-        // For each public identifier id (Points, or Scalars marked
-        // "pub"), add to the map "id" -> params.id.  For each private
-        // identifier (Scalars not marked "pub"), add to the map "id" ->
-        // witness.id.
-        for (id, is_pub) in spec.vars.values().map(|ti| match ti {
-            TaggedIdent::Scalar(st) => (&st.id, st.is_pub),
-            TaggedIdent::Point(pt) => (&pt.id, true),
-        }) {
-            let idexpr: Expr = if is_pub {
-                parse_quote! { params.#id }
-            } else {
-                parse_quote! { witness.#id }
-            };
-            idmap.insert(id.to_string(), idexpr);
-        }
-
-        Self { idmap }
-    }
-}
+struct StatementFixup {}
 
 impl VisitMut for StatementFixup {
     fn visit_expr_mut(&mut self, node: &mut Expr) {
@@ -126,15 +117,6 @@ impl VisitMut for StatementFixup {
                 right: assn.right.clone(),
             });
         }
-        if let Expr::Path(expath) = node {
-            if let Some(id) = expath.path.get_ident() {
-                if let Some(expr) = self.idmap.get(&id.to_string()) {
-                    *node = expr.clone();
-                    // Don't recurse
-                    return;
-                }
-            }
-        }
         // Unless we bailed out above, continue with the default
         // traversal
         visit_mut::visit_expr_mut(self, node);
@@ -153,6 +135,7 @@ pub struct CodeGen {
     proto_name: Ident,
     group_name: Ident,
     vars: TaggedVarDict,
+    prove_code: TokenStream,
 }
 
 impl CodeGen {
@@ -163,9 +146,30 @@ impl CodeGen {
             proto_name: spec.proto_name.clone(),
             group_name: spec.group_name.clone(),
             vars: spec.vars.clone(),
+            prove_code: quote! {},
         }
     }
 
+    #[cfg(test)]
+    /// Create an empty [`CodeGen`].  Primarily useful in testing.
+    pub fn new_empty() -> Self {
+        Self {
+            proto_name: parse_quote! { proto },
+            group_name: parse_quote! { G },
+            vars: TaggedVarDict::default(),
+            prove_code: quote! {},
+        }
+    }
+
+    /// Append some code to the generated `prove` function
+    pub fn prove_append(&mut self, code: TokenStream) {
+        let prove_code = &self.prove_code;
+        self.prove_code = quote! {
+            #prove_code
+            #code
+        };
+    }
+
     /// Generate the code to be output by this macro.
     ///
     /// `emit_prover` and `emit_verifier` are as in
@@ -184,11 +188,11 @@ impl CodeGen {
             pub type Point = super::#group_name;
         };
 
+        let mut pub_params_fields = StructFieldList::default();
+        pub_params_fields.push_vars(&self.vars, true);
+
         // Generate the public params struct definition
         let params_def = {
-            let mut pub_params_fields = StructFieldList::default();
-            pub_params_fields.push_vars(&self.vars, true);
-
             let decls = pub_params_fields.field_decls();
             let dump_impl = if cfg!(feature = "dump") {
                 let dump_chunks = pub_params_fields.fields.iter().map(|f| match f {
@@ -250,11 +254,11 @@ impl CodeGen {
             }
         };
 
+        let mut witness_fields = StructFieldList::default();
+        witness_fields.push_vars(&self.vars, false);
+
         // Generate the witness struct definition
         let witness_def = if emit_prover {
-            let mut witness_fields = StructFieldList::default();
-            witness_fields.push_vars(&self.vars, false);
-
             let decls = witness_fields.field_decls();
             quote! {
                 pub struct Witness {
@@ -276,17 +280,26 @@ impl CodeGen {
             } else {
                 quote! {}
             };
+            let params_ids = pub_params_fields.field_list();
+            let witness_ids = witness_fields.field_list();
             let mut assert_statementtree = spec.statements.clone();
-            let mut statement_fixup = StatementFixup::new(spec);
+            let mut statement_fixup = StatementFixup {};
             assert_statementtree
                 .leaves_mut()
                 .into_iter()
                 .for_each(|expr| statement_fixup.visit_expr_mut(expr));
             let assert_statements = assert_statementtree.leaves_mut();
+            let prove_code = &self.prove_code;
 
             quote! {
+                // The "#[allow(unused_variables)]" is temporary, until we
+                // actually call the underlying sigma macro
+                #[allow(unused_variables)]
                 pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
                     #dumper
+                    let Params { #params_ids } = *params;
+                    let Witness { #witness_ids } = *witness;
+                    #prove_code
                     #(assert!(#assert_statements);)*
                     Ok(Vec::<u8>::default())
                 }
@@ -306,9 +319,14 @@ impl CodeGen {
             } else {
                 quote! {}
             };
+            let params_ids = pub_params_fields.field_list();
             quote! {
+                // The "#[allow(unused_variables)]" is temporary, until we
+                // actually call the underlying sigma macro
+                #[allow(unused_variables)]
                 pub fn verify(params: &Params, proof: &[u8]) -> Result<(),()> {
                     #dumper
+                    let Params { #params_ids } = *params;
                     Ok(())
                 }
             }

+ 1 - 1
sigma_compiler_core/src/lib.rs

@@ -33,7 +33,7 @@ pub fn sigma_compiler_core(
     let mut codegen = codegen::CodeGen::new(spec);
 
     // Apply any substitution transformations
-    transform::apply_substitutions(&mut spec.statements, &mut spec.vars).unwrap();
+    transform::apply_substitutions(&mut codegen, &mut spec.statements, &mut spec.vars).unwrap();
 
     codegen.generate(spec, emit_prover, emit_verifier)
 }

+ 19 - 5
sigma_compiler_core/src/transform.rs

@@ -3,9 +3,11 @@
 //!
 //! [disjunction invariant]: StatementTree::check_disjunction_invariant
 
+use super::codegen::CodeGen;
 use super::sigma::combiners::*;
 use super::syntax::taggedvardict_to_vardict;
 use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
+use quote::quote;
 use std::collections::{HashSet, VecDeque};
 use syn::visit::Visit;
 use syn::visit_mut::{self, VisitMut};
@@ -200,7 +202,10 @@ fn do_substitution<'a>(expr: &mut Expr, idstr: &'a str, replacement: &'a Expr) {
 /// from the [`TaggedVarDict`].  The leaves of the [`StatementTree`]
 /// containing the substitution statements themselves will be turned
 /// into the constant `true` and then pruned using
-/// [`prune_statement_tree`].
+/// [`prune_statement_tree`].  The [`CodeGen`] will be used to generate
+/// tests in the generated `prove` function that the `Params` and
+/// `Witness` supplied to it do in fact satisfy the statements being
+/// substituted.
 ///
 /// It is the case that if the [disjunction invariant] is satisfied
 /// before this function is called (and the caller must ensure that it
@@ -209,7 +214,11 @@ fn do_substitution<'a>(expr: &mut Expr, idstr: &'a str, replacement: &'a Expr) {
 ///
 /// [arithmetic expression]: super::sigma::types::expr_type
 /// [disjunction invariant]: StatementTree::check_disjunction_invariant
-pub fn apply_substitutions(st: &mut StatementTree, vars: &mut TaggedVarDict) -> Result<()> {
+pub fn apply_substitutions(
+    codegen: &mut CodeGen,
+    st: &mut StatementTree,
+    vars: &mut TaggedVarDict,
+) -> Result<()> {
     // Gather mutable references to all Exprs in the leaves of the
     // StatementTree.  Note that this ignores the combiner structure in
     // the StatementTree, but that's fine.
@@ -235,8 +244,9 @@ pub fn apply_substitutions(st: &mut StatementTree, vars: &mut TaggedVarDict) ->
         }
         if let Some(id) = is_subs {
             // If this leaf is a substitution of a private Scalar, add
-            // it to subs and replace it in the StatementTree with the
-            // constant true.
+            // it to subs, replace it in the StatementTree with the
+            // constant true, and generate some code for `prove` to
+            // check the statement.
             let mut expr: Expr = parse_quote! { true };
             std::mem::swap(&mut expr, *leafexpr);
             // This "if let" is guaranteed to succeed
@@ -246,6 +256,9 @@ pub fn apply_substitutions(st: &mut StatementTree, vars: &mut TaggedVarDict) ->
                     return Err(Error::new(id.span(), "variable substituted multiple times"));
                 }
                 let right = paren_if_needed(*right);
+                codegen.prove_append(quote! {
+                    assert!(#id == #right);
+                });
                 subs.push_back((id, right, used_priv_scalars));
             }
         }
@@ -438,7 +451,8 @@ mod tests {
     ) -> Result<()> {
         let mut taggedvardict = taggedvardict_from_strs(vars);
         let mut st = StatementTree::parse(&e).unwrap();
-        apply_substitutions(&mut st, &mut taggedvardict)?;
+        let mut codegen = CodeGen::new_empty();
+        apply_substitutions(&mut codegen, &mut st, &mut taggedvardict)?;
         let subbed_taggedvardict = taggedvardict_from_strs(subbed_vars);
         let subbed_st = StatementTree::parse(&subbed_e).unwrap();
         assert_eq!(st, subbed_st);