communicator.rs 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. use crate::{AbstractCommunicator, CommunicationStats, Error, Fut, Serializable};
  2. use std::collections::HashMap;
  3. use std::fmt::Debug;
  4. use std::io::{BufReader, BufWriter, Read, Write};
  5. use std::marker::PhantomData;
  6. use std::sync::mpsc::{channel, Receiver, Sender};
  7. use std::sync::{Arc, Mutex};
  8. use std::thread;
  9. /// Future type used by [`Communicator`].
  10. pub struct MyFut<T: Serializable> {
  11. buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
  12. _phantom: PhantomData<T>,
  13. }
  14. impl<T: Serializable> MyFut<T> {
  15. fn new(buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>) -> Self {
  16. Self {
  17. buf_rx,
  18. _phantom: PhantomData,
  19. }
  20. }
  21. }
  22. impl<T: Serializable> Fut<T> for MyFut<T> {
  23. fn get(self) -> Result<T, Error> {
  24. let buf = self.buf_rx.lock().unwrap().recv()?;
  25. let (data, size) = bincode::decode_from_slice(
  26. &buf,
  27. bincode::config::standard().skip_fixed_array_length(),
  28. )?;
  29. assert_eq!(size, buf.len());
  30. Ok(data)
  31. }
  32. }
  33. /// Thread to receive messages in the background.
  34. #[derive(Debug)]
  35. struct ReceiverThread {
  36. buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
  37. join_handle: thread::JoinHandle<Result<(), Error>>,
  38. stats: Arc<Mutex<[usize; 2]>>,
  39. }
  40. impl ReceiverThread {
  41. pub fn from_reader<R: Debug + Read + Send + 'static>(reader: R) -> Self {
  42. let mut reader = BufReader::with_capacity(1 << 16, reader);
  43. let (buf_tx, buf_rx) = channel::<Vec<u8>>();
  44. let buf_rx = Arc::new(Mutex::new(buf_rx));
  45. let stats = Arc::new(Mutex::new([0usize; 2]));
  46. let stats_clone = stats.clone();
  47. let join_handle = thread::Builder::new()
  48. .name("Receiver".to_owned())
  49. .spawn(move || {
  50. loop {
  51. let mut msg_size = [0u8; 4];
  52. reader.read_exact(&mut msg_size)?;
  53. let msg_size = u32::from_be_bytes(msg_size) as usize;
  54. if msg_size == 0xffffffff {
  55. return Ok(());
  56. }
  57. let mut buf = vec![0u8; msg_size];
  58. reader.read_exact(&mut buf)?;
  59. match buf_tx.send(buf) {
  60. Ok(_) => (),
  61. Err(_) => return Ok(()), // we need to shutdown
  62. }
  63. {
  64. let mut guard = stats.lock().unwrap();
  65. guard[0] += 1;
  66. guard[1] += 4 + msg_size;
  67. }
  68. }
  69. })
  70. .unwrap();
  71. Self {
  72. join_handle,
  73. buf_rx,
  74. stats: stats_clone,
  75. }
  76. }
  77. pub fn receive<T: Serializable>(&mut self) -> Result<MyFut<T>, Error> {
  78. Ok(MyFut::new(self.buf_rx.clone()))
  79. }
  80. pub fn join(self) -> Result<(), Error> {
  81. drop(self.buf_rx);
  82. self.join_handle.join().expect("join failed")
  83. }
  84. pub fn get_stats(&self) -> [usize; 2] {
  85. *self.stats.lock().unwrap()
  86. }
  87. pub fn reset_stats(&mut self) {
  88. *self.stats.lock().unwrap() = [0usize; 2];
  89. }
  90. }
  91. /// Thread to send messages in the background.
  92. #[derive(Debug)]
  93. struct SenderThread {
  94. buf_tx: Sender<Vec<u8>>,
  95. join_handle: thread::JoinHandle<Result<(), Error>>,
  96. }
  97. impl SenderThread {
  98. pub fn from_writer<W: Debug + Write + Send + 'static>(writer: W) -> Self {
  99. let mut writer = BufWriter::with_capacity(1 << 16, writer);
  100. let (buf_tx, buf_rx) = channel::<Vec<u8>>();
  101. let join_handle = thread::Builder::new()
  102. .name("Sender-1".to_owned())
  103. .spawn(move || {
  104. for buf in buf_rx.iter() {
  105. writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
  106. debug_assert!(buf.len() <= u32::MAX as usize);
  107. writer.write_all(&buf)?;
  108. writer.flush()?;
  109. }
  110. writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
  111. writer.flush()?;
  112. Ok(())
  113. })
  114. .unwrap();
  115. Self {
  116. buf_tx,
  117. join_handle,
  118. }
  119. }
  120. pub fn send<T: Serializable>(&mut self, data: T) -> Result<usize, Error> {
  121. let buf =
  122. bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
  123. let num_bytes = 4 + buf.len();
  124. self.buf_tx.send(buf)?;
  125. Ok(num_bytes)
  126. }
  127. pub fn send_slice<T: Serializable>(&mut self, data: &[T]) -> Result<usize, Error> {
  128. let buf =
  129. bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
  130. let num_bytes = 4 + buf.len();
  131. self.buf_tx.send(buf)?;
  132. Ok(num_bytes)
  133. }
  134. pub fn join(self) -> Result<(), Error> {
  135. drop(self.buf_tx);
  136. self.join_handle.join().expect("join failed")
  137. }
  138. }
  139. /// Implementation of [`AbstractCommunicator`] that uses background threads to send and receive
  140. /// messages.
  141. #[derive(Debug)]
  142. pub struct Communicator {
  143. num_parties: usize,
  144. my_id: usize,
  145. comm_stats: HashMap<usize, CommunicationStats>,
  146. receiver_threads: HashMap<usize, ReceiverThread>,
  147. sender_threads: HashMap<usize, SenderThread>,
  148. }
  149. impl Communicator {
  150. /// Create a new Communicator from a collection of readers and writers that are connected with
  151. /// the other parties.
  152. pub fn from_reader_writer<
  153. R: Read + Send + Debug + 'static,
  154. W: Send + Write + Debug + 'static,
  155. >(
  156. num_parties: usize,
  157. my_id: usize,
  158. mut rw_map: HashMap<usize, (R, W)>,
  159. ) -> Self {
  160. assert_eq!(rw_map.len(), num_parties - 1);
  161. assert!((0..num_parties)
  162. .filter(|&pid| pid != my_id)
  163. .all(|pid| rw_map.contains_key(&pid)));
  164. let mut receiver_threads = HashMap::with_capacity(num_parties - 1);
  165. let mut sender_threads = HashMap::with_capacity(num_parties - 1);
  166. for pid in 0..num_parties {
  167. if pid == my_id {
  168. continue;
  169. }
  170. let (reader, writer) = rw_map.remove(&pid).unwrap();
  171. receiver_threads.insert(pid, ReceiverThread::from_reader(reader));
  172. sender_threads.insert(pid, SenderThread::from_writer(writer));
  173. }
  174. let comm_stats = (0..num_parties)
  175. .filter_map(|party_id| {
  176. if party_id == my_id {
  177. None
  178. } else {
  179. Some((party_id, Default::default()))
  180. }
  181. })
  182. .collect();
  183. Self {
  184. num_parties,
  185. my_id,
  186. comm_stats,
  187. receiver_threads,
  188. sender_threads,
  189. }
  190. }
  191. }
  192. impl AbstractCommunicator for Communicator {
  193. type Fut<T: Serializable> = MyFut<T>;
  194. fn get_num_parties(&self) -> usize {
  195. self.num_parties
  196. }
  197. fn get_my_id(&self) -> usize {
  198. self.my_id
  199. }
  200. fn send<T: Serializable>(&mut self, party_id: usize, val: T) -> Result<(), Error> {
  201. match self.sender_threads.get_mut(&party_id) {
  202. Some(t) => {
  203. let num_bytes = t.send(val)?;
  204. let cs = self.comm_stats.get_mut(&party_id).unwrap();
  205. cs.num_bytes_sent += num_bytes;
  206. cs.num_msgs_sent += 1;
  207. Ok(())
  208. }
  209. None => Err(Error::LogicError(format!(
  210. "SenderThread for party {party_id} not found"
  211. ))),
  212. }
  213. }
  214. fn send_slice<T: Serializable>(&mut self, party_id: usize, val: &[T]) -> Result<(), Error> {
  215. match self.sender_threads.get_mut(&party_id) {
  216. Some(t) => {
  217. let num_bytes = t.send_slice(val)?;
  218. let cs = self.comm_stats.get_mut(&party_id).unwrap();
  219. cs.num_bytes_sent += num_bytes;
  220. cs.num_msgs_sent += 1;
  221. Ok(())
  222. }
  223. None => Err(Error::LogicError(format!(
  224. "SenderThread for party {party_id} not found"
  225. ))),
  226. }
  227. }
  228. fn receive<T: Serializable>(&mut self, party_id: usize) -> Result<Self::Fut<T>, Error> {
  229. match self.receiver_threads.get_mut(&party_id) {
  230. Some(t) => t.receive::<T>(),
  231. None => Err(Error::LogicError(format!(
  232. "ReceiverThread for party {party_id} not found"
  233. ))),
  234. }
  235. }
  236. fn shutdown(&mut self) {
  237. self.sender_threads.drain().for_each(|(party_id, t)| {
  238. t.join()
  239. .unwrap_or_else(|_| panic!("join of sender thread {party_id} failed"))
  240. });
  241. self.receiver_threads.drain().for_each(|(party_id, t)| {
  242. t.join()
  243. .unwrap_or_else(|_| panic!("join of receiver thread {party_id} failed"))
  244. });
  245. }
  246. fn get_stats(&self) -> HashMap<usize, CommunicationStats> {
  247. let mut cs = self.comm_stats.clone();
  248. self.receiver_threads.iter().for_each(|(party_id, t)| {
  249. let [num_msgs_received, num_bytes_received] = t.get_stats();
  250. let cs_i = cs.get_mut(party_id).unwrap();
  251. cs_i.num_msgs_received = num_msgs_received;
  252. cs_i.num_bytes_received = num_bytes_received;
  253. });
  254. cs
  255. }
  256. fn reset_stats(&mut self) {
  257. self.comm_stats
  258. .iter_mut()
  259. .for_each(|(_, cs)| *cs = Default::default());
  260. self.receiver_threads
  261. .iter_mut()
  262. .for_each(|(_, t)| t.reset_stats());
  263. }
  264. }