3
1
Prechádzať zdrojové kódy

Flip multiplications around to suit sigma_rs

sigma_rs can do var * scalar, but not scalar * var, so automatically
flip things around as needed
Ian Goldberg 7 mesiacov pred
rodič
commit
d21bea09f9

+ 1 - 0
sigma_compiler_core/src/codegen.rs

@@ -484,6 +484,7 @@ impl CodeGen {
                 use sigma_compiler::rand::{CryptoRng, RngCore};
                 use sigma_compiler::sigma_rs;
                 use sigma_compiler::sigma_rs::errors::Error as SigmaError;
+                use std::ops::Neg;
                 #dump_use
 
                 #group_types

+ 2 - 2
sigma_compiler_core/src/rangeproof.rs

@@ -995,7 +995,7 @@ mod tests {
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
             <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (Scalar::from_u128(2u128) * r); },
+            (r * Scalar::from_u128(2u128)); },
         );
 
         convert_commitment_randomness_tester(
@@ -1009,7 +1009,7 @@ mod tests {
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
             <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
-            (Scalar::from_u128(2u128) * r +
+            (r * Scalar::from_u128(2u128) +
             (c + (Scalar::from_u128(3u128).neg()))); },
         );
     }

+ 17 - 4
sigma_compiler_core/src/sigma/types.rs

@@ -705,7 +705,15 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
     ) -> Result<TokenStream> {
         let le = larg.1;
         let re = rarg.1;
-        Ok(quote! { #le * #re })
+        // If one is public and one is private, put the private one on
+        // the left
+        if matches!(larg.0, AExprType::Scalar { is_pub: true, .. })
+            && matches!(rarg.0, AExprType::Scalar { is_pub: false, .. })
+        {
+            Ok(quote! { #re * #le })
+        } else {
+            Ok(quote! { #le * #re })
+        }
     }
 
     /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
@@ -718,7 +726,12 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
     ) -> Result<TokenStream> {
         let se = sarg.1;
         let pe = parg.1;
-        Ok(quote! { #se * #pe })
+        // If the Scalar is public, put it on the right
+        if matches!(sarg.0, AExprType::Scalar { is_pub: true, .. }) {
+            Ok(quote! { #pe * #se })
+        } else {
+            Ok(quote! { #se * #pe })
+        }
     }
 }
 
@@ -846,13 +859,13 @@ mod tests {
             &vars,
             parse_quote! {(a-(2-3))*(A+(3*4)*A)},
             quote! {
-            (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
+            (a-(Scalar::from_u128(1u128).neg()))*(A+A*(Scalar::from_u128(12u128))) },
         );
         check_tokens(
             &vars,
             parse_quote! {(a-(2-3))*(A+A*(3*4))},
             quote! {
-            (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
+            (a-(Scalar::from_u128(1u128).neg()))*(A+A*(Scalar::from_u128(12u128))) },
         );
 
         // Tests that should fail

+ 1 - 1
tests/basic.rs

@@ -12,7 +12,7 @@ fn basic_test() -> Result<(), sigma_rs::errors::Error> {
         (C, D, const cind A, const cind B),
         C = x*A + r*B,
         D = z*A + s*B,
-        z = x*2 + 1,
+        z = 2*x + 1,
     }
 
     type Scalar = <G as Group>::Scalar;

+ 1 - 1
tests/pubscalars.rs

@@ -12,7 +12,7 @@ fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
             (C, D, const cind A, const cind B),
             C = x*A + r*B,
             D = z*A + s*B,
-            z = x*2 + a,
+            z = 2*x + a,
     //        b = 2*a - 3,
         }
 

+ 1 - 1
tests/range.rs

@@ -10,7 +10,7 @@ fn range_test() -> Result<(), sigma_rs::errors::Error> {
     sigma_compiler! { proof,
         (x, y, pub a, rand r),
         (C, D, const cind A, const cind B),
-        C = (x*3+1)*A + (r*2+3)*B,
+        C = (3*x+1)*A + (2*r+3)*B,
         D = x*A + y*B,
         (a..20).contains(x),
         (0..a).contains(y),