lib.rs 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. use std::mem::size_of;
  2. use std::num::NonZeroU32;
  3. use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
  4. /// The minimum message size.
  5. /// All messages bodies less than this size (notably, receipts) will be padded to this length.
  6. const MIN_MESSAGE_SIZE: u32 = 256; // FIXME: double check what this should be
  7. #[macro_export]
  8. macro_rules! log {
  9. ( $( $x:expr ),* ) => {
  10. print!("{}", chrono::offset::Utc::now().format("%F %T: "));
  11. println!($( $x ),*)
  12. }
  13. }
  14. #[derive(Debug)]
  15. pub enum Error {
  16. Io(std::io::Error),
  17. Utf8Error(std::str::Utf8Error),
  18. TryFromSliceError(std::array::TryFromSliceError),
  19. MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
  20. }
  21. impl std::fmt::Display for Error {
  22. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  23. write!(f, "{:?}", self)
  24. }
  25. }
  26. impl std::error::Error for Error {}
  27. impl From<std::io::Error> for Error {
  28. fn from(e: std::io::Error) -> Self {
  29. Self::Io(e)
  30. }
  31. }
  32. impl From<std::str::Utf8Error> for Error {
  33. fn from(e: std::str::Utf8Error) -> Self {
  34. Self::Utf8Error(e)
  35. }
  36. }
  37. impl From<std::array::TryFromSliceError> for Error {
  38. fn from(e: std::array::TryFromSliceError) -> Self {
  39. Self::TryFromSliceError(e)
  40. }
  41. }
  42. /// Metadata for the body of the message.
  43. ///
  44. /// Message contents are always 0-filled buffers, so never represented.
  45. #[derive(Copy, Clone, Debug, PartialEq)]
  46. pub enum MessageBody {
  47. Receipt,
  48. Size(NonZeroU32),
  49. }
  50. impl MessageBody {
  51. fn size(&self) -> u32 {
  52. match self {
  53. MessageBody::Receipt => MIN_MESSAGE_SIZE,
  54. MessageBody::Size(size) => size.get(),
  55. }
  56. }
  57. }
  58. /// Message metadata.
  59. ///
  60. /// This has everything needed to reconstruct a message.
  61. // FIXME: every String should be &str
  62. #[derive(Debug)]
  63. pub struct MessageHeader {
  64. pub sender: String,
  65. pub recipients: Vec<String>,
  66. pub body: MessageBody,
  67. }
  68. impl MessageHeader {
  69. /// Generate a consise serialization of the Message.
  70. pub fn serialize(&self) -> SerializedMessage {
  71. // serialized message header: {
  72. // header_len: u32,
  73. // sender: {u32, utf-8}
  74. // num_recipients: u32,
  75. // recipients: [{u32, utf-8}],
  76. // body_type: MessageBody (i.e., u32)
  77. // }
  78. let num_recipients = self.recipients.len();
  79. let body_type = match self.body {
  80. MessageBody::Receipt => 0,
  81. MessageBody::Size(s) => s.get(),
  82. };
  83. let header_len = (1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
  84. + self.sender.len()
  85. + self.recipients.iter().map(String::len).sum::<usize>();
  86. let mut header: Vec<u8> = Vec::with_capacity(header_len);
  87. let header_len = header_len as u32;
  88. let num_recipients = num_recipients as u32;
  89. header.extend(header_len.to_be_bytes());
  90. serialize_str_to(&self.sender, &mut header);
  91. header.extend(num_recipients.to_be_bytes());
  92. for recipient in self.recipients.iter() {
  93. serialize_str_to(recipient, &mut header);
  94. }
  95. header.extend(body_type.to_be_bytes());
  96. assert!(header.len() == header_len as usize);
  97. SerializedMessage {
  98. header,
  99. body: self.body,
  100. }
  101. }
  102. /// Creates a MessageHeader from bytes created via serialization,
  103. /// but with the size already parsed out.
  104. fn deserialize(buf: &[u8]) -> Result<Self, Error> {
  105. let (sender, buf) = deserialize_str(buf)?;
  106. let sender = sender.to_string();
  107. let (num_recipients, buf) = deserialize_u32(buf)?;
  108. debug_assert!(num_recipients != 0);
  109. let mut recipients = Vec::with_capacity(num_recipients as usize);
  110. let mut recipient;
  111. let mut buf = buf;
  112. for _ in 0..num_recipients {
  113. (recipient, buf) = deserialize_str(buf)?;
  114. let recipient = recipient.to_string();
  115. recipients.push(recipient);
  116. }
  117. let (body, _) = deserialize_u32(buf)?;
  118. let body = if let Some(size) = NonZeroU32::new(body) {
  119. MessageBody::Size(size)
  120. } else {
  121. MessageBody::Receipt
  122. };
  123. Ok(Self {
  124. sender,
  125. recipients,
  126. body,
  127. })
  128. }
  129. }
  130. /// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer
  131. pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
  132. stream: &mut T,
  133. ) -> Result<(MessageHeader, Vec<u8>), Error> {
  134. let mut header_size_bytes = [0u8; 4];
  135. stream.read_exact(&mut header_size_bytes).await?;
  136. log!("got header size");
  137. get_message_with_header_size(stream, header_size_bytes).await
  138. }
  139. pub async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
  140. stream: &mut T,
  141. header_size_bytes: [u8; 4],
  142. ) -> Result<(MessageHeader, Vec<u8>), Error> {
  143. let header_size = u32::from_be_bytes(header_size_bytes);
  144. let mut header_buf = vec![0; header_size as usize];
  145. stream.read_exact(&mut header_buf[4..]).await?;
  146. let header = MessageHeader::deserialize(&header_buf[4..])?;
  147. log!(
  148. "got header from {} to {:?}, about to read {} bytes",
  149. header.sender,
  150. header.recipients,
  151. header.body.size()
  152. );
  153. let header_size_buf = &mut header_buf[..4];
  154. header_size_buf.copy_from_slice(&header_size_bytes);
  155. copy(&mut stream.take(header.body.size() as u64), &mut sink()).await?;
  156. Ok((header, header_buf))
  157. }
  158. pub fn serialize_str(s: &str) -> Vec<u8> {
  159. let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
  160. serialize_str_to(s, &mut buf);
  161. buf
  162. }
  163. pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
  164. let strlen = s.len() as u32;
  165. buf.extend(strlen.to_be_bytes());
  166. buf.extend(s.as_bytes());
  167. }
  168. fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
  169. let bytes = buf.get(0..4).ok_or(Error::MalformedSerialization(
  170. buf.to_vec(),
  171. std::backtrace::Backtrace::capture(),
  172. ))?;
  173. Ok((u32::from_be_bytes(bytes.try_into()?), &buf[4..]))
  174. }
  175. fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
  176. let (strlen, buf) = deserialize_u32(buf)?;
  177. let strlen = strlen as usize;
  178. let strbytes = buf.get(..strlen).ok_or(Error::MalformedSerialization(
  179. buf.to_vec(),
  180. std::backtrace::Backtrace::capture(),
  181. ))?;
  182. Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
  183. }
  184. /// A message almost ready for sending.
  185. ///
  186. /// We represent each message in two halves: the header, and the body.
  187. /// This way, the server can parse out the header in its own buf,
  188. /// and just pass that around intact, without keeping a (possibly large)
  189. /// 0-filled body around.
  190. #[derive(Debug)]
  191. pub struct SerializedMessage {
  192. pub header: Vec<u8>,
  193. pub body: MessageBody,
  194. }
  195. impl SerializedMessage {
  196. pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
  197. &self,
  198. writer: &mut T,
  199. ) -> std::io::Result<()> {
  200. let body_buf = vec![0; self.body.size() as usize];
  201. // write_all_vectored is not yet stable x_x
  202. // https://github.com/rust-lang/rust/issues/70436
  203. let mut header: &[u8] = &self.header;
  204. let mut body: &[u8] = &body_buf;
  205. loop {
  206. let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
  207. match writer.write_vectored(&bufs).await {
  208. Ok(written) => {
  209. if written == header.len() + body.len() {
  210. return Ok(());
  211. }
  212. if written >= header.len() {
  213. body = &body[written - header.len()..];
  214. break;
  215. } else if written == 0 {
  216. return Err(std::io::Error::new(
  217. std::io::ErrorKind::WriteZero,
  218. "failed to write any bytes from message with bytes remaining",
  219. ));
  220. } else {
  221. header = &header[written..];
  222. }
  223. }
  224. Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
  225. Err(e) => return Err(e),
  226. }
  227. }
  228. writer.write_all(body).await
  229. }
  230. }