|
@@ -8,7 +8,7 @@ use crate::stash::{
|
|
|
use communicator::{AbstractCommunicator, Fut, Serializable};
|
|
|
use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
|
|
|
use ff::PrimeField;
|
|
|
-use itertools::{izip, Itertools};
|
|
|
+use itertools::Itertools;
|
|
|
use rand::thread_rng;
|
|
|
use rayon::prelude::*;
|
|
|
use std::collections::VecDeque;
|
|
@@ -236,10 +236,10 @@ where
|
|
|
impl<F, MPDPF, SPDPF> DistributedOramProtocol<F, MPDPF, SPDPF>
|
|
|
where
|
|
|
F: FromPrf + LegendreSymbol + Serializable,
|
|
|
- F::PrfKey: Serializable,
|
|
|
- MPDPF: MultiPointDpf<Value = F>,
|
|
|
+ F::PrfKey: Serializable + Sync,
|
|
|
+ MPDPF: MultiPointDpf<Value = F> + Sync,
|
|
|
MPDPF::Key: Serializable,
|
|
|
- SPDPF: SinglePointDpf<Value = F>,
|
|
|
+ SPDPF: SinglePointDpf<Value = F> + Sync,
|
|
|
SPDPF::Key: Serializable + Sync,
|
|
|
{
|
|
|
pub fn new(party_id: usize, log_db_size: u32) -> Self {
|
|
@@ -372,28 +372,30 @@ where
|
|
|
let fut_dpf_key_from_prev = comm.receive_previous()?;
|
|
|
let fut_dpf_key_from_next = comm.receive_next()?;
|
|
|
|
|
|
- let mut points = Vec::with_capacity(self.stash_size);
|
|
|
- let mut values = Vec::with_capacity(self.stash_size);
|
|
|
let (_, stash_values_share, stash_old_values_share) = self.stash.get_stash_share();
|
|
|
assert_eq!(stash_values_share.len(), self.get_access_counter());
|
|
|
assert_eq!(stash_old_values_share.len(), self.get_access_counter());
|
|
|
assert_eq!(self.address_tags_read.len(), self.get_access_counter());
|
|
|
- for (tag, val, old_val) in izip!(
|
|
|
- self.address_tags_read.iter().copied(),
|
|
|
- stash_values_share.iter().copied(),
|
|
|
- stash_old_values_share.iter().copied()
|
|
|
- ) {
|
|
|
- points.push(self.pos_mine(tag) as u64);
|
|
|
- values.push(val - old_val);
|
|
|
- }
|
|
|
+ let mut points: Vec<_> = self
|
|
|
+ .address_tags_read
|
|
|
+ .par_iter()
|
|
|
+ .copied()
|
|
|
+ .map(|tag| self.pos_mine(tag) as u64)
|
|
|
+ .collect();
|
|
|
+ let values: Vec<_> = stash_values_share
|
|
|
+ .par_iter()
|
|
|
+ .copied()
|
|
|
+ .zip(stash_old_values_share.par_iter().copied())
|
|
|
+ .map(|(val, old_val)| val - old_val)
|
|
|
+ .collect();
|
|
|
self.address_tags_read.truncate(0);
|
|
|
|
|
|
// sort point, value pairs
|
|
|
let (points, values): (Vec<u64>, Vec<F>) = {
|
|
|
let mut indices: Vec<usize> = (0..points.len()).collect();
|
|
|
- indices.sort_by_key(|&i| points[i]);
|
|
|
- points.sort();
|
|
|
- let new_values = indices.iter().map(|&i| values[i]).collect();
|
|
|
+ indices.par_sort_unstable_by_key(|&i| points[i]);
|
|
|
+ points.par_sort();
|
|
|
+ let new_values = indices.par_iter().map(|&i| values[i]).collect();
|
|
|
(points, new_values)
|
|
|
};
|
|
|
|
|
@@ -413,11 +415,14 @@ where
|
|
|
{
|
|
|
let mut memory_share = Vec::new();
|
|
|
std::mem::swap(&mut self.memory_share, &mut memory_share);
|
|
|
- for j in 0..self.memory_size {
|
|
|
- memory_share[j] += new_memory_share_from_prev
|
|
|
- [self.pos_prev(self.memory_index_tags_prev[j])]
|
|
|
- + new_memory_share_from_next[self.pos_next(self.memory_index_tags_next[j])];
|
|
|
- }
|
|
|
+ memory_share
|
|
|
+ .par_iter_mut()
|
|
|
+ .enumerate()
|
|
|
+ .for_each(|(j, mem_cell)| {
|
|
|
+ *mem_cell += new_memory_share_from_prev
|
|
|
+ [self.pos_prev(self.memory_index_tags_prev[j])]
|
|
|
+ + new_memory_share_from_next[self.pos_next(self.memory_index_tags_next[j])];
|
|
|
+ });
|
|
|
std::mem::swap(&mut self.memory_share, &mut memory_share);
|
|
|
}
|
|
|
|
|
@@ -485,11 +490,8 @@ where
|
|
|
.into_par_iter()
|
|
|
.map(|j| LegendrePrf::eval_to_uint::<u128>(&lpk_prev, F::from_u128(j as u128)))
|
|
|
.collect();
|
|
|
- let memory_index_tags_prev_sorted: Vec<_> = memory_index_tags_prev
|
|
|
- .iter()
|
|
|
- .copied()
|
|
|
- .sorted_unstable()
|
|
|
- .collect();
|
|
|
+ let mut memory_index_tags_prev_sorted = memory_index_tags_prev.clone();
|
|
|
+ memory_index_tags_prev_sorted.par_sort_unstable();
|
|
|
self.preprocessed_memory_index_tags_prev
|
|
|
.push_back(memory_index_tags_prev);
|
|
|
self.preprocessed_memory_index_tags_prev_sorted
|
|
@@ -520,7 +522,7 @@ where
|
|
|
.push_back(memory_index_tags_next);
|
|
|
self.preprocessed_memory_index_tags_next_sorted.push_back(
|
|
|
memory_index_tags_next_with_index_sorted
|
|
|
- .iter()
|
|
|
+ .par_iter()
|
|
|
.map(|(_, x)| *x)
|
|
|
.collect(),
|
|
|
);
|
|
@@ -744,13 +746,12 @@ where
|
|
|
// - pos_(i+1)(tag) -> index of tag in mem_idx_tags_next
|
|
|
|
|
|
let mask = self.pot_key_party.expand();
|
|
|
- for (&tag, val) in self
|
|
|
- .memory_index_tags_next_sorted
|
|
|
- .iter()
|
|
|
- .zip(garbled_memory_share_next.iter_mut())
|
|
|
- {
|
|
|
- *val += mask[self.pos_next(tag)];
|
|
|
- }
|
|
|
+ self.memory_index_tags_next_sorted
|
|
|
+ .par_iter()
|
|
|
+ .zip(garbled_memory_share_next.par_iter_mut())
|
|
|
+ .for_each(|(&tag, val)| {
|
|
|
+ *val += mask[self.pos_next(tag)];
|
|
|
+ });
|
|
|
comm.send_next(garbled_memory_share_next)?;
|
|
|
|
|
|
let t_after_pot_expand = Instant::now();
|
|
@@ -900,10 +901,10 @@ where
|
|
|
impl<F, MPDPF, SPDPF> DistributedOram<F> for DistributedOramProtocol<F, MPDPF, SPDPF>
|
|
|
where
|
|
|
F: FromPrf + LegendreSymbol + Serializable,
|
|
|
- F::PrfKey: Serializable,
|
|
|
- MPDPF: MultiPointDpf<Value = F>,
|
|
|
+ F::PrfKey: Serializable + Sync,
|
|
|
+ MPDPF: MultiPointDpf<Value = F> + Sync,
|
|
|
MPDPF::Key: Serializable,
|
|
|
- SPDPF: SinglePointDpf<Value = F>,
|
|
|
+ SPDPF: SinglePointDpf<Value = F> + Sync,
|
|
|
SPDPF::Key: Serializable + Sync,
|
|
|
{
|
|
|
fn get_party_id(&self) -> usize {
|
|
@@ -992,6 +993,7 @@ mod tests {
|
|
|
use dpf::mpdpf::DummyMpDpf;
|
|
|
use dpf::spdpf::DummySpDpf;
|
|
|
use ff::Field;
|
|
|
+ use itertools::izip;
|
|
|
use std::thread;
|
|
|
use utils::field::Fp;
|
|
|
|