|
|
@@ -657,7 +657,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
) -> Result<TokenStream> {
|
|
|
let le = larg.1;
|
|
|
let re = rarg.1;
|
|
|
- Ok(quote! { #le + #re })
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_vec: l_is_vec, ..
|
|
|
+ } = larg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to add_scalars");
|
|
|
+ };
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_vec: r_is_vec, ..
|
|
|
+ } = rarg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to add_scalars");
|
|
|
+ };
|
|
|
+ match (l_is_vec, r_is_vec) {
|
|
|
+ (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
|
|
|
+ (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
|
|
|
+ (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
|
|
|
+ (false, false) => Ok(quote! { #le + #re }),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// Called when adding two `Point`s
|
|
|
@@ -669,7 +686,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
) -> Result<TokenStream> {
|
|
|
let le = larg.1;
|
|
|
let re = rarg.1;
|
|
|
- Ok(quote! { #le + #re })
|
|
|
+ let AExprType::Point {
|
|
|
+ is_vec: l_is_vec, ..
|
|
|
+ } = larg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Point passed to add_points");
|
|
|
+ };
|
|
|
+ let AExprType::Point {
|
|
|
+ is_vec: r_is_vec, ..
|
|
|
+ } = rarg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Point passed to add_points");
|
|
|
+ };
|
|
|
+ match (l_is_vec, r_is_vec) {
|
|
|
+ (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
|
|
|
+ (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
|
|
|
+ (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
|
|
|
+ (false, false) => Ok(quote! { #le + #re }),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// Called when subtracting two `Scalar`s
|
|
|
@@ -681,7 +715,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
) -> Result<TokenStream> {
|
|
|
let le = larg.1;
|
|
|
let re = rarg.1;
|
|
|
- Ok(quote! { #le + (-#re) })
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_vec: l_is_vec, ..
|
|
|
+ } = larg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to sub_scalars");
|
|
|
+ };
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_vec: r_is_vec, ..
|
|
|
+ } = rarg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to sub_scalars");
|
|
|
+ };
|
|
|
+ match (l_is_vec, r_is_vec) {
|
|
|
+ (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
|
|
|
+ (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
|
|
|
+ (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
|
|
|
+ (false, false) => Ok(quote! { #le + (-#re) }),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// Called when subtracting two `Point`s
|
|
|
@@ -693,7 +744,24 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
) -> Result<TokenStream> {
|
|
|
let le = larg.1;
|
|
|
let re = rarg.1;
|
|
|
- Ok(quote! { #le - #re })
|
|
|
+ let AExprType::Point {
|
|
|
+ is_vec: l_is_vec, ..
|
|
|
+ } = larg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Point passed to sub_points");
|
|
|
+ };
|
|
|
+ let AExprType::Point {
|
|
|
+ is_vec: r_is_vec, ..
|
|
|
+ } = rarg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Point passed to sub_points");
|
|
|
+ };
|
|
|
+ match (l_is_vec, r_is_vec) {
|
|
|
+ (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
|
|
|
+ (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
|
|
|
+ (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
|
|
|
+ (false, false) => Ok(quote! { #le - #re }),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// Called when multiplying two `Scalar`s
|
|
|
@@ -703,16 +771,34 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
rarg: (AExprType, TokenStream),
|
|
|
_restype: AExprType,
|
|
|
) -> Result<TokenStream> {
|
|
|
- let le = larg.1;
|
|
|
- let re = rarg.1;
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_pub: l_is_pub,
|
|
|
+ is_vec: l_is_vec,
|
|
|
+ ..
|
|
|
+ } = larg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to mul_scalars");
|
|
|
+ };
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_pub: r_is_pub,
|
|
|
+ is_vec: r_is_vec,
|
|
|
+ ..
|
|
|
+ } = rarg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to mul_scalars");
|
|
|
+ };
|
|
|
// If one is public and one is private, put the private one on
|
|
|
// the left
|
|
|
- if matches!(larg.0, AExprType::Scalar { is_pub: true, .. })
|
|
|
- && matches!(rarg.0, AExprType::Scalar { is_pub: false, .. })
|
|
|
- {
|
|
|
- Ok(quote! { #re * #le })
|
|
|
+ let (lv, le, rv, re) = if l_is_pub && !r_is_pub {
|
|
|
+ (r_is_vec, rarg.1, l_is_vec, larg.1)
|
|
|
} else {
|
|
|
- Ok(quote! { #le * #re })
|
|
|
+ (l_is_vec, larg.1, r_is_vec, rarg.1)
|
|
|
+ };
|
|
|
+ match (lv, rv) {
|
|
|
+ (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
|
|
|
+ (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
|
|
|
+ (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
|
|
|
+ (false, false) => Ok(quote! { #le * #re }),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -724,13 +810,31 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
parg: (AExprType, TokenStream),
|
|
|
_restype: AExprType,
|
|
|
) -> Result<TokenStream> {
|
|
|
- let se = sarg.1;
|
|
|
- let pe = parg.1;
|
|
|
+ let AExprType::Scalar {
|
|
|
+ is_pub: s_is_pub,
|
|
|
+ is_vec: s_is_vec,
|
|
|
+ ..
|
|
|
+ } = sarg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Scalar passed to mul_scalar_point");
|
|
|
+ };
|
|
|
+ let AExprType::Point {
|
|
|
+ is_vec: p_is_vec, ..
|
|
|
+ } = parg.0
|
|
|
+ else {
|
|
|
+ panic!("Should not happen: non-Point passed to mul_scalar_point");
|
|
|
+ };
|
|
|
// If the Scalar is public, put it on the right
|
|
|
- if matches!(sarg.0, AExprType::Scalar { is_pub: true, .. }) {
|
|
|
- Ok(quote! { #pe * #se })
|
|
|
+ let (lv, le, rv, re) = if s_is_pub {
|
|
|
+ (p_is_vec, parg.1, s_is_vec, sarg.1)
|
|
|
} else {
|
|
|
- Ok(quote! { #se * #pe })
|
|
|
+ (s_is_vec, sarg.1, p_is_vec, parg.1)
|
|
|
+ };
|
|
|
+ match (lv, rv) {
|
|
|
+ (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
|
|
|
+ (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
|
|
|
+ (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
|
|
|
+ (false, false) => Ok(quote! { #le * #re }),
|
|
|
}
|
|
|
}
|
|
|
}
|