communicator.rs 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. use crate::fut::{BytesFut, MyFut, MyMultiFut};
  2. use crate::AbstractCommunicator;
  3. use crate::Serializable;
  4. use std::collections::HashMap;
  5. use std::fmt::Debug;
  6. use std::io::{Read, Write};
  7. use std::sync::mpsc::{channel, sync_channel, Sender, SyncSender};
  8. use std::thread;
  9. /// Thread to receive messages in the background.
  10. #[derive(Clone, Debug)]
  11. struct ReceiverThread {
  12. data_request_tx: Sender<(usize, SyncSender<Vec<u8>>)>,
  13. }
  14. impl ReceiverThread {
  15. pub fn from_reader<R: Debug + Read + Send + 'static>(mut reader: R) -> Self {
  16. let (data_request_tx, data_request_rx) = channel::<(usize, SyncSender<Vec<u8>>)>();
  17. let _join_handle = thread::spawn(move || {
  18. for (size, sender) in data_request_rx.iter() {
  19. let mut buf = vec![0u8; size];
  20. reader.read_exact(&mut buf).expect("read failed");
  21. sender.send(buf).expect("send failed");
  22. }
  23. });
  24. Self { data_request_tx }
  25. }
  26. pub fn receive_bytes(&mut self, size: usize) -> BytesFut {
  27. let (data_tx, data_rx) = sync_channel(1);
  28. self.data_request_tx
  29. .send((size, data_tx))
  30. .expect("send failed");
  31. BytesFut { size, data_rx }
  32. }
  33. }
  34. /// Thread to send messages in the background.
  35. #[derive(Clone, Debug)]
  36. struct SenderThread {
  37. data_tx: Sender<Vec<u8>>,
  38. }
  39. impl SenderThread {
  40. pub fn from_writer<W: Debug + Write + Send + 'static>(mut writer: W) -> Self {
  41. let (data_tx, data_rx) = channel::<Vec<u8>>();
  42. let _join_handle = thread::spawn(move || {
  43. for buf in data_rx.iter() {
  44. writer.write_all(&buf).expect("write failed");
  45. writer.flush().expect("flush failed");
  46. }
  47. writer.flush().expect("flush failed");
  48. });
  49. Self { data_tx }
  50. }
  51. pub fn send_bytes(&mut self, buf: Vec<u8>) {
  52. self.data_tx.send(buf).expect("send failed");
  53. }
  54. }
  55. /// Communicator that uses background threads to send and receive messages.
  56. #[derive(Clone, Debug)]
  57. pub struct Communicator {
  58. num_parties: usize,
  59. my_id: usize,
  60. receiver_threads: HashMap<usize, ReceiverThread>,
  61. sender_threads: HashMap<usize, SenderThread>,
  62. }
  63. impl Communicator {
  64. /// Create a new Communicator from a collection of readers and writers that are connected with
  65. /// the other parties.
  66. pub fn from_reader_writer<
  67. R: Read + Send + Debug + 'static,
  68. W: Send + Write + Debug + 'static,
  69. >(
  70. num_parties: usize,
  71. my_id: usize,
  72. mut rw_map: HashMap<usize, (R, W)>,
  73. ) -> Self {
  74. assert_eq!(rw_map.len(), num_parties - 1);
  75. assert!((0..num_parties)
  76. .filter(|&pid| pid != my_id)
  77. .all(|pid| rw_map.contains_key(&pid)));
  78. let mut receiver_threads = HashMap::with_capacity(num_parties - 1);
  79. let mut sender_threads = HashMap::with_capacity(num_parties - 1);
  80. for pid in 0..num_parties {
  81. if pid == my_id {
  82. continue;
  83. }
  84. let (reader, writer) = rw_map.remove(&pid).unwrap();
  85. receiver_threads.insert(pid, ReceiverThread::from_reader(reader));
  86. sender_threads.insert(pid, SenderThread::from_writer(writer));
  87. }
  88. Self {
  89. num_parties,
  90. my_id,
  91. receiver_threads,
  92. sender_threads,
  93. }
  94. }
  95. }
  96. impl AbstractCommunicator for Communicator {
  97. type Fut<T: Serializable> = MyFut<T>;
  98. type MultiFut<T: Serializable> = MyMultiFut<T>;
  99. fn get_num_parties(&self) -> usize {
  100. self.num_parties
  101. }
  102. fn get_my_id(&self) -> usize {
  103. self.my_id
  104. }
  105. fn send<T: Serializable>(&mut self, party_id: usize, val: T) {
  106. self.sender_threads
  107. .get_mut(&party_id)
  108. .expect(&format!("SenderThread for party {} not found", party_id))
  109. .send_bytes(val.to_bytes())
  110. }
  111. fn send_slice<T: Serializable>(&mut self, party_id: usize, val: &[T]) {
  112. let mut bytes = vec![0u8; val.len() * T::bytes_required()];
  113. for (i, v) in val.iter().enumerate() {
  114. bytes[i * T::bytes_required()..(i + 1) * T::bytes_required()]
  115. .copy_from_slice(&v.to_bytes());
  116. }
  117. self.sender_threads
  118. .get_mut(&party_id)
  119. .expect(&format!("SenderThread for party {} not found", party_id))
  120. .send_bytes(bytes);
  121. }
  122. fn receive<T: Serializable>(&mut self, party_id: usize) -> Self::Fut<T> {
  123. let bytes_fut = self
  124. .receiver_threads
  125. .get_mut(&party_id)
  126. .expect(&format!("ReceiverThread for party {} not found", party_id))
  127. .receive_bytes(T::bytes_required());
  128. MyFut::new(bytes_fut)
  129. }
  130. fn receive_n<T: Serializable>(&mut self, party_id: usize, n: usize) -> Self::MultiFut<T> {
  131. let bytes_fut = self
  132. .receiver_threads
  133. .get_mut(&party_id)
  134. .expect(&format!("ReceiverThread for party {} not found", party_id))
  135. .receive_bytes(n * T::bytes_required());
  136. MyMultiFut::new(n, bytes_fut)
  137. }
  138. fn shutdown(&mut self) {
  139. self.sender_threads.drain();
  140. self.receiver_threads.drain();
  141. }
  142. }