|
@@ -1,18 +1,20 @@
|
|
|
-use mgen::{log, parse_identifier, updater::Updater, MessageHeaderRef, SerializedMessage};
|
|
|
+use mgen::{log, updater::Updater, Handshake, MessageBody, MessageHeaderRef, SerializedMessage};
|
|
|
use std::collections::HashMap;
|
|
|
use std::error::Error;
|
|
|
use std::result::Result;
|
|
|
-use std::sync::{Arc, RwLock};
|
|
|
+use std::sync::Arc;
|
|
|
use tokio::io::AsyncWriteExt;
|
|
|
use tokio::net::{
|
|
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
|
TcpListener,
|
|
|
};
|
|
|
-use tokio::sync::{mpsc, Notify};
|
|
|
+use tokio::sync::{mpsc, Notify, RwLock};
|
|
|
|
|
|
// FIXME: identifiers should be interned
|
|
|
type ID = String;
|
|
|
|
|
|
+type ReaderToSender = mpsc::UnboundedSender<Arc<SerializedMessage>>;
|
|
|
+
|
|
|
#[tokio::main]
|
|
|
async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
let args: Vec<String> = std::env::args().collect();
|
|
@@ -25,12 +27,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
|
|
|
log!("Listening on {}", listen_addr);
|
|
|
|
|
|
- let snd_db = Arc::new(RwLock::new(HashMap::<
|
|
|
- ID,
|
|
|
- mpsc::UnboundedSender<Arc<SerializedMessage>>,
|
|
|
- >::new()));
|
|
|
-
|
|
|
- let mut writer_db = HashMap::<ID, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
|
|
|
+ // Maps group name to the table of message channels.
|
|
|
+ let mut snd_db = HashMap::<ID, Arc<RwLock<HashMap<ID, ReaderToSender>>>>::new();
|
|
|
+ // Maps the (sender, group) pair to the socket updater.
|
|
|
+ let mut writer_db = HashMap::<Handshake, Updater<(OwnedWriteHalf, Arc<Notify>)>>::new();
|
|
|
|
|
|
loop {
|
|
|
let socket = match listener.accept().await {
|
|
@@ -42,21 +42,23 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
};
|
|
|
let (mut rd, wr) = socket.into_split();
|
|
|
|
|
|
- let id = match parse_identifier(&mut rd).await {
|
|
|
- Ok(id) => id,
|
|
|
- Err(mgen::Error::Utf8Error(e)) => {
|
|
|
- let err: Box<dyn Error> = Box::new(e);
|
|
|
- return Err(err);
|
|
|
- }
|
|
|
- Err(_) => continue,
|
|
|
- };
|
|
|
- log!("Accepting \"{id}\"");
|
|
|
- if let Some(socket_updater) = writer_db.get(&id) {
|
|
|
+ let handshake = mgen::get_handshake(&mut rd).await?;
|
|
|
+ log!(
|
|
|
+ "Accepting channel from {} to {}",
|
|
|
+ handshake.sender,
|
|
|
+ handshake.group
|
|
|
+ );
|
|
|
+
|
|
|
+ if let Some(socket_updater) = writer_db.get(&handshake) {
|
|
|
// we've seen this client before
|
|
|
|
|
|
// start the new reader thread with a new notify
|
|
|
+ // (we can't use the existing notify channel, else we get race conditions where
|
|
|
+ // the reader thread terminates and spawns again before the sender thread
|
|
|
+ // notices and activates its existing notify channel)
|
|
|
let notify = Arc::new(Notify::new());
|
|
|
- spawn_message_receiver(id, rd, snd_db.clone(), notify.clone());
|
|
|
+ let db = snd_db[&handshake.group].clone();
|
|
|
+ spawn_message_receiver(handshake.sender, handshake.group, rd, db, notify.clone());
|
|
|
|
|
|
// give the writer thread the new write half of the socket and notify
|
|
|
socket_updater.send((wr, notify));
|
|
@@ -64,29 +66,38 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
// newly-registered client
|
|
|
log!("New client");
|
|
|
|
|
|
- // message channel, used to send messages between threads
|
|
|
+ // message channel, for sending messages between threads
|
|
|
let (msg_snd, msg_rcv) = mpsc::unbounded_channel::<Arc<SerializedMessage>>();
|
|
|
|
|
|
- let snd_db = snd_db.clone();
|
|
|
- {
|
|
|
- let mut locked_db = snd_db.write().unwrap();
|
|
|
- locked_db.insert(id.clone(), msg_snd);
|
|
|
- }
|
|
|
-
|
|
|
- // socket notify, used to terminate the socket if the sender encounters an error
|
|
|
+ let group_snds = if let Some(table) = snd_db.get(&handshake.group) {
|
|
|
+ table
|
|
|
+ } else {
|
|
|
+ let table = Arc::new(RwLock::new(HashMap::new()));
|
|
|
+ snd_db.insert(handshake.group.clone(), table);
|
|
|
+ &snd_db[&handshake.group]
|
|
|
+ };
|
|
|
+ group_snds
|
|
|
+ .write()
|
|
|
+ .await
|
|
|
+ .insert(handshake.sender.clone(), msg_snd);
|
|
|
+
|
|
|
+ // socket notify, for terminating the socket if the sender encounters an error
|
|
|
let notify = Arc::new(Notify::new());
|
|
|
|
|
|
- // socket updater, used to give the sender thread a new socket + notify channel
|
|
|
- // (we can't use the existing notify channel, else we get race conditions where
|
|
|
- // the reader thread terminates and spawns again before the sender thread
|
|
|
- // notices and activates its existing notify channel)
|
|
|
+ // socket updater, for giving the sender thread a new socket + notify channel
|
|
|
let socket_updater_snd = Updater::new();
|
|
|
let socket_updater_rcv = socket_updater_snd.clone();
|
|
|
socket_updater_snd.send((wr, notify.clone()));
|
|
|
|
|
|
- writer_db.insert(id.clone(), socket_updater_snd);
|
|
|
+ spawn_message_receiver(
|
|
|
+ handshake.sender.clone(),
|
|
|
+ handshake.group.clone(),
|
|
|
+ rd,
|
|
|
+ group_snds.clone(),
|
|
|
+ notify,
|
|
|
+ );
|
|
|
+ writer_db.insert(handshake, socket_updater_snd);
|
|
|
|
|
|
- spawn_message_receiver(id, rd, snd_db, notify);
|
|
|
tokio::spawn(async move {
|
|
|
send_messages(msg_rcv, socket_updater_rcv).await;
|
|
|
});
|
|
@@ -96,13 +107,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
|
|
|
fn spawn_message_receiver(
|
|
|
sender: String,
|
|
|
+ group: String,
|
|
|
rd: OwnedReadHalf,
|
|
|
- db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
|
|
|
+ db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
|
|
|
notify: Arc<Notify>,
|
|
|
) {
|
|
|
tokio::spawn(async move {
|
|
|
tokio::select! {
|
|
|
- ret = get_messages(&sender, rd, db) => {
|
|
|
+ // n.b.: get_message is not cancellation safe,
|
|
|
+ // but this is one of the cases where that's expected
|
|
|
+ // (we only cancel when something is wrong with the stream anyway)
|
|
|
+ ret = get_messages(sender, group, rd, db) => {
|
|
|
if let Err(e) = ret {
|
|
|
log!("message receiver failed: {:?}", e);
|
|
|
}
|
|
@@ -118,50 +133,44 @@ fn spawn_message_receiver(
|
|
|
/// Loop for receiving messages on the socket, figuring out who to deliver them to,
|
|
|
/// and forwarding them locally to the respective channel.
|
|
|
async fn get_messages(
|
|
|
- sender: &str,
|
|
|
+ sender: String,
|
|
|
+ group: String,
|
|
|
mut socket: OwnedReadHalf,
|
|
|
- db: Arc<RwLock<HashMap<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>>>,
|
|
|
+ global_db: Arc<RwLock<HashMap<ID, ReaderToSender>>>,
|
|
|
) -> Result<(), Box<dyn Error>> {
|
|
|
- // stores snd's for contacts this client has already sent messages to, to reduce contention on the main db
|
|
|
- // if memory ends up being more of a constraint, could be worth getting rid of this
|
|
|
- let mut localdb = HashMap::<ID, mpsc::UnboundedSender<Arc<SerializedMessage>>>::new();
|
|
|
+ // Wait for the next message to be received before populating our local copy of the db,
|
|
|
+ // that way other clients have time to register.
|
|
|
+ let mut discard = vec![0u8];
|
|
|
+ while socket.peek(&mut discard).await? == 0 {}
|
|
|
+
|
|
|
+ let db = global_db.read().await.clone();
|
|
|
+ let message_channels: Vec<_> = db
|
|
|
+ .iter()
|
|
|
+ .filter_map(|(k, v)| if *k != sender { Some(v) } else { None })
|
|
|
+ .collect();
|
|
|
+
|
|
|
loop {
|
|
|
- log!("waiting for message from {}", sender);
|
|
|
+ log!("waiting for message from {} to {}", sender, group);
|
|
|
let buf = mgen::get_message_bytes(&mut socket).await?;
|
|
|
let message = MessageHeaderRef::deserialize(&buf[4..])?;
|
|
|
- log!(
|
|
|
- "got message from {} for {:?}",
|
|
|
- message.sender,
|
|
|
- message.recipients
|
|
|
- );
|
|
|
-
|
|
|
- let missing = message
|
|
|
- .recipients
|
|
|
- .iter()
|
|
|
- .filter(|&&r| !localdb.contains_key(r))
|
|
|
- .collect::<Vec<&&str>>();
|
|
|
-
|
|
|
- {
|
|
|
- let locked_db = db.read().unwrap();
|
|
|
- for m in missing {
|
|
|
- if let Some(snd) = locked_db.get(*m) {
|
|
|
- localdb.insert(m.to_string(), snd.clone());
|
|
|
+ assert!(message.sender == sender);
|
|
|
+
|
|
|
+ match message.body {
|
|
|
+ MessageBody::Size(_) => {
|
|
|
+ assert!(message.group == group);
|
|
|
+ log!("got message from {} for {}", sender, group);
|
|
|
+ let body = message.body;
|
|
|
+ let m = Arc::new(SerializedMessage { header: buf, body });
|
|
|
+ for recipient in message_channels.iter() {
|
|
|
+ recipient.send(m.clone()).unwrap();
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
-
|
|
|
- let body = message.body;
|
|
|
-
|
|
|
- let m = Arc::new(SerializedMessage { header: buf, body });
|
|
|
-
|
|
|
- // FIXME: we could avoid this second deserialization if we stored recipients by group instead of in the message itself
|
|
|
- let message = MessageHeaderRef::deserialize(&m.header[4..])?;
|
|
|
-
|
|
|
- for recipient in message.recipients.iter() {
|
|
|
- let recipient_sender = &localdb[*recipient];
|
|
|
- recipient_sender
|
|
|
- .send(m.clone())
|
|
|
- .expect("Recipient closed channel with messages still being sent");
|
|
|
+ MessageBody::Receipt => {
|
|
|
+ let recipient = &db[message.group];
|
|
|
+ let body = message.body;
|
|
|
+ let m = Arc::new(SerializedMessage { header: buf, body });
|
|
|
+ recipient.send(m.clone()).unwrap();
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|