Browse Source

communicator: return communication stats

Lennart Braun 2 years ago
parent
commit
ab17148adc
2 changed files with 57 additions and 18 deletions
  1. 46 16
      communicator/src/communicator.rs
  2. 11 2
      communicator/src/lib.rs

+ 46 - 16
communicator/src/communicator.rs

@@ -1,4 +1,4 @@
-use crate::{AbstractCommunicator, Error, Fut, Serializable};
+use crate::{AbstractCommunicator, CommunicationStats, Error, Fut, Serializable};
 use bincode;
 use std::collections::HashMap;
 use std::fmt::Debug;
@@ -38,7 +38,7 @@ impl<T: Serializable> Fut<T> for MyFut<T> {
 #[derive(Debug)]
 struct ReceiverThread {
     buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
-    join_handle: thread::JoinHandle<Result<(), Error>>,
+    join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
 }
 
 impl ReceiverThread {
@@ -49,19 +49,23 @@ impl ReceiverThread {
         let join_handle = thread::Builder::new()
             .name("Receiver".to_owned())
             .spawn(move || {
+                let mut num_msgs_received = 0;
+                let mut num_bytes_received = 0;
                 loop {
                     let mut msg_size = [0u8; 4];
                     reader.read_exact(&mut msg_size)?;
                     let msg_size = u32::from_be_bytes(msg_size) as usize;
                     if msg_size == 0xffffffff {
-                        return Ok(());
+                        return Ok((num_msgs_received, num_bytes_received));
                     }
                     let mut buf = vec![0u8; msg_size];
                     reader.read_exact(&mut buf)?;
                     match buf_tx.send(buf) {
                         Ok(_) => (),
-                        Err(_) => return Ok(()), // we need to shutdown
+                        Err(_) => return Ok((num_msgs_received, num_bytes_received)), // we need to shutdown
                     }
+                    num_msgs_received += 1;
+                    num_bytes_received += 4 + msg_size;
                 }
             })
             .unwrap();
@@ -75,10 +79,9 @@ impl ReceiverThread {
         Ok(MyFut::new(self.buf_rx.clone()))
     }
 
-    pub fn join(self) -> Result<(), Error> {
+    pub fn join(self) -> Result<(usize, usize), Error> {
         drop(self.buf_rx);
-        self.join_handle.join().expect("join failed")?;
-        Ok(())
+        self.join_handle.join().expect("join failed")
     }
 }
 
@@ -86,7 +89,7 @@ impl ReceiverThread {
 #[derive(Debug)]
 struct SenderThread {
     buf_tx: Sender<Vec<u8>>,
-    join_handle: thread::JoinHandle<Result<(), Error>>,
+    join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
 }
 
 impl SenderThread {
@@ -96,14 +99,18 @@ impl SenderThread {
         let join_handle = thread::Builder::new()
             .name("Sender-1".to_owned())
             .spawn(move || {
+                let mut num_msgs_sent = 0;
+                let mut num_bytes_sent = 0;
                 for buf in buf_rx.iter() {
                     writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
                     writer.write_all(&buf)?;
                     writer.flush()?;
+                    num_msgs_sent += 1;
+                    num_bytes_sent += 4 + buf.len();
                 }
                 writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
                 writer.flush()?;
-                Ok(())
+                Ok((num_msgs_sent, num_bytes_sent))
             })
             .unwrap();
         Self {
@@ -119,7 +126,7 @@ impl SenderThread {
         Ok(())
     }
 
-    pub fn join(self) -> Result<(), Error> {
+    pub fn join(self) -> Result<(usize, usize), Error> {
         drop(self.buf_tx);
         self.join_handle.join().expect("join failed")
     }
@@ -205,12 +212,35 @@ impl AbstractCommunicator for Communicator {
         }
     }
 
-    fn shutdown(&mut self) {
-        self.sender_threads
+    fn shutdown(&mut self) -> HashMap<usize, CommunicationStats> {
+        let mut comm_stats: HashMap<usize, CommunicationStats> = self
+            .sender_threads
             .drain()
-            .for_each(|(_, t)| t.join().unwrap());
-        self.receiver_threads
-            .drain()
-            .for_each(|(_, t)| t.join().unwrap());
+            .map(|(party_id, t)| {
+                (party_id, {
+                    let (num_msgs_sent, num_bytes_sent) = t
+                        .join()
+                        .expect(&format!("join of sender thread {party_id} failed"));
+                    CommunicationStats {
+                        num_msgs_sent,
+                        num_bytes_sent,
+                        num_msgs_received: 0,
+                        num_bytes_received: 0,
+                    }
+                })
+            })
+            .collect();
+        self.receiver_threads.drain().for_each(|(party_id, t)| {
+            let (num_msgs_received, num_bytes_received) = t
+                .join()
+                .expect(&format!("join of receiver thread {party_id} failed"));
+            let cs = comm_stats
+                .get_mut(&party_id)
+                .expect(&format!("no comm stats for party {party_id} found"));
+            cs.num_msgs_received = num_msgs_received;
+            cs.num_bytes_received = num_bytes_received;
+        });
+
+        comm_stats
     }
 }

+ 11 - 2
communicator/src/lib.rs

@@ -2,7 +2,8 @@ pub mod communicator;
 pub mod tcp;
 pub mod unix;
 
-use bincode::error::{EncodeError, DecodeError};
+use bincode::error::{DecodeError, EncodeError};
+use std::collections::HashMap;
 use std::io::Error as IoError;
 use std::sync::mpsc::{RecvError, SendError};
 
@@ -16,6 +17,14 @@ pub trait Fut<T> {
     fn get(self) -> Result<T, Error>;
 }
 
+#[derive(Debug, Clone, Copy)]
+pub struct CommunicationStats {
+    pub num_msgs_received: usize,
+    pub num_bytes_received: usize,
+    pub num_msgs_sent: usize,
+    pub num_bytes_sent: usize,
+}
+
 /// Abstract communication interface between multiple parties
 pub trait AbstractCommunicator {
     type Fut<T: Serializable>: Fut<T>;
@@ -71,7 +80,7 @@ pub trait AbstractCommunicator {
     }
 
     /// Shutdown the communication system
-    fn shutdown(&mut self);
+    fn shutdown(&mut self) -> HashMap<usize, CommunicationStats>;
 }
 
 /// Custom error type