mask_index.rs 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. use crate::common::Error;
  2. use communicator::{AbstractCommunicator, Fut, Serializable};
  3. use ff::PrimeField;
  4. use rand::{thread_rng, Rng};
  5. pub trait MaskIndex<F> {
  6. fn mask_index<C: AbstractCommunicator>(
  7. comm: &mut C,
  8. index_bits: u32,
  9. index_share: F,
  10. ) -> Result<(u16, u16, u16), Error>;
  11. }
  12. pub struct MaskIndexProtocol {}
  13. impl<F> MaskIndex<F> for MaskIndexProtocol
  14. where
  15. F: PrimeField + Serializable,
  16. {
  17. fn mask_index<C: AbstractCommunicator>(
  18. comm: &mut C,
  19. index_bits: u32,
  20. index_share: F,
  21. ) -> Result<(u16, u16, u16), Error> {
  22. let random_bits = index_bits + 40;
  23. assert!(random_bits + 1 < F::NUM_BITS);
  24. assert!(index_bits <= 16);
  25. let bit_mask = (1 << index_bits) - 1;
  26. let fut_prev = comm.receive_previous::<F>()?;
  27. let fut_next = comm.receive_next::<(u16, F)>()?;
  28. // sample mask r_{i+1} and send it to P_{i-1}
  29. let r_next: u128 = thread_rng().gen_range(0..(1 << random_bits));
  30. // send masked share to P_{i+1}
  31. comm.send_next(index_share + F::from_u128(r_next))?;
  32. let r_next = (r_next & bit_mask) as u16;
  33. // send mask and our share to P_{i-1}
  34. comm.send_previous((r_next, index_share))?;
  35. let index_masked_prev_share = fut_prev.get()?;
  36. let (r_prev, index_next_share) = fut_next.get()?;
  37. let masked_index = index_share + index_next_share + index_masked_prev_share;
  38. let masked_index =
  39. u64::from_le_bytes(masked_index.to_repr().as_ref()[..8].try_into().unwrap());
  40. let masked_index = masked_index as u16 & bit_mask as u16;
  41. Ok((masked_index, r_prev, r_next))
  42. }
  43. }
  44. #[cfg(test)]
  45. mod tests {
  46. use super::*;
  47. use communicator::unix::make_unix_communicators;
  48. use ff::Field;
  49. use std::thread;
  50. use utils::field::Fp;
  51. fn run_mask_index<Proto: MaskIndex<F>, F>(
  52. mut comm: impl AbstractCommunicator + Send + 'static,
  53. index_bits: u32,
  54. index_share: F,
  55. ) -> thread::JoinHandle<(impl AbstractCommunicator, (u16, u16, u16))>
  56. where
  57. F: PrimeField + Serializable,
  58. {
  59. thread::spawn(move || {
  60. let result = Proto::mask_index(&mut comm, index_bits, index_share);
  61. (comm, result.unwrap())
  62. })
  63. }
  64. #[test]
  65. fn test_mask_index() {
  66. let (comm_3, comm_2, comm_1) = {
  67. let mut comms = make_unix_communicators(3);
  68. (
  69. comms.pop().unwrap(),
  70. comms.pop().unwrap(),
  71. comms.pop().unwrap(),
  72. )
  73. };
  74. let mut rng = thread_rng();
  75. let index_bits = 16;
  76. let bit_mask = ((1 << index_bits) - 1) as u16;
  77. let index = rng.gen_range(0..(1 << index_bits));
  78. let (index_2, index_3) = (Fp::random(&mut rng), Fp::random(&mut rng));
  79. let index_1 = Fp::from_u128(index as u128) - index_2 - index_3;
  80. // check for <c> = <0>
  81. let h1 = run_mask_index::<MaskIndexProtocol, _>(comm_1, index_bits, index_1);
  82. let h2 = run_mask_index::<MaskIndexProtocol, _>(comm_2, index_bits, index_2);
  83. let h3 = run_mask_index::<MaskIndexProtocol, _>(comm_3, index_bits, index_3);
  84. let (_, (mi_1, m3_1, m2_1)) = h1.join().unwrap();
  85. let (_, (mi_2, m1_2, m3_2)) = h2.join().unwrap();
  86. let (_, (mi_3, m2_3, m1_3)) = h3.join().unwrap();
  87. assert_eq!(m1_2, m1_3);
  88. assert_eq!(m2_1, m2_3);
  89. assert_eq!(m3_1, m3_2);
  90. assert_eq!(m1_2, m1_3);
  91. assert_eq!(mi_1, (index as u16).wrapping_add(m1_2) & bit_mask);
  92. assert_eq!(mi_2, (index as u16).wrapping_add(m2_1) & bit_mask);
  93. assert_eq!(mi_3, (index as u16).wrapping_add(m3_1) & bit_mask);
  94. }
  95. }