Просмотр исходного кода

Start on the codegen glue to the sigma-rs API

Ian Goldberg 7 месяцев назад
Родитель
Сommit
82a074f53c

+ 12 - 8
sigma_compiler_core/src/codegen.rs

@@ -330,6 +330,9 @@ impl CodeGen {
             let prove_code = &self.prove_code;
             let codegen_params_var = format_ident!("{}sigma_params", self.unique_prefix);
             let codegen_witness_var = format_ident!("{}sigma_witness", self.unique_prefix);
+            let params_var = format_ident!("{}params", self.unique_prefix);
+            let witness_var = format_ident!("{}witness", self.unique_prefix);
+            let rng_var = format_ident!("{}rng", self.unique_prefix);
             let proof_var = format_ident!("{}proof", self.unique_prefix);
             let sid_var = format_ident!("{}session_id", self.unique_prefix);
             let sent_params_code = {
@@ -351,14 +354,14 @@ impl CodeGen {
 
             quote! {
                 pub fn prove(
-                    params: &Params,
-                    witness: &Witness,
+                    #params_var: &Params,
+                    #witness_var: &Witness,
                     #sid_var: &[u8],
-                    rng: &mut (impl CryptoRng + RngCore),
+                    #rng_var: &mut (impl CryptoRng + RngCore),
                 ) -> Result<Vec<u8>, SigmaError> {
                     #dumper
-                    let Params { #params_ids } = params.clone();
-                    let Witness { #witness_ids } = witness.clone();
+                    let Params { #params_ids } = #params_var.clone();
+                    let Witness { #witness_ids } = #witness_var.clone();
                     #prove_code
                     let mut #proof_var = Vec::<u8>::new();
                     let #codegen_params_var = sigma::Params {
@@ -373,7 +376,7 @@ impl CodeGen {
                             &#codegen_params_var,
                             &#codegen_witness_var,
                             #sid_var,
-                            rng,
+                            #rng_var,
                         )?
                     );
                     Ok(#proof_var)
@@ -401,6 +404,7 @@ impl CodeGen {
             let codegen_params_var = format_ident!("{}sigma_params", self.unique_prefix);
             let element_len_var = format_ident!("{}element_len", self.unique_prefix);
             let offset_var = format_ident!("{}proof_offset", self.unique_prefix);
+            let params_var = format_ident!("{}params", self.unique_prefix);
             let proof_var = format_ident!("{}proof", self.unique_prefix);
             let sid_var = format_ident!("{}session_id", self.unique_prefix);
             let sent_params_code = {
@@ -440,12 +444,12 @@ impl CodeGen {
 
             quote! {
                 pub fn verify(
-                    params: &Params,
+                    #params_var: &Params,
                     #proof_var: &[u8],
                     #sid_var: &[u8],
                 ) -> Result<(), SigmaError> {
                     #dumper
-                    let Params { #params_ids } = params.clone();
+                    let Params { #params_ids } = #params_var.clone();
                     #verify_pre_params_code
                     #sent_params_code
                     #verify_code

+ 5 - 2
sigma_compiler_core/src/rangeproof.rs

@@ -333,6 +333,9 @@ pub fn transform(
     // Count how many range statements we've seen
     let mut range_stmt_index = 0usize;
 
+    // The generated variable name for the rng
+    let rng_var = codegen.gen_ident(&format_ident!("rng"));
+
     for leaf in leaves.iter_mut() {
         // For each leaf expression, see if it looks like a range statement
         let StatementTree::Leaf(leafexpr) = leaf else {
@@ -392,7 +395,7 @@ pub fn transform(
                 recognize_pedersen_assignment(vars, &randoms, &vardict, &ped_assign_expr).unwrap();
 
             codegen.prove_append(quote! {
-                let #rand_var = Scalar::random(rng);
+                let #rand_var = Scalar::random(#rng_var);
                 let #ped_assign_expr;
             });
 
@@ -602,7 +605,7 @@ pub fn transform(
             // Choose randomizers r for the commitments randomly
             let #bitrand_var: Vec<Scalar> =
                 (0..(#nbits_var-1))
-                    .map(|_| Scalar::random(rng))
+                    .map(|_| Scalar::random(#rng_var))
                     .collect();
             // The randomizers s for the commitments to the squares are
             // chosen as above: s=r if b=0 and s=0 if b=1.

+ 138 - 17
sigma_compiler_core/src/sigma/codegen.rs

@@ -139,10 +139,33 @@ pub struct CodeGen<'a> {
     proto_name: Ident,
     group_name: Ident,
     vars: &'a VarDict,
+    unique_prefix: String,
     statements: &'a mut StatementTree,
 }
 
 impl<'a> CodeGen<'a> {
+    /// Find a prefix that does not appear at the beginning of any
+    /// variable name in `vars`
+    fn unique_prefix(vars: &VarDict) -> String {
+        'outer: for tag in 0usize.. {
+            let try_prefix = if tag == 0 {
+                "sigma__".to_string()
+            } else {
+                format!("sigma{}__", tag)
+            };
+            for v in vars.keys() {
+                if v.starts_with(&try_prefix) {
+                    continue 'outer;
+                }
+            }
+            return try_prefix;
+        }
+        // The compiler complains if this isn't here, but it will only
+        // get hit if vars contains at least usize::MAX entries, which
+        // isn't going to happen.
+        String::new()
+    }
+
     pub fn new(
         proto_name: Ident,
         group_name: Ident,
@@ -153,10 +176,38 @@ impl<'a> CodeGen<'a> {
             proto_name,
             group_name,
             vars,
+            unique_prefix: Self::unique_prefix(vars),
             statements,
         }
     }
 
+    /// Generate the code for the `protocol` and `protocol_witness`
+    /// functions that create the `Protocol` and `ProtocolWitness`
+    /// structs, respectively, given a [`VarDict`] and a
+    /// [`StatementTree`] describing the statements to be proven.
+    /// `node_num` is a sequentially increasing counter for nodes in the
+    /// `Protocol` tree.  The function returns the next `node_num` to
+    /// use as the first component of its output.  The other components
+    /// are the code for the `protocol` and `protocol_witness`
+    /// functions, respectively.
+    fn proto_witness_codegen(
+        &self,
+        node_num: usize,
+        statement: &StatementTree,
+    ) -> (usize, TokenStream, TokenStream) {
+        let proto_var = format_ident!("{}proto_{}", self.unique_prefix, node_num,);
+        let proto_witness_var = format_ident!("{}proto_witness_{}", self.unique_prefix, node_num,);
+        (
+            0,
+            quote! {
+                let #proto_var = Protocol::from(LinearRelation::<Point>::new());
+            },
+            quote! {
+                let #proto_witness_var = ProtocolWitness::Simple(vec![]);
+            },
+        )
+    }
+
     /// Generate the code that uses the `sigma-rs` API to prove and
     /// verify the statements in the [`CodeGen`].
     ///
@@ -236,7 +287,47 @@ impl<'a> CodeGen<'a> {
             quote! {}
         };
 
-        // Generate the (currently dummy) prove function
+        let (_, protocol_code, witness_code) = self.proto_witness_codegen(0, self.statements);
+
+        // Generate the function that creates the sigma-rs Protocol
+        let protocol_func = {
+            let params_ids = pub_params_fields.field_list();
+            let params_var = format_ident!("{}params", self.unique_prefix);
+            let proto_var = format_ident!("{}proto_0", self.unique_prefix);
+
+            quote! {
+                fn protocol(
+                    #params_var: &Params,
+                ) -> Result<Protocol<Point>, SigmaError> {
+                    let Params { #params_ids } = #params_var.clone();
+                    #protocol_code
+                    Ok(#proto_var)
+                }
+            }
+        };
+
+        // Generate the function that creates the sigma-rs ProtocolWitness
+        let witness_func = {
+            let params_ids = pub_params_fields.field_list();
+            let witness_ids = witness_fields.field_list();
+            let params_var = format_ident!("{}params", self.unique_prefix);
+            let witness_var = format_ident!("{}witness", self.unique_prefix);
+            let proto_witness_var = format_ident!("{}proto_witness_0", self.unique_prefix);
+
+            quote! {
+                fn protocol_witness(
+                    #params_var: &Params,
+                    #witness_var: &Witness,
+                ) -> Result<ProtocolWitness<Point>, SigmaError> {
+                    let Params { #params_ids } = #params_var.clone();
+                    let Witness { #witness_ids } = #witness_var.clone();
+                    #witness_code
+                    Ok(#proto_witness_var)
+                }
+            }
+        };
+
+        // Generate the prove function
         let prove_func = if emit_prover {
             let dumper = if cfg!(feature = "dump") {
                 quote! {
@@ -247,27 +338,38 @@ impl<'a> CodeGen<'a> {
             } else {
                 quote! {}
             };
-            let params_ids = pub_params_fields.field_list();
-            let witness_ids = witness_fields.field_list();
+            let params_var = format_ident!("{}params", self.unique_prefix);
+            let witness_var = format_ident!("{}witness", self.unique_prefix);
+            let session_id_var = format_ident!("{}session_id", self.unique_prefix);
+            let rng_var = format_ident!("{}rng", self.unique_prefix);
+            let proto_var = format_ident!("{}proto", self.unique_prefix);
+            let proto_witness_var = format_ident!("{}proto_witness", self.unique_prefix);
+            let nizk_var = format_ident!("{}nizk", self.unique_prefix);
 
             quote! {
                 pub fn prove(
-                    params: &Params,
-                    witness: &Witness,
-                    session_id: &[u8],
-                    rng: &mut (impl CryptoRng + RngCore),
+                    #params_var: &Params,
+                    #witness_var: &Witness,
+                    #session_id_var: &[u8],
+                    #rng_var: &mut (impl CryptoRng + RngCore),
                 ) -> Result<Vec<u8>, SigmaError> {
                     #dumper
-                    let Params { #params_ids } = params.clone();
-                    let Witness { #witness_ids } = witness.clone();
-                    Ok(Vec::<u8>::default())
+                    let #proto_var = protocol(#params_var)?;
+                    let #proto_witness_var = protocol_witness(#params_var, #witness_var)?;
+                    let #nizk_var =
+                        NISigmaProtocol::<_, ShakeCodec<Point>>::new(
+                            #session_id_var,
+                            #proto_var,
+                        );
+
+                    #nizk_var.prove_batchable(&#proto_witness_var, #rng_var)
                 }
             }
         } else {
             quote! {}
         };
 
-        // Generate the (currently dummy) verify function
+        // Generate the verify function
         let verify_func = if emit_verifier {
             let dumper = if cfg!(feature = "dump") {
                 quote! {
@@ -278,16 +380,28 @@ impl<'a> CodeGen<'a> {
             } else {
                 quote! {}
             };
-            let params_ids = pub_params_fields.field_list();
+
+            let params_var = format_ident!("{}params", self.unique_prefix);
+            let proof_var = format_ident!("{}proof", self.unique_prefix);
+            let session_id_var = format_ident!("{}session_id", self.unique_prefix);
+            let proto_var = format_ident!("{}proto", self.unique_prefix);
+            let nizk_var = format_ident!("{}nizk", self.unique_prefix);
+
             quote! {
                 pub fn verify(
-                    params: &Params,
-                    proof: &[u8],
-                    session_id: &[u8],
+                    #params_var: &Params,
+                    #proof_var: &[u8],
+                    #session_id_var: &[u8],
                 ) -> Result<(), SigmaError> {
                     #dumper
-                    let Params { #params_ids } = params.clone();
-                    Ok(())
+                    let #proto_var = protocol(#params_var)?;
+                    let #nizk_var =
+                        NISigmaProtocol::<_, ShakeCodec<Point>>::new(
+                            #session_id_var,
+                            #proto_var,
+                        );
+
+                    #nizk_var.verify_batchable(#proof_var)
                 }
             }
         } else {
@@ -305,6 +419,11 @@ impl<'a> CodeGen<'a> {
         quote! {
             #[allow(non_snake_case)]
             pub mod #proto_name {
+                use sigma_rs::{
+                    codec::ShakeCodec,
+                    composition::{Protocol, ProtocolWitness},
+                    LinearRelation, NISigmaProtocol,
+                };
                 use sigma_compiler::rand::{CryptoRng, RngCore};
                 use sigma_compiler::group::ff::PrimeField;
                 use sigma_compiler::sigma_rs::errors::Error as SigmaError;
@@ -314,6 +433,8 @@ impl<'a> CodeGen<'a> {
                 #params_def
                 #witness_def
 
+                #protocol_func
+                #witness_func
                 #prove_func
                 #verify_func
             }