use troll_patrol::{request_handler::handle, *}; use clap::Parser; use futures::future; use hyper::{ server::conn::AddrStream, service::{make_service_fn, service_fn}, Body, Request, Response, Server, }; use serde::Deserialize; use sled::Db; use std::{ collections::BTreeMap, convert::Infallible, fs::File, io::BufReader, net::SocketAddr, path::PathBuf, time::Duration, }; use tokio::{ signal, spawn, sync::{broadcast, mpsc, oneshot}, time::sleep, }; use tokio_cron::{Job, Scheduler}; async fn shutdown_signal() { tokio::signal::ctrl_c() .await .expect("failed to listen for ctrl+c signal"); println!("Shut down Troll Patrol Server"); } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Name/path of the configuration file #[arg(short, long, default_value = "config.json")] config: PathBuf, } #[derive(Debug, Deserialize)] pub struct Config { pub db: DbConfig, // map of distributor name to IP:port to contact it pub distributors: BTreeMap, extra_infos_base_url: String, // confidence required to consider a bridge blocked confidence: f64, // block open-entry bridges if they get more negative reports than this max_threshold: u32, // block open-entry bridges if they get more negative reports than // scaling_factor * bridge_ips scaling_factor: f64, //require_bridge_token: bool, port: u16, updater_schedule: String, } #[derive(Debug, Deserialize)] pub struct DbConfig { // The path for the server database, default is "server_db" pub db_path: String, } impl Default for DbConfig { fn default() -> DbConfig { DbConfig { db_path: "server_db".to_owned(), } } } async fn update_daily_info( db: &Db, distributors: &BTreeMap, extra_infos_base_url: &str, confidence: f64, max_threshold: u32, scaling_factor: f64, ) { update_extra_infos(&db, &extra_infos_base_url) .await .unwrap(); update_negative_reports(&db, &distributors).await; update_positive_reports(&db, &distributors).await; let new_blockages = guess_blockages( &db, &analysis::NormalAnalyzer::new(max_threshold, scaling_factor), confidence, ); report_blockages(&distributors, new_blockages).await; // Generate tomorrow's key if we don't already have it new_negative_report_key(&db, get_date() + 1); } async fn run_updater(updater_tx: mpsc::Sender) { updater_tx.send(Command::Update {}).await.unwrap(); } async fn create_context_manager( db_config: DbConfig, distributors: BTreeMap, extra_infos_base_url: &str, confidence: f64, max_threshold: u32, scaling_factor: f64, context_rx: mpsc::Receiver, mut kill: broadcast::Receiver<()>, ) { tokio::select! { create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, context_rx) => create_context, _ = kill.recv() => {println!("Shut down manager");}, } } async fn context_manager( db_config: DbConfig, distributors: BTreeMap, extra_infos_base_url: &str, confidence: f64, max_threshold: u32, scaling_factor: f64, mut context_rx: mpsc::Receiver, ) { let db: Db = sled::open(&db_config.db_path).unwrap(); while let Some(cmd) = context_rx.recv().await { use Command::*; match cmd { Request { req, sender } => { let response = handle(&db, req).await; if let Err(e) = sender.send(response) { eprintln!("Server Response Error: {:?}", e); }; sleep(Duration::from_millis(1)).await; } Shutdown { shutdown_sig } => { println!("Sending Shutdown Signal, all threads should shutdown."); drop(shutdown_sig); println!("Shutdown Sent."); } Update {} => { update_daily_info( &db, &distributors, &extra_infos_base_url, confidence, max_threshold, scaling_factor, ) .await; } } } } // Each of the commands that can be handled #[derive(Debug)] enum Command { Request { req: Request, sender: oneshot::Sender, Infallible>>, }, Shutdown { shutdown_sig: broadcast::Sender<()>, }, Update {}, } #[tokio::main] async fn main() { let args: Args = Args::parse(); let config: Config = serde_json::from_reader(BufReader::new( File::open(&args.config).expect("Could not read config file"), )) .expect("Reading config file from JSON failed"); let (request_tx, request_rx) = mpsc::channel(32); let updater_tx = request_tx.clone(); let shutdown_cmd_tx = request_tx.clone(); // create the shutdown broadcast channel and clone for every thread let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16); let kill = shutdown_tx.subscribe(); // TODO: Gracefully shut down updater let kill_updater = shutdown_tx.subscribe(); // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx let shutdown_handler = spawn(async move { tokio::select! { _ = signal::ctrl_c() => { let cmd = Command::Shutdown { shutdown_sig: shutdown_tx, }; shutdown_cmd_tx.send(cmd).await.unwrap(); sleep(Duration::from_secs(1)).await; _ = shutdown_rx.recv().await; } } }); let updater = spawn(async move { // Run updater once per day let mut sched = Scheduler::utc(); sched.add(Job::new(config.updater_schedule, move || { run_updater(updater_tx.clone()) })); }); let context_manager = spawn(async move { create_context_manager( config.db, config.distributors, &config.extra_infos_base_url, config.confidence, config.max_threshold, config.scaling_factor, request_rx, kill, ) .await }); let make_service = make_service_fn(move |_conn: &AddrStream| { let request_tx = request_tx.clone(); let service = service_fn(move |req| { let request_tx = request_tx.clone(); let (response_tx, response_rx) = oneshot::channel(); let cmd = Command::Request { req, sender: response_tx, }; async move { request_tx.send(cmd).await.unwrap(); response_rx.await.unwrap() } }); async move { Ok::<_, Infallible>(service) } }); let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); let server = Server::bind(&addr).serve(make_service); let graceful = server.with_graceful_shutdown(shutdown_signal()); println!("Listening on {}", addr); if let Err(e) = graceful.await { eprintln!("server error: {}", e); } future::join_all([context_manager, updater, shutdown_handler]).await; }