Browse Source

Be able to send extra generated public values along with the proof

We need to be able to send extra variables (not necessarily appearing in
the sigma_compiler macro invocation itself, since they may be generated
by the sigma_compiler itself) from the prover to the verifier along with
the proof.  These could include commitments to bits in range proofs, for
example.
Ian Goldberg 4 months ago
parent
commit
cd47fcf33c

+ 186 - 10
sigma_compiler_core/src/codegen.rs

@@ -1,7 +1,7 @@
 //! A module for generating the code produced by this macro.  This code
 //! will interact with the underlying `sigma` macro.
 
-use super::sigma::codegen::StructFieldList;
+use super::sigma::codegen::{StructField, StructFieldList};
 use super::syntax::*;
 use proc_macro2::TokenStream;
 use quote::{format_ident, quote};
@@ -18,13 +18,31 @@ use syn::Ident;
 /// [`CodeGen::generate`] with the modified [`SigmaCompSpec`] to generate the
 /// code output by this macro.
 pub struct CodeGen {
+    /// The protocol name specified in the `sigma_compiler` macro
+    /// invocation
     proto_name: Ident,
+    /// The group name specified in the `sigma_compiler` macro
+    /// invocation
     group_name: Ident,
+    /// The variables that were explicitly listed in the
+    /// `sigma_compiler` macro invocation
     vars: TaggedVarDict,
-    // A prefix that does not appear at the beginning of any variable
-    // name in vars
+    /// A prefix that does not appear at the beginning of any variable
+    /// name in `vars`
     unique_prefix: String,
+    /// Variables (not necessarily appearing in `vars`, since they may
+    /// be generated by the sigma_compiler itself) that the prover needs
+    /// to send to the verifier along with the proof.  These could
+    /// include commitments to bits in range proofs, for example.
+    sent_params: StructFieldList,
+    /// Extra code that will be emitted in the `prove` function
     prove_code: TokenStream,
+    /// Extra code that will be emitted in the `verify` function
+    verify_code: TokenStream,
+    /// Extra code that will be emitted in the `verify` function before
+    /// the `sent_params` are deserialized.  This is where the verifier
+    /// sets the lengths of vector variables in the `sent_params`.
+    verify_pre_params_code: TokenStream,
 }
 
 impl CodeGen {
@@ -58,7 +76,10 @@ impl CodeGen {
             group_name: spec.group_name.clone(),
             vars: spec.vars.clone(),
             unique_prefix: Self::unique_prefix(&spec.vars),
+            sent_params: StructFieldList::default(),
             prove_code: quote! {},
+            verify_code: quote! {},
+            verify_pre_params_code: quote! {},
         }
     }
 
@@ -70,10 +91,80 @@ impl CodeGen {
             group_name: parse_quote! { G },
             vars: TaggedVarDict::default(),
             unique_prefix: "gen__".into(),
+            sent_params: StructFieldList::default(),
             prove_code: quote! {},
+            verify_code: quote! {},
+            verify_pre_params_code: quote! {},
         }
     }
 
+    /// Create a new generated private Scalar variable to put in the
+    /// Witness.
+    ///
+    /// If you call this, you should also call
+    /// [`prove_append`](Self::prove_append) with code like `quote!{ let
+    /// #id = ... }` where `id` is the [`struct@Ident`] returned from
+    /// this function.
+    pub fn gen_scalar(
+        &self,
+        vars: &mut TaggedVarDict,
+        base: &Ident,
+        is_rand: bool,
+        is_vec: bool,
+    ) -> Ident {
+        let id = format_ident!("{}var_{}", self.unique_prefix, base);
+        vars.insert(
+            id.to_string(),
+            TaggedIdent::Scalar(TaggedScalar {
+                id: id.clone(),
+                is_pub: false,
+                is_rand,
+                is_vec,
+            }),
+        );
+        id
+    }
+
+    /// Create a new public Point variable to put in the Params,
+    /// optionally marking it as needing to be sent from the prover to
+    /// the verifier along with the proof.
+    ///
+    /// If you call this function, you should also call
+    /// [`prove_append`](Self::prove_append) with code like `quote!{ let
+    /// #id = ... }` where `id` is the [`struct@Ident`] returned from
+    /// this function.  If `is_vec` is `true`, then you should also call
+    /// [`verify_pre_params_append`](Self::verify_pre_params_append)
+    /// with code like `quote!{ let mut #id = Vec::<Point>::new();
+    /// #id.resize(#len); }` where `len` is the number of elements you
+    /// expect to have in the vector (computed at runtime, perhaps based
+    /// on the values of public parameters).
+    pub fn gen_point(
+        &mut self,
+        vars: &mut TaggedVarDict,
+        base: &Ident,
+        is_vec: bool,
+        send_to_verifier: bool,
+    ) -> Ident {
+        let id = format_ident!("{}var_{}", self.unique_prefix, base);
+        vars.insert(
+            id.to_string(),
+            TaggedIdent::Point(TaggedPoint {
+                id: id.clone(),
+                is_cind: false,
+                is_const: false,
+                is_vec,
+            }),
+        );
+        if send_to_verifier {
+            if is_vec {
+                self.sent_params.push_vecpoint(&id);
+            } else {
+                self.sent_params.push_point(&id);
+            }
+        }
+        id
+    }
+
     /// Append some code to the generated `prove` function
     pub fn prove_append(&mut self, code: TokenStream) {
         let prove_code = &self.prove_code;
@@ -83,6 +174,25 @@ impl CodeGen {
         };
     }
 
+    /// Append some code to the generated `verify` function
+    pub fn verify_append(&mut self, code: TokenStream) {
+        let verify_code = &self.verify_code;
+        self.verify_code = quote! {
+            #verify_code
+            #code
+        };
+    }
+
+    /// Append some code to the generated `verify` function to be run
+    /// before the `sent_params` are deserialized
+    pub fn verify_pre_params_append(&mut self, code: TokenStream) {
+        let verify_pre_params_code = &self.verify_pre_params_code;
+        self.verify_pre_params_code = quote! {
+            #verify_pre_params_code
+            #code
+        };
+    }
+
     /// Generate the code to be output by this macro.
     ///
     /// `emit_prover` and `emit_verifier` are as in
@@ -159,6 +269,7 @@ impl CodeGen {
                 quote! {}
             };
             quote! {
+                #[derive(Clone)]
                 pub struct Params {
                     #decls
                 }
@@ -171,6 +282,7 @@ impl CodeGen {
         let witness_def = if emit_prover {
             let decls = witness_fields.field_decls();
             quote! {
+                #[derive(Clone)]
                 pub struct Witness {
                     #decls
                 }
@@ -179,7 +291,7 @@ impl CodeGen {
             quote! {}
         };
 
-        // Generate the (currently dummy) prove function
+        // Generate the prove function
         let prove_func = if emit_prover {
             let dumper = if cfg!(feature = "dump") {
                 quote! {
@@ -197,27 +309,49 @@ impl CodeGen {
             let prove_code = &self.prove_code;
             let codegen_params_var = format_ident!("{}sigma_params", self.unique_prefix);
             let codegen_witness_var = format_ident!("{}sigma_witness", self.unique_prefix);
+            let proof_var = format_ident!("{}proof", self.unique_prefix);
+            let sent_params_code = {
+                let chunks = self.sent_params.fields.iter().map(|sf| match sf {
+                    StructField::Point(id) => quote! {
+                        #proof_var.extend(sigma_rs::serialization::serialize_elements(
+                            std::slice::from_ref(&#codegen_params_var.#id)
+                        ));
+                    },
+                    StructField::VecPoint(id) => quote! {
+                        #proof_var.extend(sigma_rs::serialization::serialize_elements(
+                            &#codegen_params_var.#id
+                        ));
+                    },
+                    _ => quote! {},
+                });
+                quote! { #(#chunks)* }
+            };
 
             quote! {
                 pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>, SigmaError> {
                     #dumper
-                    let Params { #params_ids } = *params;
-                    let Witness { #witness_ids } = *witness;
+                    let Params { #params_ids } = params.clone();
+                    let Witness { #witness_ids } = witness.clone();
                     #prove_code
+                    let mut #proof_var = Vec::<u8>::new();
                     let #codegen_params_var = sigma::Params {
                         #sigma_rs_params_ids
                     };
                     let #codegen_witness_var = sigma::Witness {
                         #sigma_rs_witness_ids
                     };
-                    sigma::prove(&#codegen_params_var, &#codegen_witness_var)
+                    #sent_params_code
+                    #proof_var.extend(
+                        sigma::prove(&#codegen_params_var, &#codegen_witness_var)?
+                    );
+                    Ok(#proof_var)
                 }
             }
         } else {
             quote! {}
         };
 
-        // Generate the (currently dummy) verify function
+        // Generate the verify function
         let verify_func = if emit_verifier {
             let dumper = if cfg!(feature = "dump") {
                 quote! {
@@ -230,15 +364,57 @@ impl CodeGen {
             };
             let params_ids = pub_params_fields.field_list();
             let sigma_rs_params_ids = sigma_rs_params_fields.field_list();
+            let verify_pre_params_code = &self.verify_pre_params_code;
+            let verify_code = &self.verify_code;
             let codegen_params_var = format_ident!("{}sigma_params", self.unique_prefix);
+            let element_len_var = format_ident!("{}element_len", self.unique_prefix);
+            let offset_var = format_ident!("{}proof_offset", self.unique_prefix);
+            let sent_params_code = {
+                let element_len_code = if self.sent_params.fields.is_empty() {
+                    quote! {}
+                } else {
+                    quote! {
+                        let #element_len_var =
+                            <Point as group::GroupEncoding>::Repr::default().as_ref().len();
+                    }
+                };
+
+                let chunks = self.sent_params.fields.iter().map(|sf| match sf {
+                    StructField::Point(id) => quote! {
+                        let #id = sigma_rs::serialization::deserialize_elements(
+                                &proof[#offset_var..],
+                                1,
+                            ).ok_or(SigmaError::VerificationFailure)?[0];
+                        #offset_var += #element_len_var;
+                    },
+                    StructField::VecPoint(id) => quote! {
+                        #id = sigma_rs::serialization::deserialize_elements(
+                                &proof[#offset_var..],
+                                #id.len(),
+                            ).ok_or(SigmaError::VerificationFailure)?;
+                        #offset_var += #element_len_var * #id.len();
+                    },
+                    _ => quote! {},
+                });
+
+                quote! {
+                    let mut #offset_var = 0usize;
+                    #element_len_code
+                    #(#chunks)*
+                }
+            };
+
             quote! {
                 pub fn verify(params: &Params, proof: &[u8]) -> Result<(), SigmaError> {
                     #dumper
-                    let Params { #params_ids } = *params;
+                    let Params { #params_ids } = params.clone();
+                    #verify_pre_params_code
+                    #sent_params_code
+                    #verify_code
                     let #codegen_params_var = sigma::Params {
                         #sigma_rs_params_ids
                     };
-                    sigma::verify(&#codegen_params_var, proof)
+                    sigma::verify(&#codegen_params_var, &proof[#offset_var..])
                 }
             }
         } else {

+ 14 - 0
sigma_compiler_core/src/lib.rs

@@ -37,5 +37,19 @@ pub fn sigma_compiler_core(
     // Apply any substitution transformations
     transform::apply_substitutions(&mut codegen, &mut spec.statements, &mut spec.vars).unwrap();
 
+    /* Just some test code for now:
+    let C_var = codegen.gen_point(&mut spec.vars, &quote::format_ident!("C"), false, true);
+    let V_var = codegen.gen_point(&mut spec.vars, &quote::format_ident!("V"), true, true);
+    codegen.prove_append(quote::quote! {
+        let #C_var = <Point as group::Group>::generator();
+        let #V_var = vec![<Point as group::Group>::generator(), <Point as
+        group::Group>::generator()];
+    });
+    codegen.verify_pre_params_append(quote::quote! {
+        let mut #V_var = Vec::<Point>::new();
+        #V_var.resize(2, Point::default());
+    });
+    */
+
     codegen.generate(spec, emit_prover, emit_verifier)
 }

+ 10 - 7
sigma_compiler_core/src/sigma/codegen.rs

@@ -9,18 +9,18 @@ use proc_macro2::TokenStream;
 use quote::{format_ident, quote, ToTokens};
 use syn::Ident;
 
-// Names and types of fields that might end up in a generated struct
-enum StructField {
+/// Names and types of fields that might end up in a generated struct
+pub enum StructField {
     Scalar(Ident),
     VecScalar(Ident),
     Point(Ident),
     VecPoint(Ident),
 }
 
-// A list of StructField items
+/// A list of StructField items
 #[derive(Default)]
 pub struct StructFieldList {
-    fields: Vec<StructField>,
+    pub fields: Vec<StructField>,
 }
 
 impl StructFieldList {
@@ -204,6 +204,7 @@ impl<'a> CodeGen<'a> {
                 quote! {}
             };
             quote! {
+                #[derive(Clone)]
                 pub struct Params {
                     #decls
                 }
@@ -219,6 +220,7 @@ impl<'a> CodeGen<'a> {
         let witness_def = if emit_prover {
             let decls = witness_fields.field_decls();
             quote! {
+                #[derive(Clone)]
                 pub struct Witness {
                     #decls
                 }
@@ -244,8 +246,8 @@ impl<'a> CodeGen<'a> {
             quote! {
                 pub fn prove(params: &Params, witness: &Witness) -> Result<Vec<u8>, SigmaError> {
                     #dumper
-                    let Params { #params_ids } = *params;
-                    let Witness { #witness_ids } = *witness;
+                    let Params { #params_ids } = params.clone();
+                    let Witness { #witness_ids } = witness.clone();
                     Ok(Vec::<u8>::default())
                 }
             }
@@ -268,7 +270,7 @@ impl<'a> CodeGen<'a> {
             quote! {
                 pub fn verify(params: &Params, proof: &[u8]) -> Result<(), SigmaError> {
                     #dumper
-                    let Params { #params_ids } = *params;
+                    let Params { #params_ids } = params.clone();
                     Ok(())
                 }
             }
@@ -287,6 +289,7 @@ impl<'a> CodeGen<'a> {
         quote! {
             #[allow(non_snake_case)]
             pub mod #proto_name {
+                use group::ff::PrimeField;
                 use sigma_compiler::sigma_rs::errors::Error as SigmaError;
                 #dump_use