server.rs 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. use futures::StreamExt;
  2. use spiral_rs::aligned_memory::*;
  3. use spiral_rs::client::*;
  4. use spiral_rs::params::*;
  5. use spiral_rs::server::*;
  6. use spiral_rs::util::*;
  7. use std::collections::HashMap;
  8. use std::env;
  9. use std::fs::File;
  10. use std::sync::Mutex;
  11. use actix_cors::Cors;
  12. use actix_files as fs;
  13. use actix_http::HttpServiceBuilder;
  14. use actix_server::{Server, ServerBuilder};
  15. use actix_service::map_config;
  16. use actix_web::error::PayloadError;
  17. use actix_web::{get, http, middleware, post, web, App, HttpServer};
  18. use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
  19. const CERT_FNAME: &str = "/etc/letsencrypt/live/spiralwiki.com/fullchain.pem";
  20. const KEY_FNAME: &str = "/etc/letsencrypt/live/spiralwiki.com/privkey.pem";
  21. struct ServerState<'a> {
  22. params: &'a Params,
  23. db: AlignedMemory64,
  24. pub_params_map: Mutex<HashMap<String, PublicParameters<'a>>>,
  25. }
  26. async fn get_request_bytes(
  27. mut body: web::Payload,
  28. sz_bytes: usize,
  29. ) -> Result<Vec<u8>, http::Error> {
  30. let mut bytes = web::BytesMut::new();
  31. while let Some(item) = body.next().await {
  32. let item_ref = &item?;
  33. bytes.extend_from_slice(item_ref);
  34. if bytes.len() > sz_bytes {
  35. println!("too big! {}", sz_bytes);
  36. return Err(PayloadError::Overflow.into());
  37. }
  38. }
  39. Ok(bytes.to_vec())
  40. }
  41. fn get_other_io_err() -> PayloadError {
  42. PayloadError::Io(std::io::Error::from(std::io::ErrorKind::Other))
  43. }
  44. fn other_io_err<T>(_: T) -> PayloadError {
  45. get_other_io_err()
  46. }
  47. fn get_not_found_err() -> PayloadError {
  48. PayloadError::Io(std::io::Error::from(std::io::ErrorKind::NotFound))
  49. }
  50. #[get("/")]
  51. async fn index<'a>(data: web::Data<ServerState<'a>>) -> String {
  52. format!("Hello {} {}!", data.params.poly_len, data.db.as_slice()[5])
  53. }
  54. #[post("/setup")]
  55. async fn setup<'a>(
  56. body: web::Bytes,
  57. data: web::Data<ServerState<'a>>,
  58. ) -> Result<String, http::Error> {
  59. println!("/setup");
  60. // Parse the request
  61. let pub_params = PublicParameters::deserialize(data.params, &body);
  62. // Generate a UUID and store it
  63. let uuid = uuid::Uuid::new_v4();
  64. let mut pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  65. pub_params_map.insert(uuid.to_string(), pub_params);
  66. Ok(format!("{{\"id\":\"{}\"}}", uuid.to_string()))
  67. }
  68. const UUID_V4_STR_BYTES: usize = 36;
  69. #[post("/query")]
  70. async fn query<'a>(
  71. body: web::Payload,
  72. data: web::Data<ServerState<'a>>,
  73. ) -> Result<Vec<u8>, http::Error> {
  74. println!("/query");
  75. // Parse the UUID
  76. let request_bytes =
  77. get_request_bytes(body, UUID_V4_STR_BYTES + data.params.query_bytes()).await?;
  78. let uuid_bytes = &request_bytes.as_slice()[..UUID_V4_STR_BYTES];
  79. let data_bytes = &request_bytes.as_slice()[UUID_V4_STR_BYTES..];
  80. let uuid =
  81. uuid::Uuid::try_parse_ascii(uuid_bytes).map_err(|_| PayloadError::EncodingCorrupted)?;
  82. // Look up UUID and get public parameters
  83. let pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  84. let pub_params = pub_params_map
  85. .get(&uuid.to_string())
  86. .ok_or(get_not_found_err())?;
  87. // Parse the query
  88. let query = Query::deserialize(data.params, data_bytes);
  89. // Process the query
  90. let result = process_query(data.params, pub_params, &query, data.db.as_slice());
  91. Ok(result)
  92. }
  93. #[actix_web::main]
  94. async fn main() -> std::io::Result<()> {
  95. let args: Vec<String> = env::args().collect();
  96. let db_preprocessed_path = &args[1];
  97. let cfg_expand = r#"
  98. {'n': 2,
  99. 'nu_1': 10,
  100. 'nu_2': 6,
  101. 'p': 512,
  102. 'q2_bits': 21,
  103. 's_e': 85.83255142749422,
  104. 't_gsw': 10,
  105. 't_conv': 4,
  106. 't_exp_left': 16,
  107. 't_exp_right': 56,
  108. 'instances': 11,
  109. 'db_item_size': 100000 }
  110. "#;
  111. let box_params = Box::new(params_from_json(&cfg_expand.replace("'", "\"")));
  112. let params: &'static Params = Box::leak(box_params);
  113. let mut file = File::open(db_preprocessed_path).unwrap();
  114. let db = load_preprocessed_db_from_file(params, &mut file);
  115. let server_state = ServerState {
  116. params: params,
  117. db: db,
  118. pub_params_map: Mutex::new(HashMap::new()),
  119. };
  120. let state = web::Data::new(server_state);
  121. let app_builder = move || {
  122. let cors = Cors::default()
  123. .allow_any_origin()
  124. .allowed_headers([
  125. http::header::ORIGIN,
  126. http::header::CONTENT_TYPE,
  127. http::header::ACCEPT,
  128. ])
  129. .allow_any_method()
  130. .max_age(3600);
  131. App::new()
  132. .wrap(middleware::Compress::default())
  133. .wrap(cors)
  134. .app_data(state.clone())
  135. .app_data(web::PayloadConfig::new(1 << 25))
  136. .service(setup)
  137. .service(query)
  138. .service(fs::Files::new("/", "../client/static").index_file("index.html"))
  139. };
  140. Server::build()
  141. .bind("http/1", "0.0.0.0:8088", move || {
  142. let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
  143. builder
  144. .set_private_key_file(KEY_FNAME, SslFiletype::PEM)
  145. .unwrap();
  146. builder.set_certificate_chain_file(CERT_FNAME).unwrap();
  147. builder.set_alpn_protos(b"\x08http/1.1").unwrap();
  148. HttpServiceBuilder::default()
  149. .h1(map_config(app_builder(), |_| {
  150. actix_web::dev::AppConfig::default()
  151. }))
  152. .openssl(builder.build())
  153. })?
  154. .run()
  155. .await
  156. }