Parcourir la source

New sigma_compiler! syntax

Instead of six lists of identifiers, there are now just two: a list of
identifiers for Scalars and a list of identifiers for Points.  Each
identifier can be preceded by zero or more tags (`pub`, `rand`, `cind`,
`const`, `vec`).

A `TaggedIndent` representing a `Scalar` can be preceded by:
 - (nothing)
 - `pub`
 - `rand`
 - `vec`
 - `pub vec`
 - `rand vec`

A `TaggedIndent` representing a `Point` can be preceded by:
 - (nothing)
 - `cind`
 - `const`
 - `cind const`
 - `vec`
 - `cind vec`
 - `const vec`
 - `cind const vec`
Ian Goldberg il y a 4 mois
Parent
commit
b357d5236f
1 fichiers modifiés avec 129 ajouts et 85 suppressions
  1. 129 85
      sigma_compiler_derive/src/lib.rs

+ 129 - 85
sigma_compiler_derive/src/lib.rs

@@ -1,50 +1,121 @@
 use proc_macro::TokenStream;
 use quote::{format_ident, quote, ToTokens};
 use std::collections::HashMap;
+use syn::ext::IdentExt;
 use syn::parse::{Parse, ParseStream, Result};
 use syn::punctuated::Punctuated;
 use syn::visit_mut::{self, VisitMut};
 use syn::{parenthesized, parse_macro_input, parse_quote, Expr, Ident, Token};
 
-// Either an Ident or "vec(Ident)"
+// A `TaggedIdent` is an `Ident`, preceded by zero or more of the
+// following tags: `pub`, `rand`, `cind`, `const`, `vec`
+//
+// A `TaggedIndent` representing a `Scalar` can be preceded by:
+//  - (nothing)
+//  - `pub`
+//  - `rand`
+//  - `vec`
+//  - `pub vec`
+//  - `rand vec`
+//
+// A `TaggedIndent` representing a `Point` can be preceded by:
+//  - (nothing)
+//  - `cind`
+//  - `const`
+//  - `cind const`
+//  - `vec`
+//  - `cind vec`
+//  - `const vec`
+//  - `cind const vec`
+
 #[derive(Debug)]
-enum VecIdent {
-    Ident(Ident),
-    VecIdent(Ident),
+struct TaggedIdent {
+    id: Ident,
+    is_pub: bool,
+    is_rand: bool,
+    is_cind: bool,
+    is_const: bool,
+    is_vec: bool,
 }
 
-impl Parse for VecIdent {
-    fn parse(input: ParseStream) -> Result<Self> {
-        let id: Ident = input.parse()?;
-        if id.to_string() == "vec" {
-            let content;
-            parenthesized!(content in input);
-            let vid: Ident = content.parse()?;
-            Ok(Self::VecIdent(vid))
-        } else {
-            Ok(Self::Ident(id))
+impl TaggedIdent {
+    // parse for a `Scalar` if point is false; parse for a `Point` if point
+    // is true
+    pub fn parse(input: ParseStream, point: bool) -> Result<Self> {
+        // Points are always pub
+        let (mut is_pub, mut is_rand, mut is_cind, mut is_const, mut is_vec) =
+            (point, false, false, false, false);
+        loop {
+            let id = input.call(Ident::parse_any)?;
+            match id.to_string().as_str() {
+                // pub and rand are only allowed for Scalars, and are
+                // mutually exclusive
+                "pub" if !point && !is_rand => {
+                    is_pub = true;
+                }
+                "rand" if !point && !is_pub => {
+                    is_rand = true;
+                }
+                // cind and const are only allowed for Points, but can
+                // be used together
+                "cind" if point => {
+                    is_cind = true;
+                }
+                "const" if point => {
+                    is_const = true;
+                }
+                // vec is allowed with either Scalars or Points, and
+                // with any other tag
+                "vec" => {
+                    is_vec = true;
+                }
+                _ => {
+                    return Ok(TaggedIdent {
+                        id,
+                        is_pub,
+                        is_rand,
+                        is_cind,
+                        is_const,
+                        is_vec,
+                    });
+                }
+            }
         }
     }
+
+    // Parse a `TaggedIndent` using the tags allowed for a `Scalar`
+    pub fn parse_scalar(input: ParseStream) -> Result<Self> {
+        Self::parse(input, false)
+    }
+
+    // Parse a `TaggedIndent` using the tags allowed for a `Point`
+    pub fn parse_point(input: ParseStream) -> Result<Self> {
+        Self::parse(input, true)
+    }
 }
 
 #[derive(Debug)]
 struct SigmaCompSpec {
     proto_name: Ident,
     group_name: Ident,
-    rand_scalars: Vec<VecIdent>,
-    priv_scalars: Vec<VecIdent>,
-    pub_scalars: Vec<VecIdent>,
-    cind_points: Vec<VecIdent>,
-    pub_points: Vec<VecIdent>,
-    const_points: Vec<VecIdent>,
+    scalars: Vec<TaggedIdent>,
+    points: Vec<TaggedIdent>,
     statements: Vec<Expr>,
 }
 
-fn paren_vecidents(input: ParseStream) -> Result<Vec<VecIdent>> {
+// parse for a `Scalar` if point is false; parse for a `Point` if point
+// is true
+fn paren_taggedidents(input: ParseStream, point: bool) -> Result<Vec<TaggedIdent>> {
     let content;
     parenthesized!(content in input);
-    let punc: Punctuated<VecIdent, Token![,]> =
-        content.parse_terminated(VecIdent::parse, Token![,])?;
+    let punc: Punctuated<TaggedIdent, Token![,]> = content.parse_terminated(
+        if point {
+            TaggedIdent::parse_point
+        } else {
+            TaggedIdent::parse_scalar
+        },
+        Token![,],
+    )?;
     Ok(punc.into_iter().collect())
 }
 
@@ -62,22 +133,10 @@ impl Parse for SigmaCompSpec {
         };
         input.parse::<Token![,]>()?;
 
-        let rand_scalars = paren_vecidents(input)?;
+        let scalars = paren_taggedidents(input, false)?;
         input.parse::<Token![,]>()?;
 
-        let priv_scalars = paren_vecidents(input)?;
-        input.parse::<Token![,]>()?;
-
-        let pub_scalars = paren_vecidents(input)?;
-        input.parse::<Token![,]>()?;
-
-        let cind_points = paren_vecidents(input)?;
-        input.parse::<Token![,]>()?;
-
-        let pub_points = paren_vecidents(input)?;
-        input.parse::<Token![,]>()?;
-
-        let const_points = paren_vecidents(input)?;
+        let points = paren_taggedidents(input, true)?;
         input.parse::<Token![,]>()?;
 
         let statementpunc: Punctuated<Expr, Token![,]> =
@@ -87,12 +146,8 @@ impl Parse for SigmaCompSpec {
         Ok(SigmaCompSpec {
             proto_name,
             group_name,
-            rand_scalars,
-            priv_scalars,
-            pub_scalars,
-            cind_points,
-            pub_points,
-            const_points,
+            scalars,
+            points,
             statements,
         })
     }
@@ -125,19 +180,25 @@ impl StructFieldList {
     pub fn push_vecpoint(&mut self, s: &Ident) {
         self.fields.push(StructField::VecPoint(s.clone()));
     }
-    pub fn push_scalars(&mut self, sl: &[VecIdent]) {
-        for vi in sl.iter() {
-            match vi {
-                VecIdent::Ident(id) => self.push_scalar(id),
-                VecIdent::VecIdent(id) => self.push_vecscalar(id),
+    pub fn push_scalars(&mut self, sl: &[TaggedIdent], is_pub: bool) {
+        for tid in sl.iter() {
+            if tid.is_pub == is_pub {
+                if tid.is_vec {
+                    self.push_vecscalar(&tid.id)
+                } else {
+                    self.push_scalar(&tid.id)
+                }
             }
         }
     }
-    pub fn push_points(&mut self, sl: &[VecIdent]) {
-        for vi in sl.iter() {
-            match vi {
-                VecIdent::Ident(id) => self.push_point(id),
-                VecIdent::VecIdent(id) => self.push_vecpoint(id),
+    pub fn push_points(&mut self, sl: &[TaggedIdent], is_pub: bool) {
+        for tid in sl.iter() {
+            if tid.is_pub == is_pub {
+                if tid.is_vec {
+                    self.push_vecpoint(&tid.id)
+                } else {
+                    self.push_point(&tid.id)
+                }
             }
         }
     }
@@ -177,32 +238,18 @@ impl StatementFixup {
     pub fn new(spec: &SigmaCompSpec) -> Self {
         let mut idmap: HashMap<String, Expr> = HashMap::new();
 
-        // For each public identifier id (Scalar, Point, or vectors of
-        // one of those), add to the map "id" -> params.id
-        for vid in spec
-            .pub_scalars
-            .iter()
-            .chain(spec.cind_points.iter())
-            .chain(spec.pub_points.iter())
-            .chain(spec.const_points.iter())
-        {
-            match vid {
-                VecIdent::Ident(id) | VecIdent::VecIdent(id) => {
-                    let idexpr: Expr = parse_quote! { params.#id };
-                    idmap.insert(id.to_string(), idexpr);
-                }
-            }
-        }
-
-        // For each private identifier id (Scalar, Point, or vectors of
-        // one of those), add to the map "id" -> witness.id
-        for vid in spec.rand_scalars.iter().chain(spec.priv_scalars.iter()) {
-            match vid {
-                VecIdent::Ident(id) | VecIdent::VecIdent(id) => {
-                    let idexpr: Expr = parse_quote! { witness.#id };
-                    idmap.insert(id.to_string(), idexpr);
-                }
-            }
+        // For each public identifier id (Points, or Scalars marked
+        // "pub"), add to the map "id" -> params.id.  For each private
+        // identifier (Scalars not marked "pub"), add to the map "id" ->
+        // witness.id.
+        for tid in spec.scalars.iter().chain(spec.points.iter()) {
+            let id = &tid.id;
+            let idexpr: Expr = if tid.is_pub {
+                parse_quote! { params.#id }
+            } else {
+                parse_quote! { witness.#id }
+            };
+            idmap.insert(id.to_string(), idexpr);
         }
 
         Self { idmap }
@@ -250,10 +297,8 @@ fn sigma_compiler_impl(
     // Generate the public params struct definition
     let params_def = {
         let mut pub_params_fields = StructFieldList::default();
-        pub_params_fields.push_points(&spec.const_points);
-        pub_params_fields.push_points(&spec.cind_points);
-        pub_params_fields.push_points(&spec.pub_points);
-        pub_params_fields.push_scalars(&spec.pub_scalars);
+        pub_params_fields.push_points(&spec.points, true);
+        pub_params_fields.push_scalars(&spec.scalars, true);
 
         let decls = pub_params_fields.field_decls();
         let dump_impl = if cfg!(feature = "dump") {
@@ -319,8 +364,7 @@ fn sigma_compiler_impl(
     // Generate the witness struct definition
     let witness_def = if emit_prover {
         let mut witness_fields = StructFieldList::default();
-        witness_fields.push_scalars(&spec.rand_scalars);
-        witness_fields.push_scalars(&spec.priv_scalars);
+        witness_fields.push_scalars(&spec.scalars, false);
 
         let decls = witness_fields.field_decls();
         quote! {