main.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. use troll_patrol::{request_handler::*, *};
  2. use clap::Parser;
  3. use futures::future;
  4. use futures::join;
  5. use hyper::{
  6. server::conn::AddrStream,
  7. service::{make_service_fn, service_fn},
  8. Body, Request, Response, Server,
  9. };
  10. use serde::Deserialize;
  11. use sled::Db;
  12. use std::{
  13. collections::{BTreeMap, HashMap, HashSet},
  14. convert::Infallible,
  15. fs::File,
  16. io::BufReader,
  17. net::SocketAddr,
  18. path::PathBuf,
  19. time::Duration,
  20. };
  21. use tokio::{
  22. signal, spawn,
  23. sync::{broadcast, mpsc, oneshot},
  24. time::sleep,
  25. };
  26. #[cfg(not(feature = "simulation"))]
  27. use tokio_cron::{Job, Scheduler};
  28. async fn shutdown_signal() {
  29. tokio::signal::ctrl_c()
  30. .await
  31. .expect("failed to listen for ctrl+c signal");
  32. println!("Shut down Troll Patrol Server");
  33. }
  34. #[derive(Parser, Debug)]
  35. #[command(author, version, about, long_about = None)]
  36. struct Args {
  37. /// Name/path of the configuration file
  38. #[arg(short, long, default_value = "config.json")]
  39. config: PathBuf,
  40. }
  41. #[derive(Debug, Deserialize)]
  42. pub struct Config {
  43. pub db: DbConfig,
  44. // map of distributor name to IP:port to contact it
  45. pub distributors: BTreeMap<BridgeDistributor, String>,
  46. extra_infos_base_url: String,
  47. // confidence required to consider a bridge blocked
  48. confidence: f64,
  49. // block open-entry bridges if they get more negative reports than this
  50. max_threshold: u32,
  51. // block open-entry bridges if they get more negative reports than
  52. // scaling_factor * bridge_ips
  53. scaling_factor: f64,
  54. // minimum number of historical days for statistical analysis
  55. min_historical_days: u32,
  56. // maximum number of historical days to consider in historical analysis
  57. max_historical_days: u32,
  58. //require_bridge_token: bool,
  59. port: u16,
  60. updater_port: u16,
  61. updater_schedule: String,
  62. }
  63. #[derive(Debug, Deserialize)]
  64. pub struct DbConfig {
  65. // The path for the server database, default is "server_db"
  66. pub db_path: String,
  67. }
  68. impl Default for DbConfig {
  69. fn default() -> DbConfig {
  70. DbConfig {
  71. db_path: "server_db".to_owned(),
  72. }
  73. }
  74. }
  75. async fn update_daily_info(
  76. db: &Db,
  77. distributors: &BTreeMap<BridgeDistributor, String>,
  78. extra_infos_base_url: &str,
  79. confidence: f64,
  80. max_threshold: u32,
  81. scaling_factor: f64,
  82. min_historical_days: u32,
  83. max_historical_days: u32,
  84. ) -> HashMap<[u8; 20], HashSet<String>> {
  85. update_extra_infos(db, extra_infos_base_url).await.unwrap();
  86. update_negative_reports(db, distributors).await;
  87. update_positive_reports(db, distributors).await;
  88. let new_blockages = guess_blockages(
  89. db,
  90. &analysis::NormalAnalyzer::new(max_threshold, scaling_factor),
  91. confidence,
  92. min_historical_days,
  93. max_historical_days,
  94. );
  95. report_blockages(distributors, new_blockages.clone()).await;
  96. // Generate tomorrow's key if we don't already have it
  97. new_negative_report_key(db, get_date() + 1);
  98. // Return new detected blockages
  99. new_blockages
  100. }
  101. /*
  102. async fn run_updater(updater_tx: mpsc::Sender<Command>) {
  103. updater_tx.send(Command::Update {
  104. }).await.unwrap();
  105. }
  106. */
  107. async fn create_context_manager(
  108. db_config: DbConfig,
  109. distributors: BTreeMap<BridgeDistributor, String>,
  110. extra_infos_base_url: &str,
  111. confidence: f64,
  112. max_threshold: u32,
  113. scaling_factor: f64,
  114. min_historical_days: u32,
  115. max_historical_days: u32,
  116. context_rx: mpsc::Receiver<Command>,
  117. mut kill: broadcast::Receiver<()>,
  118. ) {
  119. tokio::select! {
  120. create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, min_historical_days, max_historical_days, context_rx) => create_context,
  121. _ = kill.recv() => {println!("Shut down manager");},
  122. }
  123. }
  124. async fn context_manager(
  125. db_config: DbConfig,
  126. distributors: BTreeMap<BridgeDistributor, String>,
  127. extra_infos_base_url: &str,
  128. confidence: f64,
  129. max_threshold: u32,
  130. scaling_factor: f64,
  131. min_historical_days: u32,
  132. max_historical_days: u32,
  133. mut context_rx: mpsc::Receiver<Command>,
  134. ) {
  135. let db: Db = sled::open(&db_config.db_path).unwrap();
  136. // Create negative report key for today if we don't have one
  137. new_negative_report_key(&db, get_date());
  138. while let Some(cmd) = context_rx.recv().await {
  139. use Command::*;
  140. match cmd {
  141. Request { req, sender } => {
  142. let response = handle(&db, req).await;
  143. if let Err(e) = sender.send(response) {
  144. eprintln!("Server Response Error: {:?}", e);
  145. };
  146. sleep(Duration::from_millis(1)).await;
  147. }
  148. Shutdown { shutdown_sig } => {
  149. println!("Sending Shutdown Signal, all threads should shutdown.");
  150. drop(shutdown_sig);
  151. println!("Shutdown Sent.");
  152. }
  153. Update { _req, sender } => {
  154. let blockages = update_daily_info(
  155. &db,
  156. &distributors,
  157. extra_infos_base_url,
  158. confidence,
  159. max_threshold,
  160. scaling_factor,
  161. min_historical_days,
  162. max_historical_days,
  163. )
  164. .await;
  165. let response = if cfg!(feature = "simulation") {
  166. // Convert map keys from [u8; 20] to 40-character hex strings
  167. let mut blockages_str = HashMap::<String, HashSet<String>>::new();
  168. for (fingerprint, countries) in blockages {
  169. let fpr_string = array_bytes::bytes2hex("", fingerprint);
  170. blockages_str.insert(fpr_string, countries);
  171. }
  172. Ok(prepare_header(
  173. serde_json::to_string(&blockages_str).unwrap(),
  174. ))
  175. } else {
  176. Ok(prepare_header("OK".to_string()))
  177. };
  178. if let Err(e) = sender.send(response) {
  179. eprintln!("Update Response Error: {:?}", e);
  180. };
  181. sleep(Duration::from_millis(1)).await;
  182. }
  183. }
  184. }
  185. }
  186. // Each of the commands that can be handled
  187. #[derive(Debug)]
  188. enum Command {
  189. Request {
  190. req: Request<Body>,
  191. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  192. },
  193. Shutdown {
  194. shutdown_sig: broadcast::Sender<()>,
  195. },
  196. Update {
  197. _req: Request<Body>,
  198. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  199. },
  200. }
  201. #[tokio::main]
  202. async fn main() {
  203. let args: Args = Args::parse();
  204. let config: Config = serde_json::from_reader(BufReader::new(
  205. File::open(&args.config).expect("Could not read config file"),
  206. ))
  207. .expect("Reading config file from JSON failed");
  208. let (request_tx, request_rx) = mpsc::channel(32);
  209. let updater_tx = request_tx.clone();
  210. let shutdown_cmd_tx = request_tx.clone();
  211. // create the shutdown broadcast channel
  212. let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
  213. let kill = shutdown_tx.subscribe();
  214. // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
  215. let shutdown_handler = spawn(async move {
  216. tokio::select! {
  217. _ = signal::ctrl_c() => {
  218. let cmd = Command::Shutdown {
  219. shutdown_sig: shutdown_tx,
  220. };
  221. shutdown_cmd_tx.send(cmd).await.unwrap();
  222. sleep(Duration::from_secs(1)).await;
  223. _ = shutdown_rx.recv().await;
  224. }
  225. }
  226. });
  227. // TODO: Reintroduce this
  228. /*
  229. #[cfg(not(feature = "simulation"))]
  230. let updater = spawn(async move {
  231. // Run updater once per day
  232. let mut sched = Scheduler::utc();
  233. sched.add(Job::new(config.updater_schedule, move || {
  234. run_updater(updater_tx.clone())
  235. }));
  236. });
  237. */
  238. let context_manager = spawn(async move {
  239. create_context_manager(
  240. config.db,
  241. config.distributors,
  242. &config.extra_infos_base_url,
  243. config.confidence,
  244. config.max_threshold,
  245. config.scaling_factor,
  246. config.min_historical_days,
  247. config.max_historical_days,
  248. request_rx,
  249. kill,
  250. )
  251. .await
  252. });
  253. let make_service = make_service_fn(move |_conn: &AddrStream| {
  254. let request_tx = request_tx.clone();
  255. let service = service_fn(move |req| {
  256. let request_tx = request_tx.clone();
  257. let (response_tx, response_rx) = oneshot::channel();
  258. let cmd = Command::Request {
  259. req,
  260. sender: response_tx,
  261. };
  262. async move {
  263. request_tx.send(cmd).await.unwrap();
  264. response_rx.await.unwrap()
  265. }
  266. });
  267. async move { Ok::<_, Infallible>(service) }
  268. });
  269. let updater_make_service = make_service_fn(move |_conn: &AddrStream| {
  270. let request_tx = updater_tx.clone();
  271. let service = service_fn(move |_req| {
  272. let request_tx = request_tx.clone();
  273. let (response_tx, response_rx) = oneshot::channel();
  274. let cmd = Command::Update {
  275. _req,
  276. sender: response_tx,
  277. };
  278. async move {
  279. request_tx.send(cmd).await.unwrap();
  280. response_rx.await.unwrap()
  281. }
  282. });
  283. async move { Ok::<_, Infallible>(service) }
  284. });
  285. let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
  286. let server = Server::bind(&addr).serve(make_service);
  287. let graceful = server.with_graceful_shutdown(shutdown_signal());
  288. let updater_addr = SocketAddr::from(([127, 0, 0, 1], config.updater_port));
  289. let updater_server = Server::bind(&updater_addr).serve(updater_make_service);
  290. let updater_graceful = updater_server.with_graceful_shutdown(shutdown_signal());
  291. println!("Listening on {}", addr);
  292. println!("Updater listening on {}", updater_addr);
  293. let (a, b) = join!(graceful, updater_graceful);
  294. if a.is_err() {
  295. eprintln!("server error: {}", a.unwrap_err());
  296. }
  297. if b.is_err() {
  298. eprintln!("server error: {}", b.unwrap_err());
  299. }
  300. future::join_all([context_manager, shutdown_handler]).await;
  301. }