main.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. use troll_patrol::{request_handler::*, *};
  2. use clap::Parser;
  3. use futures::future;
  4. use futures::join;
  5. use hyper::{
  6. server::conn::AddrStream,
  7. service::{make_service_fn, service_fn},
  8. Body, Request, Response, Server,
  9. };
  10. use serde::Deserialize;
  11. use sled::Db;
  12. use std::{
  13. collections::{BTreeMap, HashMap, HashSet},
  14. convert::Infallible,
  15. fs::File,
  16. io::BufReader,
  17. net::SocketAddr,
  18. path::PathBuf,
  19. time::Duration,
  20. };
  21. use tokio::{
  22. signal, spawn,
  23. sync::{broadcast, mpsc, oneshot},
  24. time::sleep,
  25. };
  26. #[cfg(not(feature = "simulation"))]
  27. use tokio_cron::{Job, Scheduler};
  28. #[cfg(feature = "simulation")]
  29. use memory_stats::memory_stats;
  30. async fn shutdown_signal() {
  31. tokio::signal::ctrl_c()
  32. .await
  33. .expect("failed to listen for ctrl+c signal");
  34. println!("Shut down Troll Patrol Server");
  35. }
  36. #[derive(Parser, Debug)]
  37. #[command(author, version, about, long_about = None)]
  38. struct Args {
  39. /// Name/path of the configuration file
  40. #[arg(short, long, default_value = "config.json")]
  41. config: PathBuf,
  42. }
  43. #[derive(Debug, Deserialize)]
  44. pub struct Config {
  45. pub db: DbConfig,
  46. // map of distributor name to IP:port to contact it
  47. pub distributors: BTreeMap<BridgeDistributor, String>,
  48. extra_infos_base_url: String,
  49. // confidence required to consider a bridge blocked
  50. confidence: f64,
  51. // block open-entry bridges if they get more negative reports than this
  52. max_threshold: u32,
  53. // block open-entry bridges if they get more negative reports than
  54. // scaling_factor * bridge_ips
  55. scaling_factor: f64,
  56. // minimum number of historical days for statistical analysis
  57. min_historical_days: u32,
  58. // maximum number of historical days to consider in historical analysis
  59. max_historical_days: u32,
  60. //require_bridge_token: bool,
  61. port: u16,
  62. updater_port: u16,
  63. updater_schedule: String,
  64. }
  65. #[derive(Debug, Deserialize)]
  66. pub struct DbConfig {
  67. // The path for the server database, default is "server_db"
  68. pub db_path: String,
  69. }
  70. impl Default for DbConfig {
  71. fn default() -> DbConfig {
  72. DbConfig {
  73. db_path: "server_db".to_owned(),
  74. }
  75. }
  76. }
  77. async fn update_daily_info(
  78. db: &Db,
  79. distributors: &BTreeMap<BridgeDistributor, String>,
  80. extra_infos_base_url: &str,
  81. confidence: f64,
  82. max_threshold: u32,
  83. scaling_factor: f64,
  84. min_historical_days: u32,
  85. max_historical_days: u32,
  86. ) -> HashMap<[u8; 20], HashSet<String>> {
  87. update_extra_infos(db, extra_infos_base_url).await.unwrap();
  88. update_negative_reports(db, distributors).await;
  89. update_positive_reports(db, distributors).await;
  90. let new_blockages = guess_blockages(
  91. db,
  92. // Using max_threshold for convenience
  93. &lox_analysis::LoxAnalyzer::new(max_threshold),
  94. confidence,
  95. min_historical_days,
  96. max_historical_days,
  97. );
  98. report_blockages(distributors, new_blockages.clone()).await;
  99. // Generate tomorrow's key if we don't already have it
  100. new_negative_report_key(db, get_date() + 1);
  101. // Return new detected blockages
  102. new_blockages
  103. }
  104. /*
  105. async fn run_updater(updater_tx: mpsc::Sender<Command>) {
  106. updater_tx.send(Command::Update {
  107. }).await.unwrap();
  108. }
  109. */
  110. async fn create_context_manager(
  111. db_config: DbConfig,
  112. distributors: BTreeMap<BridgeDistributor, String>,
  113. extra_infos_base_url: &str,
  114. confidence: f64,
  115. max_threshold: u32,
  116. scaling_factor: f64,
  117. min_historical_days: u32,
  118. max_historical_days: u32,
  119. context_rx: mpsc::Receiver<Command>,
  120. mut kill: broadcast::Receiver<()>,
  121. ) {
  122. tokio::select! {
  123. create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, min_historical_days, max_historical_days, context_rx) => create_context,
  124. _ = kill.recv() => {println!("Shut down manager");},
  125. }
  126. }
  127. async fn context_manager(
  128. db_config: DbConfig,
  129. distributors: BTreeMap<BridgeDistributor, String>,
  130. extra_infos_base_url: &str,
  131. confidence: f64,
  132. max_threshold: u32,
  133. scaling_factor: f64,
  134. min_historical_days: u32,
  135. max_historical_days: u32,
  136. mut context_rx: mpsc::Receiver<Command>,
  137. ) {
  138. #[cfg(feature = "simulation")]
  139. let (mut max_physical_mem, mut max_virtual_mem) = (0, 0);
  140. let db: Db = sled::open(&db_config.db_path).unwrap();
  141. // Create negative report key for today if we don't have one
  142. new_negative_report_key(&db, get_date());
  143. while let Some(cmd) = context_rx.recv().await {
  144. #[cfg(feature = "simulation")]
  145. if let Some(usage) = memory_stats() {
  146. if usage.physical_mem > max_physical_mem {
  147. max_physical_mem = usage.physical_mem;
  148. }
  149. if usage.virtual_mem > max_virtual_mem {
  150. max_virtual_mem = usage.virtual_mem;
  151. }
  152. } else {
  153. println!("Failed to get the current memory usage");
  154. }
  155. use Command::*;
  156. match cmd {
  157. Request { req, sender } => {
  158. let response = handle(&db, req).await;
  159. if let Err(e) = sender.send(response) {
  160. eprintln!("Server Response Error: {:?}", e);
  161. };
  162. sleep(Duration::from_millis(1)).await;
  163. }
  164. Shutdown { shutdown_sig } => {
  165. println!("Sending Shutdown Signal, all threads should shutdown.");
  166. drop(shutdown_sig);
  167. println!("Shutdown Sent.");
  168. #[cfg(feature = "simulation")]
  169. println!(
  170. "\nMaximum physical memory usage: {}\nMaximum virtual memory usage: {}\n",
  171. max_physical_mem, max_virtual_mem
  172. );
  173. }
  174. Update { _req, sender } => {
  175. let blockages = update_daily_info(
  176. &db,
  177. &distributors,
  178. extra_infos_base_url,
  179. confidence,
  180. max_threshold,
  181. scaling_factor,
  182. min_historical_days,
  183. max_historical_days,
  184. )
  185. .await;
  186. let response = if cfg!(feature = "simulation") {
  187. // Convert map keys from [u8; 20] to 40-character hex strings
  188. let mut blockages_str = HashMap::<String, HashSet<String>>::new();
  189. for (fingerprint, countries) in blockages {
  190. let fpr_string = array_bytes::bytes2hex("", fingerprint);
  191. blockages_str.insert(fpr_string, countries);
  192. }
  193. Ok(prepare_header(
  194. serde_json::to_string(&blockages_str).unwrap(),
  195. ))
  196. } else {
  197. Ok(prepare_header("OK".to_string()))
  198. };
  199. if let Err(e) = sender.send(response) {
  200. eprintln!("Update Response Error: {:?}", e);
  201. };
  202. sleep(Duration::from_millis(1)).await;
  203. }
  204. }
  205. }
  206. }
  207. // Each of the commands that can be handled
  208. #[derive(Debug)]
  209. enum Command {
  210. Request {
  211. req: Request<Body>,
  212. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  213. },
  214. Shutdown {
  215. shutdown_sig: broadcast::Sender<()>,
  216. },
  217. Update {
  218. _req: Request<Body>,
  219. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  220. },
  221. }
  222. #[tokio::main]
  223. async fn main() {
  224. let args: Args = Args::parse();
  225. let config: Config = serde_json::from_reader(BufReader::new(
  226. File::open(&args.config).expect("Could not read config file"),
  227. ))
  228. .expect("Reading config file from JSON failed");
  229. let (request_tx, request_rx) = mpsc::channel(32);
  230. let updater_tx = request_tx.clone();
  231. let shutdown_cmd_tx = request_tx.clone();
  232. // create the shutdown broadcast channel
  233. let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
  234. let kill = shutdown_tx.subscribe();
  235. // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
  236. let shutdown_handler = spawn(async move {
  237. tokio::select! {
  238. _ = signal::ctrl_c() => {
  239. let cmd = Command::Shutdown {
  240. shutdown_sig: shutdown_tx,
  241. };
  242. shutdown_cmd_tx.send(cmd).await.unwrap();
  243. sleep(Duration::from_secs(1)).await;
  244. _ = shutdown_rx.recv().await;
  245. }
  246. }
  247. });
  248. // TODO: Reintroduce this
  249. /*
  250. #[cfg(not(feature = "simulation"))]
  251. let updater = spawn(async move {
  252. // Run updater once per day
  253. let mut sched = Scheduler::utc();
  254. sched.add(Job::new(config.updater_schedule, move || {
  255. run_updater(updater_tx.clone())
  256. }));
  257. });
  258. */
  259. let context_manager = spawn(async move {
  260. create_context_manager(
  261. config.db,
  262. config.distributors,
  263. &config.extra_infos_base_url,
  264. config.confidence,
  265. config.max_threshold,
  266. config.scaling_factor,
  267. config.min_historical_days,
  268. config.max_historical_days,
  269. request_rx,
  270. kill,
  271. )
  272. .await
  273. });
  274. let make_service = make_service_fn(move |_conn: &AddrStream| {
  275. let request_tx = request_tx.clone();
  276. let service = service_fn(move |req| {
  277. let request_tx = request_tx.clone();
  278. let (response_tx, response_rx) = oneshot::channel();
  279. let cmd = Command::Request {
  280. req,
  281. sender: response_tx,
  282. };
  283. async move {
  284. request_tx.send(cmd).await.unwrap();
  285. response_rx.await.unwrap()
  286. }
  287. });
  288. async move { Ok::<_, Infallible>(service) }
  289. });
  290. let updater_make_service = make_service_fn(move |_conn: &AddrStream| {
  291. let request_tx = updater_tx.clone();
  292. let service = service_fn(move |_req| {
  293. let request_tx = request_tx.clone();
  294. let (response_tx, response_rx) = oneshot::channel();
  295. let cmd = Command::Update {
  296. _req,
  297. sender: response_tx,
  298. };
  299. async move {
  300. request_tx.send(cmd).await.unwrap();
  301. response_rx.await.unwrap()
  302. }
  303. });
  304. async move { Ok::<_, Infallible>(service) }
  305. });
  306. let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
  307. let server = Server::bind(&addr).serve(make_service);
  308. let graceful = server.with_graceful_shutdown(shutdown_signal());
  309. let updater_addr = SocketAddr::from(([127, 0, 0, 1], config.updater_port));
  310. let updater_server = Server::bind(&updater_addr).serve(updater_make_service);
  311. let updater_graceful = updater_server.with_graceful_shutdown(shutdown_signal());
  312. println!("Listening on {}", addr);
  313. println!("Updater listening on {}", updater_addr);
  314. let (a, b) = join!(graceful, updater_graceful);
  315. if a.is_err() {
  316. eprintln!("server error: {}", a.unwrap_err());
  317. }
  318. if b.is_err() {
  319. eprintln!("server error: {}", b.unwrap_err());
  320. }
  321. future::join_all([context_manager, shutdown_handler]).await;
  322. }