Просмотр исходного кода

Unify TaggedScalar and TaggedPoint into an enum TaggedIdent, and add VarDict type

VarDict is a HashMap from String (the string version of an Ident
representing a variable) to a TaggedIdent, which encodes the type
(Scalar or Point) and tags on the variable.
Ian Goldberg 4 месяцев назад
Родитель
Сommit
20057964af
2 измененных файлов с 54 добавлено и 31 удалено
  1. 27 27
      sigma_compiler_core/src/lib.rs
  2. 27 4
      sigma_compiler_core/src/syntax.rs

+ 27 - 27
sigma_compiler_core/src/lib.rs

@@ -6,8 +6,7 @@ use syn::{parse_quote, Expr, Ident, Token};
 
 mod syntax;
 
-pub use syntax::SigmaCompSpec;
-pub use syntax::{TaggedPoint, TaggedScalar};
+pub use syntax::{SigmaCompSpec, TaggedIdent, TaggedPoint, TaggedScalar, VarDict};
 
 // Names and types of fields that might end up in a generated struct
 enum StructField {
@@ -36,23 +35,27 @@ impl StructFieldList {
     pub fn push_vecpoint(&mut self, s: &Ident) {
         self.fields.push(StructField::VecPoint(s.clone()));
     }
-    pub fn push_scalars(&mut self, sl: &[TaggedScalar], is_pub: bool) {
-        for tid in sl.iter() {
-            if tid.is_pub == is_pub {
-                if tid.is_vec {
-                    self.push_vecscalar(&tid.id)
-                } else {
-                    self.push_scalar(&tid.id)
+    pub fn push_vars(&mut self, vardict: &VarDict, is_pub: bool) {
+        for (_, ti) in vardict.iter() {
+            match ti {
+                TaggedIdent::Scalar(st) => {
+                    if st.is_pub == is_pub {
+                        if st.is_vec {
+                            self.push_vecscalar(&st.id)
+                        } else {
+                            self.push_scalar(&st.id)
+                        }
+                    }
+                }
+                TaggedIdent::Point(pt) => {
+                    if is_pub {
+                        if pt.is_vec {
+                            self.push_vecpoint(&pt.id)
+                        } else {
+                            self.push_point(&pt.id)
+                        }
+                    }
                 }
-            }
-        }
-    }
-    pub fn push_points(&mut self, sl: &[TaggedPoint]) {
-        for tid in sl.iter() {
-            if tid.is_vec {
-                self.push_vecpoint(&tid.id)
-            } else {
-                self.push_point(&tid.id)
             }
         }
     }
@@ -96,12 +99,10 @@ impl StatementFixup {
         // "pub"), add to the map "id" -> params.id.  For each private
         // identifier (Scalars not marked "pub"), add to the map "id" ->
         // witness.id.
-        for (id, is_pub) in spec
-            .scalars
-            .iter()
-            .map(|ts| (&ts.id, ts.is_pub))
-            .chain(spec.points.iter().map(|tp| (&tp.id, true)))
-        {
+        for (id, is_pub) in spec.vars.values().map(|ti| match ti {
+            TaggedIdent::Scalar(st) => (&st.id, st.is_pub),
+            TaggedIdent::Point(pt) => (&pt.id, true),
+        }) {
             let idexpr: Expr = if is_pub {
                 parse_quote! { params.#id }
             } else {
@@ -155,8 +156,7 @@ pub fn sigma_compiler_core(
     // Generate the public params struct definition
     let params_def = {
         let mut pub_params_fields = StructFieldList::default();
-        pub_params_fields.push_points(&spec.points);
-        pub_params_fields.push_scalars(&spec.scalars, true);
+        pub_params_fields.push_vars(&spec.vars, true);
 
         let decls = pub_params_fields.field_decls();
         let dump_impl = if cfg!(feature = "dump") {
@@ -222,7 +222,7 @@ pub fn sigma_compiler_core(
     // Generate the witness struct definition
     let witness_def = if emit_prover {
         let mut witness_fields = StructFieldList::default();
-        witness_fields.push_scalars(&spec.scalars, false);
+        witness_fields.push_vars(&spec.vars, false);
 
         let decls = witness_fields.field_decls();
         quote! {

+ 27 - 4
sigma_compiler_core/src/syntax.rs

@@ -1,4 +1,5 @@
 use quote::format_ident;
+use std::collections::HashMap;
 use syn::ext::IdentExt;
 use syn::parse::{Parse, ParseStream, Result};
 use syn::punctuated::Punctuated;
@@ -78,6 +79,18 @@ pub struct TaggedPoint {
     pub is_vec: bool,
 }
 
+/// A `TaggedIdent` can be either a `TaggedScalar` or a `TaggedPoint`
+#[derive(Debug)]
+pub enum TaggedIdent {
+    Scalar(TaggedScalar),
+    Point(TaggedPoint),
+}
+
+/// A `VarDict` is a dictionary of the available variables, mapping
+/// the string version of `Ident`s to `TaggedIdent`, which includes
+/// their type (`Scalar` or `Point`)
+pub type VarDict = HashMap<String, TaggedIdent>;
+
 impl Parse for TaggedPoint {
     fn parse(input: ParseStream) -> Result<Self> {
         // Points are always pub
@@ -115,8 +128,7 @@ impl Parse for TaggedPoint {
 pub struct SigmaCompSpec {
     pub proto_name: Ident,
     pub group_name: Ident,
-    pub scalars: Vec<TaggedScalar>,
-    pub points: Vec<TaggedPoint>,
+    pub vars: VarDict,
     pub statements: Vec<Expr>,
 }
 
@@ -142,10 +154,22 @@ impl Parse for SigmaCompSpec {
         };
         input.parse::<Token![,]>()?;
 
+        let mut vars: VarDict = HashMap::new();
+
         let scalars = paren_taggedidents::<TaggedScalar>(input)?;
+        vars.extend(
+            scalars
+                .into_iter()
+                .map(|ts| (ts.id.to_string(), TaggedIdent::Scalar(ts))),
+        );
         input.parse::<Token![,]>()?;
 
         let points = paren_taggedidents::<TaggedPoint>(input)?;
+        vars.extend(
+            points
+                .into_iter()
+                .map(|tp| (tp.id.to_string(), TaggedIdent::Point(tp))),
+        );
         input.parse::<Token![,]>()?;
 
         let statementpunc: Punctuated<Expr, Token![,]> =
@@ -155,8 +179,7 @@ impl Parse for SigmaCompSpec {
         Ok(SigmaCompSpec {
             proto_name,
             group_name,
-            scalars,
-            points,
+            vars,
             statements,
         })
     }