unix.rs 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. use crate::communicator::Communicator;
  2. use crate::AbstractCommunicator;
  3. use std::collections::HashMap;
  4. use std::os::unix::net::UnixStream;
  5. /// Create a set of connected Communicators that are based on local Unix sockets
  6. pub fn make_unix_communicators(num_parties: usize) -> Vec<impl AbstractCommunicator> {
  7. // prepare maps for each parties to store readers and writers to every other party
  8. let mut rw_maps: Vec<_> = (0..num_parties)
  9. .map(|_| HashMap::with_capacity(num_parties - 1))
  10. .collect();
  11. // create pairs of unix sockets connecting each pair of parties
  12. for party_i in 0..num_parties {
  13. for party_j in 0..party_i {
  14. let (stream_i_to_j, stream_j_to_i) = UnixStream::pair().unwrap();
  15. rw_maps[party_i].insert(party_j, (stream_i_to_j.try_clone().unwrap(), stream_i_to_j));
  16. rw_maps[party_j].insert(party_i, (stream_j_to_i.try_clone().unwrap(), stream_j_to_i));
  17. }
  18. }
  19. // create communicators from the reader/writer maps
  20. rw_maps
  21. .into_iter()
  22. .enumerate()
  23. .map(|(party_id, rw_map)| Communicator::from_reader_writer(num_parties, party_id, rw_map))
  24. .collect()
  25. }
  26. #[cfg(test)]
  27. mod tests {
  28. use super::*;
  29. use crate::Fut;
  30. use std::thread;
  31. #[test]
  32. fn test_unix_communicators() {
  33. let num_parties = 3;
  34. let msg_0: u8 = 42;
  35. let msg_1: u32 = 0x_dead_beef;
  36. let msg_2: [u32; 2] = [0x_1333_3337, 0x_c0ff_ffee];
  37. let communicators = make_unix_communicators(num_parties);
  38. let thread_handles: Vec<_> = communicators
  39. .into_iter()
  40. .enumerate()
  41. .map(|(party_id, mut communicator)| {
  42. thread::spawn(move || {
  43. if party_id == 0 {
  44. let fut_1 = communicator.receive::<u32>(1);
  45. let fut_2 = communicator.receive::<[u32; 2]>(2);
  46. communicator.send(1, msg_0);
  47. communicator.send(2, msg_0);
  48. let val_1 = fut_1.get();
  49. let val_2 = fut_2.get();
  50. assert!(val_1.is_ok());
  51. assert!(val_2.is_ok());
  52. assert_eq!(val_1.unwrap(), msg_1);
  53. assert_eq!(val_2.unwrap(), msg_2);
  54. } else if party_id == 1 {
  55. let fut_0 = communicator.receive::<u8>(0);
  56. let fut_2 = communicator.receive::<[u32; 2]>(2);
  57. communicator.send(0, msg_1);
  58. communicator.send(2, msg_1);
  59. let val_0 = fut_0.get();
  60. let val_2 = fut_2.get();
  61. assert!(val_0.is_ok());
  62. assert!(val_2.is_ok());
  63. assert_eq!(val_0.unwrap(), msg_0);
  64. assert_eq!(val_2.unwrap(), msg_2);
  65. } else if party_id == 2 {
  66. let fut_0 = communicator.receive::<u8>(0);
  67. let fut_1 = communicator.receive::<u32>(1);
  68. communicator.send(0, msg_2);
  69. communicator.send(1, msg_2);
  70. let val_0 = fut_0.get();
  71. let val_1 = fut_1.get();
  72. assert!(val_0.is_ok());
  73. assert!(val_1.is_ok());
  74. assert_eq!(val_0.unwrap(), msg_0);
  75. assert_eq!(val_1.unwrap(), msg_1);
  76. }
  77. communicator.shutdown();
  78. })
  79. })
  80. .collect();
  81. thread_handles.into_iter().for_each(|h| h.join().unwrap());
  82. }
  83. }