123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- use std::mem::size_of;
- use std::num::NonZeroU32;
- use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
- /// The minimum message size.
- /// All messages bodies less than this size (notably, receipts) will be padded to this length.
- const MIN_MESSAGE_SIZE: u32 = 256; // FIXME: double check what this should be
- #[macro_export]
- macro_rules! log {
- ( $( $x:expr ),* ) => {
- print!("{}", chrono::offset::Utc::now().format("%F %T: "));
- println!($( $x ),*)
- }
- }
- #[derive(Debug)]
- pub enum Error {
- Io(std::io::Error),
- Utf8Error(std::str::Utf8Error),
- TryFromSliceError(std::array::TryFromSliceError),
- MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
- }
- impl std::fmt::Display for Error {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self)
- }
- }
- impl std::error::Error for Error {}
- impl From<std::io::Error> for Error {
- fn from(e: std::io::Error) -> Self {
- Self::Io(e)
- }
- }
- impl From<std::str::Utf8Error> for Error {
- fn from(e: std::str::Utf8Error) -> Self {
- Self::Utf8Error(e)
- }
- }
- impl From<std::array::TryFromSliceError> for Error {
- fn from(e: std::array::TryFromSliceError) -> Self {
- Self::TryFromSliceError(e)
- }
- }
- /// Metadata for the body of the message.
- ///
- /// Message contents are always 0-filled buffers, so never represented.
- #[derive(Copy, Clone, Debug, PartialEq)]
- pub enum MessageBody {
- Receipt,
- Size(NonZeroU32),
- }
- impl MessageBody {
- fn size(&self) -> u32 {
- match self {
- MessageBody::Receipt => MIN_MESSAGE_SIZE,
- MessageBody::Size(size) => size.get(),
- }
- }
- }
- /// Message metadata.
- ///
- /// This has everything needed to reconstruct a message.
- // FIXME: every String should be &str
- #[derive(Debug)]
- pub struct MessageHeader {
- pub sender: String,
- pub recipients: Vec<String>,
- pub body: MessageBody,
- }
- impl MessageHeader {
- /// Generate a consise serialization of the Message.
- pub fn serialize(&self) -> SerializedMessage {
- // serialized message header: {
- // header_len: u32,
- // sender: {u32, utf-8}
- // num_recipients: u32,
- // recipients: [{u32, utf-8}],
- // body_type: MessageBody (i.e., u32)
- // }
- let num_recipients = self.recipients.len();
- let body_type = match self.body {
- MessageBody::Receipt => 0,
- MessageBody::Size(s) => s.get(),
- };
- let header_len = (1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
- + self.sender.len()
- + self.recipients.iter().map(String::len).sum::<usize>();
- let mut header: Vec<u8> = Vec::with_capacity(header_len);
- let header_len = header_len as u32;
- let num_recipients = num_recipients as u32;
- header.extend(header_len.to_be_bytes());
- serialize_str_to(&self.sender, &mut header);
- header.extend(num_recipients.to_be_bytes());
- for recipient in self.recipients.iter() {
- serialize_str_to(recipient, &mut header);
- }
- header.extend(body_type.to_be_bytes());
- assert!(header.len() == header_len as usize);
- SerializedMessage {
- header,
- body: self.body,
- }
- }
- /// Creates a MessageHeader from bytes created via serialization,
- /// but with the size already parsed out.
- fn deserialize(buf: &[u8]) -> Result<Self, Error> {
- let (sender, buf) = deserialize_str(buf)?;
- let sender = sender.to_string();
- let (num_recipients, buf) = deserialize_u32(buf)?;
- debug_assert!(num_recipients != 0);
- let mut recipients = Vec::with_capacity(num_recipients as usize);
- let mut recipient;
- let mut buf = buf;
- for _ in 0..num_recipients {
- (recipient, buf) = deserialize_str(buf)?;
- let recipient = recipient.to_string();
- recipients.push(recipient);
- }
- let (body, _) = deserialize_u32(buf)?;
- let body = if let Some(size) = NonZeroU32::new(body) {
- MessageBody::Size(size)
- } else {
- MessageBody::Receipt
- };
- Ok(Self {
- sender,
- recipients,
- body,
- })
- }
- }
- /// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer
- pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
- stream: &mut T,
- ) -> Result<(MessageHeader, Vec<u8>), Error> {
- let mut header_size_bytes = [0u8; 4];
- stream.read_exact(&mut header_size_bytes).await?;
- log!("got header size");
- get_message_with_header_size(stream, header_size_bytes).await
- }
- pub async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
- stream: &mut T,
- header_size_bytes: [u8; 4],
- ) -> Result<(MessageHeader, Vec<u8>), Error> {
- let header_size = u32::from_be_bytes(header_size_bytes);
- let mut header_buf = vec![0; header_size as usize];
- stream.read_exact(&mut header_buf[4..]).await?;
- let header = MessageHeader::deserialize(&header_buf[4..])?;
- log!(
- "got header from {} to {:?}, about to read {} bytes",
- header.sender,
- header.recipients,
- header.body.size()
- );
- let header_size_buf = &mut header_buf[..4];
- header_size_buf.copy_from_slice(&header_size_bytes);
- copy(&mut stream.take(header.body.size() as u64), &mut sink()).await?;
- Ok((header, header_buf))
- }
- pub fn serialize_str(s: &str) -> Vec<u8> {
- let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
- serialize_str_to(s, &mut buf);
- buf
- }
- pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
- let strlen = s.len() as u32;
- buf.extend(strlen.to_be_bytes());
- buf.extend(s.as_bytes());
- }
- fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
- let bytes = buf.get(0..4).ok_or(Error::MalformedSerialization(
- buf.to_vec(),
- std::backtrace::Backtrace::capture(),
- ))?;
- Ok((u32::from_be_bytes(bytes.try_into()?), &buf[4..]))
- }
- fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
- let (strlen, buf) = deserialize_u32(buf)?;
- let strlen = strlen as usize;
- let strbytes = buf.get(..strlen).ok_or(Error::MalformedSerialization(
- buf.to_vec(),
- std::backtrace::Backtrace::capture(),
- ))?;
- Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
- }
- /// A message almost ready for sending.
- ///
- /// We represent each message in two halves: the header, and the body.
- /// This way, the server can parse out the header in its own buf,
- /// and just pass that around intact, without keeping a (possibly large)
- /// 0-filled body around.
- #[derive(Debug)]
- pub struct SerializedMessage {
- pub header: Vec<u8>,
- pub body: MessageBody,
- }
- impl SerializedMessage {
- pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
- &self,
- writer: &mut T,
- ) -> std::io::Result<()> {
- let body_buf = vec![0; self.body.size() as usize];
- // write_all_vectored is not yet stable x_x
- // https://github.com/rust-lang/rust/issues/70436
- let mut header: &[u8] = &self.header;
- let mut body: &[u8] = &body_buf;
- loop {
- let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
- match writer.write_vectored(&bufs).await {
- Ok(written) => {
- if written == header.len() + body.len() {
- return Ok(());
- }
- if written >= header.len() {
- body = &body[written - header.len()..];
- break;
- } else if written == 0 {
- return Err(std::io::Error::new(
- std::io::ErrorKind::WriteZero,
- "failed to write any bytes from message with bytes remaining",
- ));
- } else {
- header = &header[written..];
- }
- }
- Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
- Err(e) => return Err(e),
- }
- }
- writer.write_all(body).await
- }
- }
|