lib.rs 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. use std::mem::size_of;
  2. use std::num::NonZeroU32;
  3. use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
  4. /// The padding interval. All message bodies are a size of some multiple of this.
  5. /// All messages bodies are a minimum of this size.
  6. // FIXME: double check what this should be
  7. pub const PADDING_BLOCK_SIZE: u32 = 256;
  8. /// The most blocks a message body can contain.
  9. // from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
  10. pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
  11. #[macro_export]
  12. macro_rules! log {
  13. ( $( $x:expr ),* ) => {
  14. print!("{}", chrono::offset::Utc::now().format("%F %T: "));
  15. println!($( $x ),*)
  16. }
  17. }
  18. #[derive(Debug)]
  19. pub enum Error {
  20. Io(std::io::Error),
  21. Utf8Error(std::str::Utf8Error),
  22. MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
  23. }
  24. impl std::fmt::Display for Error {
  25. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  26. write!(f, "{:?}", self)
  27. }
  28. }
  29. impl std::error::Error for Error {}
  30. impl From<std::io::Error> for Error {
  31. fn from(e: std::io::Error) -> Self {
  32. Self::Io(e)
  33. }
  34. }
  35. impl From<std::str::Utf8Error> for Error {
  36. fn from(e: std::str::Utf8Error) -> Self {
  37. Self::Utf8Error(e)
  38. }
  39. }
  40. /// Metadata for the body of the message.
  41. ///
  42. /// Message contents are always 0-filled buffers, so never represented.
  43. #[derive(Copy, Clone, Debug, PartialEq)]
  44. pub enum MessageBody {
  45. Receipt,
  46. Size(NonZeroU32),
  47. }
  48. impl MessageBody {
  49. fn size(&self) -> u32 {
  50. match self {
  51. MessageBody::Receipt => PADDING_BLOCK_SIZE,
  52. MessageBody::Size(size) => size.get(),
  53. }
  54. }
  55. }
  56. /// Message metadata.
  57. ///
  58. /// This has everything needed to reconstruct a message.
  59. // FIXME: every String should be &str
  60. #[derive(Debug)]
  61. pub struct MessageHeader {
  62. pub sender: String,
  63. pub recipients: Vec<String>,
  64. pub body: MessageBody,
  65. }
  66. impl MessageHeader {
  67. /// Generate a concise serialization of the Message.
  68. pub fn serialize(&self) -> SerializedMessage {
  69. // serialized message header: {
  70. // header_len: u32,
  71. // sender: {u32, utf-8}
  72. // num_recipients: u32,
  73. // recipients: [{u32, utf-8}],
  74. // body_type: MessageBody (i.e., u32)
  75. // }
  76. let num_recipients = self.recipients.len();
  77. let body_type = match self.body {
  78. MessageBody::Receipt => 0,
  79. MessageBody::Size(s) => s.get(),
  80. };
  81. let header_len = (1 + 1 + 1 + num_recipients + 1) * size_of::<u32>()
  82. + self.sender.len()
  83. + self.recipients.iter().map(String::len).sum::<usize>();
  84. let mut header: Vec<u8> = Vec::with_capacity(header_len);
  85. let header_len = header_len as u32;
  86. let num_recipients = num_recipients as u32;
  87. header.extend(header_len.to_be_bytes());
  88. serialize_str_to(&self.sender, &mut header);
  89. header.extend(num_recipients.to_be_bytes());
  90. for recipient in self.recipients.iter() {
  91. serialize_str_to(recipient, &mut header);
  92. }
  93. header.extend(body_type.to_be_bytes());
  94. assert!(header.len() == header_len as usize);
  95. SerializedMessage {
  96. header,
  97. body: self.body,
  98. }
  99. }
  100. /// Creates a MessageHeader from bytes created via serialization,
  101. /// but with the size already parsed out.
  102. fn deserialize(buf: &[u8]) -> Result<Self, Error> {
  103. let (sender, buf) = deserialize_str(buf)?;
  104. let sender = sender.to_string();
  105. let (num_recipients, buf) = deserialize_u32(buf)?;
  106. debug_assert!(num_recipients != 0);
  107. let mut recipients = Vec::with_capacity(num_recipients as usize);
  108. let mut recipient;
  109. let mut buf = buf;
  110. for _ in 0..num_recipients {
  111. (recipient, buf) = deserialize_str(buf)?;
  112. let recipient = recipient.to_string();
  113. recipients.push(recipient);
  114. }
  115. let (body, _) = deserialize_u32(buf)?;
  116. let body = if let Some(size) = NonZeroU32::new(body) {
  117. MessageBody::Size(size)
  118. } else {
  119. MessageBody::Receipt
  120. };
  121. Ok(Self {
  122. sender,
  123. recipients,
  124. body,
  125. })
  126. }
  127. }
  128. /// Parse the identifier from the start of the TcpStream.
  129. pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(stream: &mut T) -> Result<String, Error> {
  130. // this should maybe be buffered
  131. let strlen = stream.read_u32().await?;
  132. let mut buf = vec![0u8; strlen as usize];
  133. stream.read_exact(&mut buf).await?;
  134. let s = std::str::from_utf8(&buf)?;
  135. Ok(s.to_string())
  136. }
  137. /// Gets a message from the stream and constructs a MessageHeader object, also returning the raw byte buffer
  138. pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
  139. stream: &mut T,
  140. ) -> Result<(MessageHeader, Vec<u8>), Error> {
  141. let mut header_size_bytes = [0u8; 4];
  142. stream.read_exact(&mut header_size_bytes).await?;
  143. log!("got header size");
  144. get_message_with_header_size(stream, header_size_bytes).await
  145. }
  146. pub async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
  147. stream: &mut T,
  148. header_size_bytes: [u8; 4],
  149. ) -> Result<(MessageHeader, Vec<u8>), Error> {
  150. let header_size = u32::from_be_bytes(header_size_bytes);
  151. let mut header_buf = vec![0; header_size as usize];
  152. stream.read_exact(&mut header_buf[4..]).await?;
  153. let header = MessageHeader::deserialize(&header_buf[4..])?;
  154. log!(
  155. "got header from {} to {:?}, about to read {} bytes",
  156. header.sender,
  157. header.recipients,
  158. header.body.size()
  159. );
  160. let header_size_buf = &mut header_buf[..4];
  161. header_size_buf.copy_from_slice(&header_size_bytes);
  162. copy(&mut stream.take(header.body.size() as u64), &mut sink()).await?;
  163. Ok((header, header_buf))
  164. }
  165. pub fn serialize_str(s: &str) -> Vec<u8> {
  166. let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
  167. serialize_str_to(s, &mut buf);
  168. buf
  169. }
  170. pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
  171. let strlen = s.len() as u32;
  172. buf.extend(strlen.to_be_bytes());
  173. buf.extend(s.as_bytes());
  174. }
  175. fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
  176. let bytes = buf.get(0..4).ok_or_else(|| {
  177. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  178. })?;
  179. Ok((u32::from_be_bytes(bytes.try_into().unwrap()), &buf[4..]))
  180. }
  181. fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
  182. let (strlen, buf) = deserialize_u32(buf)?;
  183. let strlen = strlen as usize;
  184. let strbytes = buf.get(..strlen).ok_or(Error::MalformedSerialization(
  185. buf.to_vec(),
  186. std::backtrace::Backtrace::capture(),
  187. ))?;
  188. Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
  189. }
  190. /// A message almost ready for sending.
  191. ///
  192. /// We represent each message in two halves: the header, and the body.
  193. /// This way, the server can parse out the header in its own buf,
  194. /// and just pass that around intact, without keeping a (possibly large)
  195. /// 0-filled body around.
  196. #[derive(Debug)]
  197. pub struct SerializedMessage {
  198. pub header: Vec<u8>,
  199. pub body: MessageBody,
  200. }
  201. impl SerializedMessage {
  202. pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
  203. &self,
  204. writer: &mut T,
  205. ) -> std::io::Result<()> {
  206. let body_buf = vec![0; self.body.size() as usize];
  207. // write_all_vectored is not yet stable x_x
  208. // https://github.com/rust-lang/rust/issues/70436
  209. let mut header: &[u8] = &self.header;
  210. let mut body: &[u8] = &body_buf;
  211. loop {
  212. let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
  213. match writer.write_vectored(&bufs).await {
  214. Ok(written) => {
  215. if written == header.len() + body.len() {
  216. return Ok(());
  217. }
  218. if written >= header.len() {
  219. body = &body[written - header.len()..];
  220. break;
  221. } else if written == 0 {
  222. return Err(std::io::Error::new(
  223. std::io::ErrorKind::WriteZero,
  224. "failed to write any bytes from message with bytes remaining",
  225. ));
  226. } else {
  227. header = &header[written..];
  228. }
  229. }
  230. Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
  231. Err(e) => return Err(e),
  232. }
  233. }
  234. writer.write_all(body).await
  235. }
  236. }