瀏覽代碼

Add a binary that takes the macro input on stdin and outputs the macro output

With the `-t` option, show only the output of the transformations
Ian Goldberg 1 月之前
父節點
當前提交
d1ea73312c

+ 2 - 0
sigma_compiler_core/Cargo.toml

@@ -7,6 +7,8 @@ edition = "2021"
 syn = { version = "2.0", features = ["extra-traits", "visit", "visit-mut", "full"] }
 quote = "1.0"
 proc-macro2 = "1.0"
+clap = { version = "4.5", features = ["derive"] }
+prettyplease = "0.2"
 
 [features]
 # Dump (to stdout) the value of the instance on both the prover's and

+ 12 - 0
sigma_compiler_core/src/codegen.rs

@@ -229,6 +229,18 @@ impl CodeGen {
         };
     }
 
+    /// Extract (as [`String`]s) the code inserted by
+    /// [`prove_append`](Self::prove_append),
+    /// [`verify_append`](Self::verify_append), and
+    /// [`verify_pre_instance_append`](Self::verify_pre_instance_append).
+    pub fn code_strings(&self) -> (String, String, String) {
+        (
+            self.prove_code.to_string(),
+            self.verify_code.to_string(),
+            self.verify_pre_instance_code.to_string(),
+        )
+    }
+
     /// Generate the code to be output by this macro.
     ///
     /// `emit_prover` and `emit_verifier` are as in

+ 72 - 0
sigma_compiler_core/src/main.rs

@@ -0,0 +1,72 @@
+use clap::Parser;
+use sigma_compiler_core::*;
+use std::io;
+use std::process::ExitCode;
+
+#[derive(Parser, Debug)]
+#[clap(version, about, long_about = None)]
+struct Args {
+    /// show just the output of the transformations (as opposed to the
+    /// entire generated code)
+    #[arg(short, long)]
+    transforms: bool,
+}
+
+/// Produce a [`String`] representation of a [`TaggedVarDict`]
+fn taggedvardict_to_string(vd: &TaggedVarDict) -> String {
+    let scalars_str = vd
+        .values()
+        .filter_map(|v| match v {
+            TaggedIdent::Scalar(ts) => Some(ts.to_string()),
+            _ => None,
+        })
+        .collect::<Vec<String>>()
+        .join(", ");
+    let points_str = vd
+        .values()
+        .filter_map(|v| match v {
+            TaggedIdent::Point(tp) => Some(tp.to_string()),
+            _ => None,
+        })
+        .collect::<Vec<String>>()
+        .join(", ");
+    format!("({scalars_str}),\n({points_str}),\n")
+}
+
+fn pretty_print(code_str: &str) {
+    let parsed_output = syn::parse_file(code_str).unwrap();
+    let formatted_output = prettyplease::unparse(&parsed_output);
+    println!("{}", formatted_output);
+}
+
+fn main() -> ExitCode {
+    let args = Args::parse();
+    let emit_prover = true;
+    let emit_verifier = true;
+
+    let stdin = io::read_to_string(io::stdin()).unwrap();
+    let mut spec: SigmaCompSpec = match syn::parse_str(&stdin) {
+        Err(_) => {
+            eprintln!("Could not parse stdin as a sigma_compiler input");
+            return ExitCode::FAILURE;
+        }
+        Ok(spec) => spec,
+    };
+
+    let mut codegen = CodeGen::new(&spec);
+    enforce_disjunction_invariant(&mut codegen, &mut spec).unwrap();
+    apply_transformations(&mut codegen, &mut spec).unwrap();
+    if args.transforms {
+        print!("{}", taggedvardict_to_string(&spec.vars));
+        spec.statements.dump();
+        println!();
+        let (prove_code, verify_code, _) = codegen.code_strings();
+        pretty_print(&format!("fn prove_fragment() {{ {} }}", prove_code));
+        pretty_print(&format!("fn verify_fragment() {{ {} }}", verify_code));
+    } else {
+        let output = codegen.generate(&mut spec, emit_prover, emit_verifier);
+        pretty_print(&output.to_string());
+    }
+
+    ExitCode::SUCCESS
+}

+ 0 - 4
sigma_compiler_core/src/sigma/codegen.rs

@@ -553,10 +553,6 @@ impl<'a> CodeGen<'a> {
         // Flatten nested "And"s into single "And"s
         self.statements.flatten_ands();
 
-        println!("Statements = {{");
-        self.statements.dump();
-        println!("}}");
-
         let mut pub_instance_fields = StructFieldList::default();
         pub_instance_fields.push_vars(self.vars, true);
 

+ 37 - 0
sigma_compiler_core/src/syntax.rs

@@ -4,6 +4,7 @@ use super::sigma::combiners::StatementTree;
 use super::sigma::types::*;
 use quote::format_ident;
 use std::collections::HashMap;
+use std::fmt;
 use syn::ext::IdentExt;
 use syn::parse::{Parse, ParseStream, Result};
 use syn::punctuated::Punctuated;
@@ -62,6 +63,24 @@ impl Parse for TaggedScalar {
     }
 }
 
+impl fmt::Display for TaggedScalar {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        let mut res = String::new();
+        if self.is_pub {
+            res += "pub ";
+        }
+        if self.is_rand {
+            res += "rand ";
+        }
+        if self.is_vec {
+            res += "vec ";
+        }
+        res += &self.id.to_string();
+
+        write!(f, "{res}")
+    }
+}
+
 /// A [`TaggedPoint`] is an [`struct@Ident`] representing a `Point`,
 /// preceded by zero or more of the following tags: `cind`, `const`,
 /// `vec`
@@ -117,6 +136,24 @@ impl Parse for TaggedPoint {
     }
 }
 
+impl fmt::Display for TaggedPoint {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        let mut res = String::new();
+        if self.is_vec {
+            res += "vec ";
+        }
+        if self.is_const {
+            res += "const ";
+        }
+        if self.is_cind {
+            res += "cind ";
+        }
+        res += &self.id.to_string();
+
+        write!(f, "{res}")
+    }
+}
+
 /// A [`TaggedIdent`] can be either a [`TaggedScalar`] or a
 /// [`TaggedPoint`]
 #[derive(Clone, Debug, PartialEq, Eq)]