lib.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. use std::mem::size_of;
  2. use std::num::NonZeroU32;
  3. use tokio::io::{copy, sink, AsyncReadExt, AsyncWriteExt};
  4. pub mod updater;
  5. /// The padding interval in bytes. All message bodies are a size of some multiple of this.
  6. /// All messages bodies are a minimum of this size.
  7. // from https://github.com/signalapp/libsignal/blob/af7bb8567c812aa13625fc90076bf71a59d64ff5/rust/protocol/src/crypto.rs#L92C41-L92C41
  8. pub const PADDING_BLOCK_SIZE: u32 = 10 * 128 / 8;
  9. /// The most blocks a message body can contain.
  10. // from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
  11. pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
  12. /// The maxmimum number of bytes that can be sent inline; larger values use the HTTP server.
  13. // FIXME: should only apply to client-server, not p2p
  14. // In actuality, this is 2000 for Signal:
  15. // https://github.com/signalapp/Signal-Android/blob/244902ecfc30e21287a35bb1680e2dbe6366975b/app/src/main/java/org/thoughtcrime/securesms/util/PushCharacterCalculator.java#L23
  16. // but we align to a close block count since in practice we sample from block counts
  17. pub const INLINE_MAX_SIZE: u32 = 14 * PADDING_BLOCK_SIZE;
  18. #[macro_export]
  19. macro_rules! log {
  20. ( $( $x:expr ),* ) => {
  21. println!("{}{}",
  22. chrono::offset::Utc::now().format("%F %T,%s.%f,"),
  23. format_args!($( $x ),*)
  24. );
  25. }
  26. }
  27. #[derive(Debug)]
  28. pub enum Error {
  29. Io(std::io::Error),
  30. Utf8Error(std::str::Utf8Error),
  31. MalformedSerialization(Vec<u8>, std::backtrace::Backtrace),
  32. }
  33. impl std::fmt::Display for Error {
  34. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  35. write!(f, "{:?}", self)
  36. }
  37. }
  38. impl std::error::Error for Error {}
  39. impl From<std::io::Error> for Error {
  40. fn from(e: std::io::Error) -> Self {
  41. Self::Io(e)
  42. }
  43. }
  44. impl From<std::str::Utf8Error> for Error {
  45. fn from(e: std::str::Utf8Error) -> Self {
  46. Self::Utf8Error(e)
  47. }
  48. }
  49. /// Metadata for the body of the message.
  50. ///
  51. /// Message contents are always 0-filled buffers, so never represented.
  52. #[derive(Copy, Clone, Debug, PartialEq)]
  53. pub enum MessageBody {
  54. Receipt,
  55. Size(NonZeroU32),
  56. }
  57. impl MessageBody {
  58. /// Whether the body of the message requires an HTTP GET
  59. /// (attachment size is the message size).
  60. pub fn has_attachment(&self) -> bool {
  61. match self {
  62. MessageBody::Receipt => false,
  63. MessageBody::Size(size) => size > &NonZeroU32::new(INLINE_MAX_SIZE).unwrap(),
  64. }
  65. }
  66. /// Size on the wire of the message's body, exluding bytes fetched via http
  67. fn inline_size(&self) -> usize {
  68. match self {
  69. MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
  70. MessageBody::Size(size) => {
  71. let size = size.get();
  72. if size <= INLINE_MAX_SIZE {
  73. size as usize
  74. } else {
  75. INLINE_MAX_SIZE as usize
  76. }
  77. }
  78. }
  79. }
  80. /// Size of the message's body, including bytes fetched via http
  81. pub fn total_size(&self) -> usize {
  82. match self {
  83. MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
  84. MessageBody::Size(size) => size.get() as usize,
  85. }
  86. }
  87. }
  88. /// Message metadata.
  89. ///
  90. /// This has everything needed to reconstruct a message.
  91. // FIXME: we should try to replace MessageHeader with MessageHeaderRef
  92. #[derive(Debug, PartialEq)]
  93. pub struct MessageHeader {
  94. /// User who constructed the message.
  95. pub sender: String,
  96. /// Group associated with the message.
  97. /// In client-server mode receipts, this is the recipient instead.
  98. pub group: String,
  99. /// ID unique to a message and its receipt for a (sender, group) pair.
  100. pub id: u32,
  101. /// The type and size of the message payload.
  102. pub body: MessageBody,
  103. }
  104. impl MessageHeader {
  105. /// Generate a concise serialization of the Message.
  106. pub fn serialize(&self) -> SerializedMessage {
  107. // serialized message header: {
  108. // header_len: u32,
  109. // sender: {u32, utf-8},
  110. // group: {u32, utf-8},
  111. // id: u32,
  112. // body_type: MessageBody (i.e., u32)
  113. // }
  114. let body_type = match self.body {
  115. MessageBody::Receipt => 0,
  116. MessageBody::Size(s) => s.get(),
  117. };
  118. let header_len =
  119. (1 + 1 + 1 + 1 + 1) * size_of::<u32>() + self.sender.len() + self.group.len();
  120. let mut header: Vec<u8> = Vec::with_capacity(header_len);
  121. let header_len = header_len as u32;
  122. header.extend(header_len.to_be_bytes());
  123. serialize_str_to(&self.sender, &mut header);
  124. serialize_str_to(&self.group, &mut header);
  125. header.extend(self.id.to_be_bytes());
  126. header.extend(body_type.to_be_bytes());
  127. assert!(header.len() == header_len as usize);
  128. SerializedMessage {
  129. header,
  130. body: self.body,
  131. }
  132. }
  133. /// Creates a MessageHeader from bytes created via serialization,
  134. /// but with the size already parsed out.
  135. fn deserialize(buf: &[u8]) -> Result<Self, Error> {
  136. let (sender, buf) = deserialize_str(buf)?;
  137. let sender = sender.to_string();
  138. let (group, buf) = deserialize_str(buf)?;
  139. let group = group.to_string();
  140. let (id, buf) = deserialize_u32(buf)?;
  141. let (body, _) = deserialize_u32(buf)?;
  142. let body = if let Some(size) = NonZeroU32::new(body) {
  143. MessageBody::Size(size)
  144. } else {
  145. MessageBody::Receipt
  146. };
  147. Ok(Self {
  148. sender,
  149. group,
  150. id,
  151. body,
  152. })
  153. }
  154. }
  155. /// Message metadata.
  156. ///
  157. /// This has everything needed to reconstruct a message.
  158. #[derive(Debug)]
  159. pub struct MessageHeaderRef<'a> {
  160. pub sender: &'a str,
  161. pub group: &'a str,
  162. pub id: u32,
  163. pub body: MessageBody,
  164. }
  165. impl<'a> MessageHeaderRef<'a> {
  166. /// Creates a MessageHeader from bytes created via serialization,
  167. /// but with the size already parsed out.
  168. pub fn deserialize(buf: &'a [u8]) -> Result<Self, Error> {
  169. let (sender, buf) = deserialize_str(buf)?;
  170. let sender = sender;
  171. let (group, buf) = deserialize_str(buf)?;
  172. let group = group;
  173. let (id, buf) = deserialize_u32(buf)?;
  174. let (body, _) = deserialize_u32(buf)?;
  175. let body = if let Some(size) = NonZeroU32::new(body) {
  176. MessageBody::Size(size)
  177. } else {
  178. MessageBody::Receipt
  179. };
  180. Ok(Self {
  181. sender,
  182. group,
  183. id,
  184. body,
  185. })
  186. }
  187. }
  188. /// Parse the identifier from the start of the TcpStream.
  189. pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
  190. stream: &mut T,
  191. ) -> Result<String, Error> {
  192. // this should maybe be buffered
  193. let strlen = stream.read_u32().await?;
  194. let mut buf = vec![0u8; strlen as usize];
  195. stream.read_exact(&mut buf).await?;
  196. let s = std::str::from_utf8(&buf)?;
  197. Ok(s.to_string())
  198. }
  199. /// Gets a message from the stream, returning the raw byte buffer
  200. pub async fn get_message_bytes<T: AsyncReadExt + std::marker::Unpin>(
  201. stream: &mut T,
  202. ) -> Result<Vec<u8>, Error> {
  203. let mut header_size_bytes = [0u8; 4];
  204. stream.read_exact(&mut header_size_bytes).await?;
  205. get_message_with_header_size(stream, header_size_bytes).await
  206. }
  207. /// Gets a message from the stream and constructs a MessageHeader object
  208. pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
  209. stream: &mut T,
  210. ) -> Result<MessageHeader, Error> {
  211. let buf = get_message_bytes(stream).await?;
  212. let msg = MessageHeader::deserialize(&buf[4..])?;
  213. Ok(msg)
  214. }
  215. async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
  216. stream: &mut T,
  217. header_size_bytes: [u8; 4],
  218. ) -> Result<Vec<u8>, Error> {
  219. let header_size = u32::from_be_bytes(header_size_bytes);
  220. let mut header_buf = vec![0; header_size as usize];
  221. stream.read_exact(&mut header_buf[4..]).await?;
  222. let header = MessageHeader::deserialize(&header_buf[4..])?;
  223. let header_size_buf = &mut header_buf[..4];
  224. header_size_buf.copy_from_slice(&header_size_bytes);
  225. copy(
  226. &mut stream.take(header.body.inline_size() as u64),
  227. &mut sink(),
  228. )
  229. .await?;
  230. Ok(header_buf)
  231. }
  232. pub fn serialize_str(s: &str) -> Vec<u8> {
  233. let mut buf = Vec::with_capacity(s.len() + size_of::<u32>());
  234. serialize_str_to(s, &mut buf);
  235. buf
  236. }
  237. pub fn serialize_str_to(s: &str, buf: &mut Vec<u8>) {
  238. let strlen = s.len() as u32;
  239. buf.extend(strlen.to_be_bytes());
  240. buf.extend(s.as_bytes());
  241. }
  242. fn deserialize_u32(buf: &[u8]) -> Result<(u32, &[u8]), Error> {
  243. let bytes = buf.get(0..4).ok_or_else(|| {
  244. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  245. })?;
  246. Ok((u32::from_be_bytes(bytes.try_into().unwrap()), &buf[4..]))
  247. }
  248. fn deserialize_str(buf: &[u8]) -> Result<(&str, &[u8]), Error> {
  249. let (strlen, buf) = deserialize_u32(buf)?;
  250. let strlen = strlen as usize;
  251. let strbytes = buf.get(..strlen).ok_or_else(|| {
  252. Error::MalformedSerialization(buf.to_vec(), std::backtrace::Backtrace::capture())
  253. })?;
  254. Ok((std::str::from_utf8(strbytes)?, &buf[strlen..]))
  255. }
  256. /// A message almost ready for sending.
  257. ///
  258. /// We represent each message in two halves: the header, and the body.
  259. /// This way, the server can parse out the header in its own buf,
  260. /// and just pass that around intact, without keeping a (possibly large)
  261. /// 0-filled body around.
  262. #[derive(Debug)]
  263. pub struct SerializedMessage {
  264. pub header: Vec<u8>,
  265. pub body: MessageBody,
  266. }
  267. impl SerializedMessage {
  268. pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
  269. &self,
  270. writer: &mut T,
  271. ) -> std::io::Result<()> {
  272. let body_buf = vec![0; self.body.inline_size()];
  273. // write_all_vectored is not yet stable x_x
  274. // https://github.com/rust-lang/rust/issues/70436
  275. let mut header: &[u8] = &self.header;
  276. let mut body: &[u8] = &body_buf;
  277. loop {
  278. let bufs = [std::io::IoSlice::new(header), std::io::IoSlice::new(body)];
  279. match writer.write_vectored(&bufs).await {
  280. Ok(written) => {
  281. if written == header.len() + body.len() {
  282. return Ok(());
  283. }
  284. if written >= header.len() {
  285. body = &body[written - header.len()..];
  286. break;
  287. } else if written == 0 {
  288. return Err(std::io::Error::new(
  289. std::io::ErrorKind::WriteZero,
  290. "failed to write any bytes from message with bytes remaining",
  291. ));
  292. } else {
  293. header = &header[written..];
  294. }
  295. }
  296. Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
  297. Err(e) => return Err(e),
  298. }
  299. }
  300. writer.write_all(body).await
  301. }
  302. }
  303. /// Handshake between client and server (peers do not use).
  304. #[derive(Eq, Debug, Hash, PartialEq)]
  305. pub struct Handshake {
  306. /// Who is sending this handshake.
  307. pub sender: String,
  308. /// For normal messages, the group the message was sent to.
  309. /// For receipts, the client the receipt is for.
  310. pub group: String,
  311. }
  312. impl Handshake {
  313. /// Generate a serialized handshake message.
  314. pub fn serialize(&self) -> Vec<u8> {
  315. serialize_handshake(&self.sender, &self.group)
  316. }
  317. }
  318. /// Gets a handshake from the stream and constructs a Handshake object
  319. pub async fn get_handshake<T: AsyncReadExt + std::marker::Unpin>(
  320. stream: &mut T,
  321. ) -> Result<Handshake, Error> {
  322. let sender = parse_identifier(stream).await?;
  323. let group = parse_identifier(stream).await?;
  324. Ok(Handshake { sender, group })
  325. }
  326. /// A reference to a Handshake's fields.
  327. pub struct HandshakeRef<'a> {
  328. pub sender: &'a str,
  329. pub group: &'a str,
  330. }
  331. impl HandshakeRef<'_> {
  332. /// Generate a serialized handshake message.
  333. pub fn serialize(&self) -> Vec<u8> {
  334. serialize_handshake(self.sender, self.group)
  335. }
  336. }
  337. fn serialize_handshake(sender: &str, group: &str) -> Vec<u8> {
  338. // serialized handshake: {
  339. // sender: {u32, utf-8}
  340. // group: {u32, utf-8}
  341. // }
  342. let handshake_len = (1 + 1) * size_of::<u32>() + sender.len() + group.len();
  343. let mut handshake: Vec<u8> = Vec::with_capacity(handshake_len);
  344. serialize_str_to(sender, &mut handshake);
  345. serialize_str_to(group, &mut handshake);
  346. debug_assert!(handshake.len() == handshake_len);
  347. handshake
  348. }
  349. #[cfg(test)]
  350. mod tests {
  351. use super::*;
  352. use tokio::fs::{File, OpenOptions};
  353. /// creates a temporary file for writing
  354. async fn generate_tmp_file(name: &str) -> File {
  355. let filename = format!("mgen-test-{}", name);
  356. let mut path = std::env::temp_dir();
  357. path.push(filename);
  358. OpenOptions::new()
  359. .read(true)
  360. .write(true)
  361. .create(true)
  362. .open(path)
  363. .await
  364. .unwrap()
  365. }
  366. /// get an existing temp file for reading
  367. async fn get_tmp_file(name: &str) -> File {
  368. let filename = format!("mgen-test-{}", name);
  369. let mut path = std::env::temp_dir();
  370. path.push(filename);
  371. OpenOptions::new().read(true).open(path).await.unwrap()
  372. }
  373. #[test]
  374. fn serialize_deserialize_message() {
  375. let m1 = MessageHeader {
  376. sender: "Alice".to_string(),
  377. group: "group".to_string(),
  378. id: 1024,
  379. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  380. };
  381. let serialized = m1.serialize();
  382. let m2 = MessageHeader::deserialize(&serialized.header[4..]).unwrap();
  383. assert_eq!(m1, m2);
  384. }
  385. #[test]
  386. fn serialize_deserialize_receipt() {
  387. let m1 = MessageHeader {
  388. sender: "Alice".to_string(),
  389. group: "group".to_string(),
  390. id: 1024,
  391. body: MessageBody::Receipt,
  392. };
  393. let serialized = m1.serialize();
  394. let m2 = MessageHeader::deserialize(&serialized.header[4..]).unwrap();
  395. assert_eq!(m1, m2);
  396. }
  397. #[test]
  398. fn deserialize_message_ref() {
  399. let m1 = MessageHeader {
  400. sender: "Alice".to_string(),
  401. group: "group".to_string(),
  402. id: 1024,
  403. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  404. };
  405. let serialized = m1.serialize();
  406. let m2 = MessageHeaderRef::deserialize(&serialized.header[4..]).unwrap();
  407. assert_eq!(m1.sender, m2.sender);
  408. assert_eq!(m1.group, m2.group);
  409. assert_eq!(m1.body, m2.body);
  410. }
  411. #[test]
  412. fn deserialize_receipt_ref() {
  413. let m1 = MessageHeader {
  414. sender: "Alice".to_string(),
  415. group: "group".to_string(),
  416. id: 1024,
  417. body: MessageBody::Receipt,
  418. };
  419. let serialized = m1.serialize();
  420. let m2 = MessageHeaderRef::deserialize(&serialized.header[4..]).unwrap();
  421. assert_eq!(m1.sender, m2.sender);
  422. assert_eq!(m1.group, m2.group);
  423. assert_eq!(m1.body, m2.body);
  424. }
  425. #[tokio::test]
  426. async fn serialize_get_message() {
  427. let m1 = MessageHeader {
  428. sender: "Alice".to_string(),
  429. group: "group".to_string(),
  430. id: 1024,
  431. body: MessageBody::Size(NonZeroU32::new(256).unwrap()),
  432. };
  433. let serialized = m1.serialize();
  434. let file_name = "serialize_message_get";
  435. let mut f = generate_tmp_file(file_name).await;
  436. serialized.write_all_to(&mut f).await.unwrap();
  437. let mut f = get_tmp_file(file_name).await;
  438. let m2 = get_message(&mut f).await.unwrap();
  439. assert_eq!(m1, m2);
  440. }
  441. #[tokio::test]
  442. async fn serialize_get_receipt() {
  443. let m1 = MessageHeader {
  444. sender: "Alice".to_string(),
  445. group: "group".to_string(),
  446. id: 1024,
  447. body: MessageBody::Receipt,
  448. };
  449. let serialized = m1.serialize();
  450. let file_name = "serialize_receipt_get";
  451. let mut f = generate_tmp_file(file_name).await;
  452. serialized.write_all_to(&mut f).await.unwrap();
  453. let mut f = get_tmp_file(file_name).await;
  454. let m2 = get_message(&mut f).await.unwrap();
  455. assert_eq!(m1, m2);
  456. }
  457. #[tokio::test]
  458. async fn serialize_get_handshake() {
  459. let h1 = Handshake {
  460. sender: "Alice".to_string(),
  461. group: "group".to_string(),
  462. };
  463. let file_name = "handshake";
  464. let mut f = generate_tmp_file(file_name).await;
  465. f.write_all(&h1.serialize()).await.unwrap();
  466. let mut f = get_tmp_file(file_name).await;
  467. let h2 = get_handshake(&mut f).await.unwrap();
  468. assert_eq!(h1, h2);
  469. }
  470. #[tokio::test]
  471. async fn serialize_get_handshake_ref() {
  472. let h1 = HandshakeRef {
  473. sender: "Alice",
  474. group: "group",
  475. };
  476. let file_name = "handshake-ref";
  477. let mut f = generate_tmp_file(file_name).await;
  478. f.write_all(&h1.serialize()).await.unwrap();
  479. let mut f = get_tmp_file(file_name).await;
  480. let h2 = get_handshake(&mut f).await.unwrap();
  481. assert_eq!(h1.sender, h2.sender);
  482. assert_eq!(h1.group, h2.group);
  483. }
  484. }