lib.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. use std::mem::size_of;
  2. use std::num::NonZeroU32;
  3. use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
  4. pub mod updater;
  5. /// The padding interval in bytes. All message bodies are a size of some multiple of this.
  6. /// All messages bodies are a minimum of this size.
  7. // from https://github.com/signalapp/libsignal/blob/af7bb8567c812aa13625fc90076bf71a59d64ff5/rust/protocol/src/crypto.rs#L92C41-L92C41
  8. pub const PADDING_BLOCK_SIZE: u32 = 10 * 128 / 8;
  9. /// The most blocks a message body can contain.
  10. // from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
  11. pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
  12. /// The maxmimum number of bytes that can be sent inline; larger values use the HTTP server.
  13. // FIXME: should only apply to client-server, not p2p
  14. // In actuality, this is 2000 for Signal:
  15. // https://github.com/signalapp/Signal-Android/blob/244902ecfc30e21287a35bb1680e2dbe6366975b/app/src/main/java/org/thoughtcrime/securesms/util/PushCharacterCalculator.java#L23
  16. // but we align to a close block count since in practice we sample from block counts
  17. pub const INLINE_MAX_SIZE: u32 = 14 * PADDING_BLOCK_SIZE;
  18. #[macro_export]
  19. macro_rules! log {
  20. ( $( $x:expr ),* ) => {
  21. println!("{}{}",
  22. chrono::offset::Utc::now().format("%F %T,%s.%f,"),
  23. format_args!($( $x ),*)
  24. );
  25. }
  26. }
  27. #[derive(Debug)]
  28. pub enum Error {
  29. Io(std::io::Error),
  30. Utf8Error(std::str::Utf8Error),
  31. MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
  32. }
  33. impl std::fmt::Display for Error {
  34. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  35. write!(f, "{:?}", self)
  36. }
  37. }
  38. impl std::error::Error for Error {}
  39. impl From<std::io::Error> for Error {
  40. fn from(e: std::io::Error) -> Self {
  41. Self::Io(e)
  42. }
  43. }
  44. impl From<std::str::Utf8Error> for Error {
  45. fn from(e: std::str::Utf8Error) -> Self {
  46. Self::Utf8Error(e)
  47. }
  48. }
  49. /// Metadata for the body of the message.
  50. ///
  51. /// Message contents are always 0-filled buffers, so never represented.
  52. #[derive(Copy, Clone, Debug, PartialEq)]
  53. pub enum MessageBody {
  54. Receipt,
  55. Size(NonZeroU32),
  56. }
  57. impl MessageBody {
  58. /// Whether the body of the message requires an HTTP GET
  59. /// (attachment size is the message size).
  60. pub fn has_attachment(&self) -> bool {
  61. match self {
  62. MessageBody::Receipt => false,
  63. MessageBody::Size(size) => size > &NonZeroU32::new(INLINE_MAX_SIZE).unwrap(),
  64. }
  65. }
  66. /// Size on the wire of the message's body, exluding bytes fetched via http
  67. fn inline_size(&self) -> usize {
  68. match self {
  69. MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
  70. MessageBody::Size(size) => {
  71. let size = size.get();
  72. if size <= INLINE_MAX_SIZE {
  73. size as usize
  74. } else {
  75. INLINE_MAX_SIZE as usize
  76. }
  77. }
  78. }
  79. }
  80. /// Size of the message's body, including bytes fetched via http
  81. pub fn total_size(&self) -> usize {
  82. match self {
  83. MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
  84. MessageBody::Size(size) => size.get() as usize,
  85. }
  86. }
  87. }
  88. /// Message metadata.
  89. ///
  90. /// This has everything needed to reconstruct a message.
  91. // FIXME: we should try to replace MessageHeader with MessageHeaderRef
  92. #[derive(Debug, PartialEq)]
  93. pub struct MessageHeader {
  94. /// User who constructed the message.
  95. pub sender: String,
  96. /// Group associated with the message.
  97. /// In client-server mode receipts, this is the recipient instead.
  98. pub group: String,
  99. /// ID unique to a message and its receipt for a (sender, group) pair.
  100. pub id: u32,
  101. /// The type and size of the message payload.
  102. pub body: MessageBody,
  103. }
  104. impl MessageHeader {
  105. /// Generate a concise serialization of the Message.
  106. pub fn serialize(&self) -> SerializedMessage {
  107. // serialized message header: {
  108. // header_len: u32,
  109. // sender: {u32, utf-8},
  110. // group: {u32, utf-8},
  111. // id: u32,
  112. // body_type: MessageBody (i.e., u32)
  113. // }
  114. let body_type = match self.body {
  115. MessageBody::Receipt => 0,
  116. MessageBody::Size(s) => s.get(),
  117. };
  118. let header_len =
  119. (1 + 1 + 1 + 1 + 1) * size_of::<u32>() + self.sender.len() + self.group.len();
  120. let mut header: Vec<u8> = Vec::with_capacity(header_len);
  121. let header_len = header_len as u32;
  122. header.extend(header_len.to_be_bytes());
  123. serialize_str_to(&self.sender, &mut header);
  124. serialize_str_to(&self.group, &mut header);
  125. header.extend(self.id.to_be_bytes());
  126. header.extend(body_type.to_be_bytes());
  127. assert!(header.len() == header_len as usize);
  128. SerializedMessage {
  129. header,
  130. body: self.body,
  131. }
  132. }
  133. /// Creates a MessageHeader from bytes created via serialization,
  134. /// but with the size already parsed out.
  135. fn deserialize(buf: &[u8]) -> Result<Self, Error> {
  136. let (sender, buf) = deserialize_str(buf)?;
  137. let sender = sender.to_string();
  138. let (group, buf) = deserialize_str(buf)?;
  139. let group = group.to_string();
  140. let (id, buf) = deserialize_u32(buf)?;
  141. let (body, _) = deserialize_u32(buf)?;
  142. let body = if let Some(size) = NonZeroU32::new(body) {
  143. MessageBody::Size(size)
  144. } else {
  145. MessageBody::Receipt
  146. };
  147. Ok(Self {
  148. sender,
  149. group,
  150. id,
  151. body,
  152. })
  153. }
  154. }
  155. /// Message metadata.
  156. ///
  157. /// This has everything needed to reconstruct a message.
  158. #[derive(Debug)]
  159. pub struct MessageHeaderRef<'a> {
  160. pub sender: &'a str,
  161. pub group: &'a str,
  162. pub id: u32,
  163. pub body: MessageBody,
  164. }
  165. impl<'a> MessageHeaderRef<'a> {
  166. /// Creates a MessageHeader from bytes created via serialization,
  167. /// but with the size already parsed out.
  168. pub fn deserialize(buf: &'a [u8]) -> Result<Self, Error> {
  169. let (sender, buf) = deserialize_str(buf)?;
  170. let (group, buf) = deserialize_str(buf)?;
  171. let (id, buf) = deserialize_u32(buf)?;
  172. let (body, _) = deserialize_u32(buf)?;
  173. let body = if let Some(size) = NonZeroU32::new(body) {
  174. MessageBody::Size(size)
  175. } else {
  176. MessageBody::Receipt
  177. };
  178. Ok(Self {
  179. sender,
  180. group,
  181. id,
  182. body,
  183. })
  184. }
  185. }
  186. /// Parse the identifier from the start of the TcpStream.
  187. pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
  188. stream: &mut T,
  189. ) -> Result<String, Error> {
  190. // this should maybe be buffered
  191. let strlen = stream.read_u32().await?;
  192. let mut buf = vec![0u8; strlen as usize];
  193. stream.read_exact(&mut buf).await?;
  194. let s = std::str::from_utf8(&buf)?;
  195. Ok(s.to_string())
  196. }
  197. /// Gets a message from the stream, returning the raw byte buffer
  198. pub async fn get_message_bytes<T: AsyncReadExt + std::marker::Unpin>(
  199. stream: &mut T,
  200. ) -> Result<Vec<u8>, Error> {
  201. let mut header_size_bytes = [0u8; 4];
  202. stream.read_exact(&mut header_size_bytes).await?;
  203. get_message_with_header_size(stream, header_size_bytes).await
  204. }
  205. /// Gets a message from the stream and constructs a MessageHeader object
  206. pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
  207. stream: &mut T,
  208. ) -> Result<MessageHeader, Error> {
  209. let buf = get_message_bytes(stream).await?;
  210. let msg = MessageHeader::deserialize(&buf[4..])?;
  211. Ok(msg)
  212. }
  213. async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
  214. stream: &mut T,
  215. header_size_bytes: [u8; 4],
  216. ) -> Result<Vec<u8>, Error> {
  217. let header_size = u32::from_be_bytes(header_size_bytes);
  218. let mut header_buf = vec![0; header_size as usize];
  219. stream.read_exact(&mut header_buf[4..]).await?;
  220. let header = MessageHeader::deserialize(&header_buf[4..])?;
  221. let header_size_buf = &mut header_buf[..4];
  222. header_size_buf.copy_from_slice(&header_size_bytes);
  223. copy(
  224. &mut stream.take(header.body.inline_size() as u64),
  225. &mut sink(),
  226. )
  227. .await?;
  228. Ok(header_buf)
  229. }
  230. pub fn serialize_str(s: &str) -> Vec<u8> {
  231. let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
  232. serialize_str_to(s, &mut buf);
  233. buf
  234. }
  235. pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
  236. let strlen = s.len() as u32;
  237. buf.extend(strlen.to_be_bytes());
  238. buf.extend(s.as_bytes());
  239. }
  240. fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
  241. let bytes = buf.get(0..4).ok_or_else(|| {
  242. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  243. })?;
  244. Ok((u32::from_be_bytes(bytes.try_into().unwrap()), &buf[4..]))
  245. }
  246. fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
  247. let (strlen, buf) = deserialize_u32(buf)?;
  248. let strlen = strlen as usize;
  249. let strbytes = buf.get(..strlen).ok_or_else(|| {
  250. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  251. })?;
  252. Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
  253. }
  254. /// A message almost ready for sending.
  255. ///
  256. /// We represent each message in two halves: the header, and the body.
  257. /// This way, the server can parse out the header in its own buf,
  258. /// and just pass that around intact, without keeping a (possibly large)
  259. /// 0-filled body around.
  260. #[derive(Debug)]
  261. pub struct SerializedMessage {
  262. pub header: Vec<u8>,
  263. pub body: MessageBody,
  264. }
  265. impl SerializedMessage {
  266. pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
  267. &self,
  268. writer: &mut T,
  269. ) -> std::io::Result<()> {
  270. let body_buf = vec![0; self.body.inline_size()];
  271. // write_all_vectored is not yet stable x_x
  272. // https://github.com/rust-lang/rust/issues/70436
  273. let mut header: &[u8] = &self.header;
  274. let mut body: &[u8] = &body_buf;
  275. loop {
  276. let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
  277. match writer.write_vectored(&bufs).await {
  278. Ok(written) => {
  279. if written == header.len() + body.len() {
  280. return Ok(());
  281. }
  282. if written >= header.len() {
  283. body = &body[written - header.len()..];
  284. break;
  285. } else if written == 0 {
  286. return Err(std::io::Error::new(
  287. std::io::ErrorKind::WriteZero,
  288. "failed to write any bytes from message with bytes remaining",
  289. ));
  290. } else {
  291. header = &header[written..];
  292. }
  293. }
  294. Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
  295. Err(e) => return Err(e),
  296. }
  297. }
  298. writer.write_all(body).await
  299. }
  300. }
  301. /// Handshake between client and server (peers do not use).
  302. #[derive(Eq, Debug, Hash, PartialEq)]
  303. pub struct Handshake {
  304. /// Who is sending this handshake.
  305. pub sender: String,
  306. /// For normal messages, the group the message was sent to.
  307. /// For receipts, the client the receipt is for.
  308. pub group: String,
  309. }
  310. impl Handshake {
  311. /// Generate a serialized handshake message.
  312. pub fn serialize(&self) -> Vec<u8> {
  313. serialize_handshake(&self.sender, &self.group)
  314. }
  315. }
  316. /// Gets a handshake from the stream and constructs a Handshake object
  317. pub async fn get_handshake<T: AsyncReadExt + std::marker::Unpin>(
  318. stream: &mut T,
  319. ) -> Result<Handshake, Error> {
  320. let sender = parse_identifier(stream).await?;
  321. let group = parse_identifier(stream).await?;
  322. Ok(Handshake { sender, group })
  323. }
  324. /// A reference to a Handshake's fields.
  325. pub struct HandshakeRef<'a> {
  326. pub sender: &'a str,
  327. pub group: &'a str,
  328. }
  329. impl HandshakeRef<'_> {
  330. /// Generate a serialized handshake message.
  331. pub fn serialize(&self) -> Vec<u8> {
  332. serialize_handshake(self.sender, self.group)
  333. }
  334. }
  335. fn serialize_handshake(sender: &str, group: &str) -> Vec<u8> {
  336. // serialized handshake: {
  337. // sender: {u32, utf-8}
  338. // group: {u32, utf-8}
  339. // }
  340. let handshake_len = (1 + 1) * size_of::<u32>() + sender.len() + group.len();
  341. let mut handshake: Vec<u8> = Vec::with_capacity(handshake_len);
  342. serialize_str_to(sender, &mut handshake);
  343. serialize_str_to(group, &mut handshake);
  344. debug_assert!(handshake.len() == handshake_len);
  345. handshake
  346. }
  347. #[cfg(test)]
  348. mod tests {
  349. use super::*;
  350. use tokio::fs::{File, OpenOptions};
  351. /// creates a temporary file for writing
  352. async fn generate_tmp_file(name: &str) -> File {
  353. let filename = format!("mgen-test-{}", name);
  354. let mut path = std::env::temp_dir();
  355. path.push(filename);
  356. OpenOptions::new()
  357. .read(true)
  358. .write(true)
  359. .create(true)
  360. .open(path)
  361. .await
  362. .unwrap()
  363. }
  364. /// get an existing temp file for reading
  365. async fn get_tmp_file(name: &str) -> File {
  366. let filename = format!("mgen-test-{}", name);
  367. let mut path = std::env::temp_dir();
  368. path.push(filename);
  369. OpenOptions::new().read(true).open(path).await.unwrap()
  370. }
  371. #[test]
  372. fn serialize_deserialize_message() {
  373. let m1 = MessageHeader {
  374. sender: "Alice".to_string(),
  375. group: "group".to_string(),
  376. id: 1024,
  377. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  378. };
  379. let serialized = m1.serialize();
  380. let m2 = MessageHeader::deserialize(&serialized.header[4..]).unwrap();
  381. assert_eq!(m1, m2);
  382. }
  383. #[test]
  384. fn serialize_deserialize_receipt() {
  385. let m1 = MessageHeader {
  386. sender: "Alice".to_string(),
  387. group: "group".to_string(),
  388. id: 1024,
  389. body: MessageBody::Receipt,
  390. };
  391. let serialized = m1.serialize();
  392. let m2 = MessageHeader::deserialize(&serialized.header[4..]).unwrap();
  393. assert_eq!(m1, m2);
  394. }
  395. #[test]
  396. fn deserialize_message_ref() {
  397. let m1 = MessageHeader {
  398. sender: "Alice".to_string(),
  399. group: "group".to_string(),
  400. id: 1024,
  401. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  402. };
  403. let serialized = m1.serialize();
  404. let m2 = MessageHeaderRef::deserialize(&serialized.header[4..]).unwrap();
  405. assert_eq!(m1.sender, m2.sender);
  406. assert_eq!(m1.group, m2.group);
  407. assert_eq!(m1.body, m2.body);
  408. }
  409. #[test]
  410. fn deserialize_receipt_ref() {
  411. let m1 = MessageHeader {
  412. sender: "Alice".to_string(),
  413. group: "group".to_string(),
  414. id: 1024,
  415. body: MessageBody::Receipt,
  416. };
  417. let serialized = m1.serialize();
  418. let m2 = MessageHeaderRef::deserialize(&serialized.header[4..]).unwrap();
  419. assert_eq!(m1.sender, m2.sender);
  420. assert_eq!(m1.group, m2.group);
  421. assert_eq!(m1.body, m2.body);
  422. }
  423. #[tokio::test]
  424. async fn serialize_get_message() {
  425. let m1 = MessageHeader {
  426. sender: "Alice".to_string(),
  427. group: "group".to_string(),
  428. id: 1024,
  429. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  430. };
  431. let serialized = m1.serialize();
  432. let file_name = "serialize_message_get";
  433. let mut f = generate_tmp_file(file_name).await;
  434. serialized.write_all_to(&mut f).await.unwrap();
  435. let mut f = get_tmp_file(file_name).await;
  436. let m2 = get_message(&mut f).await.unwrap();
  437. assert_eq!(m1, m2);
  438. }
  439. #[tokio::test]
  440. async fn serialize_get_receipt() {
  441. let m1 = MessageHeader {
  442. sender: "Alice".to_string(),
  443. group: "group".to_string(),
  444. id: 1024,
  445. body: MessageBody::Receipt,
  446. };
  447. let serialized = m1.serialize();
  448. let file_name = "serialize_receipt_get";
  449. let mut f = generate_tmp_file(file_name).await;
  450. serialized.write_all_to(&mut f).await.unwrap();
  451. let mut f = get_tmp_file(file_name).await;
  452. let m2 = get_message(&mut f).await.unwrap();
  453. assert_eq!(m1, m2);
  454. }
  455. #[tokio::test]
  456. async fn serialize_get_handshake() {
  457. let h1 = Handshake {
  458. sender: "Alice".to_string(),
  459. group: "group".to_string(),
  460. };
  461. let file_name = "handshake";
  462. let mut f = generate_tmp_file(file_name).await;
  463. f.write_all(&h1.serialize()).await.unwrap();
  464. let mut f = get_tmp_file(file_name).await;
  465. let h2 = get_handshake(&mut f).await.unwrap();
  466. assert_eq!(h1, h2);
  467. }
  468. #[tokio::test]
  469. async fn serialize_get_handshake_ref() {
  470. let h1 = HandshakeRef {
  471. sender: "Alice",
  472. group: "group",
  473. };
  474. let file_name = "handshake-ref";
  475. let mut f = generate_tmp_file(file_name).await;
  476. f.write_all(&h1.serialize()).await.unwrap();
  477. let mut f = get_tmp_file(file_name).await;
  478. let h2 = get_handshake(&mut f).await.unwrap();
  479. assert_eq!(h1.sender, h2.sender);
  480. assert_eq!(h1.group, h2.group);
  481. }
  482. }