Explorar el Código

Improving support for vector variables

Ian Goldberg hace 6 meses
padre
commit
1de1cc05c4

+ 1 - 0
sigma_compiler_core/src/codegen.rs

@@ -501,6 +501,7 @@ impl CodeGen {
                 use sigma_compiler::rand::{CryptoRng, RngCore};
                 use sigma_compiler::sigma_rs;
                 use sigma_compiler::sigma_rs::errors::Error as SigmaError;
+                use sigma_compiler::vecutils::*;
                 use std::ops::Neg;
                 #dump_use
 

+ 8 - 2
sigma_compiler_core/src/pubscalareq.rs

@@ -53,19 +53,25 @@ pub fn transform(
                         let idstr = id.to_string();
                         if let Some(TaggedIdent::Scalar(TaggedScalar {
                             is_pub: true,
-                            is_vec: false,
+                            is_vec: l_is_vec,
                             ..
                         })) = vars.get(&idstr)
                         {
                             if let (
                                 AExprType::Scalar {
                                     is_pub: true,
-                                    is_vec: false,
+                                    is_vec: r_is_vec,
                                     ..
                                 },
                                 right_tokens,
                             ) = expr_type_tokens(&vardict, right)?
                             {
+                                if *l_is_vec != r_is_vec {
+                                    return Err(Error::new(
+                                        proc_macro2::Span::call_site(),
+                                        "Only one side of the public equality statement is a vector",
+                                    ));
+                                }
                                 // We found a public Scalar equality
                                 // statement.
                                 if in_root_disjunction_branch {

+ 16 - 15
sigma_compiler_core/src/sigma/codegen.rs

@@ -239,7 +239,7 @@ impl<'a> CodeGen<'a> {
                     }
                     | AExprType::Point { is_vec: true, .. } => {
                         vec_param_vars.insert(id.clone());
-                        Ok(quote! {#instance_var.#id[#vec_index_var]})
+                        Ok(quote! {#instance_var.#id})
                     }
                 })
                 .unwrap();
@@ -293,7 +293,7 @@ impl<'a> CodeGen<'a> {
                                 witnessvec.extend(witness.#id.clone());
                             };
                         }
-                        Ok(quote! {#id[#vec_index_var]})
+                        Ok(quote! { #id })
                     }
                     AExprType::Scalar {
                         is_vec: true,
@@ -301,7 +301,7 @@ impl<'a> CodeGen<'a> {
                         ..
                     } => {
                         vec_param_vars.insert(id.clone());
-                        Ok(quote! {#instance_var.#id[#vec_index_var]})
+                        Ok(quote! {#instance_var.#id})
                     }
                     AExprType::Point { is_vec: false, .. } => {
                         if allocated_vars.insert(id.clone()) {
@@ -335,7 +335,7 @@ impl<'a> CodeGen<'a> {
                                 }
                             };
                         }
-                        Ok(quote! {#id[#vec_index_var]})
+                        Ok(quote! { #id })
                     }
                 })
             else {
@@ -399,18 +399,18 @@ impl<'a> CodeGen<'a> {
             if right_is_vec {
                 eq_code = quote! {
                     #eq_code
-                    let #eq_id = (0..#vec_len_var)
-                        .map(|#vec_index_var| #lr_var.allocate_eq(#right_tokens))
+                    let #eq_id = (#right_tokens)
+                        .iter()
+                        .cloned()
+                        .map(|lr| #lr_var.allocate_eq(lr))
                         .collect::<Vec<_>>();
                 };
                 element_assigns = quote! {
                     #element_assigns
-                    for #vec_index_var in 0..#vec_len_var {
-                        #lr_var.set_element(
-                            #eq_id[#vec_index_var],
-                            #left_tokens,
-                        );
-                    }
+                    (#left_tokens)
+                        .iter()
+                        .zip(#eq_id.iter())
+                        .for_each(|(l,eq)| #lr_var.set_element(*eq, *l));
                 };
             } else {
                 eq_code = quote! {
@@ -729,14 +729,15 @@ impl<'a> CodeGen<'a> {
             #[allow(non_snake_case)]
             pub mod #proto_name {
                 use sigma_compiler::sigma_rs;
+                use sigma_compiler::group::ff::PrimeField;
+                use sigma_compiler::rand::{CryptoRng, RngCore};
+                use sigma_compiler::subtle::CtOption;
+                use sigma_compiler::vecutils::*;
                 use sigma_rs::{
                     composition::{ComposedRelation, ComposedWitness},
                     errors::Error as SigmaError,
                     LinearRelation, Nizk,
                 };
-                use sigma_compiler::rand::{CryptoRng, RngCore};
-                use sigma_compiler::group::ff::PrimeField;
-                use sigma_compiler::subtle::CtOption;
                 use std::ops::Neg;
                 #dump_use
 

+ 120 - 16
sigma_compiler_core/src/sigma/types.rs

@@ -657,7 +657,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
     ) -> Result<TokenStream> {
         let le = larg.1;
         let re = rarg.1;
-        Ok(quote! { #le + #re })
+        let AExprType::Scalar {
+            is_vec: l_is_vec, ..
+        } = larg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to add_scalars");
+        };
+        let AExprType::Scalar {
+            is_vec: r_is_vec, ..
+        } = rarg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to add_scalars");
+        };
+        match (l_is_vec, r_is_vec) {
+            (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
+            (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
+            (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
+            (false, false) => Ok(quote! { #le + #re }),
+        }
     }
 
     /// Called when adding two `Point`s
@@ -669,7 +686,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
     ) -> Result<TokenStream> {
         let le = larg.1;
         let re = rarg.1;
-        Ok(quote! { #le + #re })
+        let AExprType::Point {
+            is_vec: l_is_vec, ..
+        } = larg.0
+        else {
+            panic!("Should not happen: non-Point passed to add_points");
+        };
+        let AExprType::Point {
+            is_vec: r_is_vec, ..
+        } = rarg.0
+        else {
+            panic!("Should not happen: non-Point passed to add_points");
+        };
+        match (l_is_vec, r_is_vec) {
+            (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
+            (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
+            (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
+            (false, false) => Ok(quote! { #le + #re }),
+        }
     }
 
     /// Called when subtracting two `Scalar`s
@@ -681,7 +715,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
     ) -> Result<TokenStream> {
         let le = larg.1;
         let re = rarg.1;
-        Ok(quote! { #le + (-#re) })
+        let AExprType::Scalar {
+            is_vec: l_is_vec, ..
+        } = larg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to sub_scalars");
+        };
+        let AExprType::Scalar {
+            is_vec: r_is_vec, ..
+        } = rarg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to sub_scalars");
+        };
+        match (l_is_vec, r_is_vec) {
+            (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
+            (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
+            (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
+            (false, false) => Ok(quote! { #le + (-#re) }),
+        }
     }
 
     /// Called when subtracting two `Point`s
@@ -693,7 +744,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
     ) -> Result<TokenStream> {
         let le = larg.1;
         let re = rarg.1;
-        Ok(quote! { #le - #re })
+        let AExprType::Point {
+            is_vec: l_is_vec, ..
+        } = larg.0
+        else {
+            panic!("Should not happen: non-Point passed to sub_points");
+        };
+        let AExprType::Point {
+            is_vec: r_is_vec, ..
+        } = rarg.0
+        else {
+            panic!("Should not happen: non-Point passed to sub_points");
+        };
+        match (l_is_vec, r_is_vec) {
+            (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
+            (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
+            (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
+            (false, false) => Ok(quote! { #le - #re }),
+        }
     }
 
     /// Called when multiplying two `Scalar`s
@@ -703,16 +771,34 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         rarg: (AExprType, TokenStream),
         _restype: AExprType,
     ) -> Result<TokenStream> {
-        let le = larg.1;
-        let re = rarg.1;
+        let AExprType::Scalar {
+            is_pub: l_is_pub,
+            is_vec: l_is_vec,
+            ..
+        } = larg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to mul_scalars");
+        };
+        let AExprType::Scalar {
+            is_pub: r_is_pub,
+            is_vec: r_is_vec,
+            ..
+        } = rarg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to mul_scalars");
+        };
         // If one is public and one is private, put the private one on
         // the left
-        if matches!(larg.0, AExprType::Scalar { is_pub: true, .. })
-            && matches!(rarg.0, AExprType::Scalar { is_pub: false, .. })
-        {
-            Ok(quote! { #re * #le })
+        let (lv, le, rv, re) = if l_is_pub && !r_is_pub {
+            (r_is_vec, rarg.1, l_is_vec, larg.1)
         } else {
-            Ok(quote! { #le * #re })
+            (l_is_vec, larg.1, r_is_vec, rarg.1)
+        };
+        match (lv, rv) {
+            (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
+            (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
+            (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
+            (false, false) => Ok(quote! { #le * #re }),
         }
     }
 
@@ -724,13 +810,31 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         parg: (AExprType, TokenStream),
         _restype: AExprType,
     ) -> Result<TokenStream> {
-        let se = sarg.1;
-        let pe = parg.1;
+        let AExprType::Scalar {
+            is_pub: s_is_pub,
+            is_vec: s_is_vec,
+            ..
+        } = sarg.0
+        else {
+            panic!("Should not happen: non-Scalar passed to mul_scalar_point");
+        };
+        let AExprType::Point {
+            is_vec: p_is_vec, ..
+        } = parg.0
+        else {
+            panic!("Should not happen: non-Point passed to mul_scalar_point");
+        };
         // If the Scalar is public, put it on the right
-        if matches!(sarg.0, AExprType::Scalar { is_pub: true, .. }) {
-            Ok(quote! { #pe * #se })
+        let (lv, le, rv, re) = if s_is_pub {
+            (p_is_vec, parg.1, s_is_vec, sarg.1)
         } else {
-            Ok(quote! { #se * #pe })
+            (s_is_vec, sarg.1, p_is_vec, parg.1)
+        };
+        match (lv, rv) {
+            (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
+            (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
+            (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
+            (false, false) => Ok(quote! { #le * #re }),
         }
     }
 }

+ 1 - 0
src/lib.rs

@@ -5,3 +5,4 @@ pub use sigma_rs;
 pub use subtle;
 
 pub mod rangeutils;
+pub mod vecutils;

+ 97 - 0
src/vecutils.rs

@@ -0,0 +1,97 @@
+//! A module containing some utility functions useful for the runtime
+//! processing of vector operations.
+
+use std::ops::{Add, Mul, Sub};
+
+/// Add two vectors componentwise
+pub fn add_vecs<L, R, P>(left: &[L], right: &[R]) -> Vec<P>
+where
+    L: Add<R, Output = P> + Clone,
+    R: Clone,
+{
+    left.iter()
+        .cloned()
+        .zip(right.iter().cloned())
+        .map(|(l, r)| l + r)
+        .collect()
+}
+
+/// Add components of a vector by a non-vector
+pub fn add_vec_nv<L, R, P>(left: &[L], right: &R) -> Vec<P>
+where
+    L: Add<R, Output = P> + Clone,
+    R: Clone,
+{
+    left.iter().cloned().map(|l| l + right.clone()).collect()
+}
+
+/// Add a non-vector by components of a vector
+pub fn add_nv_vec<L, R, P>(left: &L, right: &[R]) -> Vec<P>
+where
+    L: Add<R, Output = P> + Clone,
+    R: Clone,
+{
+    right.iter().cloned().map(|r| left.clone() + r).collect()
+}
+
+/// Subtract two vectors componentwise
+pub fn sub_vecs<L, R, P>(left: &[L], right: &[R]) -> Vec<P>
+where
+    L: Sub<R, Output = P> + Clone,
+    R: Clone,
+{
+    left.iter()
+        .cloned()
+        .zip(right.iter().cloned())
+        .map(|(l, r)| l - r)
+        .collect()
+}
+
+/// Subtract components of a vector by a non-vector
+pub fn sub_vec_nv<L, R, P>(left: &[L], right: &R) -> Vec<P>
+where
+    L: Sub<R, Output = P> + Clone,
+    R: Clone,
+{
+    left.iter().cloned().map(|l| l - right.clone()).collect()
+}
+
+/// Subtract a non-vector by components of a vector
+pub fn sub_nv_vec<L, R, P>(left: &L, right: &[R]) -> Vec<P>
+where
+    L: Sub<R, Output = P> + Clone,
+    R: Clone,
+{
+    right.iter().cloned().map(|r| left.clone() - r).collect()
+}
+
+/// Multiply two vectors componentwise
+pub fn mul_vecs<L, R, P>(left: &[L], right: &[R]) -> Vec<P>
+where
+    L: Mul<R, Output = P> + Clone,
+    R: Clone,
+{
+    left.iter()
+        .cloned()
+        .zip(right.iter().cloned())
+        .map(|(l, r)| l * r)
+        .collect()
+}
+
+/// Multiply components of a vector by a non-vector
+pub fn mul_vec_nv<L, R, P>(left: &[L], right: &R) -> Vec<P>
+where
+    L: Mul<R, Output = P> + Clone,
+    R: Clone,
+{
+    left.iter().cloned().map(|l| l * right.clone()).collect()
+}
+
+/// Multiply a non-vector by components of a vector
+pub fn mul_nv_vec<L, R, P>(left: &L, right: &[R]) -> Vec<P>
+where
+    L: Mul<R, Output = P> + Clone,
+    R: Clone,
+{
+    right.iter().cloned().map(|r| left.clone() * r).collect()
+}

+ 48 - 0
tests/pubscalars_vec.rs

@@ -0,0 +1,48 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn pubscalars_vec_test_vecsize(vecsize: usize) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, vec y, pub vec a, pub vec b, rand vec r, rand vec s),
+        (vec C, vec D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = y*A + s*B,
+        y = 2*x + b,
+        b = 3*a - 7,
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let s: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let x: Vec<Scalar> = (0..vecsize).map(|i| Scalar::from_u128(i as u128)).collect();
+    let a: Vec<Scalar> = (0..vecsize)
+        .map(|i| Scalar::from_u128((i + 12) as u128))
+        .collect();
+    let b: Vec<Scalar> = (0..vecsize)
+        .map(|i| a[i] + a[i] + a[i] - Scalar::from_u128(7))
+        .collect();
+    let y: Vec<Scalar> = (0..vecsize).map(|i| x[i] + x[i] + b[i]).collect();
+    let C: Vec<G> = (0..vecsize).map(|i| x[i] * A + r[i] * B).collect();
+    let D: Vec<G> = (0..vecsize).map(|i| y[i] * A + s[i] * B).collect();
+
+    let instance = proof::Instance { C, D, A, B, a, b };
+    let witness = proof::Witness { x, y, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"pubscalars_vec_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"pubscalars_vec_test")
+}
+
+#[test]
+fn pubscalars_vec_test() {
+    pubscalars_vec_test_vecsize(0).unwrap();
+    pubscalars_vec_test_vecsize(1).unwrap();
+    pubscalars_vec_test_vecsize(2).unwrap();
+    pubscalars_vec_test_vecsize(20).unwrap();
+}

+ 41 - 0
tests/substitution_vec.rs

@@ -0,0 +1,41 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn substitution_vec_test_vecsize(vecsize: usize) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, vec y, rand vec r, rand vec s),
+        (vec C, vec D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = y*A + s*B,
+        y = 2*x + 1,
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let s: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let x: Vec<Scalar> = (0..vecsize).map(|i| Scalar::from_u128(i as u128)).collect();
+    let y: Vec<Scalar> = x.iter().map(|v| v + v + Scalar::ONE).collect();
+    let C: Vec<G> = (0..vecsize).map(|i| x[i] * A + r[i] * B).collect();
+    let D: Vec<G> = (0..vecsize).map(|i| y[i] * A + s[i] * B).collect();
+
+    let instance = proof::Instance { C, D, A, B };
+    let witness = proof::Witness { x, y, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"substitution_vec_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"substitution_vec_test")
+}
+
+#[test]
+fn substitution_vec_test() {
+    substitution_vec_test_vecsize(0).unwrap();
+    substitution_vec_test_vecsize(1).unwrap();
+    substitution_vec_test_vecsize(2).unwrap();
+    substitution_vec_test_vecsize(20).unwrap();
+}

+ 45 - 0
tests/subtract_vec.rs

@@ -0,0 +1,45 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn subtract_vec_test_vecsize(vecsize: usize) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x),
+        (vec C, vec D, vec E, const cind A, const cind B),
+        C = (x-1)*A,
+        D = (x-2)*B - C,
+        E = (x-2)*B - A,
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let x: Vec<Scalar> = (0..vecsize)
+        .map(|i| Scalar::from_u128((i + 5) as u128))
+        .collect();
+    let C: Vec<G> = (0..vecsize).map(|i| (x[i] - Scalar::ONE) * A).collect();
+    let D: Vec<G> = (0..vecsize)
+        .map(|i| (x[i] - Scalar::from_u128(2)) * B - C[i])
+        .collect();
+    let E: Vec<G> = (0..vecsize)
+        .map(|i| (x[i] - Scalar::from_u128(2)) * B - A)
+        .collect();
+
+    let instance = proof::Instance { C, D, E, A, B };
+    let witness = proof::Witness { x };
+
+    let proof = proof::prove(&instance, &witness, b"subtract_vec_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"subtract_vec_test")
+}
+
+#[test]
+fn subtract_vec_test() {
+    subtract_vec_test_vecsize(0).unwrap();
+    subtract_vec_test_vecsize(1).unwrap();
+    subtract_vec_test_vecsize(2).unwrap();
+    subtract_vec_test_vecsize(20).unwrap();
+}