unix.rs 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. //! Functionality for communicators using Unix sockets.
  2. use crate::AbstractCommunicator;
  3. use crate::Communicator;
  4. use std::collections::HashMap;
  5. use std::os::unix::net::UnixStream;
  6. /// Create a set of connected Communicators that are based on local Unix sockets
  7. pub fn make_unix_communicators(num_parties: usize) -> Vec<impl AbstractCommunicator> {
  8. // prepare maps for each parties to store readers and writers to every other party
  9. let mut rw_maps: Vec<_> = (0..num_parties)
  10. .map(|_| HashMap::with_capacity(num_parties - 1))
  11. .collect();
  12. // create pairs of unix sockets connecting each pair of parties
  13. for party_i in 0..num_parties {
  14. for party_j in 0..party_i {
  15. let (stream_i_to_j, stream_j_to_i) = UnixStream::pair().unwrap();
  16. rw_maps[party_i].insert(party_j, (stream_i_to_j.try_clone().unwrap(), stream_i_to_j));
  17. rw_maps[party_j].insert(party_i, (stream_j_to_i.try_clone().unwrap(), stream_j_to_i));
  18. }
  19. }
  20. // create communicators from the reader/writer maps
  21. rw_maps
  22. .into_iter()
  23. .enumerate()
  24. .map(|(party_id, rw_map)| Communicator::from_reader_writer(num_parties, party_id, rw_map))
  25. .collect()
  26. }
  27. #[cfg(test)]
  28. mod tests {
  29. use super::*;
  30. use crate::Fut;
  31. use std::thread;
  32. #[test]
  33. fn test_unix_communicators() {
  34. let num_parties = 3;
  35. let msg_0: u8 = 42;
  36. let msg_1: u32 = 0x_dead_beef;
  37. let msg_2: [u32; 2] = [0x_1333_3337, 0x_c0ff_ffee];
  38. let communicators = make_unix_communicators(num_parties);
  39. let thread_handles: Vec<_> = communicators
  40. .into_iter()
  41. .enumerate()
  42. .map(|(party_id, mut communicator)| {
  43. thread::spawn(move || {
  44. if party_id == 0 {
  45. let fut_1 = communicator.receive::<u32>(1).unwrap();
  46. let fut_2 = communicator.receive::<[u32; 2]>(2).unwrap();
  47. communicator.send(1, msg_0).unwrap();
  48. communicator.send(2, msg_0).unwrap();
  49. let val_1 = fut_1.get();
  50. let val_2 = fut_2.get();
  51. assert!(val_1.is_ok());
  52. assert!(val_2.is_ok());
  53. assert_eq!(val_1.unwrap(), msg_1);
  54. assert_eq!(val_2.unwrap(), msg_2);
  55. } else if party_id == 1 {
  56. let fut_0 = communicator.receive::<u8>(0).unwrap();
  57. let fut_2 = communicator.receive::<[u32; 2]>(2).unwrap();
  58. communicator.send(0, msg_1).unwrap();
  59. communicator.send(2, msg_1).unwrap();
  60. let val_0 = fut_0.get();
  61. let val_2 = fut_2.get();
  62. assert!(val_0.is_ok());
  63. assert!(val_2.is_ok());
  64. assert_eq!(val_0.unwrap(), msg_0);
  65. assert_eq!(val_2.unwrap(), msg_2);
  66. } else if party_id == 2 {
  67. let fut_0 = communicator.receive::<u8>(0).unwrap();
  68. let fut_1 = communicator.receive::<u32>(1).unwrap();
  69. communicator.send(0, msg_2).unwrap();
  70. communicator.send(1, msg_2).unwrap();
  71. let val_0 = fut_0.get();
  72. let val_1 = fut_1.get();
  73. assert!(val_0.is_ok());
  74. assert!(val_1.is_ok());
  75. assert_eq!(val_0.unwrap(), msg_0);
  76. assert_eq!(val_1.unwrap(), msg_1);
  77. }
  78. communicator.shutdown();
  79. })
  80. })
  81. .collect();
  82. thread_handles.into_iter().for_each(|h| h.join().unwrap());
  83. }
  84. }