use troll_patrol::{request_handler::*, *}; use clap::Parser; use futures::future; use futures::join; use hyper::{ server::conn::AddrStream, service::{make_service_fn, service_fn}, Body, Request, Response, Server, }; use serde::Deserialize; use sled::Db; use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::Infallible, fs::File, io::BufReader, net::SocketAddr, path::PathBuf, time::Duration, }; use tokio::{ signal, spawn, sync::{broadcast, mpsc, oneshot}, time::sleep, }; #[cfg(not(feature = "simulation"))] use tokio_cron::{Job, Scheduler}; #[cfg(feature = "simulation")] use memory_stats::memory_stats; async fn shutdown_signal() { tokio::signal::ctrl_c() .await .expect("failed to listen for ctrl+c signal"); println!("Shut down Troll Patrol Server"); } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Name/path of the configuration file #[arg(short, long, default_value = "config.json")] config: PathBuf, } #[derive(Debug, Deserialize)] pub struct Config { pub db: DbConfig, // map of distributor name to IP:port to contact it pub distributors: BTreeMap, extra_infos_base_url: String, // confidence required to consider a bridge blocked confidence: f64, // block open-entry bridges if they get more negative reports than this max_threshold: u32, // block open-entry bridges if they get more negative reports than // scaling_factor * bridge_ips scaling_factor: f64, // minimum number of historical days for statistical analysis min_historical_days: u32, // maximum number of historical days to consider in historical analysis max_historical_days: u32, //require_bridge_token: bool, port: u16, updater_port: u16, updater_schedule: String, } #[derive(Debug, Deserialize)] pub struct DbConfig { // The path for the server database, default is "server_db" pub db_path: String, } impl Default for DbConfig { fn default() -> DbConfig { DbConfig { db_path: "server_db".to_owned(), } } } async fn update_daily_info( db: &Db, distributors: &BTreeMap, extra_infos_base_url: &str, confidence: f64, max_threshold: u32, scaling_factor: f64, min_historical_days: u32, max_historical_days: u32, ) -> HashMap<[u8; 20], HashSet> { update_extra_infos(db, extra_infos_base_url).await.unwrap(); update_negative_reports(db, distributors).await; update_positive_reports(db, distributors).await; let new_blockages = guess_blockages( db, &analysis::NormalAnalyzer::new(max_threshold, scaling_factor), confidence, min_historical_days, max_historical_days, ); report_blockages(distributors, new_blockages.clone()).await; // Generate tomorrow's key if we don't already have it new_negative_report_key(db, get_date() + 1); // Return new detected blockages new_blockages } /* async fn run_updater(updater_tx: mpsc::Sender) { updater_tx.send(Command::Update { }).await.unwrap(); } */ async fn create_context_manager( db_config: DbConfig, distributors: BTreeMap, extra_infos_base_url: &str, confidence: f64, max_threshold: u32, scaling_factor: f64, min_historical_days: u32, max_historical_days: u32, context_rx: mpsc::Receiver, mut kill: broadcast::Receiver<()>, ) { tokio::select! { create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, min_historical_days, max_historical_days, context_rx) => create_context, _ = kill.recv() => {println!("Shut down manager");}, } } async fn context_manager( db_config: DbConfig, distributors: BTreeMap, extra_infos_base_url: &str, confidence: f64, max_threshold: u32, scaling_factor: f64, min_historical_days: u32, max_historical_days: u32, mut context_rx: mpsc::Receiver, ) { #[cfg(feature = "simulation")] let (mut max_physical_mem, mut max_virtual_mem) = (0, 0); let db: Db = sled::open(&db_config.db_path).unwrap(); // Create negative report key for today if we don't have one new_negative_report_key(&db, get_date()); while let Some(cmd) = context_rx.recv().await { #[cfg(feature = "simulation")] if let Some(usage) = memory_stats() { if usage.physical_mem > max_physical_mem { max_physical_mem = usage.physical_mem; } if usage.virtual_mem > max_virtual_mem { max_virtual_mem = usage.virtual_mem; } } else { println!("Failed to get the current memory usage"); } use Command::*; match cmd { Request { req, sender } => { let response = handle(&db, req).await; if let Err(e) = sender.send(response) { eprintln!("Server Response Error: {:?}", e); }; sleep(Duration::from_millis(1)).await; } Shutdown { shutdown_sig } => { println!("Sending Shutdown Signal, all threads should shutdown."); drop(shutdown_sig); println!("Shutdown Sent."); #[cfg(feature = "simulation")] println!( "\nMaximum physical memory usage: {}\nMaximum virtual memory usage: {}\n", max_physical_mem, max_virtual_mem ); } Update { _req, sender } => { let blockages = update_daily_info( &db, &distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, min_historical_days, max_historical_days, ) .await; let response = if cfg!(feature = "simulation") { // Convert map keys from [u8; 20] to 40-character hex strings let mut blockages_str = HashMap::>::new(); for (fingerprint, countries) in blockages { let fpr_string = array_bytes::bytes2hex("", fingerprint); blockages_str.insert(fpr_string, countries); } Ok(prepare_header( serde_json::to_string(&blockages_str).unwrap(), )) } else { Ok(prepare_header("OK".to_string())) }; if let Err(e) = sender.send(response) { eprintln!("Update Response Error: {:?}", e); }; sleep(Duration::from_millis(1)).await; } } } } // Each of the commands that can be handled #[derive(Debug)] enum Command { Request { req: Request, sender: oneshot::Sender, Infallible>>, }, Shutdown { shutdown_sig: broadcast::Sender<()>, }, Update { _req: Request, sender: oneshot::Sender, Infallible>>, }, } #[tokio::main] async fn main() { let args: Args = Args::parse(); let config: Config = serde_json::from_reader(BufReader::new( File::open(&args.config).expect("Could not read config file"), )) .expect("Reading config file from JSON failed"); let (request_tx, request_rx) = mpsc::channel(32); let updater_tx = request_tx.clone(); let shutdown_cmd_tx = request_tx.clone(); // create the shutdown broadcast channel let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16); let kill = shutdown_tx.subscribe(); // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx let shutdown_handler = spawn(async move { tokio::select! { _ = signal::ctrl_c() => { let cmd = Command::Shutdown { shutdown_sig: shutdown_tx, }; shutdown_cmd_tx.send(cmd).await.unwrap(); sleep(Duration::from_secs(1)).await; _ = shutdown_rx.recv().await; } } }); // TODO: Reintroduce this /* #[cfg(not(feature = "simulation"))] let updater = spawn(async move { // Run updater once per day let mut sched = Scheduler::utc(); sched.add(Job::new(config.updater_schedule, move || { run_updater(updater_tx.clone()) })); }); */ let context_manager = spawn(async move { create_context_manager( config.db, config.distributors, &config.extra_infos_base_url, config.confidence, config.max_threshold, config.scaling_factor, config.min_historical_days, config.max_historical_days, request_rx, kill, ) .await }); let make_service = make_service_fn(move |_conn: &AddrStream| { let request_tx = request_tx.clone(); let service = service_fn(move |req| { let request_tx = request_tx.clone(); let (response_tx, response_rx) = oneshot::channel(); let cmd = Command::Request { req, sender: response_tx, }; async move { request_tx.send(cmd).await.unwrap(); response_rx.await.unwrap() } }); async move { Ok::<_, Infallible>(service) } }); let updater_make_service = make_service_fn(move |_conn: &AddrStream| { let request_tx = updater_tx.clone(); let service = service_fn(move |_req| { let request_tx = request_tx.clone(); let (response_tx, response_rx) = oneshot::channel(); let cmd = Command::Update { _req, sender: response_tx, }; async move { request_tx.send(cmd).await.unwrap(); response_rx.await.unwrap() } }); async move { Ok::<_, Infallible>(service) } }); let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); let server = Server::bind(&addr).serve(make_service); let graceful = server.with_graceful_shutdown(shutdown_signal()); let updater_addr = SocketAddr::from(([127, 0, 0, 1], config.updater_port)); let updater_server = Server::bind(&updater_addr).serve(updater_make_service); let updater_graceful = updater_server.with_graceful_shutdown(shutdown_signal()); println!("Listening on {}", addr); println!("Updater listening on {}", updater_addr); let (a, b) = join!(graceful, updater_graceful); if a.is_err() { eprintln!("server error: {}", a.unwrap_err()); } if b.is_err() { eprintln!("server error: {}", b.unwrap_err()); } future::join_all([context_manager, shutdown_handler]).await; }