123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- use troll_patrol::{request_handler::handle, *};
- use clap::Parser;
- use futures::future;
- use hyper::{
- server::conn::AddrStream,
- service::{make_service_fn, service_fn},
- Body, Request, Response, Server,
- };
- use serde::Deserialize;
- use sled::Db;
- use std::{
- collections::BTreeMap, convert::Infallible, fs::File, io::BufReader, net::SocketAddr,
- path::PathBuf, time::Duration,
- };
- use tokio::{
- signal, spawn,
- sync::{broadcast, mpsc, oneshot},
- time::sleep,
- };
- use tokio_cron::{Job, Scheduler};
- async fn shutdown_signal() {
- tokio::signal::ctrl_c()
- .await
- .expect("failed to listen for ctrl+c signal");
- println!("Shut down Troll Patrol Server");
- }
- #[derive(Parser, Debug)]
- #[command(author, version, about, long_about = None)]
- struct Args {
- /// Name/path of the configuration file
- #[arg(short, long, default_value = "config.json")]
- config: PathBuf,
- }
- #[derive(Debug, Deserialize)]
- pub struct Config {
- pub db: DbConfig,
- // map of distributor name to IP:port to contact it
- pub distributors: BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: String,
- // confidence required to consider a bridge blocked
- confidence: f64,
- // block open-entry bridges if they get more negative reports than this
- max_threshold: u32,
- // block open-entry bridges if they get more negative reports than
- // scaling_factor * bridge_ips
- scaling_factor: f64,
- //require_bridge_token: bool,
- port: u16,
- updater_schedule: String,
- }
- #[derive(Debug, Deserialize)]
- pub struct DbConfig {
- // The path for the server database, default is "server_db"
- pub db_path: String,
- }
- impl Default for DbConfig {
- fn default() -> DbConfig {
- DbConfig {
- db_path: "server_db".to_owned(),
- }
- }
- }
- async fn update_daily_info(
- db: &Db,
- distributors: &BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: &str,
- confidence: f64,
- max_threshold: u32,
- scaling_factor: f64,
- ) {
- update_extra_infos(&db, &extra_infos_base_url)
- .await
- .unwrap();
- update_negative_reports(&db, &distributors).await;
- update_positive_reports(&db, &distributors).await;
- let new_blockages = guess_blockages(
- &db,
- &analysis::NormalAnalyzer::new(max_threshold, scaling_factor),
- confidence,
- );
- report_blockages(&distributors, new_blockages).await;
- // Generate tomorrow's key if we don't already have it
- new_negative_report_key(&db, get_date() + 1);
- }
- async fn run_updater(updater_tx: mpsc::Sender<Command>) {
- updater_tx.send(Command::Update {}).await.unwrap();
- }
- async fn create_context_manager(
- db_config: DbConfig,
- distributors: BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: &str,
- confidence: f64,
- max_threshold: u32,
- scaling_factor: f64,
- context_rx: mpsc::Receiver<Command>,
- mut kill: broadcast::Receiver<()>,
- ) {
- tokio::select! {
- create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, context_rx) => create_context,
- _ = kill.recv() => {println!("Shut down manager");},
- }
- }
- async fn context_manager(
- db_config: DbConfig,
- distributors: BTreeMap<BridgeDistributor, String>,
- extra_infos_base_url: &str,
- confidence: f64,
- max_threshold: u32,
- scaling_factor: f64,
- mut context_rx: mpsc::Receiver<Command>,
- ) {
- let db: Db = sled::open(&db_config.db_path).unwrap();
- while let Some(cmd) = context_rx.recv().await {
- use Command::*;
- match cmd {
- Request { req, sender } => {
- let response = handle(&db, req).await;
- if let Err(e) = sender.send(response) {
- eprintln!("Server Response Error: {:?}", e);
- };
- sleep(Duration::from_millis(1)).await;
- }
- Shutdown { shutdown_sig } => {
- println!("Sending Shutdown Signal, all threads should shutdown.");
- drop(shutdown_sig);
- println!("Shutdown Sent.");
- }
- Update {} => {
- update_daily_info(
- &db,
- &distributors,
- &extra_infos_base_url,
- confidence,
- max_threshold,
- scaling_factor,
- )
- .await;
- }
- }
- }
- }
- // Each of the commands that can be handled
- #[derive(Debug)]
- enum Command {
- Request {
- req: Request<Body>,
- sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
- },
- Shutdown {
- shutdown_sig: broadcast::Sender<()>,
- },
- Update {},
- }
- #[tokio::main]
- async fn main() {
- let args: Args = Args::parse();
- let config: Config = serde_json::from_reader(BufReader::new(
- File::open(&args.config).expect("Could not read config file"),
- ))
- .expect("Reading config file from JSON failed");
- let (request_tx, request_rx) = mpsc::channel(32);
- let updater_tx = request_tx.clone();
- let shutdown_cmd_tx = request_tx.clone();
- // create the shutdown broadcast channel and clone for every thread
- let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
- let kill = shutdown_tx.subscribe();
- // TODO: Gracefully shut down updater
- let kill_updater = shutdown_tx.subscribe();
- // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
- let shutdown_handler = spawn(async move {
- tokio::select! {
- _ = signal::ctrl_c() => {
- let cmd = Command::Shutdown {
- shutdown_sig: shutdown_tx,
- };
- shutdown_cmd_tx.send(cmd).await.unwrap();
- sleep(Duration::from_secs(1)).await;
- _ = shutdown_rx.recv().await;
- }
- }
- });
- let updater = spawn(async move {
- // Run updater once per day
- let mut sched = Scheduler::utc();
- sched.add(Job::new(config.updater_schedule, move || {
- run_updater(updater_tx.clone())
- }));
- });
- let context_manager = spawn(async move {
- create_context_manager(
- config.db,
- config.distributors,
- &config.extra_infos_base_url,
- config.confidence,
- config.max_threshold,
- config.scaling_factor,
- request_rx,
- kill,
- )
- .await
- });
- let make_service = make_service_fn(move |_conn: &AddrStream| {
- let request_tx = request_tx.clone();
- let service = service_fn(move |req| {
- let request_tx = request_tx.clone();
- let (response_tx, response_rx) = oneshot::channel();
- let cmd = Command::Request {
- req,
- sender: response_tx,
- };
- async move {
- request_tx.send(cmd).await.unwrap();
- response_rx.await.unwrap()
- }
- });
- async move { Ok::<_, Infallible>(service) }
- });
- let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
- let server = Server::bind(&addr).serve(make_service);
- let graceful = server.with_graceful_shutdown(shutdown_signal());
- println!("Listening on {}", addr);
- if let Err(e) = graceful.await {
- eprintln!("server error: {}", e);
- }
- future::join_all([context_manager, updater, shutdown_handler]).await;
- }
|