communicator.rs 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. use crate::{AbstractCommunicator, CommunicationStats, Error, Fut, Serializable};
  2. use bincode;
  3. use std::collections::HashMap;
  4. use std::fmt::Debug;
  5. use std::io::{BufReader, BufWriter, Read, Write};
  6. use std::marker::PhantomData;
  7. use std::sync::mpsc::{channel, Receiver, Sender};
  8. use std::sync::{Arc, Mutex};
  9. use std::thread;
  10. pub struct MyFut<T: Serializable> {
  11. buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
  12. _phantom: PhantomData<T>,
  13. }
  14. impl<T: Serializable> MyFut<T> {
  15. fn new(buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>) -> Self {
  16. Self {
  17. buf_rx,
  18. _phantom: PhantomData,
  19. }
  20. }
  21. }
  22. impl<T: Serializable> Fut<T> for MyFut<T> {
  23. fn get(self) -> Result<T, Error> {
  24. let buf = self.buf_rx.lock().unwrap().recv()?;
  25. let (data, size) = bincode::decode_from_slice(
  26. &buf,
  27. bincode::config::standard().skip_fixed_array_length(),
  28. )?;
  29. assert_eq!(size, buf.len());
  30. Ok(data)
  31. }
  32. }
  33. /// Thread to receive messages in the background.
  34. #[derive(Debug)]
  35. struct ReceiverThread {
  36. buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
  37. join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
  38. }
  39. impl ReceiverThread {
  40. pub fn from_reader<R: Debug + Read + Send + 'static>(reader: R) -> Self {
  41. let mut reader = BufReader::with_capacity(1 << 16, reader);
  42. let (buf_tx, buf_rx) = channel::<Vec<u8>>();
  43. let buf_rx = Arc::new(Mutex::new(buf_rx));
  44. let join_handle = thread::Builder::new()
  45. .name("Receiver".to_owned())
  46. .spawn(move || {
  47. let mut num_msgs_received = 0;
  48. let mut num_bytes_received = 0;
  49. loop {
  50. let mut msg_size = [0u8; 4];
  51. reader.read_exact(&mut msg_size)?;
  52. let msg_size = u32::from_be_bytes(msg_size) as usize;
  53. if msg_size == 0xffffffff {
  54. return Ok((num_msgs_received, num_bytes_received));
  55. }
  56. let mut buf = vec![0u8; msg_size];
  57. reader.read_exact(&mut buf)?;
  58. match buf_tx.send(buf) {
  59. Ok(_) => (),
  60. Err(_) => return Ok((num_msgs_received, num_bytes_received)), // we need to shutdown
  61. }
  62. num_msgs_received += 1;
  63. num_bytes_received += 4 + msg_size;
  64. }
  65. })
  66. .unwrap();
  67. Self {
  68. join_handle,
  69. buf_rx,
  70. }
  71. }
  72. pub fn receive<T: Serializable>(&mut self) -> Result<MyFut<T>, Error> {
  73. Ok(MyFut::new(self.buf_rx.clone()))
  74. }
  75. pub fn join(self) -> Result<(usize, usize), Error> {
  76. drop(self.buf_rx);
  77. self.join_handle.join().expect("join failed")
  78. }
  79. }
  80. /// Thread to send messages in the background.
  81. #[derive(Debug)]
  82. struct SenderThread {
  83. buf_tx: Sender<Vec<u8>>,
  84. join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
  85. }
  86. impl SenderThread {
  87. pub fn from_writer<W: Debug + Write + Send + 'static>(writer: W) -> Self {
  88. let mut writer = BufWriter::with_capacity(1 << 16, writer);
  89. let (buf_tx, buf_rx) = channel::<Vec<u8>>();
  90. let join_handle = thread::Builder::new()
  91. .name("Sender-1".to_owned())
  92. .spawn(move || {
  93. let mut num_msgs_sent = 0;
  94. let mut num_bytes_sent = 0;
  95. for buf in buf_rx.iter() {
  96. writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
  97. writer.write_all(&buf)?;
  98. writer.flush()?;
  99. num_msgs_sent += 1;
  100. num_bytes_sent += 4 + buf.len();
  101. }
  102. writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
  103. writer.flush()?;
  104. Ok((num_msgs_sent, num_bytes_sent))
  105. })
  106. .unwrap();
  107. Self {
  108. buf_tx,
  109. join_handle,
  110. }
  111. }
  112. pub fn send<T: Serializable>(&mut self, data: T) -> Result<(), Error> {
  113. let buf =
  114. bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
  115. self.buf_tx.send(buf)?;
  116. Ok(())
  117. }
  118. pub fn send_slice<T: Serializable>(&mut self, data: &[T]) -> Result<(), Error> {
  119. let buf =
  120. bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
  121. self.buf_tx.send(buf)?;
  122. Ok(())
  123. }
  124. pub fn join(self) -> Result<(usize, usize), Error> {
  125. drop(self.buf_tx);
  126. self.join_handle.join().expect("join failed")
  127. }
  128. }
  129. /// Communicator that uses background threads to send and receive messages.
  130. #[derive(Debug)]
  131. pub struct Communicator {
  132. num_parties: usize,
  133. my_id: usize,
  134. receiver_threads: HashMap<usize, ReceiverThread>,
  135. sender_threads: HashMap<usize, SenderThread>,
  136. }
  137. impl Communicator {
  138. /// Create a new Communicator from a collection of readers and writers that are connected with
  139. /// the other parties.
  140. pub fn from_reader_writer<
  141. R: Read + Send + Debug + 'static,
  142. W: Send + Write + Debug + 'static,
  143. >(
  144. num_parties: usize,
  145. my_id: usize,
  146. mut rw_map: HashMap<usize, (R, W)>,
  147. ) -> Self {
  148. assert_eq!(rw_map.len(), num_parties - 1);
  149. assert!((0..num_parties)
  150. .filter(|&pid| pid != my_id)
  151. .all(|pid| rw_map.contains_key(&pid)));
  152. let mut receiver_threads = HashMap::with_capacity(num_parties - 1);
  153. let mut sender_threads = HashMap::with_capacity(num_parties - 1);
  154. for pid in 0..num_parties {
  155. if pid == my_id {
  156. continue;
  157. }
  158. let (reader, writer) = rw_map.remove(&pid).unwrap();
  159. receiver_threads.insert(pid, ReceiverThread::from_reader(reader));
  160. sender_threads.insert(pid, SenderThread::from_writer(writer));
  161. }
  162. Self {
  163. num_parties,
  164. my_id,
  165. receiver_threads,
  166. sender_threads,
  167. }
  168. }
  169. }
  170. impl AbstractCommunicator for Communicator {
  171. type Fut<T: Serializable> = MyFut<T>;
  172. fn get_num_parties(&self) -> usize {
  173. self.num_parties
  174. }
  175. fn get_my_id(&self) -> usize {
  176. self.my_id
  177. }
  178. fn send<T: Serializable>(&mut self, party_id: usize, val: T) -> Result<(), Error> {
  179. match self.sender_threads.get_mut(&party_id) {
  180. Some(t) => {
  181. t.send(val)?;
  182. Ok(())
  183. }
  184. None => Err(Error::LogicError(format!(
  185. "SenderThread for party {} not found",
  186. party_id
  187. ))),
  188. }
  189. }
  190. fn send_slice<T: Serializable>(&mut self, party_id: usize, val: &[T]) -> Result<(), Error> {
  191. match self.sender_threads.get_mut(&party_id) {
  192. Some(t) => {
  193. t.send_slice(val)?;
  194. Ok(())
  195. }
  196. None => Err(Error::LogicError(format!(
  197. "SenderThread for party {} not found",
  198. party_id
  199. ))),
  200. }
  201. }
  202. fn receive<T: Serializable>(&mut self, party_id: usize) -> Result<Self::Fut<T>, Error> {
  203. match self.receiver_threads.get_mut(&party_id) {
  204. Some(t) => t.receive::<T>(),
  205. None => Err(Error::LogicError(format!(
  206. "ReceiverThread for party {} not found",
  207. party_id
  208. ))),
  209. }
  210. }
  211. fn shutdown(&mut self) -> HashMap<usize, CommunicationStats> {
  212. let mut comm_stats: HashMap<usize, CommunicationStats> = self
  213. .sender_threads
  214. .drain()
  215. .map(|(party_id, t)| {
  216. (party_id, {
  217. let (num_msgs_sent, num_bytes_sent) = t
  218. .join()
  219. .expect(&format!("join of sender thread {party_id} failed"));
  220. CommunicationStats {
  221. num_msgs_sent,
  222. num_bytes_sent,
  223. num_msgs_received: 0,
  224. num_bytes_received: 0,
  225. }
  226. })
  227. })
  228. .collect();
  229. self.receiver_threads.drain().for_each(|(party_id, t)| {
  230. let (num_msgs_received, num_bytes_received) = t
  231. .join()
  232. .expect(&format!("join of receiver thread {party_id} failed"));
  233. let cs = comm_stats
  234. .get_mut(&party_id)
  235. .expect(&format!("no comm stats for party {party_id} found"));
  236. cs.num_msgs_received = num_msgs_received;
  237. cs.num_bytes_received = num_bytes_received;
  238. });
  239. comm_stats
  240. }
  241. }