pedersen.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. //! A module for finding and manipulating Pedersen commitments in a
  2. //! [`StatementTree`].
  3. //!
  4. //! A Pedersen commitment to a private `Scalar` `x` looks like
  5. //!
  6. //! `C = (a*x+b)*A + (c*r+d)*B`
  7. //!
  8. //! Where `a` and `c` are a constant non-zero `Scalar`s (defaults to
  9. //! [`Scalar::ONE`](https://docs.rs/ff/0.13.1/ff/trait.Field.html#associatedconstant.ONE)),
  10. //! `b`, and `d` are public `Scalar`s or constants (or combinations of
  11. //! those), `r` is a random private `Scalar` that appears nowhere else
  12. //! in the [`StatementTree`], `C` is a public `Point`, and `A` and `B`
  13. //! are computationally independent public `Point`s.
  14. use super::sigma::combiners::*;
  15. use super::sigma::types::*;
  16. use super::syntax::*;
  17. use std::collections::{HashMap, HashSet};
  18. use syn::parse::Result;
  19. use syn::visit::Visit;
  20. use syn::{parse_quote, Error, Expr, Ident};
  21. /// Find all random private `Scalar`s (according to the
  22. /// [`TaggedVarDict`]) that appear exactly once in the
  23. /// [`StatementTree`].
  24. pub fn unique_random_scalars(vars: &TaggedVarDict, st: &StatementTree) -> HashSet<String> {
  25. // Filter the TaggedVarDict so that it only contains the private
  26. // _random_ Scalars
  27. let random_private_scalars: VarDict = vars
  28. .iter()
  29. .filter(|(_, v)| {
  30. matches!(
  31. v,
  32. TaggedIdent::Scalar(TaggedScalar {
  33. is_pub: false,
  34. is_rand: true,
  35. ..
  36. })
  37. )
  38. })
  39. .map(|(k, v)| (k.clone(), AExprType::from(v)))
  40. .collect();
  41. let mut seen_randoms: HashMap<String, usize> = HashMap::new();
  42. // Create a PrivScalarMap that will call the given closure for each
  43. // private Scalar (listed in the VarDict) in a supplied expression
  44. let mut var_map = PrivScalarMap {
  45. vars: &random_private_scalars,
  46. // The closure counts how many times each private random Scalar
  47. // in the VarDict appears in total
  48. closure: &mut |ident| {
  49. let id_str = ident.to_string();
  50. let val = seen_randoms.get(&id_str);
  51. let newval = match val {
  52. Some(n) => n + 1,
  53. None => 1,
  54. };
  55. seen_randoms.insert(id_str, newval);
  56. Ok(())
  57. },
  58. result: Ok(()),
  59. };
  60. // Call the PrivScalarMap for each leaf expression in the
  61. // StatementTree
  62. for e in st.leaves() {
  63. var_map.visit_expr(e);
  64. }
  65. // Return a HashSet of the ones that we saw exactly once
  66. seen_randoms
  67. .into_iter()
  68. .filter_map(|(k, v)| if v == 1 { Some(k) } else { None })
  69. .collect()
  70. }
  71. /// A representation of `a*x + b` where `a` is a constant `Scalar`, `b`
  72. /// is a public `Scalar` [arithmetic expression], and `x` is a private
  73. /// `Scalar` variable
  74. ///
  75. /// [arithmetic expression]: expr_type
  76. pub struct LinScalar {
  77. /// The coefficient `a`
  78. pub coeff: i128,
  79. /// The public `Scalar` expression `b`, if present
  80. pub pub_scalar_expr: Option<Expr>,
  81. /// The private `Scalar` `x`
  82. pub id: Ident,
  83. /// Whether `x` is a vector variable
  84. pub is_vec: bool,
  85. }
  86. /// A representation of `(a*x + b)*A` where `a` is a constant `Scalar`,
  87. /// `b` is a public `Scalar` [arithmetic expression], `x` is a private
  88. /// `Scalar` variable, and `A` is a computationally independent `Point`
  89. pub struct Term {
  90. /// The `Scalar` expression `a*x + b`
  91. pub coeff: LinScalar,
  92. /// The public `Point` `A`
  93. pub id: Ident,
  94. }
  95. /// A representation of `(a*x+b)*A + (c*r+d)*B` where `a` and `c` are a
  96. /// constant non-zero `Scalar`s, `b`, and `d` are public `Scalar`s or
  97. /// constants (or combinations of those), `r` is a random private
  98. /// `Scalar` that appears nowhere else in the [`StatementTree`], and `A`
  99. /// and `B` are computationally independent public `Point`s.
  100. pub struct Pedersen {
  101. /// The term containing the variable being committed to (`x` above)
  102. pub var_term: Term,
  103. /// The term containing the random variable (`r` above)
  104. pub rand_term: Term,
  105. }
  106. /// Get the `Ident` for the committed private `Scalar` in a [`Pedersen`]
  107. impl Pedersen {
  108. pub fn var(&self) -> Option<Ident> {
  109. Some(self.var_term.coeff.id.clone())
  110. }
  111. }
  112. /// Components of a Pedersen commitment
  113. pub enum PedersenExpr {
  114. PubScalarExpr(Expr),
  115. LinScalar(LinScalar),
  116. CIndPoint(Ident),
  117. Term(Term),
  118. Pedersen(Pedersen),
  119. }
  120. /// A struct that implements [`AExprFold`] in service of [`recognize`]
  121. struct RecognizeFold<'a> {
  122. /// The [`TaggedVarDict`] that maps variable names to their types
  123. vars: &'a TaggedVarDict,
  124. /// The HashSet of random variables that appear exactly once in the
  125. /// parent [`StatementTree`]
  126. randoms: &'a HashSet<String>,
  127. }
  128. impl<'a> AExprFold<PedersenExpr> for RecognizeFold<'a> {
  129. /// Called when an identifier found in the [`VarDict`] is
  130. /// encountered in the [`Expr`]
  131. fn ident(&mut self, id: &Ident, _restype: AExprType) -> Result<PedersenExpr> {
  132. let Some(vartype) = self.vars.get(&id.to_string()) else {
  133. return Err(Error::new(id.span(), "unknown identifier"));
  134. };
  135. match vartype {
  136. TaggedIdent::Scalar(TaggedScalar { is_pub: true, .. }) => {
  137. // A bare public Scalar is a simple PubScalarExpr
  138. Ok(PedersenExpr::PubScalarExpr(parse_quote! { #id }))
  139. }
  140. TaggedIdent::Scalar(TaggedScalar {
  141. is_pub: false,
  142. is_vec,
  143. ..
  144. }) => {
  145. // A bare private Scalar is a simple LinScalar
  146. Ok(PedersenExpr::LinScalar(LinScalar {
  147. coeff: 1i128,
  148. pub_scalar_expr: None,
  149. id: id.clone(),
  150. is_vec: *is_vec,
  151. }))
  152. }
  153. TaggedIdent::Point(TaggedPoint { is_cind: true, .. }) => {
  154. // A bare cind Point is a CIndPoint
  155. Ok(PedersenExpr::CIndPoint(id.clone()))
  156. }
  157. TaggedIdent::Point(TaggedPoint { is_cind: false, .. }) => {
  158. // Not a part of a valid Pedersen expression
  159. Err(Error::new(id.span(), "non-cind Point"))
  160. }
  161. }
  162. }
  163. /// Called when the arithmetic expression evaluates to a constant
  164. /// [`i128`] value.
  165. fn const_i128(&mut self, restype: AExprType) -> Result<PedersenExpr> {
  166. let AExprType::Scalar { val: Some(val), .. } = restype else {
  167. return Err(Error::new(
  168. proc_macro2::Span::call_site(),
  169. "BUG: it should not happen that const_i128 is called without a value",
  170. ));
  171. };
  172. Ok(PedersenExpr::PubScalarExpr(parse_quote! { #val }))
  173. }
  174. /// Called for unary negation
  175. fn neg(&mut self, arg: (AExprType, PedersenExpr), restype: AExprType) -> Result<PedersenExpr> {
  176. Ok(arg.1)
  177. }
  178. /// Called for a parenthesized expression
  179. fn paren(
  180. &mut self,
  181. arg: (AExprType, PedersenExpr),
  182. restype: AExprType,
  183. ) -> Result<PedersenExpr> {
  184. Ok(arg.1)
  185. }
  186. /// Called when adding two `Scalar`s
  187. fn add_scalars(
  188. &mut self,
  189. larg: (AExprType, PedersenExpr),
  190. rarg: (AExprType, PedersenExpr),
  191. restype: AExprType,
  192. ) -> Result<PedersenExpr> {
  193. Ok(larg.1)
  194. }
  195. /// Called when adding two `Point`s
  196. fn add_points(
  197. &mut self,
  198. larg: (AExprType, PedersenExpr),
  199. rarg: (AExprType, PedersenExpr),
  200. restype: AExprType,
  201. ) -> Result<PedersenExpr> {
  202. Ok(larg.1)
  203. }
  204. /// Called when subtracting two `Scalar`s
  205. fn sub_scalars(
  206. &mut self,
  207. larg: (AExprType, PedersenExpr),
  208. rarg: (AExprType, PedersenExpr),
  209. restype: AExprType,
  210. ) -> Result<PedersenExpr> {
  211. Ok(larg.1)
  212. }
  213. /// Called when subtracting two `Point`s
  214. fn sub_points(
  215. &mut self,
  216. larg: (AExprType, PedersenExpr),
  217. rarg: (AExprType, PedersenExpr),
  218. restype: AExprType,
  219. ) -> Result<PedersenExpr> {
  220. Ok(larg.1)
  221. }
  222. /// Called when multiplying two `Scalar`s
  223. fn mul_scalars(
  224. &mut self,
  225. larg: (AExprType, PedersenExpr),
  226. rarg: (AExprType, PedersenExpr),
  227. restype: AExprType,
  228. ) -> Result<PedersenExpr> {
  229. Ok(larg.1)
  230. }
  231. /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
  232. /// will always be passed as the first argument)
  233. fn mul_scalar_point(
  234. &mut self,
  235. sarg: (AExprType, PedersenExpr),
  236. parg: (AExprType, PedersenExpr),
  237. restype: AExprType,
  238. ) -> Result<PedersenExpr> {
  239. Ok(sarg.1)
  240. }
  241. }
  242. /// Parse the right-hand side of the = in an [`Expr`] to see if we
  243. /// recognize it as a Pedersen commitment
  244. pub fn recognize(
  245. vars: &TaggedVarDict,
  246. randoms: &HashSet<String>,
  247. vardict: &VarDict,
  248. expr: &Expr,
  249. ) -> Option<Pedersen> {
  250. let mut fold = RecognizeFold { vars, randoms };
  251. let Ok((aetype, PedersenExpr::Pedersen(pedersen))) = fold.fold(vardict, expr) else {
  252. return None;
  253. };
  254. // It's not allowed for the overall expression to be a vector type,
  255. // but the randomizer variable be a non-vector
  256. if let Some(TaggedIdent::Scalar(TaggedScalar { is_vec: false, .. })) =
  257. vars.get(&pedersen.rand_term.id.to_string())
  258. {
  259. if matches!(aetype, AExprType::Point { is_vec: true, .. }) {
  260. return None;
  261. }
  262. }
  263. Some(pedersen)
  264. }
  265. #[cfg(test)]
  266. mod test {
  267. use super::*;
  268. use syn::{parse_quote, Expr};
  269. fn unique_random_scalars_tester(vars: (&[&str], &[&str]), e: Expr, expected: &[&str]) {
  270. let taggedvardict = taggedvardict_from_strs(vars);
  271. let st = StatementTree::parse(&e).unwrap();
  272. let expected_out = expected.iter().map(|s| s.to_string()).collect();
  273. let output = unique_random_scalars(&taggedvardict, &st);
  274. assert_eq!(output, expected_out);
  275. }
  276. #[test]
  277. fn unique_random_scalars_test() {
  278. let vars = (
  279. ["x", "y", "z", "rand r", "rand s", "rand t"].as_slice(),
  280. ["C", "cind A", "cind B"].as_slice(),
  281. );
  282. unique_random_scalars_tester(
  283. vars,
  284. parse_quote! {
  285. C = x*A + r*B
  286. },
  287. ["r"].as_slice(),
  288. );
  289. unique_random_scalars_tester(
  290. vars,
  291. parse_quote! {
  292. AND (
  293. C = x*A + r*B,
  294. D = y*A + s*B,
  295. )
  296. },
  297. ["r", "s"].as_slice(),
  298. );
  299. unique_random_scalars_tester(
  300. vars,
  301. parse_quote! {
  302. AND (
  303. C = x*A + r*B,
  304. OR (
  305. D = y*A + s*B,
  306. E = y*A + t*B,
  307. ),
  308. E = z*A + r*B,
  309. )
  310. },
  311. ["s", "t"].as_slice(),
  312. );
  313. }
  314. }