Explorar el Código

Refactor sigma_compiler_derive into sigma_compiler_derive and sigma_compiler_core

All of the work is done in the sigma_compiler_core crate, which is just
a normal library crate, not a proc-macro crate, so that it's more easily
developed and tested.

sigma_compiler_derive is then just a tiny wrapper.

h/t https://medium.com/data-science/nine-rules-for-creating-procedural-macros-in-rust-595aa476a7ff
Ian Goldberg hace 4 meses
padre
commit
266b07ad4f

+ 15 - 0
sigma_compiler_core/Cargo.toml

@@ -0,0 +1,15 @@
+[package]
+name = "sigma_compiler_core"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+syn = { version = "2.0", features = ["extra-traits", "visit-mut", "full"] }
+quote = "1.0"
+proc-macro2 = "1.0"
+
+[features]
+# Dump (to stdout) the public params on both the prover's and verifier's
+# side.  They should match.
+dump = []
+# default = ["dump"]

+ 450 - 0
sigma_compiler_core/src/lib.rs

@@ -0,0 +1,450 @@
+use proc_macro2::TokenStream;
+use quote::{format_ident, quote, ToTokens};
+use std::collections::HashMap;
+use syn::ext::IdentExt;
+use syn::parse::{Parse, ParseStream, Result};
+use syn::punctuated::Punctuated;
+use syn::visit_mut::{self, VisitMut};
+use syn::{parenthesized, parse_quote, Expr, Ident, Token};
+
+// A `TaggedIdent` is an `Ident`, preceded by zero or more of the
+// following tags: `pub`, `rand`, `cind`, `const`, `vec`
+//
+// A `TaggedIndent` representing a `Scalar` can be preceded by:
+//  - (nothing)
+//  - `pub`
+//  - `rand`
+//  - `vec`
+//  - `pub vec`
+//  - `rand vec`
+//
+// A `TaggedIndent` representing a `Point` can be preceded by:
+//  - (nothing)
+//  - `cind`
+//  - `const`
+//  - `cind const`
+//  - `vec`
+//  - `cind vec`
+//  - `const vec`
+//  - `cind const vec`
+
+#[derive(Debug)]
+struct TaggedIdent {
+    id: Ident,
+    is_pub: bool,
+    is_rand: bool,
+    is_cind: bool,
+    is_const: bool,
+    is_vec: bool,
+}
+
+impl TaggedIdent {
+    // parse for a `Scalar` if point is false; parse for a `Point` if point
+    // is true
+    pub fn parse(input: ParseStream, point: bool) -> Result<Self> {
+        // Points are always pub
+        let (mut is_pub, mut is_rand, mut is_cind, mut is_const, mut is_vec) =
+            (point, false, false, false, false);
+        loop {
+            let id = input.call(Ident::parse_any)?;
+            match id.to_string().as_str() {
+                // pub and rand are only allowed for Scalars, and are
+                // mutually exclusive
+                "pub" if !point && !is_rand => {
+                    is_pub = true;
+                }
+                "rand" if !point && !is_pub => {
+                    is_rand = true;
+                }
+                // cind and const are only allowed for Points, but can
+                // be used together
+                "cind" if point => {
+                    is_cind = true;
+                }
+                "const" if point => {
+                    is_const = true;
+                }
+                // vec is allowed with either Scalars or Points, and
+                // with any other tag
+                "vec" => {
+                    is_vec = true;
+                }
+                _ => {
+                    return Ok(TaggedIdent {
+                        id,
+                        is_pub,
+                        is_rand,
+                        is_cind,
+                        is_const,
+                        is_vec,
+                    });
+                }
+            }
+        }
+    }
+
+    // Parse a `TaggedIndent` using the tags allowed for a `Scalar`
+    pub fn parse_scalar(input: ParseStream) -> Result<Self> {
+        Self::parse(input, false)
+    }
+
+    // Parse a `TaggedIndent` using the tags allowed for a `Point`
+    pub fn parse_point(input: ParseStream) -> Result<Self> {
+        Self::parse(input, true)
+    }
+}
+
+#[derive(Debug)]
+pub struct SigmaCompSpec {
+    proto_name: Ident,
+    group_name: Ident,
+    scalars: Vec<TaggedIdent>,
+    points: Vec<TaggedIdent>,
+    statements: Vec<Expr>,
+}
+
+// parse for a `Scalar` if point is false; parse for a `Point` if point
+// is true
+fn paren_taggedidents(input: ParseStream, point: bool) -> Result<Vec<TaggedIdent>> {
+    let content;
+    parenthesized!(content in input);
+    let punc: Punctuated<TaggedIdent, Token![,]> = content.parse_terminated(
+        if point {
+            TaggedIdent::parse_point
+        } else {
+            TaggedIdent::parse_scalar
+        },
+        Token![,],
+    )?;
+    Ok(punc.into_iter().collect())
+}
+
+impl Parse for SigmaCompSpec {
+    fn parse(input: ParseStream) -> Result<Self> {
+        let proto_name: Ident = input.parse()?;
+        // See if a group was specified
+        let group_name = if input.peek(Token![<]) {
+            input.parse::<Token![<]>()?;
+            let gr: Ident = input.parse()?;
+            input.parse::<Token![>]>()?;
+            gr
+        } else {
+            format_ident!("G")
+        };
+        input.parse::<Token![,]>()?;
+
+        let scalars = paren_taggedidents(input, false)?;
+        input.parse::<Token![,]>()?;
+
+        let points = paren_taggedidents(input, true)?;
+        input.parse::<Token![,]>()?;
+
+        let statementpunc: Punctuated<Expr, Token![,]> =
+            input.parse_terminated(Expr::parse, Token![,])?;
+        let statements: Vec<Expr> = statementpunc.into_iter().collect();
+
+        Ok(SigmaCompSpec {
+            proto_name,
+            group_name,
+            scalars,
+            points,
+            statements,
+        })
+    }
+}
+
+// Names and types of fields that might end up in a generated struct
+enum StructField {
+    Scalar(Ident),
+    VecScalar(Ident),
+    Point(Ident),
+    VecPoint(Ident),
+}
+
+// A list of StructField items
+#[derive(Default)]
+struct StructFieldList {
+    fields: Vec<StructField>,
+}
+
+impl StructFieldList {
+    pub fn push_scalar(&mut self, s: &Ident) {
+        self.fields.push(StructField::Scalar(s.clone()));
+    }
+    pub fn push_vecscalar(&mut self, s: &Ident) {
+        self.fields.push(StructField::VecScalar(s.clone()));
+    }
+    pub fn push_point(&mut self, s: &Ident) {
+        self.fields.push(StructField::Point(s.clone()));
+    }
+    pub fn push_vecpoint(&mut self, s: &Ident) {
+        self.fields.push(StructField::VecPoint(s.clone()));
+    }
+    pub fn push_scalars(&mut self, sl: &[TaggedIdent], is_pub: bool) {
+        for tid in sl.iter() {
+            if tid.is_pub == is_pub {
+                if tid.is_vec {
+                    self.push_vecscalar(&tid.id)
+                } else {
+                    self.push_scalar(&tid.id)
+                }
+            }
+        }
+    }
+    pub fn push_points(&mut self, sl: &[TaggedIdent], is_pub: bool) {
+        for tid in sl.iter() {
+            if tid.is_pub == is_pub {
+                if tid.is_vec {
+                    self.push_vecpoint(&tid.id)
+                } else {
+                    self.push_point(&tid.id)
+                }
+            }
+        }
+    }
+    /// Output a ToTokens of the fields as they would appear in a struct
+    /// definition
+    pub fn field_decls(&self) -> impl ToTokens {
+        let decls = self.fields.iter().map(|f| match f {
+            StructField::Scalar(id) => quote! {
+                pub #id: Scalar,
+            },
+            StructField::VecScalar(id) => quote! {
+                pub #id: Vec<Scalar>,
+            },
+            StructField::Point(id) => quote! {
+                pub #id: Point,
+            },
+            StructField::VecPoint(id) => quote! {
+                pub #id: Vec<Point>,
+            },
+        });
+        quote! { #(#decls)* }
+    }
+}
+
+/// An implementation of the
+/// [`VisitMut`](https://docs.rs/syn/latest/syn/visit_mut/trait.VisitMut.html)
+/// trait that massages the provided statements.  This massaging
+/// currently consists of:
+///   - Changing equality from = to ==
+///   - Changing any identifier `id` to either `params.id` or
+///     `witness.id` depending on whether it is public or private
+struct StatementFixup {
+    idmap: HashMap<String, Expr>,
+}
+
+impl StatementFixup {
+    pub fn new(spec: &SigmaCompSpec) -> Self {
+        let mut idmap: HashMap<String, Expr> = HashMap::new();
+
+        // For each public identifier id (Points, or Scalars marked
+        // "pub"), add to the map "id" -> params.id.  For each private
+        // identifier (Scalars not marked "pub"), add to the map "id" ->
+        // witness.id.
+        for tid in spec.scalars.iter().chain(spec.points.iter()) {
+            let id = &tid.id;
+            let idexpr: Expr = if tid.is_pub {
+                parse_quote! { params.#id }
+            } else {
+                parse_quote! { witness.#id }
+            };
+            idmap.insert(id.to_string(), idexpr);
+        }
+
+        Self { idmap }
+    }
+}
+
+impl VisitMut for StatementFixup {
+    fn visit_expr_mut(&mut self, node: &mut Expr) {
+        if let Expr::Assign(assn) = node {
+            *node = Expr::Binary(syn::ExprBinary {
+                attrs: assn.attrs.clone(),
+                left: assn.left.clone(),
+                op: syn::BinOp::Eq(Token![==](assn.eq_token.span)),
+                right: assn.right.clone(),
+            });
+        }
+        if let Expr::Path(expath) = node {
+            if let Some(id) = expath.path.get_ident() {
+                if let Some(expr) = self.idmap.get(&id.to_string()) {
+                    *node = expr.clone();
+                    // Don't recurse
+                    return;
+                }
+            }
+        }
+        // Unless we bailed out above, continue with the default
+        // traversal
+        visit_mut::visit_expr_mut(self, node);
+    }
+}
+
+pub fn sigma_compiler_core(
+    spec: &SigmaCompSpec,
+    emit_prover: bool,
+    emit_verifier: bool,
+) -> TokenStream {
+    let proto_name = &spec.proto_name;
+    let group_name = &spec.group_name;
+
+    let group_types = quote! {
+        pub type Scalar = <super::#group_name as super::Group>::Scalar;
+        pub type Point = super::#group_name;
+    };
+
+    // Generate the public params struct definition
+    let params_def = {
+        let mut pub_params_fields = StructFieldList::default();
+        pub_params_fields.push_points(&spec.points, true);
+        pub_params_fields.push_scalars(&spec.scalars, true);
+
+        let decls = pub_params_fields.field_decls();
+        let dump_impl = if cfg!(feature = "dump") {
+            let dump_chunks = pub_params_fields.fields.iter().map(|f| match f {
+                StructField::Scalar(id) => quote! {
+                    print!("  {}: ", stringify!(#id));
+                    Params::dump_scalar(&self.#id);
+                    println!("");
+                },
+                StructField::VecScalar(id) => quote! {
+                    print!("  {}: [", stringify!(#id));
+                    for s in self.#id.iter() {
+                        print!("    ");
+                        Params::dump_scalar(s);
+                        println!(",");
+                    }
+                    println!("  ]");
+                },
+                StructField::Point(id) => quote! {
+                    print!("  {}: ", stringify!(#id));
+                    Params::dump_point(&self.#id);
+                    println!("");
+                },
+                StructField::VecPoint(id) => quote! {
+                    print!("  {}: [", stringify!(#id));
+                    for p in self.#id.iter() {
+                        print!("    ");
+                        Params::dump_point(p);
+                        println!(",");
+                    }
+                    println!("  ]");
+                },
+            });
+            quote! {
+                impl Params {
+                    fn dump_scalar(s: &Scalar) {
+                        let bytes: &[u8] = &s.to_repr();
+                        print!("{:02x?}", bytes);
+                    }
+
+                    fn dump_point(p: &Point) {
+                        let bytes: &[u8] = &p.to_bytes();
+                        print!("{:02x?}", bytes);
+                    }
+
+                    pub fn dump(&self) {
+                        #(#dump_chunks)*
+                    }
+                }
+            }
+        } else {
+            quote! {}
+        };
+        quote! {
+            pub struct Params {
+                #decls
+            }
+
+            #dump_impl
+        }
+    };
+
+    // Generate the witness struct definition
+    let witness_def = if emit_prover {
+        let mut witness_fields = StructFieldList::default();
+        witness_fields.push_scalars(&spec.scalars, false);
+
+        let decls = witness_fields.field_decls();
+        quote! {
+            pub struct Witness {
+                #decls
+            }
+        }
+    } else {
+        quote! {}
+    };
+
+    // Generate the (currently dummy) prove function
+    let prove_func = if emit_prover {
+        let dumper = if cfg!(feature = "dump") {
+            quote! {
+                println!("prover params = {{");
+                params.dump();
+                println!("}}");
+            }
+        } else {
+            quote! {}
+        };
+        let mut assert_statements = spec.statements.clone();
+        let mut statement_fixup = StatementFixup::new(&spec);
+        assert_statements
+            .iter_mut()
+            .for_each(|expr| statement_fixup.visit_expr_mut(expr));
+        quote! {
+            pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
+                #dumper
+                #(assert!(#assert_statements);)*
+                Ok(Vec::<u8>::default())
+            }
+        }
+    } else {
+        quote! {}
+    };
+
+    // Generate the (currently dummy) verify function
+    let verify_func = if emit_verifier {
+        let dumper = if cfg!(feature = "dump") {
+            quote! {
+                println!("verifier params = {{");
+                params.dump();
+                println!("}}");
+            }
+        } else {
+            quote! {}
+        };
+        quote! {
+            pub fn verify(params: &Params, proof: &[u8]) -> Result<(),()> {
+                #dumper
+                Ok(())
+            }
+        }
+    } else {
+        quote! {}
+    };
+
+    // Output the generated module for this protocol
+    let dump_use = if cfg!(feature = "dump") {
+        quote! {
+            use ff::PrimeField;
+            use group::GroupEncoding;
+        }
+    } else {
+        quote! {}
+    };
+    quote! {
+        #[allow(non_snake_case)]
+        pub mod #proto_name {
+            #dump_use
+
+            #group_types
+            #params_def
+            #witness_def
+
+            #prove_func
+            #verify_func
+        }
+    }
+    .into()
+}

+ 2 - 8
sigma_compiler_derive/Cargo.toml

@@ -7,11 +7,5 @@ edition = "2021"
 proc-macro = true
 
 [dependencies]
-syn = { version = "2.0", features = ["extra-traits", "visit-mut", "full"] }
-quote = "1.0"
-
-[features]
-# Dump (to stdout) the public params on both the prover's and verifier's
-# side.  They should match.
-dump = []
-# default = ["dump"]
+sigma_compiler_core = { path = "../sigma_compiler_core" }
+syn = "2.0"

+ 4 - 451
sigma_compiler_derive/src/lib.rs

@@ -1,456 +1,9 @@
 use proc_macro::TokenStream;
-use quote::{format_ident, quote, ToTokens};
-use std::collections::HashMap;
-use syn::ext::IdentExt;
-use syn::parse::{Parse, ParseStream, Result};
-use syn::punctuated::Punctuated;
-use syn::visit_mut::{self, VisitMut};
-use syn::{parenthesized, parse_macro_input, parse_quote, Expr, Ident, Token};
-
-// A `TaggedIdent` is an `Ident`, preceded by zero or more of the
-// following tags: `pub`, `rand`, `cind`, `const`, `vec`
-//
-// A `TaggedIndent` representing a `Scalar` can be preceded by:
-//  - (nothing)
-//  - `pub`
-//  - `rand`
-//  - `vec`
-//  - `pub vec`
-//  - `rand vec`
-//
-// A `TaggedIndent` representing a `Point` can be preceded by:
-//  - (nothing)
-//  - `cind`
-//  - `const`
-//  - `cind const`
-//  - `vec`
-//  - `cind vec`
-//  - `const vec`
-//  - `cind const vec`
-
-#[derive(Debug)]
-struct TaggedIdent {
-    id: Ident,
-    is_pub: bool,
-    is_rand: bool,
-    is_cind: bool,
-    is_const: bool,
-    is_vec: bool,
-}
-
-impl TaggedIdent {
-    // parse for a `Scalar` if point is false; parse for a `Point` if point
-    // is true
-    pub fn parse(input: ParseStream, point: bool) -> Result<Self> {
-        // Points are always pub
-        let (mut is_pub, mut is_rand, mut is_cind, mut is_const, mut is_vec) =
-            (point, false, false, false, false);
-        loop {
-            let id = input.call(Ident::parse_any)?;
-            match id.to_string().as_str() {
-                // pub and rand are only allowed for Scalars, and are
-                // mutually exclusive
-                "pub" if !point && !is_rand => {
-                    is_pub = true;
-                }
-                "rand" if !point && !is_pub => {
-                    is_rand = true;
-                }
-                // cind and const are only allowed for Points, but can
-                // be used together
-                "cind" if point => {
-                    is_cind = true;
-                }
-                "const" if point => {
-                    is_const = true;
-                }
-                // vec is allowed with either Scalars or Points, and
-                // with any other tag
-                "vec" => {
-                    is_vec = true;
-                }
-                _ => {
-                    return Ok(TaggedIdent {
-                        id,
-                        is_pub,
-                        is_rand,
-                        is_cind,
-                        is_const,
-                        is_vec,
-                    });
-                }
-            }
-        }
-    }
-
-    // Parse a `TaggedIndent` using the tags allowed for a `Scalar`
-    pub fn parse_scalar(input: ParseStream) -> Result<Self> {
-        Self::parse(input, false)
-    }
-
-    // Parse a `TaggedIndent` using the tags allowed for a `Point`
-    pub fn parse_point(input: ParseStream) -> Result<Self> {
-        Self::parse(input, true)
-    }
-}
-
-#[derive(Debug)]
-struct SigmaCompSpec {
-    proto_name: Ident,
-    group_name: Ident,
-    scalars: Vec<TaggedIdent>,
-    points: Vec<TaggedIdent>,
-    statements: Vec<Expr>,
-}
-
-// parse for a `Scalar` if point is false; parse for a `Point` if point
-// is true
-fn paren_taggedidents(input: ParseStream, point: bool) -> Result<Vec<TaggedIdent>> {
-    let content;
-    parenthesized!(content in input);
-    let punc: Punctuated<TaggedIdent, Token![,]> = content.parse_terminated(
-        if point {
-            TaggedIdent::parse_point
-        } else {
-            TaggedIdent::parse_scalar
-        },
-        Token![,],
-    )?;
-    Ok(punc.into_iter().collect())
-}
-
-impl Parse for SigmaCompSpec {
-    fn parse(input: ParseStream) -> Result<Self> {
-        let proto_name: Ident = input.parse()?;
-        // See if a group was specified
-        let group_name = if input.peek(Token![<]) {
-            input.parse::<Token![<]>()?;
-            let gr: Ident = input.parse()?;
-            input.parse::<Token![>]>()?;
-            gr
-        } else {
-            format_ident!("G")
-        };
-        input.parse::<Token![,]>()?;
-
-        let scalars = paren_taggedidents(input, false)?;
-        input.parse::<Token![,]>()?;
-
-        let points = paren_taggedidents(input, true)?;
-        input.parse::<Token![,]>()?;
-
-        let statementpunc: Punctuated<Expr, Token![,]> =
-            input.parse_terminated(Expr::parse, Token![,])?;
-        let statements: Vec<Expr> = statementpunc.into_iter().collect();
-
-        Ok(SigmaCompSpec {
-            proto_name,
-            group_name,
-            scalars,
-            points,
-            statements,
-        })
-    }
-}
-
-// Names and types of fields that might end up in a generated struct
-enum StructField {
-    Scalar(Ident),
-    VecScalar(Ident),
-    Point(Ident),
-    VecPoint(Ident),
-}
-
-// A list of StructField items
-#[derive(Default)]
-struct StructFieldList {
-    fields: Vec<StructField>,
-}
-
-impl StructFieldList {
-    pub fn push_scalar(&mut self, s: &Ident) {
-        self.fields.push(StructField::Scalar(s.clone()));
-    }
-    pub fn push_vecscalar(&mut self, s: &Ident) {
-        self.fields.push(StructField::VecScalar(s.clone()));
-    }
-    pub fn push_point(&mut self, s: &Ident) {
-        self.fields.push(StructField::Point(s.clone()));
-    }
-    pub fn push_vecpoint(&mut self, s: &Ident) {
-        self.fields.push(StructField::VecPoint(s.clone()));
-    }
-    pub fn push_scalars(&mut self, sl: &[TaggedIdent], is_pub: bool) {
-        for tid in sl.iter() {
-            if tid.is_pub == is_pub {
-                if tid.is_vec {
-                    self.push_vecscalar(&tid.id)
-                } else {
-                    self.push_scalar(&tid.id)
-                }
-            }
-        }
-    }
-    pub fn push_points(&mut self, sl: &[TaggedIdent], is_pub: bool) {
-        for tid in sl.iter() {
-            if tid.is_pub == is_pub {
-                if tid.is_vec {
-                    self.push_vecpoint(&tid.id)
-                } else {
-                    self.push_point(&tid.id)
-                }
-            }
-        }
-    }
-    /// Output a ToTokens of the fields as they would appear in a struct
-    /// definition
-    pub fn field_decls(&self) -> impl ToTokens {
-        let decls = self.fields.iter().map(|f| match f {
-            StructField::Scalar(id) => quote! {
-                pub #id: Scalar,
-            },
-            StructField::VecScalar(id) => quote! {
-                pub #id: Vec<Scalar>,
-            },
-            StructField::Point(id) => quote! {
-                pub #id: Point,
-            },
-            StructField::VecPoint(id) => quote! {
-                pub #id: Vec<Point>,
-            },
-        });
-        quote! { #(#decls)* }
-    }
-}
-
-/// An implementation of the
-/// [`VisitMut`](https://docs.rs/syn/latest/syn/visit_mut/trait.VisitMut.html)
-/// trait that massages the provided statements.  This massaging
-/// currently consists of:
-///   - Changing equality from = to ==
-///   - Changing any identifier `id` to either `params.id` or
-///     `witness.id` depending on whether it is public or private
-struct StatementFixup {
-    idmap: HashMap<String, Expr>,
-}
-
-impl StatementFixup {
-    pub fn new(spec: &SigmaCompSpec) -> Self {
-        let mut idmap: HashMap<String, Expr> = HashMap::new();
-
-        // For each public identifier id (Points, or Scalars marked
-        // "pub"), add to the map "id" -> params.id.  For each private
-        // identifier (Scalars not marked "pub"), add to the map "id" ->
-        // witness.id.
-        for tid in spec.scalars.iter().chain(spec.points.iter()) {
-            let id = &tid.id;
-            let idexpr: Expr = if tid.is_pub {
-                parse_quote! { params.#id }
-            } else {
-                parse_quote! { witness.#id }
-            };
-            idmap.insert(id.to_string(), idexpr);
-        }
-
-        Self { idmap }
-    }
-}
-
-impl VisitMut for StatementFixup {
-    fn visit_expr_mut(&mut self, node: &mut Expr) {
-        if let Expr::Assign(assn) = node {
-            *node = Expr::Binary(syn::ExprBinary {
-                attrs: assn.attrs.clone(),
-                left: assn.left.clone(),
-                op: syn::BinOp::Eq(Token![==](assn.eq_token.span)),
-                right: assn.right.clone(),
-            });
-        }
-        if let Expr::Path(expath) = node {
-            if let Some(id) = expath.path.get_ident() {
-                if let Some(expr) = self.idmap.get(&id.to_string()) {
-                    *node = expr.clone();
-                    // Don't recurse
-                    return;
-                }
-            }
-        }
-        // Unless we bailed out above, continue with the default
-        // traversal
-        visit_mut::visit_expr_mut(self, node);
-    }
-}
-
-fn sigma_compiler_impl(
-    spec: &SigmaCompSpec,
-    emit_prover: bool,
-    emit_verifier: bool,
-) -> TokenStream {
-    let proto_name = &spec.proto_name;
-    let group_name = &spec.group_name;
-
-    let group_types = quote! {
-        pub type Scalar = <super::#group_name as super::Group>::Scalar;
-        pub type Point = super::#group_name;
-    };
-
-    // Generate the public params struct definition
-    let params_def = {
-        let mut pub_params_fields = StructFieldList::default();
-        pub_params_fields.push_points(&spec.points, true);
-        pub_params_fields.push_scalars(&spec.scalars, true);
-
-        let decls = pub_params_fields.field_decls();
-        let dump_impl = if cfg!(feature = "dump") {
-            let dump_chunks = pub_params_fields.fields.iter().map(|f| match f {
-                StructField::Scalar(id) => quote! {
-                    print!("  {}: ", stringify!(#id));
-                    Params::dump_scalar(&self.#id);
-                    println!("");
-                },
-                StructField::VecScalar(id) => quote! {
-                    print!("  {}: [", stringify!(#id));
-                    for s in self.#id.iter() {
-                        print!("    ");
-                        Params::dump_scalar(s);
-                        println!(",");
-                    }
-                    println!("  ]");
-                },
-                StructField::Point(id) => quote! {
-                    print!("  {}: ", stringify!(#id));
-                    Params::dump_point(&self.#id);
-                    println!("");
-                },
-                StructField::VecPoint(id) => quote! {
-                    print!("  {}: [", stringify!(#id));
-                    for p in self.#id.iter() {
-                        print!("    ");
-                        Params::dump_point(p);
-                        println!(",");
-                    }
-                    println!("  ]");
-                },
-            });
-            quote! {
-                impl Params {
-                    fn dump_scalar(s: &Scalar) {
-                        let bytes: &[u8] = &s.to_repr();
-                        print!("{:02x?}", bytes);
-                    }
-
-                    fn dump_point(p: &Point) {
-                        let bytes: &[u8] = &p.to_bytes();
-                        print!("{:02x?}", bytes);
-                    }
-
-                    pub fn dump(&self) {
-                        #(#dump_chunks)*
-                    }
-                }
-            }
-        } else {
-            quote! {}
-        };
-        quote! {
-            pub struct Params {
-                #decls
-            }
-
-            #dump_impl
-        }
-    };
-
-    // Generate the witness struct definition
-    let witness_def = if emit_prover {
-        let mut witness_fields = StructFieldList::default();
-        witness_fields.push_scalars(&spec.scalars, false);
-
-        let decls = witness_fields.field_decls();
-        quote! {
-            pub struct Witness {
-                #decls
-            }
-        }
-    } else {
-        quote! {}
-    };
-
-    // Generate the (currently dummy) prove function
-    let prove_func = if emit_prover {
-        let dumper = if cfg!(feature = "dump") {
-            quote! {
-                println!("prover params = {{");
-                params.dump();
-                println!("}}");
-            }
-        } else {
-            quote! {}
-        };
-        let mut assert_statements = spec.statements.clone();
-        let mut statement_fixup = StatementFixup::new(&spec);
-        assert_statements
-            .iter_mut()
-            .for_each(|expr| statement_fixup.visit_expr_mut(expr));
-        quote! {
-            pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>,()> {
-                #dumper
-                #(assert!(#assert_statements);)*
-                Ok(Vec::<u8>::default())
-            }
-        }
-    } else {
-        quote! {}
-    };
-
-    // Generate the (currently dummy) verify function
-    let verify_func = if emit_verifier {
-        let dumper = if cfg!(feature = "dump") {
-            quote! {
-                println!("verifier params = {{");
-                params.dump();
-                println!("}}");
-            }
-        } else {
-            quote! {}
-        };
-        quote! {
-            pub fn verify(params: &Params, proof: &[u8]) -> Result<(),()> {
-                #dumper
-                Ok(())
-            }
-        }
-    } else {
-        quote! {}
-    };
-
-    // Output the generated module for this protocol
-    let dump_use = if cfg!(feature = "dump") {
-        quote! {
-            use ff::PrimeField;
-            use group::GroupEncoding;
-        }
-    } else {
-        quote! {}
-    };
-    quote! {
-        #[allow(non_snake_case)]
-        pub mod #proto_name {
-            #dump_use
-
-            #group_types
-            #params_def
-            #witness_def
-
-            #prove_func
-            #verify_func
-        }
-    }
-    .into()
-}
+use sigma_compiler_core::{SigmaCompSpec, sigma_compiler_core};
+use syn::parse_macro_input;
 
 #[proc_macro]
 pub fn sigma_compiler(input: TokenStream) -> TokenStream {
-    let spec: SigmaCompSpec = parse_macro_input!(input as SigmaCompSpec);
-    sigma_compiler_impl(&spec, true, true)
+    let spec = parse_macro_input!(input as SigmaCompSpec);
+    sigma_compiler_core(&spec, true, true).into()
 }

+ 1 - 1
src/lib.rs

@@ -1 +1 @@
-pub use sigma_compiler_derive::*;
+pub use sigma_compiler_derive::sigma_compiler;