浏览代码

oram: more parallelization

Lennart Braun 2 年之前
父节点
当前提交
3f46cf84a5
共有 2 个文件被更改,包括 50 次插入40 次删除
  1. 40 38
      oram/src/oram.rs
  2. 10 2
      oram/src/p_ot.rs

+ 40 - 38
oram/src/oram.rs

@@ -8,7 +8,7 @@ use crate::stash::{
 use communicator::{AbstractCommunicator, Fut, Serializable};
 use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
 use ff::PrimeField;
-use itertools::{izip, Itertools};
+use itertools::Itertools;
 use rand::thread_rng;
 use rayon::prelude::*;
 use std::collections::VecDeque;
@@ -236,10 +236,10 @@ where
 impl<F, MPDPF, SPDPF> DistributedOramProtocol<F, MPDPF, SPDPF>
 where
     F: FromPrf + LegendreSymbol + Serializable,
-    F::PrfKey: Serializable,
-    MPDPF: MultiPointDpf<Value = F>,
+    F::PrfKey: Serializable + Sync,
+    MPDPF: MultiPointDpf<Value = F> + Sync,
     MPDPF::Key: Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
+    SPDPF: SinglePointDpf<Value = F> + Sync,
     SPDPF::Key: Serializable + Sync,
 {
     pub fn new(party_id: usize, log_db_size: u32) -> Self {
@@ -372,28 +372,30 @@ where
         let fut_dpf_key_from_prev = comm.receive_previous()?;
         let fut_dpf_key_from_next = comm.receive_next()?;
 
-        let mut points = Vec::with_capacity(self.stash_size);
-        let mut values = Vec::with_capacity(self.stash_size);
         let (_, stash_values_share, stash_old_values_share) = self.stash.get_stash_share();
         assert_eq!(stash_values_share.len(), self.get_access_counter());
         assert_eq!(stash_old_values_share.len(), self.get_access_counter());
         assert_eq!(self.address_tags_read.len(), self.get_access_counter());
-        for (tag, val, old_val) in izip!(
-            self.address_tags_read.iter().copied(),
-            stash_values_share.iter().copied(),
-            stash_old_values_share.iter().copied()
-        ) {
-            points.push(self.pos_mine(tag) as u64);
-            values.push(val - old_val);
-        }
+        let mut points: Vec<_> = self
+            .address_tags_read
+            .par_iter()
+            .copied()
+            .map(|tag| self.pos_mine(tag) as u64)
+            .collect();
+        let values: Vec<_> = stash_values_share
+            .par_iter()
+            .copied()
+            .zip(stash_old_values_share.par_iter().copied())
+            .map(|(val, old_val)| val - old_val)
+            .collect();
         self.address_tags_read.truncate(0);
 
         // sort point, value pairs
         let (points, values): (Vec<u64>, Vec<F>) = {
             let mut indices: Vec<usize> = (0..points.len()).collect();
-            indices.sort_by_key(|&i| points[i]);
-            points.sort();
-            let new_values = indices.iter().map(|&i| values[i]).collect();
+            indices.par_sort_unstable_by_key(|&i| points[i]);
+            points.par_sort();
+            let new_values = indices.par_iter().map(|&i| values[i]).collect();
             (points, new_values)
         };
 
@@ -413,11 +415,14 @@ where
         {
             let mut memory_share = Vec::new();
             std::mem::swap(&mut self.memory_share, &mut memory_share);
-            for j in 0..self.memory_size {
-                memory_share[j] += new_memory_share_from_prev
-                    [self.pos_prev(self.memory_index_tags_prev[j])]
-                    + new_memory_share_from_next[self.pos_next(self.memory_index_tags_next[j])];
-            }
+            memory_share
+                .par_iter_mut()
+                .enumerate()
+                .for_each(|(j, mem_cell)| {
+                    *mem_cell += new_memory_share_from_prev
+                        [self.pos_prev(self.memory_index_tags_prev[j])]
+                        + new_memory_share_from_next[self.pos_next(self.memory_index_tags_next[j])];
+                });
             std::mem::swap(&mut self.memory_share, &mut memory_share);
         }
 
@@ -485,11 +490,8 @@ where
                 .into_par_iter()
                 .map(|j| LegendrePrf::eval_to_uint::<u128>(&lpk_prev, F::from_u128(j as u128)))
                 .collect();
-            let memory_index_tags_prev_sorted: Vec<_> = memory_index_tags_prev
-                .iter()
-                .copied()
-                .sorted_unstable()
-                .collect();
+            let mut memory_index_tags_prev_sorted = memory_index_tags_prev.clone();
+            memory_index_tags_prev_sorted.par_sort_unstable();
             self.preprocessed_memory_index_tags_prev
                 .push_back(memory_index_tags_prev);
             self.preprocessed_memory_index_tags_prev_sorted
@@ -520,7 +522,7 @@ where
                 .push_back(memory_index_tags_next);
             self.preprocessed_memory_index_tags_next_sorted.push_back(
                 memory_index_tags_next_with_index_sorted
-                    .iter()
+                    .par_iter()
                     .map(|(_, x)| *x)
                     .collect(),
             );
@@ -744,13 +746,12 @@ where
         // - pos_(i+1)(tag) -> index of tag in mem_idx_tags_next
 
         let mask = self.pot_key_party.expand();
-        for (&tag, val) in self
-            .memory_index_tags_next_sorted
-            .iter()
-            .zip(garbled_memory_share_next.iter_mut())
-        {
-            *val += mask[self.pos_next(tag)];
-        }
+        self.memory_index_tags_next_sorted
+            .par_iter()
+            .zip(garbled_memory_share_next.par_iter_mut())
+            .for_each(|(&tag, val)| {
+                *val += mask[self.pos_next(tag)];
+            });
         comm.send_next(garbled_memory_share_next)?;
 
         let t_after_pot_expand = Instant::now();
@@ -900,10 +901,10 @@ where
 impl<F, MPDPF, SPDPF> DistributedOram<F> for DistributedOramProtocol<F, MPDPF, SPDPF>
 where
     F: FromPrf + LegendreSymbol + Serializable,
-    F::PrfKey: Serializable,
-    MPDPF: MultiPointDpf<Value = F>,
+    F::PrfKey: Serializable + Sync,
+    MPDPF: MultiPointDpf<Value = F> + Sync,
     MPDPF::Key: Serializable,
-    SPDPF: SinglePointDpf<Value = F>,
+    SPDPF: SinglePointDpf<Value = F> + Sync,
     SPDPF::Key: Serializable + Sync,
 {
     fn get_party_id(&self) -> usize {
@@ -992,6 +993,7 @@ mod tests {
     use dpf::mpdpf::DummyMpDpf;
     use dpf::spdpf::DummySpDpf;
     use ff::Field;
+    use itertools::izip;
     use std::thread;
     use utils::field::Fp;
 

+ 10 - 2
oram/src/p_ot.rs

@@ -2,6 +2,7 @@ use crate::common::Error;
 use communicator::{AbstractCommunicator, Fut, Serializable};
 use core::marker::PhantomData;
 use ff::Field;
+use rayon::prelude::*;
 use utils::field::FromPrf;
 use utils::permutation::Permutation;
 
@@ -19,7 +20,12 @@ pub struct POTKeyParty<F: FromPrf, Perm> {
     _phantom: PhantomData<F>,
 }
 
-impl<F: Field + FromPrf, Perm: Permutation> POTKeyParty<F, Perm> {
+impl<F, Perm> POTKeyParty<F, Perm>
+where
+    F: Field + FromPrf,
+    F::PrfKey: Sync,
+    Perm: Permutation + Sync,
+{
     pub fn new(domain_size: usize) -> Self {
         Self {
             domain_size,
@@ -66,6 +72,7 @@ impl<F: Field + FromPrf, Perm: Permutation> POTKeyParty<F, Perm> {
     pub fn expand(&self) -> Vec<F> {
         assert!(self.is_initialized);
         (0..self.domain_size)
+            .into_par_iter()
             .map(|x| {
                 let pi_x = self.permutation.as_ref().unwrap().permute(x);
                 F::prf(&self.prf_key_i.unwrap(), pi_x as u64)
@@ -210,7 +217,8 @@ mod tests {
     fn test_pot<F, Perm>(log_domain_size: u32)
     where
         F: Field + FromPrf,
-        Perm: Permutation,
+        F::PrfKey: Sync,
+        Perm: Permutation + Sync,
     {
         let domain_size = 1 << log_domain_size;