lib.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. use std::mem::size_of;
  2. use std::num::NonZeroU32;
  3. use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
  4. pub mod updater;
  5. /// The padding interval in bytes. All message bodies are a size of some multiple of this.
  6. /// All messages bodies are a minimum of this size.
  7. // from https://github.com/signalapp/libsignal/blob/af7bb8567c812aa13625fc90076bf71a59d64ff5/rust/protocol/src/crypto.rs#L92C41-L92C41
  8. pub const PADDING_BLOCK_SIZE: u32 = 10 * 128 / 8;
  9. /// The most blocks a message body can contain.
  10. // from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
  11. pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
  12. /// The maxmimum number of bytes that can be sent inline; larger values use the HTTP server.
  13. // In actuality, this is 2000 for Signal:
  14. // https://github.com/signalapp/Signal-Android/blob/244902ecfc30e21287a35bb1680e2dbe6366975b/app/src/main/java/org/thoughtcrime/securesms/util/PushCharacterCalculator.java#L23
  15. // but we align to a close block count since in practice we sample from block counts
  16. pub const INLINE_MAX_SIZE: u32 = 14 * PADDING_BLOCK_SIZE;
  17. #[macro_export]
  18. macro_rules! log {
  19. ( $( $x:expr ),* ) => {
  20. println!("{}{}",
  21. chrono::offset::Utc::now().format("%F %T,%s.%f,"),
  22. format_args!($( $x ),*)
  23. );
  24. }
  25. }
  26. #[derive(Debug)]
  27. pub enum Error {
  28. Io(std::io::Error),
  29. Utf8Error(std::str::Utf8Error),
  30. MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
  31. }
  32. impl std::fmt::Display for Error {
  33. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  34. write!(f, "{:?}", self)
  35. }
  36. }
  37. impl std::error::Error for Error {}
  38. impl From<std::io::Error> for Error {
  39. fn from(e: std::io::Error) -> Self {
  40. Self::Io(e)
  41. }
  42. }
  43. impl From<std::str::Utf8Error> for Error {
  44. fn from(e: std::str::Utf8Error) -> Self {
  45. Self::Utf8Error(e)
  46. }
  47. }
  48. /// Metadata for the body of the message.
  49. ///
  50. /// Message contents are always 0-filled buffers, so never represented.
  51. #[derive(Copy, Clone, Debug, PartialEq)]
  52. pub enum MessageBody {
  53. Receipt,
  54. Size(NonZeroU32),
  55. }
  56. impl MessageBody {
  57. /// Whether the body of the message requires an HTTP GET
  58. /// (attachment size is the message size).
  59. pub fn has_attachment(&self) -> bool {
  60. match self {
  61. MessageBody::Receipt => false,
  62. MessageBody::Size(size) => size > &NonZeroU32::new(INLINE_MAX_SIZE).unwrap(),
  63. }
  64. }
  65. /// Size on the wire of the message's body, exluding bytes fetched via http
  66. fn inline_size<const P2P: bool>(&self) -> usize {
  67. match self {
  68. MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
  69. MessageBody::Size(size) => {
  70. let size = size.get();
  71. if P2P || size <= INLINE_MAX_SIZE {
  72. size as usize
  73. } else {
  74. INLINE_MAX_SIZE as usize
  75. }
  76. }
  77. }
  78. }
  79. /// Size of the message's body, including bytes fetched via http
  80. pub fn total_size(&self) -> usize {
  81. match self {
  82. MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
  83. MessageBody::Size(size) => size.get() as usize,
  84. }
  85. }
  86. }
  87. /// Message metadata.
  88. ///
  89. /// This has everything needed to reconstruct a message.
  90. // FIXME: we should try to replace MessageHeader with MessageHeaderRef
  91. #[derive(Debug, PartialEq)]
  92. pub struct MessageHeader {
  93. /// User who constructed the message.
  94. pub sender: String,
  95. /// Group associated with the message.
  96. /// In client-server mode receipts, this is the recipient instead.
  97. pub group: String,
  98. /// ID unique to a message and its receipt for a (sender, group) pair.
  99. pub id: u32,
  100. /// The type and size of the message payload.
  101. pub body: MessageBody,
  102. }
  103. impl MessageHeader {
  104. /// Generate a concise serialization of the Message.
  105. pub fn serialize(&self) -> SerializedMessage {
  106. // serialized message header: {
  107. // header_len: u32,
  108. // sender: {u32, utf-8},
  109. // group: {u32, utf-8},
  110. // id: u32,
  111. // body_type: MessageBody (i.e., u32)
  112. // }
  113. let body_type = match self.body {
  114. MessageBody::Receipt => 0,
  115. MessageBody::Size(s) => s.get(),
  116. };
  117. let header_len =
  118. (1 + 1 + 1 + 1 + 1) * size_of::<u32>() + self.sender.len() + self.group.len();
  119. let mut header: Vec<u8> = Vec::with_capacity(header_len);
  120. let header_len = header_len as u32;
  121. header.extend(header_len.to_be_bytes());
  122. serialize_str_to(&self.sender, &mut header);
  123. serialize_str_to(&self.group, &mut header);
  124. header.extend(self.id.to_be_bytes());
  125. header.extend(body_type.to_be_bytes());
  126. assert!(header.len() == header_len as usize);
  127. SerializedMessage {
  128. header,
  129. body: self.body,
  130. }
  131. }
  132. /// Creates a MessageHeader from bytes created via serialization,
  133. /// but with the size already parsed out.
  134. fn deserialize(buf: &[u8]) -> Result<Self, Error> {
  135. let (sender, buf) = deserialize_str(buf)?;
  136. let sender = sender.to_string();
  137. let (group, buf) = deserialize_str(buf)?;
  138. let group = group.to_string();
  139. let (id, buf) = deserialize_u32(buf)?;
  140. let (body, _) = deserialize_u32(buf)?;
  141. let body = if let Some(size) = NonZeroU32::new(body) {
  142. MessageBody::Size(size)
  143. } else {
  144. MessageBody::Receipt
  145. };
  146. Ok(Self {
  147. sender,
  148. group,
  149. id,
  150. body,
  151. })
  152. }
  153. }
  154. /// Message metadata.
  155. ///
  156. /// This has everything needed to reconstruct a message.
  157. #[derive(Debug)]
  158. pub struct MessageHeaderRef<'a> {
  159. pub sender: &'a str,
  160. pub group: &'a str,
  161. pub id: u32,
  162. pub body: MessageBody,
  163. }
  164. impl<'a> MessageHeaderRef<'a> {
  165. /// Creates a MessageHeader from bytes created via serialization,
  166. /// but with the size already parsed out.
  167. pub fn deserialize(buf: &'a [u8]) -> Result<Self, Error> {
  168. let (sender, buf) = deserialize_str(buf)?;
  169. let (group, buf) = deserialize_str(buf)?;
  170. let (id, buf) = deserialize_u32(buf)?;
  171. let (body, _) = deserialize_u32(buf)?;
  172. let body = if let Some(size) = NonZeroU32::new(body) {
  173. MessageBody::Size(size)
  174. } else {
  175. MessageBody::Receipt
  176. };
  177. Ok(Self {
  178. sender,
  179. group,
  180. id,
  181. body,
  182. })
  183. }
  184. }
  185. /// Parse the identifier from the start of the TcpStream.
  186. pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
  187. stream: &mut T,
  188. ) -> Result<String, Error> {
  189. // this should maybe be buffered
  190. let strlen = stream.read_u32().await?;
  191. let mut buf = vec![0u8; strlen as usize];
  192. stream.read_exact(&mut buf).await?;
  193. let s = std::str::from_utf8(&buf)?;
  194. Ok(s.to_string())
  195. }
  196. /// Gets a message from the stream, returning the raw byte buffer
  197. pub async fn get_message_bytes<const P2P: bool, T: AsyncReadExt + std::marker::Unpin>(
  198. stream: &mut T,
  199. ) -> Result<Vec<u8>, Error> {
  200. let mut header_size_bytes = [0u8; 4];
  201. stream.read_exact(&mut header_size_bytes).await?;
  202. get_message_with_header_size::<P2P, _>(stream, header_size_bytes).await
  203. }
  204. /// Gets a message from the stream and constructs a MessageHeader object
  205. pub async fn get_message<const P2P: bool, T: AsyncReadExt + std::marker::Unpin>(
  206. stream: &mut T,
  207. ) -> Result<MessageHeader, Error> {
  208. let buf = get_message_bytes::<P2P, _>(stream).await?;
  209. let msg = MessageHeader::deserialize(&buf[4..])?;
  210. Ok(msg)
  211. }
  212. async fn get_message_with_header_size<const P2P: bool, T: AsyncReadExt + std::marker::Unpin>(
  213. stream: &mut T,
  214. header_size_bytes: [u8; 4],
  215. ) -> Result<Vec<u8>, Error> {
  216. let header_size = u32::from_be_bytes(header_size_bytes);
  217. let mut header_buf = vec![0; header_size as usize];
  218. stream.read_exact(&mut header_buf[4..]).await?;
  219. let header = MessageHeader::deserialize(&header_buf[4..])?;
  220. let header_size_buf = &mut header_buf[..4];
  221. header_size_buf.copy_from_slice(&header_size_bytes);
  222. copy(
  223. &mut stream.take(header.body.inline_size::<P2P>() as u64),
  224. &mut sink(),
  225. )
  226. .await?;
  227. Ok(header_buf)
  228. }
  229. pub fn serialize_str(s: &str) -> Vec<u8> {
  230. let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
  231. serialize_str_to(s, &mut buf);
  232. buf
  233. }
  234. pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
  235. let strlen = s.len() as u32;
  236. buf.extend(strlen.to_be_bytes());
  237. buf.extend(s.as_bytes());
  238. }
  239. fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
  240. let bytes = buf.get(0..4).ok_or_else(|| {
  241. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  242. })?;
  243. Ok((u32::from_be_bytes(bytes.try_into().unwrap()), &buf[4..]))
  244. }
  245. fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
  246. let (strlen, buf) = deserialize_u32(buf)?;
  247. let strlen = strlen as usize;
  248. let strbytes = buf.get(..strlen).ok_or_else(|| {
  249. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  250. })?;
  251. Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
  252. }
  253. /// A message almost ready for sending.
  254. ///
  255. /// We represent each message in two halves: the header, and the body.
  256. /// This way, the server can parse out the header in its own buf,
  257. /// and just pass that around intact, without keeping a (possibly large)
  258. /// 0-filled body around.
  259. #[derive(Debug)]
  260. pub struct SerializedMessage {
  261. pub header: Vec<u8>,
  262. pub body: MessageBody,
  263. }
  264. impl SerializedMessage {
  265. pub async fn write_all_to<const P2P: bool, T: AsyncWriteExt + std::marker::Unpin>(
  266. &self,
  267. writer: &mut T,
  268. ) -> std::io::Result<()> {
  269. let body_buf = vec![0; self.body.inline_size::<P2P>()];
  270. // write_all_vectored is not yet stable x_x
  271. // https://github.com/rust-lang/rust/issues/70436
  272. let mut header: &[u8] = &self.header;
  273. let mut body: &[u8] = &body_buf;
  274. loop {
  275. let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
  276. match writer.write_vectored(&bufs).await {
  277. Ok(written) => {
  278. if written == header.len() + body.len() {
  279. return Ok(());
  280. }
  281. if written >= header.len() {
  282. body = &body[written - header.len()..];
  283. break;
  284. } else if written == 0 {
  285. return Err(std::io::Error::new(
  286. std::io::ErrorKind::WriteZero,
  287. "failed to write any bytes from message with bytes remaining",
  288. ));
  289. } else {
  290. header = &header[written..];
  291. }
  292. }
  293. Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
  294. Err(e) => return Err(e),
  295. }
  296. }
  297. writer.write_all(body).await
  298. }
  299. }
  300. /// Handshake between client and server (peers do not use).
  301. #[derive(Eq, Debug, Hash, PartialEq)]
  302. pub struct Handshake {
  303. /// Who is sending this handshake.
  304. pub sender: String,
  305. /// For normal messages, the group the message was sent to.
  306. /// For receipts, the client the receipt is for.
  307. pub group: String,
  308. }
  309. impl Handshake {
  310. /// Generate a serialized handshake message.
  311. pub fn serialize(&self) -> Vec<u8> {
  312. serialize_handshake(&self.sender, &self.group)
  313. }
  314. }
  315. /// Gets a handshake from the stream and constructs a Handshake object
  316. pub async fn get_handshake<T: AsyncReadExt + std::marker::Unpin>(
  317. stream: &mut T,
  318. ) -> Result<Handshake, Error> {
  319. let sender = parse_identifier(stream).await?;
  320. let group = parse_identifier(stream).await?;
  321. Ok(Handshake { sender, group })
  322. }
  323. /// A reference to a Handshake's fields.
  324. pub struct HandshakeRef<'a> {
  325. pub sender: &'a str,
  326. pub group: &'a str,
  327. }
  328. impl HandshakeRef<'_> {
  329. /// Generate a serialized handshake message.
  330. pub fn serialize(&self) -> Vec<u8> {
  331. serialize_handshake(self.sender, self.group)
  332. }
  333. }
  334. fn serialize_handshake(sender: &str, group: &str) -> Vec<u8> {
  335. // serialized handshake: {
  336. // sender: {u32, utf-8}
  337. // group: {u32, utf-8}
  338. // }
  339. let handshake_len = (1 + 1) * size_of::<u32>() + sender.len() + group.len();
  340. let mut handshake: Vec<u8> = Vec::with_capacity(handshake_len);
  341. serialize_str_to(sender, &mut handshake);
  342. serialize_str_to(group, &mut handshake);
  343. debug_assert!(handshake.len() == handshake_len);
  344. handshake
  345. }
  346. #[cfg(test)]
  347. mod tests {
  348. use super::*;
  349. use tokio::fs::{File, OpenOptions};
  350. /// creates a temporary file for writing
  351. async fn generate_tmp_file(name: &str) -> File {
  352. let filename = format!("mgen-test-{}", name);
  353. let mut path = std::env::temp_dir();
  354. path.push(filename);
  355. OpenOptions::new()
  356. .read(true)
  357. .write(true)
  358. .create(true)
  359. .open(path)
  360. .await
  361. .unwrap()
  362. }
  363. /// get an existing temp file for reading
  364. async fn get_tmp_file(name: &str) -> File {
  365. let filename = format!("mgen-test-{}", name);
  366. let mut path = std::env::temp_dir();
  367. path.push(filename);
  368. OpenOptions::new().read(true).open(path).await.unwrap()
  369. }
  370. #[test]
  371. fn serialize_deserialize_message() {
  372. let m1 = MessageHeader {
  373. sender: "Alice".to_string(),
  374. group: "group".to_string(),
  375. id: 1024,
  376. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  377. };
  378. let serialized = m1.serialize();
  379. let m2 = MessageHeader::deserialize(&serialized.header[4..]).unwrap();
  380. assert_eq!(m1, m2);
  381. }
  382. #[test]
  383. fn serialize_deserialize_receipt() {
  384. let m1 = MessageHeader {
  385. sender: "Alice".to_string(),
  386. group: "group".to_string(),
  387. id: 1024,
  388. body: MessageBody::Receipt,
  389. };
  390. let serialized = m1.serialize();
  391. let m2 = MessageHeader::deserialize(&serialized.header[4..]).unwrap();
  392. assert_eq!(m1, m2);
  393. }
  394. #[test]
  395. fn deserialize_message_ref() {
  396. let m1 = MessageHeader {
  397. sender: "Alice".to_string(),
  398. group: "group".to_string(),
  399. id: 1024,
  400. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  401. };
  402. let serialized = m1.serialize();
  403. let m2 = MessageHeaderRef::deserialize(&serialized.header[4..]).unwrap();
  404. assert_eq!(m1.sender, m2.sender);
  405. assert_eq!(m1.group, m2.group);
  406. assert_eq!(m1.body, m2.body);
  407. }
  408. #[test]
  409. fn deserialize_receipt_ref() {
  410. let m1 = MessageHeader {
  411. sender: "Alice".to_string(),
  412. group: "group".to_string(),
  413. id: 1024,
  414. body: MessageBody::Receipt,
  415. };
  416. let serialized = m1.serialize();
  417. let m2 = MessageHeaderRef::deserialize(&serialized.header[4..]).unwrap();
  418. assert_eq!(m1.sender, m2.sender);
  419. assert_eq!(m1.group, m2.group);
  420. assert_eq!(m1.body, m2.body);
  421. }
  422. #[tokio::test]
  423. async fn serialize_get_message() {
  424. let m1 = MessageHeader {
  425. sender: "Alice".to_string(),
  426. group: "group".to_string(),
  427. id: 1024,
  428. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  429. };
  430. let serialized = m1.serialize();
  431. let file_name = "serialize_message_get";
  432. let mut f = generate_tmp_file(file_name).await;
  433. serialized.write_all_to(&mut f).await.unwrap();
  434. let mut f = get_tmp_file(file_name).await;
  435. let m2 = get_message(&mut f).await.unwrap();
  436. assert_eq!(m1, m2);
  437. }
  438. #[tokio::test]
  439. async fn serialize_get_receipt() {
  440. let m1 = MessageHeader {
  441. sender: "Alice".to_string(),
  442. group: "group".to_string(),
  443. id: 1024,
  444. body: MessageBody::Receipt,
  445. };
  446. let serialized = m1.serialize();
  447. let file_name = "serialize_receipt_get";
  448. let mut f = generate_tmp_file(file_name).await;
  449. serialized.write_all_to(&mut f).await.unwrap();
  450. let mut f = get_tmp_file(file_name).await;
  451. let m2 = get_message(&mut f).await.unwrap();
  452. assert_eq!(m1, m2);
  453. }
  454. #[tokio::test]
  455. async fn serialize_get_handshake() {
  456. let h1 = Handshake {
  457. sender: "Alice".to_string(),
  458. group: "group".to_string(),
  459. };
  460. let file_name = "handshake";
  461. let mut f = generate_tmp_file(file_name).await;
  462. f.write_all(&h1.serialize()).await.unwrap();
  463. let mut f = get_tmp_file(file_name).await;
  464. let h2 = get_handshake(&mut f).await.unwrap();
  465. assert_eq!(h1, h2);
  466. }
  467. #[tokio::test]
  468. async fn serialize_get_handshake_ref() {
  469. let h1 = HandshakeRef {
  470. sender: "Alice",
  471. group: "group",
  472. };
  473. let file_name = "handshake-ref";
  474. let mut f = generate_tmp_file(file_name).await;
  475. f.write_all(&h1.serialize()).await.unwrap();
  476. let mut f = get_tmp_file(file_name).await;
  477. let h2 = get_handshake(&mut f).await.unwrap();
  478. assert_eq!(h1.sender, h2.sender);
  479. assert_eq!(h1.group, h2.group);
  480. }
  481. }