server.rs 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. use troll_patrol::{request_handler::handle, *};
  2. use clap::Parser;
  3. use futures::future;
  4. use hyper::{
  5. server::conn::AddrStream,
  6. service::{make_service_fn, service_fn},
  7. Body, Request, Response, Server,
  8. };
  9. use serde::Deserialize;
  10. use sled::Db;
  11. use std::{
  12. collections::BTreeMap, convert::Infallible, fs::File, io::BufReader, net::SocketAddr,
  13. path::PathBuf, time::Duration,
  14. };
  15. use tokio::{
  16. signal, spawn,
  17. sync::{broadcast, mpsc, oneshot},
  18. time::sleep,
  19. };
  20. use tokio_cron::{Job, Scheduler};
  21. async fn shutdown_signal() {
  22. tokio::signal::ctrl_c()
  23. .await
  24. .expect("failed to listen for ctrl+c signal");
  25. println!("Shut down Troll Patrol Server");
  26. }
  27. #[derive(Parser, Debug)]
  28. #[command(author, version, about, long_about = None)]
  29. struct Args {
  30. /// Name/path of the configuration file
  31. #[arg(short, long, default_value = "config.json")]
  32. config: PathBuf,
  33. }
  34. #[derive(Debug, Deserialize)]
  35. pub struct Config {
  36. pub db: DbConfig,
  37. // map of distributor name to IP:port to contact it
  38. pub distributors: BTreeMap<BridgeDistributor, String>,
  39. extra_infos_base_url: String,
  40. // confidence required to consider a bridge blocked
  41. confidence: f64,
  42. // block open-entry bridges if they get more negative reports than this
  43. max_threshold: u32,
  44. // block open-entry bridges if they get more negative reports than
  45. // scaling_factor * bridge_ips
  46. scaling_factor: f64,
  47. //require_bridge_token: bool,
  48. port: u16,
  49. updater_schedule: String,
  50. }
  51. #[derive(Debug, Deserialize)]
  52. pub struct DbConfig {
  53. // The path for the server database, default is "server_db"
  54. pub db_path: String,
  55. }
  56. impl Default for DbConfig {
  57. fn default() -> DbConfig {
  58. DbConfig {
  59. db_path: "server_db".to_owned(),
  60. }
  61. }
  62. }
  63. async fn update_daily_info(
  64. db: &Db,
  65. distributors: &BTreeMap<BridgeDistributor, String>,
  66. extra_infos_base_url: &str,
  67. confidence: f64,
  68. max_threshold: u32,
  69. scaling_factor: f64,
  70. ) {
  71. update_extra_infos(&db, &extra_infos_base_url)
  72. .await
  73. .unwrap();
  74. update_negative_reports(&db, &distributors).await;
  75. update_positive_reports(&db, &distributors).await;
  76. let new_blockages = guess_blockages(
  77. &db,
  78. &analyzer::NormalAnalyzer::new(max_threshold, scaling_factor),
  79. confidence,
  80. );
  81. report_blockages(&distributors, new_blockages).await;
  82. }
  83. async fn run_updater(updater_tx: mpsc::Sender<Command>) {
  84. updater_tx.send(Command::Update {}).await.unwrap();
  85. }
  86. async fn create_context_manager(
  87. db_config: DbConfig,
  88. distributors: BTreeMap<BridgeDistributor, String>,
  89. extra_infos_base_url: &str,
  90. confidence: f64,
  91. max_threshold: u32,
  92. scaling_factor: f64,
  93. context_rx: mpsc::Receiver<Command>,
  94. mut kill: broadcast::Receiver<()>,
  95. ) {
  96. tokio::select! {
  97. create_context = context_manager(db_config, distributors, extra_infos_base_url, confidence, max_threshold, scaling_factor, context_rx) => create_context,
  98. _ = kill.recv() => {println!("Shut down manager");},
  99. }
  100. }
  101. async fn context_manager(
  102. db_config: DbConfig,
  103. distributors: BTreeMap<BridgeDistributor, String>,
  104. extra_infos_base_url: &str,
  105. confidence: f64,
  106. max_threshold: u32,
  107. scaling_factor: f64,
  108. mut context_rx: mpsc::Receiver<Command>,
  109. ) {
  110. let db: Db = sled::open(&db_config.db_path).unwrap();
  111. while let Some(cmd) = context_rx.recv().await {
  112. use Command::*;
  113. match cmd {
  114. Request { req, sender } => {
  115. let response = handle(&db, req).await;
  116. if let Err(e) = sender.send(response) {
  117. eprintln!("Server Response Error: {:?}", e);
  118. };
  119. sleep(Duration::from_millis(1)).await;
  120. }
  121. Shutdown { shutdown_sig } => {
  122. println!("Sending Shutdown Signal, all threads should shutdown.");
  123. drop(shutdown_sig);
  124. println!("Shutdown Sent.");
  125. }
  126. Update {} => {
  127. update_daily_info(
  128. &db,
  129. &distributors,
  130. &extra_infos_base_url,
  131. confidence,
  132. max_threshold,
  133. scaling_factor,
  134. )
  135. .await;
  136. }
  137. }
  138. }
  139. }
  140. // Each of the commands that can be handled
  141. #[derive(Debug)]
  142. enum Command {
  143. Request {
  144. req: Request<Body>,
  145. sender: oneshot::Sender<Result<Response<Body>, Infallible>>,
  146. },
  147. Shutdown {
  148. shutdown_sig: broadcast::Sender<()>,
  149. },
  150. Update {},
  151. }
  152. #[tokio::main]
  153. async fn main() {
  154. let args: Args = Args::parse();
  155. let config: Config = serde_json::from_reader(BufReader::new(
  156. File::open(&args.config).expect("Could not read config file"),
  157. ))
  158. .expect("Reading config file from JSON failed");
  159. let (request_tx, request_rx) = mpsc::channel(32);
  160. let updater_tx = request_tx.clone();
  161. let shutdown_cmd_tx = request_tx.clone();
  162. // create the shutdown broadcast channel and clone for every thread
  163. let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16);
  164. let kill = shutdown_tx.subscribe();
  165. // TODO: Gracefully shut down updater
  166. let kill_updater = shutdown_tx.subscribe();
  167. // Listen for ctrl_c, send signal to broadcast shutdown to all threads by dropping shutdown_tx
  168. let shutdown_handler = spawn(async move {
  169. tokio::select! {
  170. _ = signal::ctrl_c() => {
  171. let cmd = Command::Shutdown {
  172. shutdown_sig: shutdown_tx,
  173. };
  174. shutdown_cmd_tx.send(cmd).await.unwrap();
  175. sleep(Duration::from_secs(1)).await;
  176. _ = shutdown_rx.recv().await;
  177. }
  178. }
  179. });
  180. let updater = spawn(async move {
  181. // Run updater once per day
  182. let mut sched = Scheduler::utc();
  183. sched.add(Job::new(config.updater_schedule, move || {
  184. run_updater(updater_tx.clone())
  185. }));
  186. });
  187. let context_manager = spawn(async move {
  188. create_context_manager(
  189. config.db,
  190. config.distributors,
  191. &config.extra_infos_base_url,
  192. config.confidence,
  193. config.max_threshold,
  194. config.scaling_factor,
  195. request_rx,
  196. kill,
  197. )
  198. .await
  199. });
  200. let make_service = make_service_fn(move |_conn: &AddrStream| {
  201. let request_tx = request_tx.clone();
  202. let service = service_fn(move |req| {
  203. let request_tx = request_tx.clone();
  204. let (response_tx, response_rx) = oneshot::channel();
  205. let cmd = Command::Request {
  206. req,
  207. sender: response_tx,
  208. };
  209. async move {
  210. request_tx.send(cmd).await.unwrap();
  211. response_rx.await.unwrap()
  212. }
  213. });
  214. async move { Ok::<_, Infallible>(service) }
  215. });
  216. let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
  217. let server = Server::bind(&addr).serve(make_service);
  218. let graceful = server.with_graceful_shutdown(shutdown_signal());
  219. println!("Listening on {}", addr);
  220. if let Err(e) = graceful.await {
  221. eprintln!("server error: {}", e);
  222. }
  223. future::join_all([context_manager, updater, shutdown_handler]).await;
  224. }