浏览代码

communicator: rework stats collection

Lennart Braun 1 年之前
父节点
当前提交
146659c3cc
共有 3 个文件被更改,包括 87 次插入51 次删除
  1. 77 48
      communicator/src/communicator.rs
  2. 8 2
      communicator/src/lib.rs
  3. 2 1
      oram/examples/bench_doram.rs

+ 77 - 48
communicator/src/communicator.rs

@@ -38,7 +38,8 @@ impl<T: Serializable> Fut<T> for MyFut<T> {
 #[derive(Debug)]
 struct ReceiverThread {
     buf_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
-    join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
+    join_handle: thread::JoinHandle<Result<(), Error>>,
+    stats: Arc<Mutex<[usize; 2]>>,
 }
 
 impl ReceiverThread {
@@ -46,32 +47,36 @@ impl ReceiverThread {
         let mut reader = BufReader::with_capacity(1 << 16, reader);
         let (buf_tx, buf_rx) = channel::<Vec<u8>>();
         let buf_rx = Arc::new(Mutex::new(buf_rx));
+        let stats = Arc::new(Mutex::new([0usize; 2]));
+        let stats_clone = stats.clone();
         let join_handle = thread::Builder::new()
             .name("Receiver".to_owned())
             .spawn(move || {
-                let mut num_msgs_received = 0;
-                let mut num_bytes_received = 0;
                 loop {
                     let mut msg_size = [0u8; 4];
                     reader.read_exact(&mut msg_size)?;
                     let msg_size = u32::from_be_bytes(msg_size) as usize;
                     if msg_size == 0xffffffff {
-                        return Ok((num_msgs_received, num_bytes_received));
+                        return Ok(());
                     }
                     let mut buf = vec![0u8; msg_size];
                     reader.read_exact(&mut buf)?;
                     match buf_tx.send(buf) {
                         Ok(_) => (),
-                        Err(_) => return Ok((num_msgs_received, num_bytes_received)), // we need to shutdown
+                        Err(_) => return Ok(()), // we need to shutdown
+                    }
+                    {
+                        let mut guard = stats.lock().unwrap();
+                        guard[0] += 1;
+                        guard[1] += 4 + msg_size;
                     }
-                    num_msgs_received += 1;
-                    num_bytes_received += 4 + msg_size;
                 }
             })
             .unwrap();
         Self {
             join_handle,
             buf_rx,
+            stats: stats_clone,
         }
     }
 
@@ -79,17 +84,25 @@ impl ReceiverThread {
         Ok(MyFut::new(self.buf_rx.clone()))
     }
 
-    pub fn join(self) -> Result<(usize, usize), Error> {
+    pub fn join(self) -> Result<(), Error> {
         drop(self.buf_rx);
         self.join_handle.join().expect("join failed")
     }
+
+    pub fn get_stats(&self) -> [usize; 2] {
+        self.stats.lock().unwrap().clone()
+    }
+
+    pub fn reset_stats(&mut self) {
+        *self.stats.lock().unwrap() = [0usize; 2];
+    }
 }
 
 /// Thread to send messages in the background.
 #[derive(Debug)]
 struct SenderThread {
     buf_tx: Sender<Vec<u8>>,
-    join_handle: thread::JoinHandle<Result<(usize, usize), Error>>,
+    join_handle: thread::JoinHandle<Result<(), Error>>,
 }
 
 impl SenderThread {
@@ -99,18 +112,15 @@ impl SenderThread {
         let join_handle = thread::Builder::new()
             .name("Sender-1".to_owned())
             .spawn(move || {
-                let mut num_msgs_sent = 0;
-                let mut num_bytes_sent = 0;
                 for buf in buf_rx.iter() {
                     writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
+                    debug_assert!(buf.len() <= u32::MAX as usize);
                     writer.write_all(&buf)?;
                     writer.flush()?;
-                    num_msgs_sent += 1;
-                    num_bytes_sent += 4 + buf.len();
                 }
                 writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
                 writer.flush()?;
-                Ok((num_msgs_sent, num_bytes_sent))
+                Ok(())
             })
             .unwrap();
         Self {
@@ -119,21 +129,23 @@ impl SenderThread {
         }
     }
 
-    pub fn send<T: Serializable>(&mut self, data: T) -> Result<(), Error> {
+    pub fn send<T: Serializable>(&mut self, data: T) -> Result<usize, Error> {
         let buf =
             bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
+        let num_bytes = 4 + buf.len();
         self.buf_tx.send(buf)?;
-        Ok(())
+        Ok(num_bytes)
     }
 
-    pub fn send_slice<T: Serializable>(&mut self, data: &[T]) -> Result<(), Error> {
+    pub fn send_slice<T: Serializable>(&mut self, data: &[T]) -> Result<usize, Error> {
         let buf =
             bincode::encode_to_vec(data, bincode::config::standard().skip_fixed_array_length())?;
+        let num_bytes = 4 + buf.len();
         self.buf_tx.send(buf)?;
-        Ok(())
+        Ok(num_bytes)
     }
 
-    pub fn join(self) -> Result<(usize, usize), Error> {
+    pub fn join(self) -> Result<(), Error> {
         drop(self.buf_tx);
         self.join_handle.join().expect("join failed")
     }
@@ -144,6 +156,7 @@ impl SenderThread {
 pub struct Communicator {
     num_parties: usize,
     my_id: usize,
+    comm_stats: HashMap<usize, CommunicationStats>,
     receiver_threads: HashMap<usize, ReceiverThread>,
     sender_threads: HashMap<usize, SenderThread>,
 }
@@ -176,9 +189,20 @@ impl Communicator {
             sender_threads.insert(pid, SenderThread::from_writer(writer));
         }
 
+        let comm_stats = (0..num_parties)
+            .filter_map(|party_id| {
+                if party_id == my_id {
+                    None
+                } else {
+                    Some((party_id, Default::default()))
+                }
+            })
+            .collect();
+
         Self {
             num_parties,
             my_id,
+            comm_stats,
             receiver_threads,
             sender_threads,
         }
@@ -199,7 +223,10 @@ impl AbstractCommunicator for Communicator {
     fn send<T: Serializable>(&mut self, party_id: usize, val: T) -> Result<(), Error> {
         match self.sender_threads.get_mut(&party_id) {
             Some(t) => {
-                t.send(val)?;
+                let num_bytes = t.send(val)?;
+                let cs = self.comm_stats.get_mut(&party_id).unwrap();
+                cs.num_bytes_sent += num_bytes;
+                cs.num_msgs_sent += 1;
                 Ok(())
             }
             None => Err(Error::LogicError(format!(
@@ -212,7 +239,10 @@ impl AbstractCommunicator for Communicator {
     fn send_slice<T: Serializable>(&mut self, party_id: usize, val: &[T]) -> Result<(), Error> {
         match self.sender_threads.get_mut(&party_id) {
             Some(t) => {
-                t.send_slice(val)?;
+                let num_bytes = t.send_slice(val)?;
+                let cs = self.comm_stats.get_mut(&party_id).unwrap();
+                cs.num_bytes_sent += num_bytes;
+                cs.num_msgs_sent += 1;
                 Ok(())
             }
             None => Err(Error::LogicError(format!(
@@ -232,35 +262,34 @@ impl AbstractCommunicator for Communicator {
         }
     }
 
-    fn shutdown(&mut self) -> HashMap<usize, CommunicationStats> {
-        let mut comm_stats: HashMap<usize, CommunicationStats> = self
-            .sender_threads
-            .drain()
-            .map(|(party_id, t)| {
-                (party_id, {
-                    let (num_msgs_sent, num_bytes_sent) = t
-                        .join()
-                        .expect(&format!("join of sender thread {party_id} failed"));
-                    CommunicationStats {
-                        num_msgs_sent,
-                        num_bytes_sent,
-                        num_msgs_received: 0,
-                        num_bytes_received: 0,
-                    }
-                })
-            })
-            .collect();
+    fn shutdown(&mut self) {
+        self.sender_threads.drain().for_each(|(party_id, t)| {
+            t.join()
+                .expect(&format!("join of sender thread {party_id} failed"))
+        });
         self.receiver_threads.drain().for_each(|(party_id, t)| {
-            let (num_msgs_received, num_bytes_received) = t
-                .join()
-                .expect(&format!("join of receiver thread {party_id} failed"));
-            let cs = comm_stats
-                .get_mut(&party_id)
-                .expect(&format!("no comm stats for party {party_id} found"));
-            cs.num_msgs_received = num_msgs_received;
-            cs.num_bytes_received = num_bytes_received;
+            t.join()
+                .expect(&format!("join of receiver thread {party_id} failed"))
         });
+    }
+
+    fn get_stats(&self) -> HashMap<usize, CommunicationStats> {
+        let mut cs = self.comm_stats.clone();
+        self.receiver_threads.iter().for_each(|(party_id, t)| {
+            let [num_msgs_received, num_bytes_received] = t.get_stats();
+            let cs_i = cs.get_mut(party_id).unwrap();
+            cs_i.num_msgs_received = num_msgs_received;
+            cs_i.num_bytes_received = num_bytes_received;
+        });
+        cs
+    }
 
-        comm_stats
+    fn reset_stats(&mut self) {
+        self.comm_stats
+            .iter_mut()
+            .for_each(|(_, cs)| *cs = Default::default());
+        self.receiver_threads
+            .iter_mut()
+            .for_each(|(_, t)| t.reset_stats());
     }
 }

+ 8 - 2
communicator/src/lib.rs

@@ -17,7 +17,7 @@ pub trait Fut<T> {
     fn get(self) -> Result<T, Error>;
 }
 
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Default, Clone, Copy)]
 pub struct CommunicationStats {
     pub num_msgs_received: usize,
     pub num_bytes_received: usize,
@@ -96,7 +96,13 @@ pub trait AbstractCommunicator {
     }
 
     /// Shutdown the communication system
-    fn shutdown(&mut self) -> HashMap<usize, CommunicationStats>;
+    fn shutdown(&mut self);
+
+    /// Obtain statistics about how many messages/bytes were send/received
+    fn get_stats(&self) -> HashMap<usize, CommunicationStats>;
+
+    /// Reset statistics
+    fn reset_stats(&mut self);
 }
 
 /// Custom error type

+ 2 - 1
oram/examples/bench_doram.rs

@@ -203,7 +203,8 @@ fn main() {
         d_accesses.as_secs_f64() * 1000.0 / stash_size as f64
     );
 
-    let comm_stats = comm.shutdown();
+    let comm_stats = comm.get_stats();
+    comm.shutdown();
 
     runtimes.print(cli.party_id as usize + 1, stash_size);