ソースを参照

Code generation stubs for sigma_rs

Add stubs that will eventually generate the code that will call the
sigma_rs API.  The sigma_compiler codegen already calls those stubs (but
the stubs do nothing).
Ian Goldberg 5 ヶ月 前
コミット
31e596b0c3
5 ファイル変更280 行追加78 行削除
  1. 1 0
      Cargo.toml
  2. 50 76
      sigma_compiler_core/src/codegen.rs
  3. 194 1
      sigma_compiler_core/src/sigma/codegen.rs
  4. 1 1
      src/lib.rs
  5. 34 0
      tests/basic.rs

+ 1 - 0
Cargo.toml

@@ -11,3 +11,4 @@ sigma-rs = { path = "../sigma" }
 [dev-dependencies]
 curve25519-dalek = { version = "4", features = [ "group", "rand_core", "digest" ] }
 rand = "0.8.5"
+sha2 = "0.10"

+ 50 - 76
sigma_compiler_core/src/codegen.rs

@@ -1,64 +1,13 @@
 //! A module for generating the code produced by this macro.  This code
 //! will interact with the underlying `sigma` macro.
 
-use super::syntax::*;
 use super::sigma::codegen::StructFieldList;
+use super::syntax::*;
 use proc_macro2::TokenStream;
-use quote::quote;
+use quote::{format_ident, quote};
 #[cfg(test)]
 use syn::parse_quote;
-use syn::visit_mut::{self, VisitMut};
-use syn::{Expr, Ident, Token};
-
-fn push_vars(structfieldlist: &mut StructFieldList,
-    vars: &TaggedVarDict, is_pub: bool) {
-    for (_, ti) in vars.iter() {
-        match ti {
-            TaggedIdent::Scalar(st) => {
-                if st.is_pub == is_pub {
-                    if st.is_vec {
-                        structfieldlist.push_vecscalar(&st.id)
-                    } else {
-                        structfieldlist.push_scalar(&st.id)
-                    }
-                }
-            }
-            TaggedIdent::Point(pt) => {
-                if is_pub {
-                    if pt.is_vec {
-                        structfieldlist.push_vecpoint(&pt.id)
-                    } else {
-                        structfieldlist.push_point(&pt.id)
-                    }
-                }
-            }
-        }
-    }
-}
-
-/// An implementation of the
-/// [`VisitMut`](https://docs.rs/syn/latest/syn/visit_mut/trait.VisitMut.html)
-/// trait that massages the provided statements.
-///
-/// This massaging currently consists of:
-///   - Changing equality from = to ==
-struct StatementFixup {}
-
-impl VisitMut for StatementFixup {
-    fn visit_expr_mut(&mut self, node: &mut Expr) {
-        if let Expr::Assign(assn) = node {
-            *node = Expr::Binary(syn::ExprBinary {
-                attrs: assn.attrs.clone(),
-                left: assn.left.clone(),
-                op: syn::BinOp::Eq(Token![==](assn.eq_token.span)),
-                right: assn.right.clone(),
-            });
-        }
-        // Unless we bailed out above, continue with the default
-        // traversal
-        visit_mut::visit_expr_mut(self, node);
-    }
-}
+use syn::Ident;
 
 /// The main struct to handle code generation for this macro.
 ///
@@ -110,7 +59,7 @@ impl CodeGen {
     /// Generate the code to be output by this macro.
     ///
     /// `emit_prover` and `emit_verifier` are as in
-    /// [`super::sigma_compiler_core`].
+    /// [`sigma_compiler_core`](super::sigma_compiler_core).
     pub fn generate(
         &self,
         spec: &SigmaCompSpec,
@@ -126,8 +75,33 @@ impl CodeGen {
             pub type Point = super::#group_name;
         };
 
+        // vardict contains the variables that were defined in the macro
+        // call to [`sigma_compiler`]
+        let vardict = taggedvardict_to_vardict(&self.vars);
+        // sigma_rs_vardict contains the variables that we are passing
+        // to sigma_rs.  We may have removed some via substitution, and
+        // we may have added some when compiling statements like range
+        // assertions into underlying linear combination assertions.
+        let sigma_rs_vardict = taggedvardict_to_vardict(&spec.vars);
+
+        // Generate the code that uses the underlying sigma_rs API
+        let sigma_rs_codegen = super::sigma::codegen::CodeGen::new(
+            format_ident!("sigma"),
+            format_ident!("Point"),
+            &sigma_rs_vardict,
+            &spec.statements,
+        );
+        let sigma_rs_code = sigma_rs_codegen.generate(emit_prover, emit_verifier);
+
         let mut pub_params_fields = StructFieldList::default();
-        push_vars(&mut pub_params_fields, &self.vars, true);
+        pub_params_fields.push_vars(&vardict, true);
+        let mut witness_fields = StructFieldList::default();
+        witness_fields.push_vars(&vardict, false);
+
+        let mut sigma_rs_params_fields = StructFieldList::default();
+        sigma_rs_params_fields.push_vars(&sigma_rs_vardict, true);
+        let mut sigma_rs_witness_fields = StructFieldList::default();
+        sigma_rs_witness_fields.push_vars(&sigma_rs_vardict, false);
 
         // Generate the public params struct definition
         let params_def = {
@@ -166,9 +140,6 @@ impl CodeGen {
             }
         };
 
-        let mut witness_fields = StructFieldList::default();
-        push_vars(&mut witness_fields, &self.vars, false);
-
         // Generate the witness struct definition
         let witness_def = if emit_prover {
             let decls = witness_fields.field_decls();
@@ -194,26 +165,25 @@ impl CodeGen {
             };
             let params_ids = pub_params_fields.field_list();
             let witness_ids = witness_fields.field_list();
-            let mut assert_statementtree = spec.statements.clone();
-            let mut statement_fixup = StatementFixup {};
-            assert_statementtree
-                .leaves_mut()
-                .into_iter()
-                .for_each(|expr| statement_fixup.visit_expr_mut(expr));
-            let assert_statements = assert_statementtree.leaves_mut();
+            let sigma_rs_params_ids = sigma_rs_params_fields.field_list();
+            let sigma_rs_witness_ids = sigma_rs_witness_fields.field_list();
             let prove_code = &self.prove_code;
+            let codegen_params_var = format_ident!("{}_sigma_params", "codegen");
+            let codegen_witness_var = format_ident!("{}_sigma_witness", "codegen");
 
             quote! {
-                // The "#[allow(unused_variables)]" is temporary, until we
-                // actually call the underlying sigma macro
-                #[allow(unused_variables)]
                 pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
                     #dumper
                     let Params { #params_ids } = *params;
                     let Witness { #witness_ids } = *witness;
                     #prove_code
-                    #(assert!(#assert_statements);)*
-                    Ok(Vec::<u8>::default())
+                    let #codegen_params_var = sigma::Params {
+                        #sigma_rs_params_ids
+                    };
+                    let #codegen_witness_var = sigma::Witness {
+                        #sigma_rs_witness_ids
+                    };
+                    sigma::prove(&#codegen_params_var, &#codegen_witness_var)
                 }
             }
         } else {
@@ -232,14 +202,16 @@ impl CodeGen {
                 quote! {}
             };
             let params_ids = pub_params_fields.field_list();
+            let sigma_rs_params_ids = sigma_rs_params_fields.field_list();
+            let codegen_params_var = format_ident!("{}_sigma_params", "codegen");
             quote! {
-                // The "#[allow(unused_variables)]" is temporary, until we
-                // actually call the underlying sigma macro
-                #[allow(unused_variables)]
                 pub fn verify(params: &Params, proof: &[u8]) -> Result<(),()> {
                     #dumper
                     let Params { #params_ids } = *params;
-                    Ok(())
+                    let #codegen_params_var = sigma::Params {
+                        #sigma_rs_params_ids
+                    };
+                    sigma::verify(&#codegen_params_var, proof)
                 }
             }
         } else {
@@ -261,9 +233,11 @@ impl CodeGen {
                 #dump_use
 
                 #group_types
+
+                #sigma_rs_code
+
                 #params_def
                 #witness_def
-
                 #prove_func
                 #verify_func
             }

+ 194 - 1
sigma_compiler_core/src/sigma/codegen.rs

@@ -3,7 +3,10 @@
 //! If that crate gets its own macro interface, it can use this module
 //! directly.
 
-use quote::{quote, ToTokens};
+use super::combiners::StatementTree;
+use super::types::{AExprType, VarDict};
+use proc_macro2::TokenStream;
+use quote::{format_ident, quote, ToTokens};
 use syn::Ident;
 
 // Names and types of fields that might end up in a generated struct
@@ -33,6 +36,30 @@ impl StructFieldList {
     pub fn push_vecpoint(&mut self, s: &Ident) {
         self.fields.push(StructField::VecPoint(s.clone()));
     }
+    pub fn push_vars(&mut self, vars: &VarDict, for_params: bool) {
+        for (id, ti) in vars.iter() {
+            match ti {
+                AExprType::Scalar { is_pub, is_vec, .. } => {
+                    if *is_pub == for_params {
+                        if *is_vec {
+                            self.push_vecscalar(&format_ident!("{}", id))
+                        } else {
+                            self.push_scalar(&format_ident!("{}", id))
+                        }
+                    }
+                }
+                AExprType::Point { is_vec, .. } => {
+                    if for_params {
+                        if *is_vec {
+                            self.push_vecpoint(&format_ident!("{}", id))
+                        } else {
+                            self.push_point(&format_ident!("{}", id))
+                        }
+                    }
+                }
+            }
+        }
+    }
     #[cfg(feature = "dump")]
     /// Output a ToTokens of the contents of the fields
     pub fn dump(&self) -> impl ToTokens {
@@ -106,3 +133,169 @@ impl StructFieldList {
         quote! { #(#field_ids)* }
     }
 }
+
+/// The main struct to handle code generation using the `sigma-rs` API.
+pub struct CodeGen<'a> {
+    proto_name: Ident,
+    group_name: Ident,
+    vars: &'a VarDict,
+    statements: &'a StatementTree,
+}
+
+impl<'a> CodeGen<'a> {
+    pub fn new(
+        proto_name: Ident,
+        group_name: Ident,
+        vars: &'a VarDict,
+        statements: &'a StatementTree,
+    ) -> Self {
+        Self {
+            proto_name,
+            group_name,
+            vars,
+            statements,
+        }
+    }
+
+    /// Generate the code that uses the `sigma-rs` API to prove and
+    /// verify the statements in the [`CodeGen`].
+    ///
+    /// `emit_prover` and `emit_verifier` are as in
+    /// [`sigma_compiler_core`](super::super::sigma_compiler_core).
+    pub fn generate(&self, emit_prover: bool, emit_verifier: bool) -> TokenStream {
+        let proto_name = &self.proto_name;
+        let group_name = &self.group_name;
+
+        let group_types = quote! {
+            use super::group;
+            pub type Scalar = <super::#group_name as group::Group>::Scalar;
+            pub type Point = super::#group_name;
+        };
+
+        let mut pub_params_fields = StructFieldList::default();
+        pub_params_fields.push_vars(self.vars, true);
+
+        // Generate the public params struct definition
+        let params_def = {
+            let decls = pub_params_fields.field_decls();
+            #[cfg(feature = "dump")]
+            let dump_impl = {
+                let dump_chunks = pub_params_fields.dump();
+                quote! {
+                    impl Params {
+                        fn dump_scalar(s: &Scalar) {
+                            let bytes: &[u8] = &s.to_repr();
+                            print!("{:02x?}", bytes);
+                        }
+
+                        fn dump_point(p: &Point) {
+                            let bytes: &[u8] = &p.to_bytes();
+                            print!("{:02x?}", bytes);
+                        }
+
+                        pub fn dump(&self) {
+                            #dump_chunks
+                        }
+                    }
+                }
+            };
+            #[cfg(not(feature = "dump"))]
+            let dump_impl = {
+                quote! {}
+            };
+            quote! {
+                pub struct Params {
+                    #decls
+                }
+
+                #dump_impl
+            }
+        };
+
+        let mut witness_fields = StructFieldList::default();
+        witness_fields.push_vars(self.vars, false);
+
+        // Generate the witness struct definition
+        let witness_def = if emit_prover {
+            let decls = witness_fields.field_decls();
+            quote! {
+                pub struct Witness {
+                    #decls
+                }
+            }
+        } else {
+            quote! {}
+        };
+
+        // Generate the (currently dummy) prove function
+        let prove_func = if emit_prover {
+            let dumper = if cfg!(feature = "dump") {
+                quote! {
+                    println!("prover params = {{");
+                    params.dump();
+                    println!("}}");
+                }
+            } else {
+                quote! {}
+            };
+            let params_ids = pub_params_fields.field_list();
+            let witness_ids = witness_fields.field_list();
+
+            quote! {
+                pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
+                    #dumper
+                    let Params { #params_ids } = *params;
+                    let Witness { #witness_ids } = *witness;
+                    Ok(Vec::<u8>::default())
+                }
+            }
+        } else {
+            quote! {}
+        };
+
+        // Generate the (currently dummy) verify function
+        let verify_func = if emit_verifier {
+            let dumper = if cfg!(feature = "dump") {
+                quote! {
+                    println!("verifier params = {{");
+                    params.dump();
+                    println!("}}");
+                }
+            } else {
+                quote! {}
+            };
+            let params_ids = pub_params_fields.field_list();
+            quote! {
+                pub fn verify(params: &Params, proof: &[u8]) -> Result<(),()> {
+                    #dumper
+                    let Params { #params_ids } = *params;
+                    Ok(())
+                }
+            }
+        } else {
+            quote! {}
+        };
+
+        // Output the generated module for this protocol
+        let dump_use = if cfg!(feature = "dump") {
+            quote! {
+                use group::GroupEncoding;
+            }
+        } else {
+            quote! {}
+        };
+        quote! {
+            #[allow(non_snake_case)]
+            pub mod #proto_name {
+                #dump_use
+
+                #group_types
+                #params_def
+                #witness_def
+
+                #prove_func
+                #verify_func
+            }
+        }
+    }
+}

+ 1 - 1
src/lib.rs

@@ -1,3 +1,3 @@
-pub use sigma_compiler_derive::sigma_compiler;
 pub use group;
+pub use sigma_compiler_derive::sigma_compiler;
 pub use sigma_rs;

+ 34 - 0
tests/basic.rs

@@ -0,0 +1,34 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+#[test]
+fn basic_test() -> Result<(), ()> {
+    sigma_compiler! { proof,
+        (x, z, rand r, rand s),
+        (C, D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = z*A + s*B,
+        z = 2*x + 1,
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r = Scalar::random(&mut rng);
+    let s = Scalar::random(&mut rng);
+    let x = Scalar::from_u128(5);
+    let z = Scalar::from_u128(11);
+    let C = x * A + r * B;
+    let D = z * A + s * B;
+
+    let params = proof::Params { C, D, A, B };
+    let witness = proof::Witness { x, z, r, s };
+
+    let proof = proof::prove(&params, &witness)?;
+    proof::verify(&params, &proof)
+}