load_balancer.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. use crate::omap::ObliviousMap;
  2. pub use crate::record::{IndexRecord, Record, RecordType, SubmapRecord};
  3. use fastapprox::fast;
  4. use otils::{self, ObliviousOps};
  5. use rayon::ThreadPool;
  6. use std::{
  7. cmp,
  8. f64::consts::E,
  9. sync::{Arc, Mutex},
  10. // time::UNIX_EPOCH,
  11. };
  12. const LAMBDA: usize = 128;
  13. pub struct LoadBalancer {
  14. num_users: i64,
  15. num_submaps: usize,
  16. num_threads: usize,
  17. pool: ThreadPool,
  18. pub user_store: Vec<IndexRecord>,
  19. pub submaps: Vec<ObliviousMap>,
  20. }
  21. impl LoadBalancer {
  22. pub fn new(num_users: i64, num_threads: usize, num_submaps: usize) -> Self {
  23. let component_threads = num_threads / (num_submaps + 1);
  24. let pool = rayon::ThreadPoolBuilder::new()
  25. .num_threads(component_threads)
  26. .build()
  27. .unwrap();
  28. let mut user_store = Vec::new();
  29. user_store.reserve(num_users as usize);
  30. user_store.extend((0..num_users).map(|i| IndexRecord::new(i, RecordType::User)));
  31. let mut submaps = Vec::with_capacity(num_submaps as usize);
  32. submaps.extend((0..num_submaps).map(|_| ObliviousMap::new(component_threads)));
  33. LoadBalancer {
  34. num_users,
  35. num_submaps,
  36. num_threads: component_threads,
  37. pool,
  38. user_store,
  39. submaps,
  40. }
  41. }
  42. fn pad_size(&self, num_requests: f64) -> usize {
  43. let num_submaps = self.num_submaps as f64;
  44. let mu = num_requests / num_submaps;
  45. let gamma = (num_submaps + 2_f64.powf(LAMBDA as f64)).ln();
  46. let rhs = (gamma / mu - 1_f64) / E;
  47. num_requests
  48. .min(mu * E.powf(fast::lambertw(rhs as f32) as f64 + 1_f64))
  49. .ceil() as usize
  50. }
  51. pub fn pad_for_submap(
  52. &self,
  53. mut requests: Vec<SubmapRecord>,
  54. submap_size: usize,
  55. is_send: bool,
  56. ) -> Vec<SubmapRecord> {
  57. requests.reserve(self.num_submaps * submap_size);
  58. for submap in 0..self.num_submaps {
  59. if is_send {
  60. requests.extend(SubmapRecord::dummy_send(submap_size, submap as u8));
  61. } else {
  62. requests.extend(SubmapRecord::dummy_fetch(submap_size, submap as u8));
  63. }
  64. }
  65. requests
  66. }
  67. pub fn get_submap_requests(
  68. &self,
  69. requests: Vec<IndexRecord>,
  70. submap_size: usize,
  71. is_send: bool,
  72. ) -> Vec<SubmapRecord> {
  73. let requests: Vec<SubmapRecord> = requests.into_iter().map(|r| SubmapRecord(r.0)).collect();
  74. let mut requests = self.pad_for_submap(requests, submap_size, is_send);
  75. requests = otils::sort(requests, &self.pool, self.num_threads); // sort by omap, then by dummy
  76. let mut prev_map = self.num_submaps;
  77. let mut remaining_marks = submap_size as i32;
  78. for request in requests.iter_mut() {
  79. let submap = request.0.map as u32;
  80. remaining_marks = i32::oselect(
  81. submap != prev_map as u32,
  82. submap_size as i32,
  83. remaining_marks,
  84. );
  85. request.0.mark = u16::oselect(remaining_marks > 0, 1, 0);
  86. remaining_marks += i32::oselect(remaining_marks > 0, -1, 0);
  87. prev_map = submap as usize;
  88. }
  89. otils::compact(
  90. &mut requests[..],
  91. |r| r.0.mark == 1,
  92. &self.pool,
  93. self.num_threads,
  94. );
  95. requests.truncate(self.num_submaps * submap_size);
  96. requests
  97. }
  98. fn propagate_send_indices(&mut self) {
  99. let mut idx: u32 = 0;
  100. let mut is_same_u: bool;
  101. let mut user_store_iter = self.user_store.iter_mut().peekable();
  102. while let Some(record) = user_store_iter.next() {
  103. let is_user_store = record.0.is_user_store();
  104. idx = u32::oselect(
  105. is_user_store,
  106. cmp::max(record.0.last_fetch, record.0.last_send),
  107. idx + 1,
  108. );
  109. record.0.idx = u32::oselect(is_user_store, 0, record.get_idx(idx));
  110. record.0.map = (record.0.idx % (self.num_submaps as u32)) as u8;
  111. record.0.last_send = idx;
  112. if let Some(next_record) = user_store_iter.peek() {
  113. is_same_u = record.0.uid == next_record.0.uid;
  114. } else {
  115. is_same_u = false;
  116. }
  117. record.0.mark = u16::oselect(is_same_u, 0, 1);
  118. }
  119. }
  120. pub fn get_send_indices(&mut self, sends: Vec<IndexRecord>) -> Vec<IndexRecord> {
  121. let num_requests = sends.len();
  122. self.user_store.reserve(num_requests);
  123. self.user_store.extend(sends);
  124. self.user_store = otils::sort(
  125. std::mem::take(&mut self.user_store),
  126. &self.pool,
  127. self.num_threads,
  128. );
  129. self.propagate_send_indices();
  130. otils::compact(
  131. &mut self.user_store[..],
  132. |r| r.is_request(),
  133. &self.pool,
  134. self.num_threads,
  135. );
  136. let requests = self.user_store.drain(0..num_requests).collect();
  137. otils::compact(
  138. &mut self.user_store[..],
  139. |r| r.is_updated_user_store(),
  140. &self.pool,
  141. self.num_threads,
  142. );
  143. self.user_store.truncate(self.num_users as usize);
  144. self.user_store.iter_mut().for_each(|r| {
  145. r.set_user_store();
  146. });
  147. requests
  148. }
  149. pub fn batch_send(&mut self, sends: Vec<Record>) {
  150. let sends = sends.into_iter().map(|r| IndexRecord(r)).collect();
  151. let requests = self.get_send_indices(sends);
  152. let submap_size = self.pad_size(requests.len() as f64);
  153. let mut requests: Vec<Record> = self
  154. .get_submap_requests(requests, submap_size, true)
  155. .into_iter()
  156. .map(|r| r.0)
  157. .collect();
  158. let mut remaining_submaps = &mut self.submaps[..];
  159. self.pool.scope(|s| {
  160. for _ in 0..self.num_submaps {
  161. let (submap, rest_submaps) = remaining_submaps.split_at_mut(1);
  162. remaining_submaps = rest_submaps;
  163. let batch = requests.drain(0..submap_size).collect();
  164. s.spawn(|_| submap[0].batch_send(batch));
  165. }
  166. // let (submap, rest_submaps) = remaining_submaps.split_at_mut(1);
  167. // remaining_submaps = rest_submaps;
  168. // let batch = requests.drain(0..submap_size).collect();
  169. // submap[0].batch_send(batch);
  170. });
  171. }
  172. fn update_with_fetches(&mut self, fetches: Vec<IndexRecord>, num_fetches: usize) {
  173. self.user_store.reserve(num_fetches);
  174. for fetch in fetches.into_iter() {
  175. self.user_store.extend(fetch.dummy_fetches());
  176. }
  177. }
  178. fn propagate_fetch_indices(&mut self) {
  179. let mut idx: u32 = 0;
  180. let mut is_same_u: bool;
  181. let mut user_store_iter = self.user_store.iter_mut().peekable();
  182. while let Some(record) = user_store_iter.next() {
  183. let is_user_store = record.0.is_user_store();
  184. idx = u32::oselect(is_user_store, record.0.last_fetch, idx + 1);
  185. record.0.idx = u32::oselect(is_user_store, 0, record.get_idx(idx));
  186. record.0.map = (record.0.idx % (self.num_submaps as u32)) as u8;
  187. record.0.last_fetch = idx;
  188. if let Some(next_record) = user_store_iter.peek() {
  189. is_same_u = record.0.uid == next_record.0.uid;
  190. } else {
  191. is_same_u = false;
  192. }
  193. record.0.mark = u16::oselect(is_same_u, 0, 1);
  194. }
  195. }
  196. pub fn get_fetch_indices(
  197. &mut self,
  198. fetches: Vec<IndexRecord>,
  199. num_requests: usize,
  200. ) -> Vec<IndexRecord> {
  201. self.update_with_fetches(fetches, num_requests);
  202. self.user_store = otils::sort(
  203. std::mem::take(&mut self.user_store),
  204. &self.pool,
  205. self.num_threads,
  206. );
  207. self.propagate_fetch_indices();
  208. otils::compact(
  209. &mut self.user_store[..],
  210. |r| r.is_request(),
  211. &self.pool,
  212. self.num_threads,
  213. );
  214. let deliver = self.user_store.drain(0..num_requests).collect();
  215. otils::compact(
  216. &mut self.user_store[..],
  217. |r| r.is_updated_user_store(),
  218. &self.pool,
  219. self.num_threads,
  220. );
  221. self.user_store.truncate(self.num_users as usize);
  222. self.user_store.iter_mut().for_each(|r| {
  223. r.set_user_store();
  224. });
  225. deliver
  226. }
  227. pub fn batch_fetch(&mut self, fetches: Vec<Record>) -> Vec<Record> {
  228. let num_requests = fetches
  229. .iter()
  230. .fold(0, |acc, fetch| acc + fetch.data as usize);
  231. let fetches = fetches.into_iter().map(|r| IndexRecord(r)).collect();
  232. // let start = std::time::SystemTime::now()
  233. // .duration_since(UNIX_EPOCH)
  234. // .unwrap()
  235. // .as_nanos();
  236. let requests = self.get_fetch_indices(fetches, num_requests);
  237. // let end = std::time::SystemTime::now()
  238. // .duration_since(UNIX_EPOCH)
  239. // .unwrap()
  240. // .as_nanos();
  241. // println!("fetch idx {}: {}", requests.len(), end - start);
  242. let submap_size = self.pad_size(requests.len() as f64);
  243. // let start = std::time::SystemTime::now()
  244. // .duration_since(UNIX_EPOCH)
  245. // .unwrap()
  246. // .as_nanos();
  247. let mut requests: Vec<Record> = self
  248. .get_submap_requests(requests, submap_size, false)
  249. .into_iter()
  250. .map(|r| r.0)
  251. .collect();
  252. // let end = std::time::SystemTime::now()
  253. // .duration_since(UNIX_EPOCH)
  254. // .unwrap()
  255. // .as_nanos();
  256. // println!("submap requests {}: {}", requests.len(), end - start);
  257. let mut remaining_submaps = &mut self.submaps[..];
  258. let responses: Arc<Mutex<Vec<IndexRecord>>> = Arc::new(Mutex::new(Vec::with_capacity(
  259. submap_size * self.num_submaps,
  260. )));
  261. // let start = std::time::SystemTime::now()
  262. // .duration_since(UNIX_EPOCH)
  263. // .unwrap()
  264. // .as_nanos();
  265. self.pool.scope(|s| {
  266. for _ in 0..self.num_submaps {
  267. let (submap, rest_submaps) = remaining_submaps.split_at_mut(1);
  268. remaining_submaps = rest_submaps;
  269. let batch = requests.drain(0..submap_size).collect();
  270. s.spawn(|_| {
  271. let responses = Arc::clone(&responses);
  272. let response = submap[0].batch_fetch(batch);
  273. let mut responses = responses.lock().unwrap();
  274. responses.extend(response);
  275. });
  276. }
  277. });
  278. let mutex = Arc::into_inner(responses).unwrap();
  279. let mut responses: Vec<IndexRecord> = mutex.into_inner().unwrap();
  280. // let end = std::time::SystemTime::now()
  281. // .duration_since(UNIX_EPOCH)
  282. // .unwrap()
  283. // .as_nanos();
  284. // println!("submap response {}: {}", responses.len(), end - start);
  285. // this only really needs to be a shuffle
  286. // let start = std::time::SystemTime::now()
  287. // .duration_since(UNIX_EPOCH)
  288. // .unwrap()
  289. // .as_nanos();
  290. responses = otils::sort(responses, &self.pool, self.num_threads);
  291. otils::compact(
  292. &mut responses,
  293. |r| r.0.is_send(),
  294. &self.pool,
  295. self.num_threads,
  296. );
  297. // let end = std::time::SystemTime::now()
  298. // .duration_since(UNIX_EPOCH)
  299. // .unwrap()
  300. // .as_nanos();
  301. // println!("final: {}", end - start);
  302. responses.drain(0..num_requests).map(|r| r.0).collect()
  303. }
  304. }