use std::mem::size_of; use std::num::NonZeroU32; use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt}; /// The padding interval. All message bodies are a size of some multiple of this. /// All messages bodies are a minimum of this size. // FIXME: double check what this should be pub const PADDING_BLOCK_SIZE: u32 = 256; /// The most blocks a message body can contain. // from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE; #[macro_export] macro_rules! log { ( $( $x:expr ),* ) => { print!("{}", chrono::offset::Utc::now().format("%F %T: ")); println!($( $x ),*) } } #[derive(Debug)] pub enum Error { Io(std::io::Error), Utf8Error(std::str::Utf8Error), MalformedSerialization(Vec, std::backtrace::Backtrace), } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } } impl std::error::Error for Error {} impl From for Error { fn from(e: std::io::Error) -> Self { Self::Io(e) } } impl From for Error { fn from(e: std::str::Utf8Error) -> Self { Self::Utf8Error(e) } } /// Metadata for the body of the message. /// /// Message contents are always 0-filled buffers, so never represented. #[derive(Copy, Clone, Debug, PartialEq)] pub enum MessageBody { Receipt, Size(NonZeroU32), } impl MessageBody { fn size(&self) -> u32 { match self { MessageBody::Receipt => PADDING_BLOCK_SIZE, MessageBody::Size(size) => size.get(), } } } /// Message metadata. /// /// This has everything needed to reconstruct a message. // FIXME: every String should be &str #[derive(Debug)] pub struct MessageHeader { pub sender: String, pub recipients: Vec, pub body: MessageBody, } impl MessageHeader { /// Generate a concise serialization of the Message. pub fn serialize(&self) -> SerializedMessage { // serialized message header: { // header_len: u32, // sender: {u32, utf-8} // num_recipients: u32, // recipients: [{u32, utf-8}], // body_type: MessageBody (i.e., u32) // } let num_recipients = self.recipients.len(); let body_type = match self.body { MessageBody::Receipt => 0, MessageBody::Size(s) => s.get(), }; let header_len = (1 + 1 + 1 + num_recipients + 1) * size_of::() + self.sender.len() + self.recipients.iter().map(String::len).sum::(); let mut header: Vec = Vec::with_capacity(header_len); let header_len = header_len as u32; let num_recipients = num_recipients as u32; header.extend(header_len.to_be_bytes()); serialize_str_to(&self.sender, &mut header); header.extend(num_recipients.to_be_bytes()); for recipient in self.recipients.iter() { serialize_str_to(recipient, &mut header); } header.extend(body_type.to_be_bytes()); assert!(header.len() == header_len as usize); SerializedMessage { header, body: self.body, } } /// Creates a MessageHeader from bytes created via serialization, /// but with the size already parsed out. fn deserialize(buf: &[u8]) -> Result { let (sender, buf) = deserialize_str(buf)?; let sender = sender.to_string(); let (num_recipients, buf) = deserialize_u32(buf)?; debug_assert!(num_recipients != 0); let mut recipients = Vec::with_capacity(num_recipients as usize); let mut recipient; let mut buf = buf; for _ in 0..num_recipients { (recipient, buf) = deserialize_str(buf)?; let recipient = recipient.to_string(); recipients.push(recipient); } let (body, _) = deserialize_u32(buf)?; let body = if let Some(size) = NonZeroU32::new(body) { MessageBody::Size(size) } else { MessageBody::Receipt }; Ok(Self { sender, recipients, body, }) } } /// Parse the identifier from the start of the TcpStream. pub async fn parse_identifier(stream: &mut T) -> Result { // this should maybe be buffered let strlen = stream.read_u32().await?; let mut buf = vec![0u8; strlen as usize]; stream.read_exact(&mut buf).await?; let s = std::str::from_utf8(&buf)?; Ok(s.to_string()) } /// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer pub async fn get_message( stream: &mut T, ) -> Result<(MessageHeader, Vec), Error> { let mut header_size_bytes = [0u8; 4]; stream.read_exact(&mut header_size_bytes).await?; log!("got header size"); get_message_with_header_size(stream, header_size_bytes).await } pub async fn get_message_with_header_size( stream: &mut T, header_size_bytes: [u8; 4], ) -> Result<(MessageHeader, Vec), Error> { let header_size = u32::from_be_bytes(header_size_bytes); let mut header_buf = vec![0; header_size as usize]; stream.read_exact(&mut header_buf[4..]).await?; let header = MessageHeader::deserialize(&header_buf[4..])?; log!( "got header from {} to {:?}, about to read {} bytes", header.sender, header.recipients, header.body.size() ); let header_size_buf = &mut header_buf[..4]; header_size_buf.copy_from_slice(&header_size_bytes); copy(&mut stream.take(header.body.size() as u64), &mut sink()).await?; Ok((header, header_buf)) } pub fn serialize_str(s: &str) -> Vec { let mut buf = Vec::with_capacity(s.len() + size_of::()); serialize_str_to(s, &mut buf); buf } pub fn serialize_str_to(s: &str, buf: &mut Vec) { let strlen = s.len() as u32; buf.extend(strlen.to_be_bytes()); buf.extend(s.as_bytes()); } fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> { let bytes = buf.get(0..4).ok_or_else(|| { Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture()) })?; Ok((u32::from_be_bytes(bytes.try_into().unwrap()), &buf[4..])) } fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> { let (strlen, buf) = deserialize_u32(buf)?; let strlen = strlen as usize; let strbytes = buf.get(..strlen).ok_or(Error::MalformedSerialization( buf.to_vec(), std::backtrace::Backtrace::capture(), ))?; Ok((std::str::from_utf8(strbytes)?, &buf[strlen..])) } /// A message almost ready for sending. /// /// We represent each message in two halves: the header, and the body. /// This way, the server can parse out the header in its own buf, /// and just pass that around intact, without keeping a (possibly large) /// 0-filled body around. #[derive(Debug)] pub struct SerializedMessage { pub header: Vec, pub body: MessageBody, } impl SerializedMessage { pub async fn write_all_to( &self, writer: &mut T, ) -> std::io::Result<()> { let body_buf = vec![0; self.body.size() as usize]; // write_all_vectored is not yet stable x_x // https://github.com/rust-lang/rust/issues/70436 let mut header: &[u8] = &self.header; let mut body: &[u8] = &body_buf; loop { let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)]; match writer.write_vectored(&bufs).await { Ok(written) => { if written == header.len() + body.len() { return Ok(()); } if written >= header.len() { body = &body[written - header.len()..]; break; } else if written == 0 { return Err(std::io::Error::new( std::io::ErrorKind::WriteZero, "failed to write any bytes from message with bytes remaining", )); } else { header = &header[written..]; } } Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue, Err(e) => return Err(e), } } writer.write_all(body).await } }