Просмотр исходного кода

prove() and verify() now take a session id as an additional argument

Ian Goldberg 3 месяцев назад
Родитель
Сommit
b5fc2df4c9

+ 17 - 4
sigma_compiler_core/src/codegen.rs

@@ -331,6 +331,7 @@ impl CodeGen {
             let codegen_params_var = format_ident!("{}sigma_params", self.unique_prefix);
             let codegen_witness_var = format_ident!("{}sigma_witness", self.unique_prefix);
             let proof_var = format_ident!("{}proof", self.unique_prefix);
+            let sid_var = format_ident!("{}session_id", self.unique_prefix);
             let sent_params_code = {
                 let chunks = self.sent_params.fields.iter().map(|sf| match sf {
                     StructField::Point(id) => quote! {
@@ -352,6 +353,7 @@ impl CodeGen {
                 pub fn prove(
                     params: &Params,
                     witness: &Witness,
+                    #sid_var: &[u8],
                     rng: &mut (impl CryptoRng + RngCore),
                 ) -> Result<Vec<u8>, SigmaError> {
                     #dumper
@@ -370,6 +372,7 @@ impl CodeGen {
                         sigma::prove(
                             &#codegen_params_var,
                             &#codegen_witness_var,
+                            #sid_var,
                             rng,
                         )?
                     );
@@ -398,6 +401,8 @@ impl CodeGen {
             let codegen_params_var = format_ident!("{}sigma_params", self.unique_prefix);
             let element_len_var = format_ident!("{}element_len", self.unique_prefix);
             let offset_var = format_ident!("{}proof_offset", self.unique_prefix);
+            let proof_var = format_ident!("{}proof", self.unique_prefix);
+            let sid_var = format_ident!("{}session_id", self.unique_prefix);
             let sent_params_code = {
                 let element_len_code = if self.sent_params.fields.is_empty() {
                     quote! {}
@@ -411,14 +416,14 @@ impl CodeGen {
                 let chunks = self.sent_params.fields.iter().map(|sf| match sf {
                     StructField::Point(id) => quote! {
                         let #id: Point = sigma_rs::serialization::deserialize_elements(
-                                &proof[#offset_var..],
+                                &#proof_var[#offset_var..],
                                 1,
                             ).ok_or(SigmaError::VerificationFailure)?[0];
                         #offset_var += #element_len_var;
                     },
                     StructField::VecPoint(id) => quote! {
                         #id = sigma_rs::serialization::deserialize_elements(
-                                &proof[#offset_var..],
+                                &#proof_var[#offset_var..],
                                 #id.len(),
                             ).ok_or(SigmaError::VerificationFailure)?;
                         #offset_var += #element_len_var * #id.len();
@@ -434,7 +439,11 @@ impl CodeGen {
             };
 
             quote! {
-                pub fn verify(params: &Params, proof: &[u8]) -> Result<(), SigmaError> {
+                pub fn verify(
+                    params: &Params,
+                    #proof_var: &[u8],
+                    #sid_var: &[u8],
+                ) -> Result<(), SigmaError> {
                     #dumper
                     let Params { #params_ids } = params.clone();
                     #verify_pre_params_code
@@ -443,7 +452,11 @@ impl CodeGen {
                     let #codegen_params_var = sigma::Params {
                         #sigma_rs_params_ids
                     };
-                    sigma::verify(&#codegen_params_var, &proof[#offset_var..])
+                    sigma::verify(
+                        &#codegen_params_var,
+                        &#proof_var[#offset_var..],
+                        #sid_var,
+                    )
                 }
             }
         } else {

+ 6 - 1
sigma_compiler_core/src/sigma/codegen.rs

@@ -247,6 +247,7 @@ impl<'a> CodeGen<'a> {
                 pub fn prove(
                     params: &Params,
                     witness: &Witness,
+                    session_id: &[u8],
                     rng: &mut (impl CryptoRng + RngCore),
                 ) -> Result<Vec<u8>, SigmaError> {
                     #dumper
@@ -272,7 +273,11 @@ impl<'a> CodeGen<'a> {
             };
             let params_ids = pub_params_fields.field_list();
             quote! {
-                pub fn verify(params: &Params, proof: &[u8]) -> Result<(), SigmaError> {
+                pub fn verify(
+                    params: &Params,
+                    proof: &[u8],
+                    session_id: &[u8],
+                ) -> Result<(), SigmaError> {
                     #dumper
                     let Params { #params_ids } = params.clone();
                     Ok(())

+ 2 - 2
tests/basic.rs

@@ -29,6 +29,6 @@ fn basic_test() -> Result<(), sigma_rs::errors::Error> {
     let params = proof::Params { C, D, A, B };
     let witness = proof::Witness { x, z, r, s };
 
-    let proof = proof::prove(&params, &witness, &mut rng)?;
-    proof::verify(&params, &proof)
+    let proof = proof::prove(&params, &witness, b"basic_test", &mut rng)?;
+    proof::verify(&params, &proof, b"basic_test")
 }

+ 2 - 2
tests/pubscalars.rs

@@ -32,6 +32,6 @@ fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
     let params = proof::Params { C, D, A, B, a, b };
     let witness = proof::Witness { x, z, r, s };
 
-    let proof = proof::prove(&params, &witness, &mut rng)?;
-    proof::verify(&params, &proof)
+    let proof = proof::prove(&params, &witness, b"pubscalars_test", &mut rng)?;
+    proof::verify(&params, &proof, b"pubscalars_test")
 }

+ 2 - 2
tests/range.rs

@@ -30,6 +30,6 @@ fn range_test() -> Result<(), sigma_rs::errors::Error> {
     let params = proof::Params { C, D, a, A, B };
     let witness = proof::Witness { x, y, r };
 
-    let proof = proof::prove(&params, &witness, &mut rng)?;
-    proof::verify(&params, &proof)
+    let proof = proof::prove(&params, &witness, b"range_test", &mut rng)?;
+    proof::verify(&params, &proof, b"range_test")
 }