Browse Source

only limit message sizes in client-server

Justin Tracey 1 month ago
parent
commit
a305d44730
4 changed files with 19 additions and 17 deletions
  1. 2 2
      src/bin/mgen-client.rs
  2. 2 2
      src/bin/mgen-peer.rs
  3. 5 2
      src/bin/mgen-server.rs
  4. 10 11
      src/lib.rs

+ 2 - 2
src/bin/mgen-client.rs

@@ -176,7 +176,7 @@ async fn reader(
             let mut message_stream = socket_updater.recv().await;
 
             loop {
-                let msg = match mgen::get_message(&mut message_stream).await {
+                let msg = match mgen::get_message::<false, _>(&mut message_stream).await {
                     Ok(msg) => msg,
                     Err(e) => {
                         error_channel.send(e.into()).expect("Error channel closed");
@@ -289,7 +289,7 @@ async fn writer(
                     .expect("Attachment channel closed");
             }
 
-            if let Err(e) = msg.write_all_to(&mut stream).await {
+            if let Err(e) = msg.write_all_to::<false, _>(&mut stream).await {
                 error_channel.send(e.into()).expect("Error channel closed");
                 break;
             }

+ 2 - 2
src/bin/mgen-peer.rs

@@ -139,7 +139,7 @@ async fn reader(
         // wait for listener or writer thread to give us a stream to read from
         let mut stream = connection_channel.recv().await;
         loop {
-            let Ok(msg) = mgen::get_message(&mut stream).await else {
+            let Ok(msg) = mgen::get_message::<true, _>(&mut stream).await else {
                 // Unlike the client-server case, we can assume that if there
                 // were a message someone was trying to send us, they'd make
                 // sure to re-establish the connection; so when the socket
@@ -186,7 +186,7 @@ async fn writer<'a>(
     .expect("Fatal error establishing connection");
 
     loop {
-        while msg.write_all_to(&mut stream).await.is_err() {
+        while msg.write_all_to::<true, _>(&mut stream).await.is_err() {
             stream = establish_connection(
                 &mut write_socket_updater,
                 &read_socket_updater,

+ 5 - 2
src/bin/mgen-server.rs

@@ -383,7 +383,7 @@ async fn get_messages<T: tokio::io::AsyncRead>(
         .collect();
 
     loop {
-        let buf = mgen::get_message_bytes(&mut socket).await?;
+        let buf = mgen::get_message_bytes::<false, _>(&mut socket).await?;
         let message = MessageHeaderRef::deserialize(&buf[4..])?;
         assert!(message.sender == sender);
 
@@ -430,7 +430,10 @@ async fn send_messages<T: Send + Sync + tokio::io::AsyncWrite>(
         } else {
             msg_rcv.recv().await.expect("message channel closed")
         };
-        if message.write_all_to(&mut current_socket).await.is_err()
+        if message
+            .write_all_to::<false, _>(&mut current_socket)
+            .await
+            .is_err()
             || current_socket.flush().await.is_err()
         {
             message_cache = Some(message);

+ 10 - 11
src/lib.rs

@@ -12,7 +12,6 @@ pub const PADDING_BLOCK_SIZE: u32 = 10 * 128 / 8;
 // from https://github.com/signalapp/Signal-Android/blob/36a8c4d8ba9fdb62905ecb9a20e3eeba4d2f9022/app/src/main/java/org/thoughtcrime/securesms/mms/PushMediaConstraints.java
 pub const MAX_BLOCKS_IN_BODY: u32 = (100 * 1024 * 1024) / PADDING_BLOCK_SIZE;
 /// The maxmimum number of bytes that can be sent inline; larger values use the HTTP server.
-// FIXME: should only apply to client-server, not p2p
 // In actuality, this is 2000 for Signal:
 // https://github.com/signalapp/Signal-Android/blob/244902ecfc30e21287a35bb1680e2dbe6366975b/app/src/main/java/org/thoughtcrime/securesms/util/PushCharacterCalculator.java#L23
 // but we align to a close block count since in practice we sample from block counts
@@ -75,12 +74,12 @@ impl MessageBody {
     }
 
     /// Size on the wire of the message's body, exluding bytes fetched via http
-    fn inline_size(&self) -> usize {
+    fn inline_size<const P2P: bool>(&self) -> usize {
         match self {
             MessageBody::Receipt => PADDING_BLOCK_SIZE as usize,
             MessageBody::Size(size) => {
                 let size = size.get();
-                if size <= INLINE_MAX_SIZE {
+                if P2P || size <= INLINE_MAX_SIZE {
                     size as usize
                 } else {
                     INLINE_MAX_SIZE as usize
@@ -228,24 +227,24 @@ pub async fn parse_identifier<T: AsyncReadExt + std::marker::Unpin>(
 }
 
 /// Gets a message from the stream, returning the raw byte buffer
-pub async fn get_message_bytes<T: AsyncReadExt + std::marker::Unpin>(
+pub async fn get_message_bytes<const P2P: bool, T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,
 ) -> Result<Vec<u8>, Error> {
     let mut header_size_bytes = [0u8; 4];
     stream.read_exact(&mut header_size_bytes).await?;
-    get_message_with_header_size(stream, header_size_bytes).await
+    get_message_with_header_size::<P2P, _>(stream, header_size_bytes).await
 }
 
 /// Gets a message from the stream and constructs a MessageHeader object
-pub async fn get_message<T: AsyncReadExt + std::marker::Unpin>(
+pub async fn get_message<const P2P: bool, T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,
 ) -> Result<MessageHeader, Error> {
-    let buf = get_message_bytes(stream).await?;
+    let buf = get_message_bytes::<P2P, _>(stream).await?;
     let msg = MessageHeader::deserialize(&buf[4..])?;
     Ok(msg)
 }
 
-async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
+async fn get_message_with_header_size<const P2P: bool, T: AsyncReadExt + std::marker::Unpin>(
     stream: &mut T,
     header_size_bytes: [u8; 4],
 ) -> Result<Vec<u8>, Error> {
@@ -256,7 +255,7 @@ async fn get_message_with_header_size<T: AsyncReadExt + std::marker::Unpin>(
     let header_size_buf = &mut header_buf[..4];
     header_size_buf.copy_from_slice(&header_size_bytes);
     copy(
-        &mut stream.take(header.body.inline_size() as u64),
+        &mut stream.take(header.body.inline_size::<P2P>() as u64),
         &mut sink(),
     )
     .await?;
@@ -304,11 +303,11 @@ pub struct SerializedMessage {
 }
 
 impl SerializedMessage {
-    pub async fn write_all_to<T: AsyncWriteExt + std::marker::Unpin>(
+    pub async fn write_all_to<const P2P: bool, T: AsyncWriteExt + std::marker::Unpin>(
         &self,
         writer: &mut T,
     ) -> std::io::Result<()> {
-        let body_buf = vec![0; self.body.inline_size()];
+        let body_buf = vec![0; self.body.inline_size::<P2P>()];
 
         // write_all_vectored is not yet stable x_x
         // https://github.com/rust-lang/rust/issues/70436