|
|
@@ -18,15 +18,14 @@ use super::syntax::*;
|
|
|
use super::transform::paren_if_needed;
|
|
|
use proc_macro2::TokenStream;
|
|
|
use quote::quote;
|
|
|
-use std::collections::{HashMap, HashSet};
|
|
|
+use std::collections::HashSet;
|
|
|
use syn::parse::Result;
|
|
|
use syn::visit::Visit;
|
|
|
use syn::{parse_quote, Error, Expr, Ident};
|
|
|
|
|
|
/// Find all random private `Scalar`s (according to the
|
|
|
-/// [`TaggedVarDict`]) that appear exactly once in the
|
|
|
-/// [`StatementTree`].
|
|
|
-pub fn unique_random_scalars(vars: &TaggedVarDict, st: &StatementTree) -> HashSet<String> {
|
|
|
+/// [`TaggedVarDict`]) that appear in the [`StatementTree`].
|
|
|
+pub fn random_scalars(vars: &TaggedVarDict, st: &StatementTree) -> HashSet<String> {
|
|
|
// Filter the TaggedVarDict so that it only contains the private
|
|
|
// _random_ Scalars
|
|
|
let random_private_scalars: VarDict = vars
|
|
|
@@ -44,22 +43,17 @@ pub fn unique_random_scalars(vars: &TaggedVarDict, st: &StatementTree) -> HashSe
|
|
|
.map(|(k, v)| (k.clone(), AExprType::from(v)))
|
|
|
.collect();
|
|
|
|
|
|
- let mut seen_randoms: HashMap<String, usize> = HashMap::new();
|
|
|
+ let mut seen_randoms: HashSet<String> = HashSet::new();
|
|
|
|
|
|
// Create a PrivScalarMap that will call the given closure for each
|
|
|
// private Scalar (listed in the VarDict) in a supplied expression
|
|
|
let mut var_map = PrivScalarMap {
|
|
|
vars: &random_private_scalars,
|
|
|
- // The closure counts how many times each private random Scalar
|
|
|
- // in the VarDict appears in total
|
|
|
+ // The closure records each private random Scalar in the VarDict
|
|
|
+ // we encounter
|
|
|
closure: &mut |ident| {
|
|
|
let id_str = ident.to_string();
|
|
|
- let val = seen_randoms.get(&id_str);
|
|
|
- let newval = match val {
|
|
|
- Some(n) => n + 1,
|
|
|
- None => 1,
|
|
|
- };
|
|
|
- seen_randoms.insert(id_str, newval);
|
|
|
+ seen_randoms.insert(id_str);
|
|
|
Ok(())
|
|
|
},
|
|
|
result: Ok(()),
|
|
|
@@ -69,11 +63,8 @@ pub fn unique_random_scalars(vars: &TaggedVarDict, st: &StatementTree) -> HashSe
|
|
|
for e in st.leaves() {
|
|
|
var_map.visit_expr(e);
|
|
|
}
|
|
|
- // Return a HashSet of the ones that we saw exactly once
|
|
|
+ // Return a HashSet of the ones that we saw
|
|
|
seen_randoms
|
|
|
- .into_iter()
|
|
|
- .filter_map(|(k, v)| if v == 1 { Some(k) } else { None })
|
|
|
- .collect()
|
|
|
}
|
|
|
|
|
|
/// A representation of `a*x + b` where `a` is a constant `Scalar`, `b`
|
|
|
@@ -371,9 +362,6 @@ impl Pedersen {
|
|
|
..self
|
|
|
})
|
|
|
} else if self.rand_term.id == arg.id {
|
|
|
- // This branch actually can't happen, since the private
|
|
|
- // random Scalar variable can only appear once in the
|
|
|
- // StatementTree.
|
|
|
Ok(Self {
|
|
|
rand_term: self.rand_term.add_cind(arg)?,
|
|
|
..self
|
|
|
@@ -394,9 +382,6 @@ impl Pedersen {
|
|
|
..self
|
|
|
})
|
|
|
} else if self.rand_term.id == arg.id {
|
|
|
- // This branch actually can't happen, since the private
|
|
|
- // random Scalar variable can only appear once in the
|
|
|
- // StatementTree.
|
|
|
Ok(Self {
|
|
|
rand_term: self.rand_term.add_term(arg)?,
|
|
|
..self
|
|
|
@@ -1029,22 +1014,22 @@ mod test {
|
|
|
use quote::format_ident;
|
|
|
use syn::{parse_quote, Expr};
|
|
|
|
|
|
- fn unique_random_scalars_tester(vars: (&[&str], &[&str]), e: Expr, expected: &[&str]) {
|
|
|
+ fn random_scalars_tester(vars: (&[&str], &[&str]), e: Expr, expected: &[&str]) {
|
|
|
let taggedvardict = taggedvardict_from_strs(vars);
|
|
|
let st = StatementTree::parse(&e).unwrap();
|
|
|
let expected_out = expected.iter().map(|s| s.to_string()).collect();
|
|
|
- let output = unique_random_scalars(&taggedvardict, &st);
|
|
|
+ let output = random_scalars(&taggedvardict, &st);
|
|
|
assert_eq!(output, expected_out);
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
- fn unique_random_scalars_test() {
|
|
|
+ fn random_scalars_test() {
|
|
|
let vars = (
|
|
|
["x", "y", "z", "rand r", "rand s", "rand t"].as_slice(),
|
|
|
["C", "cind A", "cind B"].as_slice(),
|
|
|
);
|
|
|
|
|
|
- unique_random_scalars_tester(
|
|
|
+ random_scalars_tester(
|
|
|
vars,
|
|
|
parse_quote! {
|
|
|
C = x*A + r*B
|
|
|
@@ -1052,7 +1037,7 @@ mod test {
|
|
|
["r"].as_slice(),
|
|
|
);
|
|
|
|
|
|
- unique_random_scalars_tester(
|
|
|
+ random_scalars_tester(
|
|
|
vars,
|
|
|
parse_quote! {
|
|
|
AND (
|
|
|
@@ -1063,7 +1048,7 @@ mod test {
|
|
|
["r", "s"].as_slice(),
|
|
|
);
|
|
|
|
|
|
- unique_random_scalars_tester(
|
|
|
+ random_scalars_tester(
|
|
|
vars,
|
|
|
parse_quote! {
|
|
|
AND (
|
|
|
@@ -1075,7 +1060,7 @@ mod test {
|
|
|
E = z*A + r*B,
|
|
|
)
|
|
|
},
|
|
|
- ["s", "t"].as_slice(),
|
|
|
+ ["r", "s", "t"].as_slice(),
|
|
|
);
|
|
|
}
|
|
|
|