Browse Source

Refactor StructFieldList from codegen to sigma/codegen

Ian Goldberg 4 months ago
parent
commit
7e4b865b1f

+ 28 - 117
sigma_compiler_core/src/codegen.rs

@@ -2,101 +2,38 @@
 //! will interact with the underlying `sigma` macro.
 
 use super::syntax::*;
+use super::sigma::codegen::StructFieldList;
 use proc_macro2::TokenStream;
-use quote::{quote, ToTokens};
+use quote::quote;
 #[cfg(test)]
 use syn::parse_quote;
 use syn::visit_mut::{self, VisitMut};
 use syn::{Expr, Ident, Token};
 
-// Names and types of fields that might end up in a generated struct
-enum StructField {
-    Scalar(Ident),
-    VecScalar(Ident),
-    Point(Ident),
-    VecPoint(Ident),
-}
-
-// A list of StructField items
-#[derive(Default)]
-struct StructFieldList {
-    fields: Vec<StructField>,
-}
-
-impl StructFieldList {
-    pub fn push_scalar(&mut self, s: &Ident) {
-        self.fields.push(StructField::Scalar(s.clone()));
-    }
-    pub fn push_vecscalar(&mut self, s: &Ident) {
-        self.fields.push(StructField::VecScalar(s.clone()));
-    }
-    pub fn push_point(&mut self, s: &Ident) {
-        self.fields.push(StructField::Point(s.clone()));
-    }
-    pub fn push_vecpoint(&mut self, s: &Ident) {
-        self.fields.push(StructField::VecPoint(s.clone()));
-    }
-    pub fn push_vars(&mut self, vardict: &TaggedVarDict, is_pub: bool) {
-        for (_, ti) in vardict.iter() {
-            match ti {
-                TaggedIdent::Scalar(st) => {
-                    if st.is_pub == is_pub {
-                        if st.is_vec {
-                            self.push_vecscalar(&st.id)
-                        } else {
-                            self.push_scalar(&st.id)
-                        }
+fn push_vars(structfieldlist: &mut StructFieldList,
+    vars: &TaggedVarDict, is_pub: bool) {
+    for (_, ti) in vars.iter() {
+        match ti {
+            TaggedIdent::Scalar(st) => {
+                if st.is_pub == is_pub {
+                    if st.is_vec {
+                        structfieldlist.push_vecscalar(&st.id)
+                    } else {
+                        structfieldlist.push_scalar(&st.id)
                     }
                 }
-                TaggedIdent::Point(pt) => {
-                    if is_pub {
-                        if pt.is_vec {
-                            self.push_vecpoint(&pt.id)
-                        } else {
-                            self.push_point(&pt.id)
-                        }
+            }
+            TaggedIdent::Point(pt) => {
+                if is_pub {
+                    if pt.is_vec {
+                        structfieldlist.push_vecpoint(&pt.id)
+                    } else {
+                        structfieldlist.push_point(&pt.id)
                     }
                 }
             }
         }
     }
-    /// Output a ToTokens of the fields as they would appear in a struct
-    /// definition
-    pub fn field_decls(&self) -> impl ToTokens {
-        let decls = self.fields.iter().map(|f| match f {
-            StructField::Scalar(id) => quote! {
-                pub #id: Scalar,
-            },
-            StructField::VecScalar(id) => quote! {
-                pub #id: Vec<Scalar>,
-            },
-            StructField::Point(id) => quote! {
-                pub #id: Point,
-            },
-            StructField::VecPoint(id) => quote! {
-                pub #id: Vec<Point>,
-            },
-        });
-        quote! { #(#decls)* }
-    }
-    /// Output a ToTokens of the list of fields
-    pub fn field_list(&self) -> impl ToTokens {
-        let field_ids = self.fields.iter().map(|f| match f {
-            StructField::Scalar(id) => quote! {
-                #id,
-            },
-            StructField::VecScalar(id) => quote! {
-                #id,
-            },
-            StructField::Point(id) => quote! {
-                #id,
-            },
-            StructField::VecPoint(id) => quote! {
-                #id,
-            },
-        });
-        quote! { #(#field_ids)* }
-    }
 }
 
 /// An implementation of the
@@ -190,42 +127,14 @@ impl CodeGen {
         };
 
         let mut pub_params_fields = StructFieldList::default();
-        pub_params_fields.push_vars(&self.vars, true);
+        push_vars(&mut pub_params_fields, &self.vars, true);
 
         // Generate the public params struct definition
         let params_def = {
             let decls = pub_params_fields.field_decls();
-            let dump_impl = if cfg!(feature = "dump") {
-                let dump_chunks = pub_params_fields.fields.iter().map(|f| match f {
-                    StructField::Scalar(id) => quote! {
-                        print!("  {}: ", stringify!(#id));
-                        Params::dump_scalar(&self.#id);
-                        println!("");
-                    },
-                    StructField::VecScalar(id) => quote! {
-                        print!("  {}: [", stringify!(#id));
-                        for s in self.#id.iter() {
-                            print!("    ");
-                            Params::dump_scalar(s);
-                            println!(",");
-                        }
-                        println!("  ]");
-                    },
-                    StructField::Point(id) => quote! {
-                        print!("  {}: ", stringify!(#id));
-                        Params::dump_point(&self.#id);
-                        println!("");
-                    },
-                    StructField::VecPoint(id) => quote! {
-                        print!("  {}: [", stringify!(#id));
-                        for p in self.#id.iter() {
-                            print!("    ");
-                            Params::dump_point(p);
-                            println!(",");
-                        }
-                        println!("  ]");
-                    },
-                });
+            #[cfg(feature = "dump")]
+            let dump_impl = {
+                let dump_chunks = pub_params_fields.dump();
                 quote! {
                     impl Params {
                         fn dump_scalar(s: &Scalar) {
@@ -239,11 +148,13 @@ impl CodeGen {
                         }
 
                         pub fn dump(&self) {
-                            #(#dump_chunks)*
+                            #dump_chunks
                         }
                     }
                 }
-            } else {
+            };
+            #[cfg(not(feature = "dump"))]
+            let dump_impl = {
                 quote! {}
             };
             quote! {
@@ -256,7 +167,7 @@ impl CodeGen {
         };
 
         let mut witness_fields = StructFieldList::default();
-        witness_fields.push_vars(&self.vars, false);
+        push_vars(&mut witness_fields, &self.vars, false);
 
         // Generate the witness struct definition
         let witness_def = if emit_prover {

+ 1 - 0
sigma_compiler_core/src/lib.rs

@@ -4,6 +4,7 @@ use proc_macro2::TokenStream;
 /// `sigma` crate are for now included as submodules of a local `sigma`
 /// module
 pub mod sigma {
+    pub mod codegen;
     pub mod combiners;
     pub mod types;
 }

+ 108 - 0
sigma_compiler_core/src/sigma/codegen.rs

@@ -0,0 +1,108 @@
+//! A module for generating the code that uses the `sigma-rs` crate API.
+//!
+//! If that crate gets its own macro interface, it can use this module
+//! directly.
+
+use quote::{quote, ToTokens};
+use syn::Ident;
+
+// Names and types of fields that might end up in a generated struct
+enum StructField {
+    Scalar(Ident),
+    VecScalar(Ident),
+    Point(Ident),
+    VecPoint(Ident),
+}
+
+// A list of StructField items
+#[derive(Default)]
+pub struct StructFieldList {
+    fields: Vec<StructField>,
+}
+
+impl StructFieldList {
+    pub fn push_scalar(&mut self, s: &Ident) {
+        self.fields.push(StructField::Scalar(s.clone()));
+    }
+    pub fn push_vecscalar(&mut self, s: &Ident) {
+        self.fields.push(StructField::VecScalar(s.clone()));
+    }
+    pub fn push_point(&mut self, s: &Ident) {
+        self.fields.push(StructField::Point(s.clone()));
+    }
+    pub fn push_vecpoint(&mut self, s: &Ident) {
+        self.fields.push(StructField::VecPoint(s.clone()));
+    }
+    #[cfg(feature = "dump")]
+    /// Output a ToTokens of the contents of the fields
+    pub fn dump(&self) -> impl ToTokens {
+        let dump_chunks = self.fields.iter().map(|f| match f {
+            StructField::Scalar(id) => quote! {
+                print!("  {}: ", stringify!(#id));
+                Params::dump_scalar(&self.#id);
+                println!("");
+            },
+            StructField::VecScalar(id) => quote! {
+                print!("  {}: [", stringify!(#id));
+                for s in self.#id.iter() {
+                    print!("    ");
+                    Params::dump_scalar(s);
+                    println!(",");
+                }
+                println!("  ]");
+            },
+            StructField::Point(id) => quote! {
+                print!("  {}: ", stringify!(#id));
+                Params::dump_point(&self.#id);
+                println!("");
+            },
+            StructField::VecPoint(id) => quote! {
+                print!("  {}: [", stringify!(#id));
+                for p in self.#id.iter() {
+                    print!("    ");
+                    Params::dump_point(p);
+                    println!(",");
+                }
+                println!("  ]");
+            },
+        });
+        quote! { #(#dump_chunks)* }
+    }
+    /// Output a ToTokens of the fields as they would appear in a struct
+    /// definition
+    pub fn field_decls(&self) -> impl ToTokens {
+        let decls = self.fields.iter().map(|f| match f {
+            StructField::Scalar(id) => quote! {
+                pub #id: Scalar,
+            },
+            StructField::VecScalar(id) => quote! {
+                pub #id: Vec<Scalar>,
+            },
+            StructField::Point(id) => quote! {
+                pub #id: Point,
+            },
+            StructField::VecPoint(id) => quote! {
+                pub #id: Vec<Point>,
+            },
+        });
+        quote! { #(#decls)* }
+    }
+    /// Output a ToTokens of the list of fields
+    pub fn field_list(&self) -> impl ToTokens {
+        let field_ids = self.fields.iter().map(|f| match f {
+            StructField::Scalar(id) => quote! {
+                #id,
+            },
+            StructField::VecScalar(id) => quote! {
+                #id,
+            },
+            StructField::Point(id) => quote! {
+                #id,
+            },
+            StructField::VecPoint(id) => quote! {
+                #id,
+            },
+        });
+        quote! { #(#field_ids)* }
+    }
+}