main.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. use troll_patrol::{request_handler::*, *};
  2. use clap::Parser;
  3. use futures::future;
  4. use futures::join;
  5. use hyper::{
  6. server::conn::AddrStream,
  7. service::{make_service_fn, service_fn},
  8. Body, Request, Response, Server,
  9. };
  10. use serde::Deserialize;
  11. use sled::Db;
  12. use std::{
  13. collections::{BTreeMap, HashMap, HashSet},
  14. convert::Infallible,
  15. fs::File,
  16. io::BufReader,
  17. net::SocketAddr,
  18. path::PathBuf,
  19. time::Duration,
  20. };
  21. use tokio::{
  22. signal, spawn,
  23. sync::{broadcast, mpsc, oneshot},
  24. time::sleep,
  25. };
  26. #[cfg(not(feature = "simulation"))]
  27. use tokio_cron::{Job, Scheduler};
  28. #[cfg(feature = "simulation")]
  29. use memory_stats::memory_stats;
  30. async fn shutdown_signal() {
  31. tokio::signal::ctrl_c()
  32. .await
  33. .expect("failed to listen for ctrl+c signal");
  34. println!("Shut down Troll Patrol Server");
  35. }
  36. #[derive(Parser, Debug)]
  37. #[command(author, version, about, long_about = None)]
  38. struct Args {
  39. /// Name/path of the configuration file
  40. #[arg(short, long, default_value = "config.json")]
  41. config: PathBuf,
  42. }
  43. #[derive(Debug, Deserialize)]
  44. pub struct Config {
  45. pub db: DbConfig,
  46. // map of distributor name to IP:port to contact it
  47. pub distributors: BTreeMap<BridgeDistributor, String>,
  48. extra_infos_base_url: String,
  49. // confidence required to consider a bridge blocked
  50. confidence: f64,
  51. // block open-entry bridges if they get more negative reports than this
  52. max_threshold: u32,
  53. // block open-entry bridges if they get more negative reports than
  54. // scaling_factor * bridge_ips
  55. scaling_factor: f64,
  56. // minimum number of historical days for statistical analysis
  57. min_historical_days: u32,
  58. // maximum number of historical days to consider in historical analysis
  59. max_historical_days: u32,
  60. //require_bridge_token: bool,
  61. port: u16,
  62. updater_port: u16,
  63. updater_schedule: String,
  64. }
  65. #[derive(Debug, Deserialize)]
  66. pub struct DbConfig {
  67. // The path for the server database, default is "server_db"
  68. pub db_path: String,
  69. }
  70. impl Default for DbConfig {
  71. fn default() -> DbConfig {
  72. DbConfig {
  73. db_path: "server_db".to_owned(),
  74. }
  75. }
  76. }
  77. async fn update_daily_info(
  78. db: &Db,
  79. distributors: &BTreeMap<BridgeDistributor, String>,
  80. extra_infos_base_url: &str,
  81. confidence: f64,
  82. max_threshold: u32,
  83. scaling_factor: f64,
  84. min_historical_days: u32,
  85. max_historical_days: u32,
  86. ) -> HashMap<[u8; 20], HashSet<String>> {
  87. update_extra_infos(db, extra_infos_base_url).await.unwrap();
  88. update_negative_reports(db, distributors).await;
  89. update_positive_reports(db, distributors).await;
  90. let new_blockages = guess_blockages(
  91. db,
  92. &analysis::NormalAnalyzer::new(max_threshold, scaling_factor),
  93. confidence,
  94. min_historical_days,
  95. max_historical_days,
  96. );
  97. report_blockages(distributors, new_blockages.clone()).await;
  98. // Generate tomorrow's key if we don't already have it
  99. new_negative_report_key(db, get_date() + 1);
  100. // Return new detected blockages
  101. new_blockages
  102. }
  103. /*
  104. async fn run_updater(updater_tx: mpsc::Sender<Command>) {
  105. updater_tx.send(Command::Update {
  106. }).await.unwrap();
  107. }
  108. */
  109. async fn create_context_manager(
  110. db_config: DbConfig,
  111. distributors: BTreeMap<BridgeDistributor, String>,
  112. extra_infos_base_url: &str,
  113. confidence: f64,
  114. max_threshold: u32,
  115. scaling_factor: f64,
  116. min_historical_days: u32,
  117. max_historical_days: u32,
  118. context_rx: mpsc::Receiver<Command>,
  119. mut kill: broadcast::Receiver<()>,
  120. ) {
  121. tokio::select! {
  122. create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, min_historical_days, max_historical_days, context_rx) => create_context,
  123. _ = kill.recv() => {println!("Shut down manager");},
  124. }
  125. }
  126. async fn context_manager(
  127. db_config: DbConfig,
  128. distributors: BTreeMap<BridgeDistributor, String>,
  129. extra_infos_base_url: &str,
  130. confidence: f64,
  131. max_threshold: u32,
  132. scaling_factor: f64,
  133. min_historical_days: u32,
  134. max_historical_days: u32,
  135. mut context_rx: mpsc::Receiver<Command>,
  136. ) {
  137. #[cfg(feature = "simulation")]
  138. let (mut max_physical_mem, mut max_virtual_mem) = (0, 0);
  139. let db: Db = sled::open(&db_config.db_path).unwrap();
  140. // Create negative report key for today if we don't have one
  141. new_negative_report_key(&db, get_date());
  142. while let Some(cmd) = context_rx.recv().await {
  143. #[cfg(feature = "simulation")]
  144. if let Some(usage) = memory_stats() {
  145. if usage.physical_mem > max_physical_mem {
  146. max_physical_mem = usage.physical_mem;
  147. }
  148. if usage.virtual_mem > max_virtual_mem {
  149. max_virtual_mem = usage.virtual_mem;
  150. }
  151. } else {
  152. println!("Failed to get the current memory usage");
  153. }
  154. use Command::*;
  155. match cmd {
  156. Request { req, sender } => {
  157. let response = handle(&db, req).await;
  158. if let Err(e) = sender.send(response) {
  159. eprintln!("Server Response Error: {:?}", e);
  160. };
  161. sleep(Duration::from_millis(1)).await;
  162. }
  163. Shutdown { shutdown_sig } => {
  164. println!("Sending Shutdown Signal, all threads should shutdown.");
  165. drop(shutdown_sig);
  166. println!("Shutdown Sent.");
  167. #[cfg(feature = "simulation")]
  168. println!(
  169. "\nMaximum physical memory usage: {}\nMaximum virtual memory usage: {}\n",
  170. max_physical_mem, max_virtual_mem
  171. );
  172. }
  173. Update { _req, sender } => {
  174. let blockages = update_daily_info(
  175. &db,
  176. &distributors,
  177. extra_infos_base_url,
  178. confidence,
  179. max_threshold,
  180. scaling_factor,
  181. min_historical_days,
  182. max_historical_days,
  183. )
  184. .await;
  185. let response = if cfg!(feature = "simulation") {
  186. // Convert map keys from [u8; 20] to 40-character hex strings
  187. let mut blockages_str = HashMap::<String, HashSet<String>>::new();
  188. for (fingerprint, countries) in blockages {
  189. let fpr_string = array_bytes::bytes2hex("", fingerprint);
  190. blockages_str.insert(fpr_string, countries);
  191. }
  192. Ok(prepare_header(
  193. serde_json::to_string(&blockages_str).unwrap(),
  194. ))
  195. } else {
  196. Ok(prepare_header("OK".to_string()))
  197. };
  198. if let Err(e) = sender.send(response) {
  199. eprintln!("Update Response Error: {:?}", e);
  200. };
  201. sleep(Duration::from_millis(1)).await;
  202. }
  203. }
  204. }
  205. }
  206. // Each of the commands that can be handled
  207. #[derive(Debug)]
  208. enum Command {
  209. Request {
  210. req: Request<Body>,
  211. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  212. },
  213. Shutdown {
  214. shutdown_sig: broadcast::Sender<()>,
  215. },
  216. Update {
  217. _req: Request<Body>,
  218. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  219. },
  220. }
  221. #[tokio::main]
  222. async fn main() {
  223. let args: Args = Args::parse();
  224. let config: Config = serde_json::from_reader(BufReader::new(
  225. File::open(&args.config).expect("Could not read config file"),
  226. ))
  227. .expect("Reading config file from JSON failed");
  228. let (request_tx, request_rx) = mpsc::channel(32);
  229. let updater_tx = request_tx.clone();
  230. let shutdown_cmd_tx = request_tx.clone();
  231. // create the shutdown broadcast channel
  232. let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
  233. let kill = shutdown_tx.subscribe();
  234. // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
  235. let shutdown_handler = spawn(async move {
  236. tokio::select! {
  237. _ = signal::ctrl_c() => {
  238. let cmd = Command::Shutdown {
  239. shutdown_sig: shutdown_tx,
  240. };
  241. shutdown_cmd_tx.send(cmd).await.unwrap();
  242. sleep(Duration::from_secs(1)).await;
  243. _ = shutdown_rx.recv().await;
  244. }
  245. }
  246. });
  247. // TODO: Reintroduce this
  248. /*
  249. #[cfg(not(feature = "simulation"))]
  250. let updater = spawn(async move {
  251. // Run updater once per day
  252. let mut sched = Scheduler::utc();
  253. sched.add(Job::new(config.updater_schedule, move || {
  254. run_updater(updater_tx.clone())
  255. }));
  256. });
  257. */
  258. let context_manager = spawn(async move {
  259. create_context_manager(
  260. config.db,
  261. config.distributors,
  262. &config.extra_infos_base_url,
  263. config.confidence,
  264. config.max_threshold,
  265. config.scaling_factor,
  266. config.min_historical_days,
  267. config.max_historical_days,
  268. request_rx,
  269. kill,
  270. )
  271. .await
  272. });
  273. let make_service = make_service_fn(move |_conn: &AddrStream| {
  274. let request_tx = request_tx.clone();
  275. let service = service_fn(move |req| {
  276. let request_tx = request_tx.clone();
  277. let (response_tx, response_rx) = oneshot::channel();
  278. let cmd = Command::Request {
  279. req,
  280. sender: response_tx,
  281. };
  282. async move {
  283. request_tx.send(cmd).await.unwrap();
  284. response_rx.await.unwrap()
  285. }
  286. });
  287. async move { Ok::<_, Infallible>(service) }
  288. });
  289. let updater_make_service = make_service_fn(move |_conn: &AddrStream| {
  290. let request_tx = updater_tx.clone();
  291. let service = service_fn(move |_req| {
  292. let request_tx = request_tx.clone();
  293. let (response_tx, response_rx) = oneshot::channel();
  294. let cmd = Command::Update {
  295. _req,
  296. sender: response_tx,
  297. };
  298. async move {
  299. request_tx.send(cmd).await.unwrap();
  300. response_rx.await.unwrap()
  301. }
  302. });
  303. async move { Ok::<_, Infallible>(service) }
  304. });
  305. let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
  306. let server = Server::bind(&addr).serve(make_service);
  307. let graceful = server.with_graceful_shutdown(shutdown_signal());
  308. let updater_addr = SocketAddr::from(([127, 0, 0, 1], config.updater_port));
  309. let updater_server = Server::bind(&updater_addr).serve(updater_make_service);
  310. let updater_graceful = updater_server.with_graceful_shutdown(shutdown_signal());
  311. println!("Listening on {}", addr);
  312. println!("Updater listening on {}", updater_addr);
  313. let (a, b) = join!(graceful, updater_graceful);
  314. if a.is_err() {
  315. eprintln!("server error: {}", a.unwrap_err());
  316. }
  317. if b.is_err() {
  318. eprintln!("server error: {}", b.unwrap_err());
  319. }
  320. future::join_all([context_manager, shutdown_handler]).await;
  321. }