|
@@ -38,7 +38,8 @@ impl<T: Serializable> Fut<T> for MyFut<T> {
|
|
|
#[derive(Debug)]
|
|
|
struct ReceiverThread {
|
|
|
buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
|
|
|
- join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
|
|
|
+ join_handle: thread::JoinHandle<Result<(), Error>>,
|
|
|
+ stats: Arc<Mutex<[usize; 2]>>,
|
|
|
}
|
|
|
|
|
|
impl ReceiverThread {
|
|
@@ -46,32 +47,36 @@ impl ReceiverThread {
|
|
|
let mut reader = BufReader::with_capacity(1 << 16, reader);
|
|
|
let (buf_tx, buf_rx) = channel::<Vec<u8>>();
|
|
|
let buf_rx = Arc::new(Mutex::new(buf_rx));
|
|
|
+ let stats = Arc::new(Mutex::new([0usize; 2]));
|
|
|
+ let stats_clone = stats.clone();
|
|
|
let join_handle = thread::Builder::new()
|
|
|
.name("Receiver".to_owned())
|
|
|
.spawn(move || {
|
|
|
- let mut num_msgs_received = 0;
|
|
|
- let mut num_bytes_received = 0;
|
|
|
loop {
|
|
|
let mut msg_size = [0u8; 4];
|
|
|
reader.read_exact(&mut msg_size)?;
|
|
|
let msg_size = u32::from_be_bytes(msg_size) as usize;
|
|
|
if msg_size == 0xffffffff {
|
|
|
- return Ok((num_msgs_received, num_bytes_received));
|
|
|
+ return Ok(());
|
|
|
}
|
|
|
let mut buf = vec![0u8; msg_size];
|
|
|
reader.read_exact(&mut buf)?;
|
|
|
match buf_tx.send(buf) {
|
|
|
Ok(_) => (),
|
|
|
- Err(_) => return Ok((num_msgs_received, num_bytes_received)), // we need to shutdown
|
|
|
+ Err(_) => return Ok(()), // we need to shutdown
|
|
|
+ }
|
|
|
+ {
|
|
|
+ let mut guard = stats.lock().unwrap();
|
|
|
+ guard[0] += 1;
|
|
|
+ guard[1] += 4 + msg_size;
|
|
|
}
|
|
|
- num_msgs_received += 1;
|
|
|
- num_bytes_received += 4 + msg_size;
|
|
|
}
|
|
|
})
|
|
|
.unwrap();
|
|
|
Self {
|
|
|
join_handle,
|
|
|
buf_rx,
|
|
|
+ stats: stats_clone,
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -79,17 +84,25 @@ impl ReceiverThread {
|
|
|
Ok(MyFut::new(self.buf_rx.clone()))
|
|
|
}
|
|
|
|
|
|
- pub fn join(self) -> Result<(usize, usize), Error> {
|
|
|
+ pub fn join(self) -> Result<(), Error> {
|
|
|
drop(self.buf_rx);
|
|
|
self.join_handle.join().expect("join failed")
|
|
|
}
|
|
|
+
|
|
|
+ pub fn get_stats(&self) -> [usize; 2] {
|
|
|
+ self.stats.lock().unwrap().clone()
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn reset_stats(&mut self) {
|
|
|
+ *self.stats.lock().unwrap() = [0usize; 2];
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// Thread to send messages in the background.
|
|
|
#[derive(Debug)]
|
|
|
struct SenderThread {
|
|
|
buf_tx: Sender<Vec<u8>>,
|
|
|
- join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
|
|
|
+ join_handle: thread::JoinHandle<Result<(), Error>>,
|
|
|
}
|
|
|
|
|
|
impl SenderThread {
|
|
@@ -99,18 +112,15 @@ impl SenderThread {
|
|
|
let join_handle = thread::Builder::new()
|
|
|
.name("Sender-1".to_owned())
|
|
|
.spawn(move || {
|
|
|
- let mut num_msgs_sent = 0;
|
|
|
- let mut num_bytes_sent = 0;
|
|
|
for buf in buf_rx.iter() {
|
|
|
writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
|
|
|
+ debug_assert!(buf.len() <= u32::MAX as usize);
|
|
|
writer.write_all(&buf)?;
|
|
|
writer.flush()?;
|
|
|
- num_msgs_sent += 1;
|
|
|
- num_bytes_sent += 4 + buf.len();
|
|
|
}
|
|
|
writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
|
|
|
writer.flush()?;
|
|
|
- Ok((num_msgs_sent, num_bytes_sent))
|
|
|
+ Ok(())
|
|
|
})
|
|
|
.unwrap();
|
|
|
Self {
|
|
@@ -119,21 +129,23 @@ impl SenderThread {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- pub fn send<T: Serializable>(&mut self, data: T) -> Result<(), Error> {
|
|
|
+ pub fn send<T: Serializable>(&mut self, data: T) -> Result<usize, Error> {
|
|
|
let buf =
|
|
|
bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
|
|
|
+ let num_bytes = 4 + buf.len();
|
|
|
self.buf_tx.send(buf)?;
|
|
|
- Ok(())
|
|
|
+ Ok(num_bytes)
|
|
|
}
|
|
|
|
|
|
- pub fn send_slice<T: Serializable>(&mut self, data: &[T]) -> Result<(), Error> {
|
|
|
+ pub fn send_slice<T: Serializable>(&mut self, data: &[T]) -> Result<usize, Error> {
|
|
|
let buf =
|
|
|
bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
|
|
|
+ let num_bytes = 4 + buf.len();
|
|
|
self.buf_tx.send(buf)?;
|
|
|
- Ok(())
|
|
|
+ Ok(num_bytes)
|
|
|
}
|
|
|
|
|
|
- pub fn join(self) -> Result<(usize, usize), Error> {
|
|
|
+ pub fn join(self) -> Result<(), Error> {
|
|
|
drop(self.buf_tx);
|
|
|
self.join_handle.join().expect("join failed")
|
|
|
}
|
|
@@ -144,6 +156,7 @@ impl SenderThread {
|
|
|
pub struct Communicator {
|
|
|
num_parties: usize,
|
|
|
my_id: usize,
|
|
|
+ comm_stats: HashMap<usize, CommunicationStats>,
|
|
|
receiver_threads: HashMap<usize, ReceiverThread>,
|
|
|
sender_threads: HashMap<usize, SenderThread>,
|
|
|
}
|
|
@@ -176,9 +189,20 @@ impl Communicator {
|
|
|
sender_threads.insert(pid, SenderThread::from_writer(writer));
|
|
|
}
|
|
|
|
|
|
+ let comm_stats = (0..num_parties)
|
|
|
+ .filter_map(|party_id| {
|
|
|
+ if party_id == my_id {
|
|
|
+ None
|
|
|
+ } else {
|
|
|
+ Some((party_id, Default::default()))
|
|
|
+ }
|
|
|
+ })
|
|
|
+ .collect();
|
|
|
+
|
|
|
Self {
|
|
|
num_parties,
|
|
|
my_id,
|
|
|
+ comm_stats,
|
|
|
receiver_threads,
|
|
|
sender_threads,
|
|
|
}
|
|
@@ -199,7 +223,10 @@ impl AbstractCommunicator for Communicator {
|
|
|
fn send<T: Serializable>(&mut self, party_id: usize, val: T) -> Result<(), Error> {
|
|
|
match self.sender_threads.get_mut(&party_id) {
|
|
|
Some(t) => {
|
|
|
- t.send(val)?;
|
|
|
+ let num_bytes = t.send(val)?;
|
|
|
+ let cs = self.comm_stats.get_mut(&party_id).unwrap();
|
|
|
+ cs.num_bytes_sent += num_bytes;
|
|
|
+ cs.num_msgs_sent += 1;
|
|
|
Ok(())
|
|
|
}
|
|
|
None => Err(Error::LogicError(format!(
|
|
@@ -212,7 +239,10 @@ impl AbstractCommunicator for Communicator {
|
|
|
fn send_slice<T: Serializable>(&mut self, party_id: usize, val: &[T]) -> Result<(), Error> {
|
|
|
match self.sender_threads.get_mut(&party_id) {
|
|
|
Some(t) => {
|
|
|
- t.send_slice(val)?;
|
|
|
+ let num_bytes = t.send_slice(val)?;
|
|
|
+ let cs = self.comm_stats.get_mut(&party_id).unwrap();
|
|
|
+ cs.num_bytes_sent += num_bytes;
|
|
|
+ cs.num_msgs_sent += 1;
|
|
|
Ok(())
|
|
|
}
|
|
|
None => Err(Error::LogicError(format!(
|
|
@@ -232,35 +262,34 @@ impl AbstractCommunicator for Communicator {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- fn shutdown(&mut self) -> HashMap<usize, CommunicationStats> {
|
|
|
- let mut comm_stats: HashMap<usize, CommunicationStats> = self
|
|
|
- .sender_threads
|
|
|
- .drain()
|
|
|
- .map(|(party_id, t)| {
|
|
|
- (party_id, {
|
|
|
- let (num_msgs_sent, num_bytes_sent) = t
|
|
|
- .join()
|
|
|
- .expect(&format!("join of sender thread {party_id} failed"));
|
|
|
- CommunicationStats {
|
|
|
- num_msgs_sent,
|
|
|
- num_bytes_sent,
|
|
|
- num_msgs_received: 0,
|
|
|
- num_bytes_received: 0,
|
|
|
- }
|
|
|
- })
|
|
|
- })
|
|
|
- .collect();
|
|
|
+ fn shutdown(&mut self) {
|
|
|
+ self.sender_threads.drain().for_each(|(party_id, t)| {
|
|
|
+ t.join()
|
|
|
+ .expect(&format!("join of sender thread {party_id} failed"))
|
|
|
+ });
|
|
|
self.receiver_threads.drain().for_each(|(party_id, t)| {
|
|
|
- let (num_msgs_received, num_bytes_received) = t
|
|
|
- .join()
|
|
|
- .expect(&format!("join of receiver thread {party_id} failed"));
|
|
|
- let cs = comm_stats
|
|
|
- .get_mut(&party_id)
|
|
|
- .expect(&format!("no comm stats for party {party_id} found"));
|
|
|
- cs.num_msgs_received = num_msgs_received;
|
|
|
- cs.num_bytes_received = num_bytes_received;
|
|
|
+ t.join()
|
|
|
+ .expect(&format!("join of receiver thread {party_id} failed"))
|
|
|
});
|
|
|
+ }
|
|
|
+
|
|
|
+ fn get_stats(&self) -> HashMap<usize, CommunicationStats> {
|
|
|
+ let mut cs = self.comm_stats.clone();
|
|
|
+ self.receiver_threads.iter().for_each(|(party_id, t)| {
|
|
|
+ let [num_msgs_received, num_bytes_received] = t.get_stats();
|
|
|
+ let cs_i = cs.get_mut(party_id).unwrap();
|
|
|
+ cs_i.num_msgs_received = num_msgs_received;
|
|
|
+ cs_i.num_bytes_received = num_bytes_received;
|
|
|
+ });
|
|
|
+ cs
|
|
|
+ }
|
|
|
|
|
|
- comm_stats
|
|
|
+ fn reset_stats(&mut self) {
|
|
|
+ self.comm_stats
|
|
|
+ .iter_mut()
|
|
|
+ .for_each(|(_, cs)| *cs = Default::default());
|
|
|
+ self.receiver_threads
|
|
|
+ .iter_mut()
|
|
|
+ .for_each(|(_, t)| t.reset_stats());
|
|
|
}
|
|
|
}
|