omq.rs 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. use crate::request::Request;
  2. use otils::ObliviousOps;
  3. use rayon::ThreadPool;
  4. #[derive(Debug)]
  5. pub struct ObliviousMultiQueue {
  6. num_threads: usize,
  7. pool: ThreadPool,
  8. message_store: Vec<Request>,
  9. }
  10. impl ObliviousMultiQueue {
  11. pub fn new(num_threads: usize) -> Self {
  12. let pool = rayon::ThreadPoolBuilder::new()
  13. .num_threads(num_threads)
  14. .build()
  15. .unwrap();
  16. ObliviousMultiQueue {
  17. num_threads,
  18. pool,
  19. message_store: Vec::new(),
  20. }
  21. }
  22. pub fn batch_send(&mut self, sends: Vec<Request>) {
  23. self.message_store.reserve(sends.len());
  24. self.message_store.extend(sends);
  25. }
  26. fn update_store(&mut self, fetches: Vec<Request>, fetch_sum: usize) {
  27. self.message_store.reserve(fetches.len() + fetch_sum);
  28. for fetch in fetches.iter() {
  29. self.message_store
  30. .extend(Request::dummies(fetch.uid, fetch.volume));
  31. }
  32. self.message_store.extend(fetches);
  33. }
  34. pub fn batch_fetch(&mut self, fetches: Vec<Request>) -> Vec<Request> {
  35. let final_size = self.message_store.len();
  36. let fetch_sum = fetches.iter().fold(0, |acc, f| acc + f.volume) as usize;
  37. self.update_store(fetches, fetch_sum);
  38. self.message_store = otils::sort(
  39. std::mem::take(&mut self.message_store),
  40. &self.pool,
  41. self.num_threads,
  42. );
  43. let mut user_sum: isize = 0;
  44. let mut prev_user: i32 = -1;
  45. for request in self.message_store.iter_mut() {
  46. let same_user = prev_user == request.uid;
  47. user_sum = isize::oselect(same_user, user_sum, 0);
  48. let fetch_more = user_sum > 0;
  49. request.mark = u16::oselect(request.is_fetch(), 0, u16::oselect(fetch_more, 1, 0));
  50. prev_user = request.uid;
  51. user_sum += isize::oselect(
  52. request.is_fetch(),
  53. request.volume as isize,
  54. isize::oselect(fetch_more, -1, 0),
  55. );
  56. }
  57. otils::compact(
  58. &mut self.message_store[..],
  59. |r| r.should_deliver(),
  60. &self.pool,
  61. self.num_threads,
  62. );
  63. let deliver: Vec<Request> = self.message_store.drain(0..fetch_sum).collect();
  64. // for r in deliver.iter() {
  65. // println!("{:?}", r);
  66. // }
  67. otils::compact(
  68. &mut self.message_store[..],
  69. |r| r.should_defer(),
  70. &self.pool,
  71. self.num_threads,
  72. );
  73. self.message_store.truncate(final_size);
  74. // for r in self.message_store.iter() {
  75. // println!("{:?}", r);
  76. // }
  77. deliver
  78. }
  79. }
  80. #[cfg(test)]
  81. mod tests {
  82. use super::*;
  83. extern crate test;
  84. use test::Bencher;
  85. #[bench]
  86. fn bench_fetch(b: &mut Bencher) {
  87. let mut o = ObliviousMultiQueue::new(8);
  88. let sends: Vec<Request> = (0..1048576)
  89. .map(|x| Request::new_send(0, x.try_into().unwrap()))
  90. .collect();
  91. o.batch_send(sends);
  92. // b.iter(|| 1 + 1);
  93. b.iter(|| o.batch_fetch(vec![Request::new_fetch(0, 1048575)]));
  94. }
  95. }