Browse Source

Add support for vector sums and dot products

In a statement, `sum(x)` will add the elements of the vector `x`.
Because of the SIMD nature of vector multiplication, `sum(x*A)` is then
naturally the dot product of the vectors `x` and `A`.
Ian Goldberg 1 month ago
parent
commit
7c81911952

+ 28 - 0
sigma_compiler_core/src/pedersen.rs

@@ -619,6 +619,34 @@ impl<'a> AExprFold<PedersenExpr> for RecognizeFold<'a> {
         }
     }
 
+    /// Called when summing a vector of `Scalar`s
+    fn sum_scalars(
+        &mut self,
+        _arg: (AExprType, PedersenExpr),
+        _restype: AExprType,
+    ) -> Result<PedersenExpr> {
+        // Sums are never recognized as components of Pedersen
+        // commitments
+        Err(Error::new(
+            proc_macro2::Span::call_site(),
+            "not a component of a Pedersen commitment",
+        ))
+    }
+
+    /// Called when summing a vector of `Point`s
+    fn sum_points(
+        &mut self,
+        _arg: (AExprType, PedersenExpr),
+        _restype: AExprType,
+    ) -> Result<PedersenExpr> {
+        // Sums are never recognized as components of Pedersen
+        // commitments
+        Err(Error::new(
+            proc_macro2::Span::call_site(),
+            "not a component of a Pedersen commitment",
+        ))
+    }
+
     /// Called when subtracting two `Scalar`s
     fn sub_scalars(
         &mut self,

+ 89 - 0
sigma_compiler_core/src/sigma/types.rs

@@ -136,6 +136,8 @@ pub fn const_i128_tokens(val: i128) -> TokenStream {
 ///   - the operations `*`, `+`, `-` (binary or unary)
 ///   - the operation `<<` where both operands are expressions with no
 ///     variables
+///   - the function `sum` that takes a single vector argument and
+///     returns the sum of its elements
 ///   - parens
 pub trait AExprFold<T> {
     /// Called when an identifier found in the [`VarDict`] is
@@ -168,6 +170,12 @@ pub trait AExprFold<T> {
         restype: AExprType,
     ) -> Result<T>;
 
+    /// Called when summing a vector of `Scalar`s
+    fn sum_scalars(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
+
+    /// Called when summing a vector of `Point`s
+    fn sum_points(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
+
     /// Called when subtracting two `Scalar`s
     fn sub_scalars(
         &mut self,
@@ -505,6 +513,53 @@ pub trait AExprFold<T> {
                     "invalid operation for arithmetic expression",
                 ))
             }
+            Expr::Call(syn::ExprCall { func, args, .. }) => {
+                let funcname = match func.as_ref() {
+                    Expr::Path(syn::ExprPath { path, .. }) => {
+                        path.get_ident().map(|id| id.to_string())
+                    }
+                    _ => None,
+                };
+                match funcname {
+                    Some(ref s) if s == "sum" => {
+                        if args.len() != 1 {
+                            return Err(Error::new(
+                                func.span(),
+                                "sum must have exactly one argument",
+                            ));
+                        }
+                        let (at, ae) = self.fold(vars, args.first().unwrap())?;
+                        match at {
+                            AExprType::Scalar {
+                                is_vec: true,
+                                is_pub,
+                                ..
+                            } => {
+                                let restype = AExprType::Scalar {
+                                    is_vec: false,
+                                    is_pub,
+                                    val: None,
+                                };
+                                let res = self.sum_scalars((at, ae), restype)?;
+                                Ok((restype, res))
+                            }
+                            AExprType::Point {
+                                is_vec: true,
+                                is_pub,
+                            } => {
+                                let restype = AExprType::Point {
+                                    is_vec: false,
+                                    is_pub,
+                                };
+                                let res = self.sum_points((at, ae), restype)?;
+                                Ok((restype, res))
+                            }
+                            _ => Err(Error::new(args.span(), "argument to sum must be a vector")),
+                        }
+                    }
+                    _ => Err(Error::new(func.span(), "unknown function")),
+                }
+            }
             _ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
         }
     }
@@ -555,6 +610,16 @@ impl AExprFold<()> for FoldNoOp {
         Ok(())
     }
 
+    /// Called when summing a vector of `Scalar`s
+    fn sum_scalars(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when summing a vector of `Point`s
+    fn sum_points(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
+        Ok(())
+    }
+
     /// Called when subtracting two `Scalar`s
     fn sub_scalars(
         &mut self,
@@ -607,6 +672,8 @@ impl AExprFold<()> for FoldNoOp {
 ///   - the operations `*`, `+`, `-` (binary or unary)
 ///   - the operation `<<` where both operands are expressions with no
 ///     variables
+///   - the function `sum` that takes a single vector argument and
+///     returns the sum of its elements
 ///   - parens
 pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
     let mut fold = FoldNoOp {};
@@ -769,6 +836,26 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         Ok(tokens_add_maybe_vec(le, l_is_vec, re, r_is_vec))
     }
 
+    /// Called when summing a vector of `Scalar`s
+    fn sum_scalars(
+        &mut self,
+        arg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let arge = arg.1;
+        Ok(quote! { sigma_compiler::vecutils::sum_vec(&(#arge)) })
+    }
+
+    /// Called when summing a vector of `Point`s
+    fn sum_points(
+        &mut self,
+        arg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let arge = arg.1;
+        Ok(quote! { sigma_compiler::vecutils::sum_vec(&(#arge)) })
+    }
+
     /// Called when subtracting two `Scalar`s
     fn sub_scalars(
         &mut self,
@@ -893,6 +980,8 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
 ///   - the operations `*`, `+`, `-` (binary or unary)
 ///   - the operation `<<` where both operands are expressions with no
 ///     variables
+///   - the function `sum` that takes a single vector argument and
+///     returns the sum of its elements
 ///   - parens
 pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
     let mut fold = AExprTokenFold {

+ 4 - 1
sigma_compiler_derive/src/lib.rs

@@ -84,6 +84,8 @@ use syn::parse_macro_input;
 ///        - the operations `*`, `+`, `-` (binary or unary)
 ///        - the operation `<<` where both operands are expressions with no
 ///          variables
+///        - the function `sum` that takes a single vector argument and
+///          returns the sum of its elements
 ///        - parens
 ///
 ///        You cannot multiply together two private subexpressions, and
@@ -100,7 +102,8 @@ use syn::parse_macro_input;
 ///        statement must be the same, and the statement is proven to
 ///        hold component-wise.  Any non-vector variable in the
 ///        statement is considered equivalent to a vector variable, all
-///        of whose entries have the same value.
+///        of whose entries have the same value.  Note that you can do a
+///        dot product between two vectors `x` and `A` with `sum(x*A)`.
 ///
 ///        As an extension, you can also use an arithmetic expression
 ///        evaluating to a _public_ `Point` in place of `C` on the left

+ 16 - 0
src/vecutils.rs

@@ -1,6 +1,7 @@
 //! A module containing some utility functions useful for the runtime
 //! processing of vector operations.
 
+use core::iter::Sum;
 use std::ops::{Add, Mul, Sub};
 
 /// Add two vectors componentwise
@@ -95,3 +96,18 @@ where
 {
     right.iter().cloned().map(|r| left.clone() * r).collect()
 }
+
+/// Add the elements of a vector.
+///
+/// This wrapper avoids the problem of the standard
+/// [`sum`](Sum#tymethod.sum) function requiring you to explicitly
+/// specify the output type.  This wrapper gives you whatever type you
+/// get by adding two values of type `T` together.  `T` must be
+/// [`Clone`] because we're adding things of type `T` and not `&T`.
+pub fn sum_vec<T, S>(summable: &[T]) -> S
+where
+    T: Add<T, Output = S> + Clone,
+    S: Sum<T>,
+{
+    summable.iter().cloned().sum()
+}

+ 41 - 0
tests/basic_sum.rs

@@ -0,0 +1,41 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn basic_sum_test_vecsize(vecsize: usize) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, y, rand vec r, rand s),
+        (vec C, D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = y*A + s*B,
+        y = sum(x),
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let s = Scalar::random(&mut rng);
+    let x: Vec<Scalar> = (0..vecsize).map(|i| Scalar::from_u128(i as u128)).collect();
+    let y: Scalar = x.iter().sum();
+    let C: Vec<G> = (0..vecsize).map(|i| x[i] * A + r[i] * B).collect();
+    let D = y * A + s * B;
+
+    let instance = proof::Instance { C, D, A, B };
+    let witness = proof::Witness { x, y, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"basic_sum_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"basic_sum_test")
+}
+
+#[test]
+fn basic_sum_test() {
+    basic_sum_test_vecsize(0).unwrap();
+    basic_sum_test_vecsize(1).unwrap();
+    basic_sum_test_vecsize(2).unwrap();
+    basic_sum_test_vecsize(20).unwrap();
+}

+ 49 - 0
tests/dot_product.rs

@@ -0,0 +1,49 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::Group;
+use sigma_compiler::*;
+
+fn dot_product_test_vecsize(vecsize: usize) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, pub vec a),
+        (C, D, E, F, vec A, B),
+        C = sum(x*A),
+        D = sum(a*A),
+        E = sum(a*x*A),
+        F = sum(a*x)*B,
+        F = sum(a*x*B),
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A: Vec<G> = (0..vecsize).map(|_| G::random(&mut rng)).collect();
+    let B = G::generator();
+    let x: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let a: Vec<Scalar> = (0..vecsize).map(|_| Scalar::random(&mut rng)).collect();
+    let C: G = (0..vecsize).map(|i| x[i] * A[i]).sum();
+    let D: G = (0..vecsize).map(|i| a[i] * A[i]).sum();
+    let E: G = (0..vecsize).map(|i| a[i] * x[i] * A[i]).sum();
+    let F: G = (0..vecsize).map(|i| a[i] * x[i] * B).sum();
+
+    let instance = proof::Instance {
+        C,
+        D,
+        E,
+        F,
+        A,
+        B,
+        a,
+    };
+    let witness = proof::Witness { x };
+
+    let proof = proof::prove(&instance, &witness, b"dot_product_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"dot_product_test")
+}
+
+#[test]
+fn dot_product_test() {
+    dot_product_test_vecsize(0).unwrap();
+    dot_product_test_vecsize(1).unwrap();
+    dot_product_test_vecsize(2).unwrap();
+    dot_product_test_vecsize(20).unwrap();
+}