Преглед изворни кода

More support for vector variables

Ian Goldberg пре 3 месеци
родитељ
комит
0f2fdca04c
3 измењених фајлова са 199 додато и 53 уклоњено
  1. 96 53
      sigma_compiler_core/src/pedersen.rs
  2. 51 0
      tests/disj_vec.rs
  3. 52 0
      tests/pubscalars_or_vec.rs

+ 96 - 53
sigma_compiler_core/src/pedersen.rs

@@ -879,39 +879,75 @@ pub fn convert_commitment(
     vardict: &VarDict,
 ) -> Result<TokenStream> {
     let orig_commitment = &ped_assign.id;
+    let mut is_vec = matches!(
+        vardict.get(&orig_commitment.to_string()),
+        Some(AExprType::Point { is_vec: true, .. })
+    );
+    let mut needs_clone = is_vec;
     let ped_assign_linscalar = &ped_assign.pedersen.var_term.coeff;
     let generator = &ped_assign.pedersen.var_term.id;
+    let generator_is_vec = matches!(
+        vardict.get(&generator.to_string()),
+        Some(AExprType::Point { is_vec: true, .. })
+    );
     let mut generated_code = quote! { #orig_commitment };
     // Subtract the pub_scalar_expr in ped_assign_linscalar (if present)
     // times the generator
     if let Some(ref pse) = ped_assign_linscalar.pub_scalar_expr {
-        let ppse_tokens = expr_type_tokens(vardict, &paren_if_needed(pse.clone()))?.1;
-        generated_code = quote! {
-            ( #generated_code - #ppse_tokens * #generator )
-        };
+        let (ppse_type, ppse_tokens) = expr_type_tokens(vardict, &paren_if_needed(pse.clone()))?;
+        let ppse_is_vec = matches!(ppse_type, AExprType::Scalar { is_vec: true, .. });
+        generated_code = tokens_sub_maybe_vec(
+            generated_code,
+            is_vec,
+            tokens_mul_maybe_vec(
+                ppse_tokens,
+                ppse_is_vec,
+                quote! { #generator },
+                generator_is_vec,
+            ),
+            ppse_is_vec | generator_is_vec,
+        );
+        is_vec |= ppse_is_vec | generator_is_vec;
+        needs_clone = false;
     }
     // Divide by the coeff in ped_assign_linscalar, if present (noting
     // it also cannot be 0, so will have an inverse)
     if ped_assign_linscalar.coeff != 1 {
         let coeff_tokens = const_i128_tokens(ped_assign_linscalar.coeff);
-        generated_code = quote! {
-            <Scalar as Field>::invert(&#coeff_tokens).unwrap() * #generated_code
-        };
+        generated_code = tokens_mul_maybe_vec(
+            quote! { <Scalar as Field>::invert(&#coeff_tokens).unwrap() },
+            false,
+            generated_code,
+            is_vec,
+        );
+        needs_clone = false;
     }
     // Now multiply by the coeff in new_linscalar, if present
     if new_linscalar.coeff != 1 {
         let coeff_tokens = const_i128_tokens(new_linscalar.coeff);
-        generated_code = quote! {
-            #coeff_tokens * #generated_code
-        };
+        generated_code = tokens_mul_maybe_vec(coeff_tokens, false, generated_code, is_vec);
+        needs_clone = false;
     }
     // And add the pub_scalar_expr in new_linscalar (if present) times
     // the generator
     if let Some(ref pse) = new_linscalar.pub_scalar_expr {
-        let ppse_tokens = expr_type_tokens(vardict, &paren_if_needed(pse.clone()))?.1;
-        generated_code = quote! {
-            #generated_code + #ppse_tokens * #generator
-        };
+        let (ppse_type, ppse_tokens) = expr_type_tokens(vardict, &paren_if_needed(pse.clone()))?;
+        let ppse_is_vec = matches!(ppse_type, AExprType::Scalar { is_vec: true, .. });
+        generated_code = tokens_add_maybe_vec(
+            generated_code,
+            is_vec,
+            tokens_mul_maybe_vec(
+                ppse_tokens,
+                ppse_is_vec,
+                quote! { #generator },
+                generator_is_vec,
+            ),
+            ppse_is_vec | generator_is_vec,
+        );
+        needs_clone = false;
+    }
+    if needs_clone {
+        generated_code = quote! { #generated_code.clone() };
     }
 
     Ok(quote! { let #output_commitment = #generated_code; })
@@ -928,25 +964,32 @@ pub fn convert_randomness(
 ) -> Result<TokenStream> {
     let ped_assign_linscalar = &ped_assign.pedersen.var_term.coeff;
     // Start with the LinScalar in ped_assign.pedersen.rand_term
-    let mut generated_code = expr_type_tokens(
+    let (coeff_type, mut generated_code) = expr_type_tokens(
         vardict,
         &paren_if_needed(ped_assign.pedersen.rand_term.coeff.to_expr()),
-    )?
-    .1;
+    )?;
+    let is_vec = matches!(coeff_type, AExprType::Scalar { is_vec: true, .. });
+    let mut needs_clone = is_vec;
     // Divide by the coeff in ped_assign_linscalar, if present (noting
     // it also cannot be 0, so will have an inverse)
     if ped_assign_linscalar.coeff != 1 {
         let coeff_tokens = const_i128_tokens(ped_assign_linscalar.coeff);
-        generated_code = quote! {
-            <Scalar as Field>::invert(&#coeff_tokens).unwrap() * #generated_code
-        };
+        generated_code = tokens_mul_maybe_vec(
+            quote! { <Scalar as Field>::invert(&#coeff_tokens).unwrap() },
+            false,
+            generated_code,
+            is_vec,
+        );
+        needs_clone = false;
     }
     // Now multiply by the coeff in new_linscalar, if present
     if new_linscalar.coeff != 1 {
         let coeff_tokens = const_i128_tokens(new_linscalar.coeff);
-        generated_code = quote! {
-            #coeff_tokens * #generated_code
-        };
+        generated_code = tokens_mul_maybe_vec(coeff_tokens, false, generated_code, is_vec);
+        needs_clone = false;
+    }
+    if needs_clone {
+        generated_code = quote! { #generated_code.clone() };
     }
 
     Ok(quote! { let #output_randomness = #generated_code; })
@@ -1702,8 +1745,8 @@ mod test {
             &randoms,
             parse_quote! { C = x*A + r*B },
             parse_quote! { 2 * x + 12 },
-            quote! { let out = Scalar::from_u128(2u128) * C +
-            Scalar::from_u128(12u128) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) * C) +
+            (Scalar::from_u128(12u128) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) * r; },
         );
 
@@ -1712,8 +1755,8 @@ mod test {
             &randoms,
             parse_quote! { C = x*A + r*B },
             parse_quote! { 2 * x + 12 + a },
-            quote! { let out = Scalar::from_u128(2u128) * C +
-            (Scalar::from_u128(12u128) + a) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) * C) +
+            ((Scalar::from_u128(12u128) + a) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) * r; },
         );
 
@@ -1722,11 +1765,11 @@ mod test {
             &randoms,
             parse_quote! { C = 3*x*A + r*B },
             parse_quote! { 2 * x + 12 + a },
-            quote! { let out = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128)).unwrap() * C +
-            (Scalar::from_u128(12u128) + a) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) *
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128)).unwrap() * C)) +
+            ((Scalar::from_u128(12u128) + a) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128)).unwrap() * r; },
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128)).unwrap() * r); },
         );
 
         convert_commitment_randomness_tester(
@@ -1734,11 +1777,11 @@ mod test {
             &randoms,
             parse_quote! { C = -3*x*A + r*B },
             parse_quote! { 2 * x + 12 + a },
-            quote! { let out = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * C +
-            (Scalar::from_u128(12u128) + a) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) *
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * C)) +
+            ((Scalar::from_u128(12u128) + a) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * r; },
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * r); },
         );
 
         convert_commitment_randomness_tester(
@@ -1746,12 +1789,12 @@ mod test {
             &randoms,
             parse_quote! { C = (-3*x+4+b)*A + r*B },
             parse_quote! { 2 * x + 12 + a },
-            quote! { let out = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (C - (Scalar::from_u128(4u128) + b) * A) +
-            (Scalar::from_u128(12u128) + a) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) *
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
+            (C - ((Scalar::from_u128(4u128) + b) * A)))) +
+            ((Scalar::from_u128(12u128) + a) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * r; },
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * r); },
         );
 
         convert_commitment_randomness_tester(
@@ -1759,13 +1802,13 @@ mod test {
             &randoms,
             parse_quote! { C = (-3*x+4+b)*A + 2*r*B },
             parse_quote! { 2 * x + 12 + a },
-            quote! { let out = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (C - (Scalar::from_u128(4u128) + b) * A) +
-            (Scalar::from_u128(12u128) + a) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) *
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
+            (C - ((Scalar::from_u128(4u128) + b) * A)))) +
+            ((Scalar::from_u128(12u128) + a) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (r * Scalar::from_u128(2u128)); },
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
+            (r * Scalar::from_u128(2u128))); },
         );
 
         convert_commitment_randomness_tester(
@@ -1773,14 +1816,14 @@ mod test {
             &randoms,
             parse_quote! { C = (-3*x+4+b)*A + (2*r+c-3)*B },
             parse_quote! { 2 * x + 12 + a },
-            quote! { let out = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (C - (Scalar::from_u128(4u128) + b) * A) +
-            (Scalar::from_u128(12u128) + a) * A; },
+            quote! { let out = (Scalar::from_u128(2u128) *
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
+            (C - ((Scalar::from_u128(4u128) + b) * A)))) +
+            ((Scalar::from_u128(12u128) + a) * A); },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (r * Scalar::from_u128(2u128) +
-            (c + (Scalar::from_u128(3u128).neg()))); },
+            (<Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
+            ((r * Scalar::from_u128(2u128)) +
+            (c + (Scalar::from_u128(3u128).neg())))); },
         );
     }
 }

+ 51 - 0
tests/disj_vec.rs

@@ -0,0 +1,51 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn disj_vec_test_vecsize(vecsize: usize) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, vec y, pub vec a, rand vec r, rand vec s),
+        (vec C, vec D, const cind A, const cind B),
+        C = (3*x+1)*A + r*B,
+        D = (2*y+a)*A + s*B,
+        OR (
+            y = 2*x,
+            y = 2*x + 1,
+        )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let s: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let x: Vec<Scalar> = (0..vecsize).map(|i| Scalar::from_u128(i as u128)).collect();
+    let a: Vec<Scalar> = (0..vecsize)
+        .map(|i| Scalar::from_u128((3 * i + 12) as u128))
+        .collect();
+    let y: Vec<Scalar> = (0..vecsize).map(|i| x[i] + x[i]).collect();
+    let C: Vec<G> = (0..vecsize)
+        .map(|i| (Scalar::from_u128(3) * x[i] + Scalar::ONE) * A + r[i] * B)
+        .collect();
+    let D: Vec<G> = (0..vecsize)
+        .map(|i| (y[i] + y[i] + a[i]) * A + s[i] * B)
+        .collect();
+
+    let instance = proof::Instance { C, D, A, B, a };
+    let witness = proof::Witness { x, y, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"disj_vec_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"disj_vec_test")
+}
+
+#[test]
+fn disj_vec_test() {
+    disj_vec_test_vecsize(0).unwrap();
+    disj_vec_test_vecsize(1).unwrap();
+    disj_vec_test_vecsize(2).unwrap();
+    disj_vec_test_vecsize(20).unwrap();
+}

+ 52 - 0
tests/pubscalars_or_vec.rs

@@ -0,0 +1,52 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn pubscalars_or_vec_test_vecsize_val(
+    vecsize: usize,
+    b_val: u128,
+) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, pub vec a, pub vec b, rand vec r),
+        (vec C, const cind A, const cind B),
+        C = x*A + r*B,
+        OR (
+            b = 2*a,
+            b = 2*a - 3,
+        )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let x: Vec<Scalar> = (0..vecsize).map(|i| Scalar::from_u128(i as u128)).collect();
+    let a: Vec<Scalar> = (0..vecsize)
+        .map(|i| Scalar::from_u128((i + 12) as u128))
+        .collect();
+    let b: Vec<Scalar> = (0..vecsize)
+        .map(|i| a[i] + a[i] - Scalar::from_u128(b_val))
+        .collect();
+    let C: Vec<G> = (0..vecsize).map(|i| x[i] * A + r[i] * B).collect();
+
+    let instance = proof::Instance { C, A, B, a, b };
+    let witness = proof::Witness { x, r };
+
+    let proof = proof::prove(&instance, &witness, b"pubscalars_vec_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"pubscalars_vec_test")
+}
+
+#[test]
+fn pubscalars_or_vec_test() {
+    for vecsize in [0, 1, 2, 20] {
+        pubscalars_or_vec_test_vecsize_val(vecsize, 3).unwrap();
+        pubscalars_or_vec_test_vecsize_val(vecsize, 1).unwrap_err();
+        pubscalars_or_vec_test_vecsize_val(vecsize, 2).unwrap_err();
+        pubscalars_or_vec_test_vecsize_val(vecsize, 3).unwrap();
+        pubscalars_or_vec_test_vecsize_val(vecsize, 4).unwrap_err();
+    }
+}