server.rs 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. use troll_patrol::{request_handler::handle, *};
  2. use clap::Parser;
  3. use futures::future;
  4. use hyper::{
  5. server::conn::AddrStream,
  6. service::{make_service_fn, service_fn},
  7. Body, Request, Response, Server,
  8. };
  9. use serde::Deserialize;
  10. use sled::Db;
  11. use std::{
  12. collections::BTreeMap, convert::Infallible, fs::File, io::BufReader, net::SocketAddr,
  13. path::PathBuf, time::Duration,
  14. };
  15. use tokio::{
  16. signal, spawn,
  17. sync::{broadcast, mpsc, oneshot},
  18. time::sleep,
  19. };
  20. use tokio_cron::{Job, Scheduler};
  21. async fn shutdown_signal() {
  22. tokio::signal::ctrl_c()
  23. .await
  24. .expect("failed to listen for ctrl+c signal");
  25. println!("Shut down Troll Patrol Server");
  26. }
  27. #[derive(Parser, Debug)]
  28. #[command(author, version, about, long_about = None)]
  29. struct Args {
  30. /// Name/path of the configuration file
  31. #[arg(short, long, default_value = "config.json")]
  32. config: PathBuf,
  33. }
  34. #[derive(Debug, Deserialize)]
  35. pub struct Config {
  36. pub db: DbConfig,
  37. // map of distributor name to IP:port to contact it
  38. pub distributors: BTreeMap<BridgeDistributor, String>,
  39. extra_infos_base_url: String,
  40. // confidence required to consider a bridge blocked
  41. confidence: f64,
  42. // block open-entry bridges if they get more negative reports than this
  43. max_threshold: u32,
  44. // block open-entry bridges if they get more negative reports than
  45. // scaling_factor * bridge_ips
  46. scaling_factor: f64,
  47. //require_bridge_token: bool,
  48. port: u16,
  49. updater_schedule: String,
  50. }
  51. #[derive(Debug, Deserialize)]
  52. pub struct DbConfig {
  53. // The path for the server database, default is "server_db"
  54. pub db_path: String,
  55. }
  56. impl Default for DbConfig {
  57. fn default() -> DbConfig {
  58. DbConfig {
  59. db_path: "server_db".to_owned(),
  60. }
  61. }
  62. }
  63. async fn update_daily_info(
  64. db: &Db,
  65. distributors: &BTreeMap<BridgeDistributor, String>,
  66. extra_infos_base_url: &str,
  67. confidence: f64,
  68. max_threshold: u32,
  69. scaling_factor: f64,
  70. ) {
  71. update_extra_infos(&db, &extra_infos_base_url)
  72. .await
  73. .unwrap();
  74. update_negative_reports(&db, &distributors).await;
  75. update_positive_reports(&db, &distributors).await;
  76. let new_blockages = guess_blockages(
  77. &db,
  78. &analysis::NormalAnalyzer::new(max_threshold, scaling_factor),
  79. confidence,
  80. );
  81. report_blockages(&distributors, new_blockages).await;
  82. // Generate tomorrow's key if we don't already have it
  83. new_negative_report_key(&db, get_date() + 1);
  84. }
  85. async fn run_updater(updater_tx: mpsc::Sender<Command>) {
  86. updater_tx.send(Command::Update {}).await.unwrap();
  87. }
  88. async fn create_context_manager(
  89. db_config: DbConfig,
  90. distributors: BTreeMap<BridgeDistributor, String>,
  91. extra_infos_base_url: &str,
  92. confidence: f64,
  93. max_threshold: u32,
  94. scaling_factor: f64,
  95. context_rx: mpsc::Receiver<Command>,
  96. mut kill: broadcast::Receiver<()>,
  97. ) {
  98. tokio::select! {
  99. create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, context_rx) => create_context,
  100. _ = kill.recv() => {println!("Shut down manager");},
  101. }
  102. }
  103. async fn context_manager(
  104. db_config: DbConfig,
  105. distributors: BTreeMap<BridgeDistributor, String>,
  106. extra_infos_base_url: &str,
  107. confidence: f64,
  108. max_threshold: u32,
  109. scaling_factor: f64,
  110. mut context_rx: mpsc::Receiver<Command>,
  111. ) {
  112. let db: Db = sled::open(&db_config.db_path).unwrap();
  113. while let Some(cmd) = context_rx.recv().await {
  114. use Command::*;
  115. match cmd {
  116. Request { req, sender } => {
  117. let response = handle(&db, req).await;
  118. if let Err(e) = sender.send(response) {
  119. eprintln!("Server Response Error: {:?}", e);
  120. };
  121. sleep(Duration::from_millis(1)).await;
  122. }
  123. Shutdown { shutdown_sig } => {
  124. println!("Sending Shutdown Signal, all threads should shutdown.");
  125. drop(shutdown_sig);
  126. println!("Shutdown Sent.");
  127. }
  128. Update {} => {
  129. update_daily_info(
  130. &db,
  131. &distributors,
  132. &extra_infos_base_url,
  133. confidence,
  134. max_threshold,
  135. scaling_factor,
  136. )
  137. .await;
  138. }
  139. }
  140. }
  141. }
  142. // Each of the commands that can be handled
  143. #[derive(Debug)]
  144. enum Command {
  145. Request {
  146. req: Request<Body>,
  147. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  148. },
  149. Shutdown {
  150. shutdown_sig: broadcast::Sender<()>,
  151. },
  152. Update {},
  153. }
  154. #[tokio::main]
  155. async fn main() {
  156. let args: Args = Args::parse();
  157. let config: Config = serde_json::from_reader(BufReader::new(
  158. File::open(&args.config).expect("Could not read config file"),
  159. ))
  160. .expect("Reading config file from JSON failed");
  161. let (request_tx, request_rx) = mpsc::channel(32);
  162. let updater_tx = request_tx.clone();
  163. let shutdown_cmd_tx = request_tx.clone();
  164. // create the shutdown broadcast channel and clone for every thread
  165. let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
  166. let kill = shutdown_tx.subscribe();
  167. // TODO: Gracefully shut down updater
  168. let kill_updater = shutdown_tx.subscribe();
  169. // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
  170. let shutdown_handler = spawn(async move {
  171. tokio::select! {
  172. _ = signal::ctrl_c() => {
  173. let cmd = Command::Shutdown {
  174. shutdown_sig: shutdown_tx,
  175. };
  176. shutdown_cmd_tx.send(cmd).await.unwrap();
  177. sleep(Duration::from_secs(1)).await;
  178. _ = shutdown_rx.recv().await;
  179. }
  180. }
  181. });
  182. let updater = spawn(async move {
  183. // Run updater once per day
  184. let mut sched = Scheduler::utc();
  185. sched.add(Job::new(config.updater_schedule, move || {
  186. run_updater(updater_tx.clone())
  187. }));
  188. });
  189. let context_manager = spawn(async move {
  190. create_context_manager(
  191. config.db,
  192. config.distributors,
  193. &config.extra_infos_base_url,
  194. config.confidence,
  195. config.max_threshold,
  196. config.scaling_factor,
  197. request_rx,
  198. kill,
  199. )
  200. .await
  201. });
  202. let make_service = make_service_fn(move |_conn: &AddrStream| {
  203. let request_tx = request_tx.clone();
  204. let service = service_fn(move |req| {
  205. let request_tx = request_tx.clone();
  206. let (response_tx, response_rx) = oneshot::channel();
  207. let cmd = Command::Request {
  208. req,
  209. sender: response_tx,
  210. };
  211. async move {
  212. request_tx.send(cmd).await.unwrap();
  213. response_rx.await.unwrap()
  214. }
  215. });
  216. async move { Ok::<_, Infallible>(service) }
  217. });
  218. let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
  219. let server = Server::bind(&addr).serve(make_service);
  220. let graceful = server.with_graceful_shutdown(shutdown_signal());
  221. println!("Listening on {}", addr);
  222. if let Err(e) = graceful.await {
  223. eprintln!("server error: {}", e);
  224. }
  225. future::join_all([context_manager, updater, shutdown_handler]).await;
  226. }