mask_index.rs 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. //! Implementation of a protocol to convert a secret shared index in Fp into an equvalent secret
  2. //! sharing over modulo 2^k.
  3. //!
  4. //! The k-bit index is relatively small compared to p. First two parties add a large mask to the
  5. //! shared index, that statistically hides it, but does not overflow modulo p. The masked index is
  6. //! reconstruct for the third party. Finally, all parties locally reduce their mask or masked
  7. //! value modulo 2^k.
  8. //!
  9. //! The protocol runs three instances of this such that at the end each party holds one masked
  10. //! index and the masks corresponding to the other two parties.
  11. use crate::common::Error;
  12. use communicator::{AbstractCommunicator, Fut, Serializable};
  13. use ff::PrimeField;
  14. use rand::{thread_rng, Rng};
  15. /// Interface specification.
  16. pub trait MaskIndex<F> {
  17. /// Run the mask index protocol where the shared index is at most `index_bits` big.
  18. fn mask_index<C: AbstractCommunicator>(
  19. comm: &mut C,
  20. index_bits: u32,
  21. index_share: F,
  22. ) -> Result<(u16, u16, u16), Error>;
  23. }
  24. /// Protocol implementation.
  25. pub struct MaskIndexProtocol {}
  26. impl<F> MaskIndex<F> for MaskIndexProtocol
  27. where
  28. F: PrimeField + Serializable,
  29. {
  30. fn mask_index<C: AbstractCommunicator>(
  31. comm: &mut C,
  32. index_bits: u32,
  33. index_share: F,
  34. ) -> Result<(u16, u16, u16), Error> {
  35. let random_bits = index_bits + 40;
  36. assert!(random_bits + 1 < F::NUM_BITS);
  37. assert!(index_bits <= 16);
  38. let bit_mask = (1 << index_bits) - 1;
  39. let fut_prev = comm.receive_previous::<F>()?;
  40. let fut_next = comm.receive_next::<(u16, F)>()?;
  41. // sample mask r_{i+1} and send it to P_{i-1}
  42. let r_next: u128 = thread_rng().gen_range(0..(1 << random_bits));
  43. // send masked share to P_{i+1}
  44. comm.send_next(index_share + F::from_u128(r_next))?;
  45. let r_next = (r_next & bit_mask) as u16;
  46. // send mask and our share to P_{i-1}
  47. comm.send_previous((r_next, index_share))?;
  48. let index_masked_prev_share = fut_prev.get()?;
  49. let (r_prev, index_next_share) = fut_next.get()?;
  50. let masked_index = index_share + index_next_share + index_masked_prev_share;
  51. let masked_index =
  52. u64::from_le_bytes(masked_index.to_repr().as_ref()[..8].try_into().unwrap());
  53. let masked_index = masked_index as u16 & bit_mask as u16;
  54. Ok((masked_index, r_prev, r_next))
  55. }
  56. }
  57. #[cfg(test)]
  58. mod tests {
  59. use super::*;
  60. use communicator::unix::make_unix_communicators;
  61. use ff::Field;
  62. use std::thread;
  63. use utils::field::Fp;
  64. fn run_mask_index<Proto: MaskIndex<F>, F>(
  65. mut comm: impl AbstractCommunicator + Send + 'static,
  66. index_bits: u32,
  67. index_share: F,
  68. ) -> thread::JoinHandle<(impl AbstractCommunicator, (u16, u16, u16))>
  69. where
  70. F: PrimeField + Serializable,
  71. {
  72. thread::spawn(move || {
  73. let result = Proto::mask_index(&mut comm, index_bits, index_share);
  74. (comm, result.unwrap())
  75. })
  76. }
  77. #[test]
  78. fn test_mask_index() {
  79. let (comm_3, comm_2, comm_1) = {
  80. let mut comms = make_unix_communicators(3);
  81. (
  82. comms.pop().unwrap(),
  83. comms.pop().unwrap(),
  84. comms.pop().unwrap(),
  85. )
  86. };
  87. let mut rng = thread_rng();
  88. let index_bits = 16;
  89. let bit_mask = ((1 << index_bits) - 1) as u16;
  90. let index = rng.gen_range(0..(1 << index_bits));
  91. let (index_2, index_3) = (Fp::random(&mut rng), Fp::random(&mut rng));
  92. let index_1 = Fp::from_u128(index as u128) - index_2 - index_3;
  93. // check for <c> = <0>
  94. let h1 = run_mask_index::<MaskIndexProtocol, _>(comm_1, index_bits, index_1);
  95. let h2 = run_mask_index::<MaskIndexProtocol, _>(comm_2, index_bits, index_2);
  96. let h3 = run_mask_index::<MaskIndexProtocol, _>(comm_3, index_bits, index_3);
  97. let (_, (mi_1, m3_1, m2_1)) = h1.join().unwrap();
  98. let (_, (mi_2, m1_2, m3_2)) = h2.join().unwrap();
  99. let (_, (mi_3, m2_3, m1_3)) = h3.join().unwrap();
  100. assert_eq!(m1_2, m1_3);
  101. assert_eq!(m2_1, m2_3);
  102. assert_eq!(m3_1, m3_2);
  103. assert_eq!(m1_2, m1_3);
  104. assert_eq!(mi_1, (index as u16).wrapping_add(m1_2) & bit_mask);
  105. assert_eq!(mi_2, (index as u16).wrapping_add(m2_1) & bit_mask);
  106. assert_eq!(mi_3, (index as u16).wrapping_add(m3_1) & bit_mask);
  107. }
  108. }