123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358 |
- use troll_patrol::{request_handler::*, *};
- use clap::Parser;
- use futures::future;
- use futures::join;
- use hyper::{
- server::conn::AddrStream,
- service::{make_service_fn, service_fn},
- Body, Request, Response, Server,
- };
- use serde::Deserialize;
- use sled::Db;
- use std::{
- collections::{BTreeMap, HashMap, HashSet},
- convert::Infallible,
- fs::File,
- io::BufReader,
- net::SocketAddr,
- path::PathBuf,
- time::Duration,
- };
- use tokio::{
- signal, spawn,
- sync::{broadcast, mpsc, oneshot},
- time::sleep,
- };
- #[cfg(not(feature = "simulation"))]
- use tokio_cron::{Job, Scheduler};
- #[cfg(feature = "simulation")]
- use memory_stats::memory_stats;
- async fn shutdown_signal() {
- tokio::signal::ctrl_c()
- .await
- .expect("failed to listen for ctrl+c signal");
- println!("Shut down Troll Patrol Server");
- }
- #[derive(Parser, Debug)]
- #[command(author, version, about, long_about = None)]
- struct Args {
- /// Name/path of the configuration file
- #[arg(short, long, default_value = "config.json")]
- config: PathBuf,
- }
- #[derive(Debug, Deserialize)]
- pub struct Config {
- pub db: DbConfig,
- // map of distributor name to IP:port to contact it
- pub distributors: BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: String,
- // confidence required to consider a bridge blocked
- confidence: f64,
- // block open-entry bridges if they get more negative reports than this
- max_threshold: u32,
- // block open-entry bridges if they get more negative reports than
- // scaling_factor * bridge_ips
- scaling_factor: f64,
- // minimum number of historical days for statistical analysis
- min_historical_days: u32,
- // maximum number of historical days to consider in historical analysis
- max_historical_days: u32,
- //require_bridge_token: bool,
- port: u16,
- updater_port: u16,
- updater_schedule: String,
- }
- #[derive(Debug, Deserialize)]
- pub struct DbConfig {
- // The path for the server database, default is "server_db"
- pub db_path: String,
- }
- impl Default for DbConfig {
- fn default() -> DbConfig {
- DbConfig {
- db_path: "server_db".to_owned(),
- }
- }
- }
- async fn update_daily_info(
- db: &Db,
- distributors: &BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: &str,
- confidence: f64,
- max_threshold: u32,
- scaling_factor: f64,
- min_historical_days: u32,
- max_historical_days: u32,
- ) -> HashMap<[u8; 20], HashSet<String>> {
- update_extra_infos(db, extra_infos_base_url).await.unwrap();
- update_negative_reports(db, distributors).await;
- update_positive_reports(db, distributors).await;
- let new_blockages = guess_blockages(
- db,
- // Using max_threshold for convenience
- &lox_analysis::LoxAnalyzer::new(max_threshold),
- confidence,
- min_historical_days,
- max_historical_days,
- );
- report_blockages(distributors, new_blockages.clone()).await;
- // Generate tomorrow's key if we don't already have it
- new_negative_report_key(db, get_date() + 1);
- // Return new detected blockages
- new_blockages
- }
- /*
- async fn run_updater(updater_tx: mpsc::Sender<Command>) {
- updater_tx.send(Command::Update {
- }).await.unwrap();
- }
- */
- async fn create_context_manager(
- db_config: DbConfig,
- distributors: BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: &str,
- confidence: f64,
- max_threshold: u32,
- scaling_factor: f64,
- min_historical_days: u32,
- max_historical_days: u32,
- context_rx: mpsc::Receiver<Command>,
- mut kill: broadcast::Receiver<()>,
- ) {
- tokio::select! {
- create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, min_historical_days, max_historical_days, context_rx) => create_context,
- _ = kill.recv() => {println!("Shut down manager");},
- }
- }
- async fn context_manager(
- db_config: DbConfig,
- distributors: BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: &str,
- confidence: f64,
- max_threshold: u32,
- scaling_factor: f64,
- min_historical_days: u32,
- max_historical_days: u32,
- mut context_rx: mpsc::Receiver<Command>,
- ) {
- #[cfg(feature = "simulation")]
- let (mut max_physical_mem, mut max_virtual_mem) = (0, 0);
- let db: Db = sled::open(&db_config.db_path).unwrap();
- // Create negative report key for today if we don't have one
- new_negative_report_key(&db, get_date());
- while let Some(cmd) = context_rx.recv().await {
- #[cfg(feature = "simulation")]
- if let Some(usage) = memory_stats() {
- if usage.physical_mem > max_physical_mem {
- max_physical_mem = usage.physical_mem;
- }
- if usage.virtual_mem > max_virtual_mem {
- max_virtual_mem = usage.virtual_mem;
- }
- } else {
- println!("Failed to get the current memory usage");
- }
- use Command::*;
- match cmd {
- Request { req, sender } => {
- let response = handle(&db, req).await;
- if let Err(e) = sender.send(response) {
- eprintln!("Server Response Error: {:?}", e);
- };
- sleep(Duration::from_millis(1)).await;
- }
- Shutdown { shutdown_sig } => {
- println!("Sending Shutdown Signal, all threads should shutdown.");
- drop(shutdown_sig);
- println!("Shutdown Sent.");
- #[cfg(feature = "simulation")]
- println!(
- "\nMaximum physical memory usage: {}\nMaximum virtual memory usage: {}\n",
- max_physical_mem, max_virtual_mem
- );
- }
- Update { _req, sender } => {
- let blockages = update_daily_info(
- &db,
- &distributors,
- extra_infos_base_url,
- confidence,
- max_threshold,
- scaling_factor,
- min_historical_days,
- max_historical_days,
- )
- .await;
- let response = if cfg!(feature = "simulation") {
- // Convert map keys from [u8; 20] to 40-character hex strings
- let mut blockages_str = HashMap::<String, HashSet<String>>::new();
- for (fingerprint, countries) in blockages {
- let fpr_string = array_bytes::bytes2hex("", fingerprint);
- blockages_str.insert(fpr_string, countries);
- }
- Ok(prepare_header(
- serde_json::to_string(&blockages_str).unwrap(),
- ))
- } else {
- Ok(prepare_header("OK".to_string()))
- };
- if let Err(e) = sender.send(response) {
- eprintln!("Update Response Error: {:?}", e);
- };
- sleep(Duration::from_millis(1)).await;
- }
- }
- }
- }
- // Each of the commands that can be handled
- #[derive(Debug)]
- enum Command {
- Request {
- req: Request<Body>,
- sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
- },
- Shutdown {
- shutdown_sig: broadcast::Sender<()>,
- },
- Update {
- _req: Request<Body>,
- sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
- },
- }
- #[tokio::main]
- async fn main() {
- let args: Args = Args::parse();
- let config: Config = serde_json::from_reader(BufReader::new(
- File::open(&args.config).expect("Could not read config file"),
- ))
- .expect("Reading config file from JSON failed");
- let (request_tx, request_rx) = mpsc::channel(32);
- let updater_tx = request_tx.clone();
- let shutdown_cmd_tx = request_tx.clone();
- // create the shutdown broadcast channel
- let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
- let kill = shutdown_tx.subscribe();
- // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
- let shutdown_handler = spawn(async move {
- tokio::select! {
- _ = signal::ctrl_c() => {
- let cmd = Command::Shutdown {
- shutdown_sig: shutdown_tx,
- };
- shutdown_cmd_tx.send(cmd).await.unwrap();
- sleep(Duration::from_secs(1)).await;
- _ = shutdown_rx.recv().await;
- }
- }
- });
- // TODO: Reintroduce this
- /*
- #[cfg(not(feature = "simulation"))]
- let updater = spawn(async move {
- // Run updater once per day
- let mut sched = Scheduler::utc();
- sched.add(Job::new(config.updater_schedule, move || {
- run_updater(updater_tx.clone())
- }));
- });
- */
- let context_manager = spawn(async move {
- create_context_manager(
- config.db,
- config.distributors,
- &config.extra_infos_base_url,
- config.confidence,
- config.max_threshold,
- config.scaling_factor,
- config.min_historical_days,
- config.max_historical_days,
- request_rx,
- kill,
- )
- .await
- });
- let make_service = make_service_fn(move |_conn: &AddrStream| {
- let request_tx = request_tx.clone();
- let service = service_fn(move |req| {
- let request_tx = request_tx.clone();
- let (response_tx, response_rx) = oneshot::channel();
- let cmd = Command::Request {
- req,
- sender: response_tx,
- };
- async move {
- request_tx.send(cmd).await.unwrap();
- response_rx.await.unwrap()
- }
- });
- async move { Ok::<_, Infallible>(service) }
- });
- let updater_make_service = make_service_fn(move |_conn: &AddrStream| {
- let request_tx = updater_tx.clone();
- let service = service_fn(move |_req| {
- let request_tx = request_tx.clone();
- let (response_tx, response_rx) = oneshot::channel();
- let cmd = Command::Update {
- _req,
- sender: response_tx,
- };
- async move {
- request_tx.send(cmd).await.unwrap();
- response_rx.await.unwrap()
- }
- });
- async move { Ok::<_, Infallible>(service) }
- });
- let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
- let server = Server::bind(&addr).serve(make_service);
- let graceful = server.with_graceful_shutdown(shutdown_signal());
- let updater_addr = SocketAddr::from(([127, 0, 0, 1], config.updater_port));
- let updater_server = Server::bind(&updater_addr).serve(updater_make_service);
- let updater_graceful = updater_server.with_graceful_shutdown(shutdown_signal());
- println!("Listening on {}", addr);
- println!("Updater listening on {}", updater_addr);
- let (a, b) = join!(graceful, updater_graceful);
- if a.is_err() {
- eprintln!("server error: {}", a.unwrap_err());
- }
- if b.is_err() {
- eprintln!("server error: {}", b.unwrap_err());
- }
- future::join_all([context_manager, shutdown_handler]).await;
- }
|