Procházet zdrojové kódy

Pass byte vectors around, not strings

Vecna před 3 roky
rodič
revize
cb26f44d68
5 změnil soubory, kde provedl 61 přidání a 33 odebrání
  1. 4 2
      src/bin/bridgedb.rs
  2. 4 2
      src/bin/lox_auth.rs
  3. 2 2
      src/bin/lox_client.rs
  4. 24 13
      src/client_net.rs
  5. 27 14
      src/server_net.rs

+ 4 - 2
src/bin/bridgedb.rs

@@ -52,6 +52,8 @@ async fn main() {
     listen(addr, to_uppercase).await;
 }
 
-fn to_uppercase(str: String) -> String {
-    str.to_uppercase()
+// This function assumes the byte vector is a valid string.
+fn to_uppercase(str_vec: Vec<u8>) -> Vec<u8> {
+    let str = std::str::from_utf8(&str_vec).unwrap();
+    str.to_uppercase().into()
 }

+ 4 - 2
src/bin/lox_auth.rs

@@ -68,6 +68,8 @@ async fn main() {
     listen(addr, reverse_string).await;
 }
 
-fn reverse_string(str: String) -> String {
-    str.trim().chars().rev().collect::<String>() + "\n"
+// This function assumes the byte vector is a valid string.
+fn reverse_string(str_vec: Vec<u8>) -> Vec<u8> {
+    let str = std::str::from_utf8(&str_vec).unwrap();
+    str.trim().chars().rev().collect::<String>().into()
 }

+ 2 - 2
src/bin/lox_client.rs

@@ -35,7 +35,7 @@ async fn main() {
     let reachability_pub = &lox_auth_pubkeys[3];
     let invitation_pub = &lox_auth_pubkeys[4];
 
-    let s = send(addr, msg).await;
+    let s = send(addr, msg.into()).await;
 
-    println!("{}", s);
+    println!("{}", std::str::from_utf8(&s).unwrap());
 }

+ 24 - 13
src/client_net.rs

@@ -5,32 +5,43 @@ to a server process and sending it data. */
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
 use tokio::net::TcpStream;
 
-use std::str;
-
 // may need to change strings to byte vectors in the future
-pub async fn send(addr: String, str: String) -> String {
+pub async fn send(addr: String, payload: Vec<u8>) -> Vec<u8> {
     let mut stream = TcpStream::connect(&addr)
         .await
         .expect("Failed to create TcpStream");
 
+    // send number of bytes in payload
+    let payload_size = usize::to_be_bytes(payload.len());
+    stream
+        .write_all(&payload_size)
+        .await
+        .expect("Failed to write number of bytes to listen for");
+
     // send data
     stream
-        .write_all(str.as_bytes())
+        .write_all(&payload)
         .await
         .expect("Failed to write data to stream");
 
-    // read response
-    let mut buf = vec![0; 1024];
-    let n = stream
-        .read(&mut buf)
+    // get number of bytes in response
+    let mut nbuf: [u8; 8] = [0; 8];
+    stream
+        .read(&mut nbuf)
         .await
-        .expect("Failed to read data from socket");
+        .expect("Failed to get number of bytes to read");
+    let n = usize::from_be_bytes(nbuf);
 
     if n == 0 {
-        return "".to_string();
+        return vec![0; 0];
     }
 
-    str::from_utf8(&buf[0..n])
-        .expect("Invalid UTF-8 sequence")
-        .to_string()
+    // receive response
+    let mut buf = vec![0; n];
+    stream
+        .read(&mut buf)
+        .await
+        .expect("Failed to read data from socket");
+
+    buf
 }

+ 27 - 14
src/server_net.rs

@@ -9,9 +9,7 @@ these work. */
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
 use tokio::net::TcpListener;
 
-use std::str;
-
-pub async fn listen(addr: String, fun: fn(String) -> String) {
+pub async fn listen(addr: String, fun: fn(Vec<u8>) -> Vec<u8>) {
     let listener = TcpListener::bind(&addr)
         .await
         .expect("Failed to create TcpListener");
@@ -22,11 +20,23 @@ pub async fn listen(addr: String, fun: fn(String) -> String) {
         let (mut socket, _) = listener.accept().await.expect("Failed to create socket");
 
         tokio::spawn(async move {
-            let mut buf = vec![0; 1024];
-
-            // read data, perform function on it, write result
             loop {
-                let n = socket
+                // get number of bytes to receive
+                let mut nbuf: [u8; 8] = [0; 8];
+                socket
+                    .read(&mut nbuf)
+                    .await
+                    .expect("Failed to get number of bytes to read");
+                let n = usize::from_be_bytes(nbuf);
+
+                if n == 0 {
+                    return;
+                }
+
+                let mut buf = vec![0; n];
+
+                // receive data
+                socket
                     .read(&mut buf)
                     .await
                     .expect("Failed to read data from socket");
@@ -35,15 +45,18 @@ pub async fn listen(addr: String, fun: fn(String) -> String) {
                     return;
                 }
 
-                // I think this is a problem if there's more data than fits in the buffer...
-                // But that's a problem for future me.
-                let s = str::from_utf8(&buf[0..n])
-                    .expect("Invalid UTF-8 sequence")
-                    .to_string();
-                let response = fun(s);
+                let response = fun(buf);
+
+                // send number of bytes in response
+                let response_size = usize::to_be_bytes(response.len());
+                socket
+                    .write_all(&response_size)
+                    .await
+                    .expect("Failed to write number of bytes to listen for");
 
+                // send response
                 socket
-                    .write_all(response.as_bytes())
+                    .write_all(&response)
                     .await
                     .expect("Failed to write data to socket");
             }