select.rs 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. //! Implementation of an oblivious selection protocol.
  2. use crate::common::Error;
  3. use communicator::{AbstractCommunicator, Fut, Serializable};
  4. use ff::Field;
  5. use itertools::izip;
  6. use rand::{thread_rng, Rng, SeedableRng};
  7. use rand_chacha::ChaChaRng;
  8. use std::collections::VecDeque;
  9. /// Select between two shared values `<a>`, `<b>` based on a shared condition bit `<c>`:
  10. /// Output `<w> <- if <c> then <a> else <b>`.
  11. pub trait Select<F> {
  12. /// Initialize the protocol instance.
  13. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error>;
  14. /// Run the preprocessing for `num` invocations.
  15. fn preprocess<C: AbstractCommunicator>(
  16. &mut self,
  17. comm: &mut C,
  18. num: usize,
  19. ) -> Result<(), Error>;
  20. /// Run the online protocol for one select operation.
  21. fn select<C: AbstractCommunicator>(
  22. &mut self,
  23. comm: &mut C,
  24. c_share: F,
  25. a_share: F,
  26. b_share: F,
  27. ) -> Result<F, Error>;
  28. }
  29. const PARTY_1: usize = 0;
  30. const PARTY_2: usize = 1;
  31. const PARTY_3: usize = 2;
  32. fn other_compute_party(my_id: usize) -> usize {
  33. match my_id {
  34. PARTY_2 => PARTY_3,
  35. PARTY_3 => PARTY_2,
  36. _ => panic!("invalid party id"),
  37. }
  38. }
  39. /// Implementation of the select protocol.
  40. #[derive(Default)]
  41. pub struct SelectProtocol<F> {
  42. shared_prg_1: Option<ChaChaRng>,
  43. shared_prg_2: Option<ChaChaRng>,
  44. shared_prg_3: Option<ChaChaRng>,
  45. is_initialized: bool,
  46. num_preprocessed_invocations: usize,
  47. preprocessed_mt_x: VecDeque<F>,
  48. preprocessed_mt_y: VecDeque<F>,
  49. preprocessed_mt_z: VecDeque<F>,
  50. preprocessed_c_1_2: VecDeque<F>,
  51. preprocessed_amb_1_2: VecDeque<F>,
  52. }
  53. impl<F> Select<F> for SelectProtocol<F>
  54. where
  55. F: Field + Serializable,
  56. {
  57. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
  58. if comm.get_my_id() == PARTY_1 {
  59. self.shared_prg_2 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  60. comm.send(PARTY_2, self.shared_prg_2.as_ref().unwrap().get_seed())?;
  61. self.shared_prg_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
  62. comm.send(PARTY_3, self.shared_prg_3.as_ref().unwrap().get_seed())?;
  63. } else {
  64. let fut_seed = comm.receive(PARTY_1)?;
  65. self.shared_prg_1 = Some(ChaChaRng::from_seed(fut_seed.get()?));
  66. }
  67. self.is_initialized = true;
  68. Ok(())
  69. }
  70. fn preprocess<C: AbstractCommunicator>(&mut self, comm: &mut C, n: usize) -> Result<(), Error> {
  71. assert!(self.is_initialized);
  72. let my_id = comm.get_my_id();
  73. if my_id == PARTY_1 {
  74. let x2s: Vec<F> = (0..n)
  75. .map(|_| F::random(self.shared_prg_2.as_mut().unwrap()))
  76. .collect();
  77. let y2s: Vec<F> = (0..n)
  78. .map(|_| F::random(self.shared_prg_2.as_mut().unwrap()))
  79. .collect();
  80. let z2s: Vec<F> = (0..n)
  81. .map(|_| F::random(self.shared_prg_2.as_mut().unwrap()))
  82. .collect();
  83. let x3s: Vec<F> = (0..n)
  84. .map(|_| F::random(self.shared_prg_3.as_mut().unwrap()))
  85. .collect();
  86. let y3s: Vec<F> = (0..n)
  87. .map(|_| F::random(self.shared_prg_3.as_mut().unwrap()))
  88. .collect();
  89. let z3s: Vec<F> = (0..n)
  90. .map(|_| F::random(self.shared_prg_3.as_mut().unwrap()))
  91. .collect();
  92. let z1s = izip!(x2s, y2s, z2s, x3s, y3s, z3s)
  93. .map(|(x_2, y_2, z_2, x_3, y_3, z_3)| (x_2 + x_3) * (y_2 + y_3) - z_2 - z_3);
  94. self.preprocessed_mt_z.extend(z1s);
  95. self.preprocessed_c_1_2
  96. .extend((0..n).map(|_| F::random(self.shared_prg_2.as_mut().unwrap())));
  97. self.preprocessed_amb_1_2
  98. .extend((0..n).map(|_| F::random(self.shared_prg_2.as_mut().unwrap())));
  99. } else {
  100. self.preprocessed_mt_x
  101. .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
  102. self.preprocessed_mt_y
  103. .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
  104. self.preprocessed_mt_z
  105. .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
  106. if my_id == PARTY_2 {
  107. self.preprocessed_c_1_2
  108. .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
  109. self.preprocessed_amb_1_2
  110. .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
  111. }
  112. }
  113. self.num_preprocessed_invocations += n;
  114. Ok(())
  115. }
  116. fn select<C: AbstractCommunicator>(
  117. &mut self,
  118. comm: &mut C,
  119. c_share: F,
  120. a_share: F,
  121. b_share: F,
  122. ) -> Result<F, Error> {
  123. let my_id = comm.get_my_id();
  124. // if further preprocessing is needed, do it now
  125. if self.num_preprocessed_invocations == 0 {
  126. self.preprocess(comm, 1)?;
  127. }
  128. self.num_preprocessed_invocations -= 1;
  129. if my_id == PARTY_1 {
  130. let c_1_2 = self.preprocessed_c_1_2.pop_front().unwrap();
  131. let amb_1_2 = self.preprocessed_amb_1_2.pop_front().unwrap();
  132. comm.send(PARTY_3, (c_share - c_1_2, (a_share - b_share) - amb_1_2))?;
  133. let z = self.preprocessed_mt_z.pop_front().unwrap();
  134. Ok(b_share + z)
  135. } else {
  136. let (c_1_i, amb_1_i) = if my_id == PARTY_2 {
  137. (
  138. self.preprocessed_c_1_2.pop_front().unwrap(),
  139. self.preprocessed_amb_1_2.pop_front().unwrap(),
  140. )
  141. } else {
  142. let fut_1 = comm.receive::<(F, F)>(PARTY_1)?;
  143. fut_1.get()?
  144. };
  145. let fut_de = comm.receive::<(F, F)>(other_compute_party(my_id))?;
  146. let x_i = self.preprocessed_mt_x.pop_front().unwrap();
  147. let y_i = self.preprocessed_mt_y.pop_front().unwrap();
  148. let mut z_i = self.preprocessed_mt_z.pop_front().unwrap();
  149. let d_i = (c_share + c_1_i) - x_i;
  150. let e_i = (a_share - b_share + amb_1_i) - y_i;
  151. comm.send(other_compute_party(my_id), (d_i, e_i))?;
  152. let (d_j, e_j) = fut_de.get()?;
  153. let (d, e) = (d_i + d_j, e_i + e_j);
  154. z_i += e * (c_share + c_1_i) + d * (a_share - b_share + amb_1_i);
  155. if my_id == PARTY_2 {
  156. z_i -= d * e;
  157. }
  158. Ok(b_share + z_i)
  159. }
  160. }
  161. }
  162. #[cfg(test)]
  163. mod tests {
  164. use super::*;
  165. use communicator::unix::make_unix_communicators;
  166. use std::thread;
  167. use utils::field::Fp;
  168. fn run_init<Proto: Select<F> + Send + 'static, F>(
  169. mut comm: impl AbstractCommunicator + Send + 'static,
  170. mut proto: Proto,
  171. ) -> thread::JoinHandle<(impl AbstractCommunicator, Proto)>
  172. where
  173. F: Field + Serializable,
  174. {
  175. thread::spawn(move || {
  176. proto.init(&mut comm).unwrap();
  177. (comm, proto)
  178. })
  179. }
  180. fn run_select<Proto: Select<F> + Send + 'static, F>(
  181. mut comm: impl AbstractCommunicator + Send + 'static,
  182. mut proto: Proto,
  183. c_share: F,
  184. a_share: F,
  185. b_share: F,
  186. ) -> thread::JoinHandle<(impl AbstractCommunicator, Proto, F)>
  187. where
  188. F: Field + Serializable,
  189. {
  190. thread::spawn(move || {
  191. let result = proto.select(&mut comm, c_share, a_share, b_share);
  192. (comm, proto, result.unwrap())
  193. })
  194. }
  195. #[test]
  196. fn test_select() {
  197. let proto_1 = SelectProtocol::<Fp>::default();
  198. let proto_2 = SelectProtocol::<Fp>::default();
  199. let proto_3 = SelectProtocol::<Fp>::default();
  200. let (comm_3, comm_2, comm_1) = {
  201. let mut comms = make_unix_communicators(3);
  202. (
  203. comms.pop().unwrap(),
  204. comms.pop().unwrap(),
  205. comms.pop().unwrap(),
  206. )
  207. };
  208. let h1 = run_init(comm_1, proto_1);
  209. let h2 = run_init(comm_2, proto_2);
  210. let h3 = run_init(comm_3, proto_3);
  211. let (comm_1, proto_1) = h1.join().unwrap();
  212. let (comm_2, proto_2) = h2.join().unwrap();
  213. let (comm_3, proto_3) = h3.join().unwrap();
  214. let mut rng = thread_rng();
  215. let (a_1, a_2, a_3) = (
  216. Fp::random(&mut rng),
  217. Fp::random(&mut rng),
  218. Fp::random(&mut rng),
  219. );
  220. let a = a_1 + a_2 + a_3;
  221. let (b_1, b_2, b_3) = (
  222. Fp::random(&mut rng),
  223. Fp::random(&mut rng),
  224. Fp::random(&mut rng),
  225. );
  226. let b = b_1 + b_2 + b_3;
  227. let (c_2, c_3) = (Fp::random(&mut rng), Fp::random(&mut rng));
  228. let c0_1 = -c_2 - c_3;
  229. let c1_1 = Fp::ONE - c_2 - c_3;
  230. // check for <c> = <0>
  231. let h1 = run_select(comm_1, proto_1, c0_1, a_1, b_1);
  232. let h2 = run_select(comm_2, proto_2, c_2, a_2, b_2);
  233. let h3 = run_select(comm_3, proto_3, c_3, a_3, b_3);
  234. let (comm_1, proto_1, x_1) = h1.join().unwrap();
  235. let (comm_2, proto_2, x_2) = h2.join().unwrap();
  236. let (comm_3, proto_3, x_3) = h3.join().unwrap();
  237. assert_eq!(c0_1 + c_2 + c_3, Fp::ZERO);
  238. assert_eq!(x_1 + x_2 + x_3, b);
  239. // check for <c> = <1>
  240. let h1 = run_select(comm_1, proto_1, c1_1, a_1, b_1);
  241. let h2 = run_select(comm_2, proto_2, c_2, a_2, b_2);
  242. let h3 = run_select(comm_3, proto_3, c_3, a_3, b_3);
  243. let (_, _, y_1) = h1.join().unwrap();
  244. let (_, _, y_2) = h2.join().unwrap();
  245. let (_, _, y_3) = h3.join().unwrap();
  246. assert_eq!(c1_1 + c_2 + c_3, Fp::ONE);
  247. assert_eq!(y_1 + y_2 + y_3, a);
  248. }
  249. }