server.rs 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. use actix_web::error::PayloadError;
  2. use futures::StreamExt;
  3. use spiral_rs::aligned_memory::*;
  4. use spiral_rs::client::*;
  5. use spiral_rs::params::*;
  6. use spiral_rs::server::*;
  7. use spiral_rs::util::*;
  8. use std::collections::HashMap;
  9. use std::env;
  10. use std::fs::File;
  11. use std::sync::Mutex;
  12. use actix_web::{get, http, post, web, App, HttpServer};
  13. use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
  14. const CERT_FNAME: &str = "/etc/letsencrypt/live/spiralwiki.com/fullchain.pem";
  15. const KEY_FNAME: &str = "/etc/letsencrypt/live/spiralwiki.com/privkey.pem";
  16. struct ServerState<'a> {
  17. params: &'a Params,
  18. db: AlignedMemory64,
  19. pub_params_map: Mutex<HashMap<String, PublicParameters<'a>>>,
  20. }
  21. async fn get_request_bytes(
  22. mut body: web::Payload,
  23. sz_bytes: usize,
  24. ) -> Result<Vec<u8>, http::Error> {
  25. let mut bytes = web::BytesMut::new();
  26. while let Some(item) = body.next().await {
  27. bytes.extend_from_slice(&item?);
  28. if bytes.len() > sz_bytes {
  29. println!("too big! {}", sz_bytes);
  30. return Err(PayloadError::Overflow.into());
  31. }
  32. }
  33. Ok(bytes.to_vec())
  34. }
  35. fn get_other_io_err() -> PayloadError {
  36. PayloadError::Io(std::io::Error::from(std::io::ErrorKind::Other))
  37. }
  38. fn other_io_err<T>(_: T) -> PayloadError {
  39. get_other_io_err()
  40. }
  41. fn get_not_found_err() -> PayloadError {
  42. PayloadError::Io(std::io::Error::from(std::io::ErrorKind::NotFound))
  43. }
  44. #[get("/")]
  45. async fn index<'a>(data: web::Data<ServerState<'a>>) -> String {
  46. format!("Hello {} {}!", data.params.poly_len, data.db.as_slice()[5])
  47. }
  48. #[post("/setup")]
  49. async fn setup<'a>(
  50. body: web::Payload,
  51. data: web::Data<ServerState<'a>>,
  52. ) -> Result<String, http::Error> {
  53. println!("/setup");
  54. // Parse the request
  55. let request_bytes = get_request_bytes(body, data.params.setup_bytes()).await?;
  56. let pub_params = PublicParameters::deserialize(data.params, request_bytes.as_slice());
  57. // Generate a UUID and store it
  58. let uuid = uuid::Uuid::new_v4();
  59. let mut pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  60. pub_params_map.insert(uuid.to_string(), pub_params);
  61. Ok(format!("{{\"id\":\"{}\"}}", uuid.to_string()))
  62. }
  63. const UUID_V4_STR_BYTES: usize = 36;
  64. #[post("/query")]
  65. async fn query<'a>(
  66. body: web::Payload,
  67. data: web::Data<ServerState<'a>>,
  68. ) -> Result<Vec<u8>, http::Error> {
  69. println!("/query");
  70. // Parse the UUID
  71. let request_bytes =
  72. get_request_bytes(body, UUID_V4_STR_BYTES + data.params.query_bytes()).await?;
  73. let uuid_bytes = &request_bytes.as_slice()[..UUID_V4_STR_BYTES];
  74. let data_bytes = &request_bytes.as_slice()[UUID_V4_STR_BYTES..];
  75. let uuid =
  76. uuid::Uuid::try_parse_ascii(uuid_bytes).map_err(|_| PayloadError::EncodingCorrupted)?;
  77. // Look up UUID and get public parameters
  78. let pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  79. let pub_params = pub_params_map
  80. .get(&uuid.to_string())
  81. .ok_or(get_not_found_err())?;
  82. // Parse the query
  83. let query = Query::deserialize(data.params, data_bytes);
  84. // Process the query
  85. let result = process_query(data.params, pub_params, &query, data.db.as_slice());
  86. Ok(result)
  87. }
  88. #[actix_web::main]
  89. async fn main() -> std::io::Result<()> {
  90. let args: Vec<String> = env::args().collect();
  91. let db_preprocessed_path = &args[1];
  92. let cfg_expand = r#"
  93. {'n': 2,
  94. 'nu_1': 10,
  95. 'nu_2': 6,
  96. 'p': 512,
  97. 'q2_bits': 21,
  98. 's_e': 85.83255142749422,
  99. 't_gsw': 10,
  100. 't_conv': 4,
  101. 't_exp_left': 16,
  102. 't_exp_right': 56,
  103. 'instances': 11,
  104. 'db_item_size': 100000 }
  105. "#;
  106. let box_params = Box::new(params_from_json(&cfg_expand.replace("'", "\"")));
  107. let params: &'static Params = Box::leak(box_params);
  108. let mut file = File::open(db_preprocessed_path).unwrap();
  109. let db = load_preprocessed_db_from_file(params, &mut file);
  110. let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
  111. builder
  112. .set_private_key_file(KEY_FNAME, SslFiletype::PEM)
  113. .unwrap();
  114. builder.set_certificate_chain_file(CERT_FNAME).unwrap();
  115. let server_state = ServerState {
  116. params: params,
  117. db: db,
  118. pub_params_map: Mutex::new(HashMap::new()),
  119. };
  120. let state = web::Data::new(server_state);
  121. let res = HttpServer::new(move || {
  122. App::new()
  123. .app_data(state.clone())
  124. .service(index)
  125. .service(setup)
  126. .service(query)
  127. })
  128. .bind_openssl("0.0.0.0:8088", builder)?
  129. // .bind_openssl("127.0.0.1:7777", builder)?
  130. .run()
  131. .await;
  132. res
  133. }