|
@@ -1,4 +1,4 @@
|
|
|
-use crate::{AbstractCommunicator, Error, Fut, Serializable};
|
|
|
+use crate::{AbstractCommunicator, CommunicationStats, Error, Fut, Serializable};
|
|
|
use bincode;
|
|
|
use std::collections::HashMap;
|
|
|
use std::fmt::Debug;
|
|
@@ -38,7 +38,7 @@ impl<T: Serializable> Fut<T> for MyFut<T> {
|
|
|
#[derive(Debug)]
|
|
|
struct ReceiverThread {
|
|
|
buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
|
|
|
- join_handle: thread::JoinHandle<Result<(), Error>>,
|
|
|
+ join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
|
|
|
}
|
|
|
|
|
|
impl ReceiverThread {
|
|
@@ -49,19 +49,23 @@ impl ReceiverThread {
|
|
|
let join_handle = thread::Builder::new()
|
|
|
.name("Receiver".to_owned())
|
|
|
.spawn(move || {
|
|
|
+ let mut num_msgs_received = 0;
|
|
|
+ let mut num_bytes_received = 0;
|
|
|
loop {
|
|
|
let mut msg_size = [0u8; 4];
|
|
|
reader.read_exact(&mut msg_size)?;
|
|
|
let msg_size = u32::from_be_bytes(msg_size) as usize;
|
|
|
if msg_size == 0xffffffff {
|
|
|
- return Ok(());
|
|
|
+ return Ok((num_msgs_received, num_bytes_received));
|
|
|
}
|
|
|
let mut buf = vec![0u8; msg_size];
|
|
|
reader.read_exact(&mut buf)?;
|
|
|
match buf_tx.send(buf) {
|
|
|
Ok(_) => (),
|
|
|
- Err(_) => return Ok(()), // we need to shutdown
|
|
|
+ Err(_) => return Ok((num_msgs_received, num_bytes_received)), // we need to shutdown
|
|
|
}
|
|
|
+ num_msgs_received += 1;
|
|
|
+ num_bytes_received += 4 + msg_size;
|
|
|
}
|
|
|
})
|
|
|
.unwrap();
|
|
@@ -75,10 +79,9 @@ impl ReceiverThread {
|
|
|
Ok(MyFut::new(self.buf_rx.clone()))
|
|
|
}
|
|
|
|
|
|
- pub fn join(self) -> Result<(), Error> {
|
|
|
+ pub fn join(self) -> Result<(usize, usize), Error> {
|
|
|
drop(self.buf_rx);
|
|
|
- self.join_handle.join().expect("join failed")?;
|
|
|
- Ok(())
|
|
|
+ self.join_handle.join().expect("join failed")
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -86,7 +89,7 @@ impl ReceiverThread {
|
|
|
#[derive(Debug)]
|
|
|
struct SenderThread {
|
|
|
buf_tx: Sender<Vec<u8>>,
|
|
|
- join_handle: thread::JoinHandle<Result<(), Error>>,
|
|
|
+ join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
|
|
|
}
|
|
|
|
|
|
impl SenderThread {
|
|
@@ -96,14 +99,18 @@ impl SenderThread {
|
|
|
let join_handle = thread::Builder::new()
|
|
|
.name("Sender-1".to_owned())
|
|
|
.spawn(move || {
|
|
|
+ let mut num_msgs_sent = 0;
|
|
|
+ let mut num_bytes_sent = 0;
|
|
|
for buf in buf_rx.iter() {
|
|
|
writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
|
|
|
writer.write_all(&buf)?;
|
|
|
writer.flush()?;
|
|
|
+ num_msgs_sent += 1;
|
|
|
+ num_bytes_sent += 4 + buf.len();
|
|
|
}
|
|
|
writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
|
|
|
writer.flush()?;
|
|
|
- Ok(())
|
|
|
+ Ok((num_msgs_sent, num_bytes_sent))
|
|
|
})
|
|
|
.unwrap();
|
|
|
Self {
|
|
@@ -119,7 +126,7 @@ impl SenderThread {
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
- pub fn join(self) -> Result<(), Error> {
|
|
|
+ pub fn join(self) -> Result<(usize, usize), Error> {
|
|
|
drop(self.buf_tx);
|
|
|
self.join_handle.join().expect("join failed")
|
|
|
}
|
|
@@ -205,12 +212,35 @@ impl AbstractCommunicator for Communicator {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- fn shutdown(&mut self) {
|
|
|
- self.sender_threads
|
|
|
+ fn shutdown(&mut self) -> HashMap<usize, CommunicationStats> {
|
|
|
+ let mut comm_stats: HashMap<usize, CommunicationStats> = self
|
|
|
+ .sender_threads
|
|
|
.drain()
|
|
|
- .for_each(|(_, t)| t.join().unwrap());
|
|
|
- self.receiver_threads
|
|
|
- .drain()
|
|
|
- .for_each(|(_, t)| t.join().unwrap());
|
|
|
+ .map(|(party_id, t)| {
|
|
|
+ (party_id, {
|
|
|
+ let (num_msgs_sent, num_bytes_sent) = t
|
|
|
+ .join()
|
|
|
+ .expect(&format!("join of sender thread {party_id} failed"));
|
|
|
+ CommunicationStats {
|
|
|
+ num_msgs_sent,
|
|
|
+ num_bytes_sent,
|
|
|
+ num_msgs_received: 0,
|
|
|
+ num_bytes_received: 0,
|
|
|
+ }
|
|
|
+ })
|
|
|
+ })
|
|
|
+ .collect();
|
|
|
+ self.receiver_threads.drain().for_each(|(party_id, t)| {
|
|
|
+ let (num_msgs_received, num_bytes_received) = t
|
|
|
+ .join()
|
|
|
+ .expect(&format!("join of receiver thread {party_id} failed"));
|
|
|
+ let cs = comm_stats
|
|
|
+ .get_mut(&party_id)
|
|
|
+ .expect(&format!("no comm stats for party {party_id} found"));
|
|
|
+ cs.num_msgs_received = num_msgs_received;
|
|
|
+ cs.num_bytes_received = num_bytes_received;
|
|
|
+ });
|
|
|
+
|
|
|
+ comm_stats
|
|
|
}
|
|
|
}
|