server.rs 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. use futures::StreamExt;
  2. use spiral_rs::aligned_memory::*;
  3. use spiral_rs::client::*;
  4. use spiral_rs::params::*;
  5. use spiral_rs::server::*;
  6. use spiral_rs::util::*;
  7. use std::collections::HashMap;
  8. use std::collections::VecDeque;
  9. use std::env;
  10. use std::fs::File;
  11. use std::sync::Mutex;
  12. use actix_cors::Cors;
  13. use actix_http::HttpServiceBuilder;
  14. use actix_server::Server;
  15. use actix_service::map_config;
  16. use actix_web::error::PayloadError;
  17. use actix_web::{get, http, middleware, post, web, App};
  18. use serde::Deserialize;
  19. const PUB_PARAMS_MAX: usize = 250;
  20. struct ServerState<'a> {
  21. params: &'a Params,
  22. db: AlignedMemory64,
  23. pub_params_map: Mutex<(VecDeque<String>, HashMap<String, PublicParameters<'a>>)>,
  24. }
  25. async fn get_request_bytes(
  26. mut body: web::Payload,
  27. sz_bytes: usize,
  28. ) -> Result<Vec<u8>, http::Error> {
  29. let mut bytes = web::BytesMut::new();
  30. while let Some(item) = body.next().await {
  31. let item_ref = &item?;
  32. bytes.extend_from_slice(item_ref);
  33. if bytes.len() > sz_bytes {
  34. println!("too big! {}", sz_bytes);
  35. return Err(PayloadError::Overflow.into());
  36. }
  37. }
  38. Ok(bytes.to_vec())
  39. }
  40. fn get_other_io_err() -> PayloadError {
  41. PayloadError::Io(std::io::Error::from(std::io::ErrorKind::Other))
  42. }
  43. fn other_io_err<T>(_: T) -> PayloadError {
  44. get_other_io_err()
  45. }
  46. fn get_not_found_err() -> PayloadError {
  47. PayloadError::Io(std::io::Error::from(std::io::ErrorKind::NotFound))
  48. }
  49. #[get("/")]
  50. async fn index<'a>(data: web::Data<ServerState<'a>>) -> String {
  51. format!("Hello {} {}!", data.params.poly_len, data.db.as_slice()[5])
  52. }
  53. #[derive(Deserialize)]
  54. pub struct CheckUuid {
  55. uuid: String,
  56. }
  57. #[get("/check")]
  58. async fn check<'a>(
  59. web::Query(query_params): web::Query<CheckUuid>,
  60. data: web::Data<ServerState<'a>>,
  61. ) -> Result<String, http::Error> {
  62. let pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  63. let has_uuid = pub_params_map.1.contains_key(&query_params.uuid);
  64. Ok(format!(
  65. "{{\"uuid\":\"{}\", \"is_valid\":{}}}",
  66. query_params.uuid, has_uuid
  67. ))
  68. }
  69. #[post("/setup")]
  70. async fn setup<'a>(
  71. body: web::Bytes,
  72. data: web::Data<ServerState<'a>>,
  73. ) -> Result<String, http::Error> {
  74. // Parse the request
  75. let pub_params = PublicParameters::deserialize(data.params, &body);
  76. // Generate a UUID and store it
  77. let uuid = uuid::Uuid::new_v4();
  78. let mut pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  79. pub_params_map.0.push_back(uuid.to_string());
  80. pub_params_map.1.insert(uuid.to_string(), pub_params);
  81. // If too many public parameters, remove by LRU
  82. if pub_params_map.1.len() > PUB_PARAMS_MAX {
  83. let lru_uuid_str = pub_params_map.0.pop_front().ok_or(get_other_io_err())?;
  84. pub_params_map.1.remove(&lru_uuid_str);
  85. }
  86. Ok(format!("{{\"id\":\"{}\"}}", uuid.to_string()))
  87. }
  88. const UUID_V4_STR_BYTES: usize = 36;
  89. #[post("/query")]
  90. async fn query<'a>(
  91. body: web::Payload,
  92. data: web::Data<ServerState<'a>>,
  93. ) -> Result<Vec<u8>, http::Error> {
  94. // Parse the UUID
  95. let request_bytes =
  96. get_request_bytes(body, UUID_V4_STR_BYTES + data.params.query_bytes()).await?;
  97. let uuid_bytes = &request_bytes.as_slice()[..UUID_V4_STR_BYTES];
  98. let data_bytes = &request_bytes.as_slice()[UUID_V4_STR_BYTES..];
  99. let uuid =
  100. uuid::Uuid::try_parse_ascii(uuid_bytes).map_err(|_| PayloadError::EncodingCorrupted)?;
  101. // Look up UUID and get public parameters
  102. let pub_params_map = data.pub_params_map.lock().map_err(other_io_err)?;
  103. let pub_params = pub_params_map
  104. .1
  105. .get(&uuid.to_string())
  106. .ok_or(get_not_found_err())?;
  107. // Parse the query
  108. let query = Query::deserialize(data.params, data_bytes);
  109. // Process the query
  110. let result = process_query(data.params, pub_params, &query, data.db.as_slice());
  111. Ok(result)
  112. }
  113. #[actix_web::main]
  114. async fn main() -> std::io::Result<()> {
  115. let args: Vec<String> = env::args().collect();
  116. let db_preprocessed_path = &args[1];
  117. let cfg_expand = r#"
  118. {'n': 2,
  119. 'nu_1': 10,
  120. 'nu_2': 6,
  121. 'p': 512,
  122. 'q2_bits': 21,
  123. 's_e': 85.83255142749422,
  124. 't_gsw': 10,
  125. 't_conv': 4,
  126. 't_exp_left': 16,
  127. 't_exp_right': 56,
  128. 'instances': 11,
  129. 'db_item_size': 100000 }
  130. "#;
  131. let box_params = Box::new(params_from_json(&cfg_expand.replace("'", "\"")));
  132. let params: &'static Params = Box::leak(box_params);
  133. let mut file = File::open(db_preprocessed_path).unwrap();
  134. let db = load_preprocessed_db_from_file(params, &mut file);
  135. let server_state = ServerState {
  136. params: params,
  137. db: db,
  138. pub_params_map: Mutex::new((VecDeque::new(), HashMap::new())),
  139. };
  140. let state = web::Data::new(server_state);
  141. let app_builder = move || {
  142. let cors = Cors::default()
  143. .allow_any_origin()
  144. .allowed_headers([
  145. http::header::ORIGIN,
  146. http::header::CONTENT_TYPE,
  147. http::header::ACCEPT,
  148. ])
  149. .allow_any_method()
  150. .max_age(3600);
  151. App::new()
  152. .wrap(middleware::Compress::default())
  153. .wrap(cors)
  154. .app_data(state.clone())
  155. .app_data(web::PayloadConfig::new(1 << 25))
  156. .service(setup)
  157. .service(query)
  158. .service(check)
  159. };
  160. Server::build()
  161. .bind("http/1", "localhost:8088", move || {
  162. HttpServiceBuilder::default()
  163. .h1(map_config(app_builder(), |_| {
  164. actix_web::dev::AppConfig::default()
  165. }))
  166. .tcp()
  167. })?
  168. .run()
  169. .await
  170. }