Browse Source

switch Mutex to RwLock for global stream table in server

Justin Tracey 1 year ago
parent
commit
58ca61082f
1 changed files with 6 additions and 7 deletions
  1. 6 7
      src/bin/server.rs

+ 6 - 7
src/bin/server.rs

@@ -2,7 +2,7 @@ use mgen::{log, parse_identifier, updater::Updater, SerializedMessage};
 use std::collections::HashMap;
 use std::error::Error;
 use std::result::Result;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, RwLock};
 use tokio::io::AsyncWriteExt;
 use tokio::net::{
     tcp::{OwnedReadHalf, OwnedWriteHalf},
@@ -19,8 +19,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
     log!("Listening");
 
-    // FIXME: should probably be a readers-writer lock
-    let snd_db = Arc::new(Mutex::new(HashMap::<
+    let snd_db = Arc::new(RwLock::new(HashMap::<
         ID,
         mpsc::UnboundedSender<Arc<SerializedMessage>>,
     >::new()));
@@ -51,7 +50,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
             let snd_db = snd_db.clone();
             {
-                let mut locked_db = snd_db.lock().unwrap();
+                let mut locked_db = snd_db.write().unwrap();
                 locked_db.insert(id.clone(), msg_snd);
             }
 
@@ -77,7 +76,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 
 fn spawn_message_receiver(
     rd: OwnedReadHalf,
-    db: Arc<Mutex<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
+    db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
     notify: Arc<Notify>,
 ) {
     tokio::spawn(async move {
@@ -99,7 +98,7 @@ fn spawn_message_receiver(
 /// and forwarding them locally to the respective channel.
 async fn get_messages(
     mut socket: OwnedReadHalf,
-    db: Arc<Mutex<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
+    db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
 ) -> Result<(), Box<dyn Error>> {
     // stores snd's for contacts this client has already sent messages to, to reduce contention on the main db
     // if memory ends up being more of a constraint, could be worth getting rid of this
@@ -121,7 +120,7 @@ async fn get_messages(
             .collect::<Vec<&ID>>();
 
         {
-            let locked_db = db.lock().unwrap();
+            let locked_db = db.read().unwrap();
             for m in missing {
                 if let Some(snd) = locked_db.get(m) {
                     localdb.insert(m.to_string(), snd.clone());