Browse Source

communicator: rework and buffering

Lennart Braun 2 years ago
parent
commit
2f4ec280b3
2 changed files with 135 additions and 56 deletions
  1. 125 55
      communicator/src/communicator.rs
  2. 10 1
      communicator/src/lib.rs

+ 125 - 55
communicator/src/communicator.rs

@@ -2,102 +2,168 @@ use crate::{AbstractCommunicator, Error, Fut, Serializable};
 use bincode;
 use std::collections::HashMap;
 use std::fmt::Debug;
-use std::io::{Read, Write};
-use std::sync::mpsc::{channel, sync_channel, Receiver, Sender};
+use std::io::{BufReader, BufWriter, Read, Write};
+use std::marker::PhantomData;
+use std::mem;
+use std::sync::mpsc::{channel, Sender};
+use std::sync::{Arc, Condvar, Mutex};
 use std::thread;
 
+struct SharedState {
+    pub mutex: Mutex<Option<Vec<u8>>>,
+    pub cvar: Condvar,
+}
+
+impl SharedState {
+    pub fn new() -> Self {
+        Self {
+            mutex: Mutex::new(None),
+            cvar: Condvar::new(),
+        }
+    }
+
+    pub fn put(&self, data: Vec<u8>) {
+        let mut lock = self.mutex.lock().unwrap();
+        *lock = Some(data);
+        self.cvar.notify_one();
+    }
+
+    pub fn get(&self) -> Vec<u8> {
+        let mut val = None;
+        let mut lock = self
+            .cvar
+            .wait_while(self.mutex.lock().unwrap(), |x| x.is_none())
+            .unwrap();
+        mem::swap(&mut val, &mut lock);
+        val.unwrap()
+    }
+}
+
 pub struct MyFut<T: Serializable> {
-    data_rx: Receiver<Result<T, Error>>,
+    shared_state: Arc<SharedState>,
+    _phantom: PhantomData<T>,
 }
 
 impl<T: Serializable> MyFut<T> {
-    pub fn new(data_rx: Receiver<Result<T, Error>>) -> Self {
-        Self { data_rx }
+    fn new(shared_state: Arc<SharedState>) -> Self {
+        Self {
+            shared_state,
+            _phantom: PhantomData,
+        }
     }
 }
 
 impl<T: Serializable> Fut<T> for MyFut<T> {
     fn get(self) -> Result<T, Error> {
-        match self.data_rx.recv() {
-            Ok(x) => x,
-            Err(e) => Err(e.into()),
-        }
+        let buf = self.shared_state.get();
+        let (data, size) = bincode::decode_from_slice(&buf, bincode::config::standard())?;
+        assert_eq!(size, buf.len());
+        Ok(data)
     }
 }
 
 /// Thread to receive messages in the background.
 #[derive(Debug)]
 struct ReceiverThread {
-    data_request_tx: Sender<Box<dyn FnOnce(&mut dyn Read) + Send>>,
-    join_handle: thread::JoinHandle<()>,
+    promise_tx: Sender<Arc<SharedState>>,
+    join_handle_1: thread::JoinHandle<Result<(), Error>>,
+    join_handle_2: thread::JoinHandle<Result<(), Error>>,
 }
 
 impl ReceiverThread {
-    pub fn from_reader<R: Debug + Read + Send + 'static>(mut reader: R) -> Self {
-        let (data_request_tx, data_request_rx) = channel::<Box<dyn FnOnce(&mut dyn Read) + Send>>();
-        let join_handle = thread::spawn(move || {
-            for func in data_request_rx.iter() {
-                func(&mut reader);
-            }
-        });
+    pub fn from_reader<R: Debug + Read + Send + 'static>(reader: R) -> Self {
+        let mut reader = BufReader::new(reader);
+        let (promise_tx, promise_rx) = channel::<Arc<SharedState>>();
+        let (buf_tx, buf_rx) = channel::<Vec<u8>>();
+        let join_handle_1 = thread::Builder::new()
+            .name("Receiver-1".to_owned())
+            .spawn(move || {
+                loop {
+                    let mut msg_size = [0u8; 4];
+                    reader.read_exact(&mut msg_size)?;
+                    let msg_size = u32::from_be_bytes(msg_size) as usize;
+                    if msg_size == 0xffffffff {
+                        return Ok(());
+                    }
+                    let mut buf = vec![0u8; msg_size];
+                    reader.read_exact(&mut buf)?;
+                    match buf_tx.send(buf) {
+                        Ok(_) => (),
+                        Err(_) => return Ok(()), // we need to shutdown
+                    }
+                }
+            })
+            .unwrap();
+        let join_handle_2 = thread::Builder::new()
+            .name("Receiver-2".to_owned())
+            .spawn(move || {
+                for shared_state in promise_rx {
+                    let buf = buf_rx.recv()?;
+                    shared_state.put(buf);
+                }
+                Ok(())
+            })
+            .unwrap();
         Self {
-            data_request_tx,
-            join_handle,
+            promise_tx,
+            join_handle_1,
+            join_handle_2,
         }
     }
 
     pub fn receive<T: Serializable>(&mut self) -> Result<MyFut<T>, Error> {
-        let (data_tx, data_rx) = sync_channel(1);
-        self.data_request_tx
-            .send(Box::new(move |mut reader: &mut dyn Read| {
-                let new: Result<T, Error> =
-                    bincode::decode_from_std_read(&mut reader, bincode::config::standard())
-                        .map_err(|e| e.into());
-                data_tx.send(new).expect("send failed");
-            }))?;
-        Ok(MyFut::new(data_rx.into()))
-    }
-
-    pub fn join(self) {
-        drop(self.data_request_tx);
-        self.join_handle.join().expect("join failed")
+        let shared_state_promise = Arc::new(SharedState::new());
+        let shared_state_future = shared_state_promise.clone();
+        self.promise_tx.send(shared_state_promise)?;
+        Ok(MyFut::new(shared_state_future))
+    }
+
+    pub fn join(self) -> Result<(), Error> {
+        drop(self.promise_tx);
+        self.join_handle_1.join().expect("join failed")?;
+        self.join_handle_2.join().expect("join failed")?;
+        Ok(())
     }
 }
 
 /// Thread to send messages in the background.
 #[derive(Debug)]
 struct SenderThread {
-    data_submission_tx: Sender<Box<dyn FnOnce(&mut dyn Write) + Send>>,
-    join_handle: thread::JoinHandle<()>,
+    buf_tx: Sender<Vec<u8>>,
+    join_handle: thread::JoinHandle<Result<(), Error>>,
 }
 
 impl SenderThread {
-    pub fn from_writer<W: Debug + Write + Send + 'static>(mut writer: W) -> Self {
-        let (data_submission_tx, data_submission_rx) =
-            channel::<Box<dyn FnOnce(&mut dyn Write) + Send>>();
-        let join_handle = thread::spawn(move || {
-            for func in data_submission_rx.iter() {
-                func(&mut writer);
-            }
-            writer.flush().expect("flush failed");
-        });
+    pub fn from_writer<W: Debug + Write + Send + 'static>(writer: W) -> Self {
+        let mut writer = BufWriter::with_capacity(1 << 15, writer);
+        let (buf_tx, buf_rx) = channel::<Vec<u8>>();
+        let join_handle = thread::Builder::new()
+            .name("Sender-1".to_owned())
+            .spawn(move || {
+                for buf in buf_rx.iter() {
+                    writer.write_all(&((buf.len() as u32).to_be_bytes()))?;
+                    writer.write_all(&buf)?;
+                    writer.flush()?;
+                }
+                writer.write_all(&[0xff, 0xff, 0xff, 0xff])?;
+                writer.flush()?;
+                Ok(())
+            })
+            .unwrap();
         Self {
-            data_submission_tx,
+            buf_tx,
             join_handle,
         }
     }
 
     pub fn send<T: Serializable>(&mut self, data: T) -> Result<(), Error> {
-        self.data_submission_tx
-            .send(Box::new(move |mut writer: &mut dyn Write| {
-                bincode::encode_into_std_write(data, &mut writer, bincode::config::standard())
-                    .expect("encode failed");
-            }))?;
+        let buf = bincode::encode_to_vec(data, bincode::config::standard())?;
+        self.buf_tx.send(buf)?;
         Ok(())
     }
 
-    pub fn join(self) {
-        drop(self.data_submission_tx);
+    pub fn join(self) -> Result<(), Error> {
+        drop(self.buf_tx);
         self.join_handle.join().expect("join failed")
     }
 }
@@ -183,7 +249,11 @@ impl AbstractCommunicator for Communicator {
     }
 
     fn shutdown(&mut self) {
-        self.sender_threads.drain().for_each(|(_, t)| t.join());
-        self.receiver_threads.drain().for_each(|(_, t)| t.join());
+        self.sender_threads
+            .drain()
+            .for_each(|(_, t)| t.join().unwrap());
+        self.receiver_threads
+            .drain()
+            .for_each(|(_, t)| t.join().unwrap());
     }
 }

+ 10 - 1
communicator/src/lib.rs

@@ -2,7 +2,7 @@ pub mod communicator;
 pub mod tcp;
 pub mod unix;
 
-use bincode::error::DecodeError;
+use bincode::error::{EncodeError, DecodeError};
 use std::io::Error as IoError;
 use std::sync::mpsc::{RecvError, SendError};
 
@@ -88,6 +88,8 @@ pub enum Error {
     /// Some std::sync::mpsc::SendError appeared
     SendError(String),
     /// Some bincode::error::DecodeError appeared
+    EncodeError(EncodeError),
+    /// Some bincode::error::DecodeError appeared
     DecodeError(DecodeError),
     /// Serialization of data failed
     SerializationError(String),
@@ -116,6 +118,13 @@ impl<T> From<SendError<T>> for Error {
     }
 }
 
+/// Enable automatic conversions from bincode::error::EncodeError
+impl From<EncodeError> for Error {
+    fn from(e: EncodeError) -> Error {
+        Error::EncodeError(e)
+    }
+}
+
 /// Enable automatic conversions from bincode::error::DecodeError
 impl From<DecodeError> for Error {
     fn from(e: DecodeError) -> Error {