浏览代码

oram: parallelize prf and spdpf evaluations

Lennart Braun 2 年之前
父节点
当前提交
a3cfb6b791
共有 4 个文件被更改,包括 53 次插入43 次删除
  1. 1 0
      oram/Cargo.toml
  2. 10 1
      oram/examples/bench_doram.rs
  3. 11 12
      oram/src/oram.rs
  4. 31 30
      oram/src/stash.rs

+ 1 - 0
oram/Cargo.toml

@@ -18,6 +18,7 @@ num-bigint = "0.4.3"
 num-traits = "0.2.15"
 rand = "0.8.5"
 rand_chacha = "0.3.1"
+rayon = "1.6.1"
 strum = { version = "0.24.1", features = ["derive"] }
 strum_macros = "0.24"
 

+ 10 - 1
oram/examples/bench_doram.rs

@@ -9,6 +9,7 @@ use oram::common::{InstructionShare, Operation};
 use oram::oram::{DistributedOram, DistributedOramProtocol, Runtimes};
 use rand::{Rng, SeedableRng};
 use rand_chacha::ChaChaRng;
+use rayon;
 use std::process;
 use std::time::Instant;
 use utils::field::Fp;
@@ -27,6 +28,9 @@ struct Cli {
     /// Use preprocessing
     #[arg(long)]
     pub preprocess: bool,
+    /// How many threads to use for the computation
+    #[arg(long, short = 't', default_value_t = 1)]
+    pub threads: usize,
     /// Which address to listen on for incoming connections
     #[arg(long, short = 'l')]
     pub listen_host: String,
@@ -37,7 +41,7 @@ struct Cli {
     #[arg(long, short = 'c', value_name = "PARTY_ID>:<HOST>:<PORT", value_parser = parse_connect)]
     pub connect: Vec<(usize, String, u16)>,
     /// How long to try connecting before aborting
-    #[arg(long, short = 't', default_value_t = 10)]
+    #[arg(long, default_value_t = 10)]
     pub connect_timeout_seconds: usize,
 }
 
@@ -87,6 +91,11 @@ fn main() {
         connect_timeout_seconds: cli.connect_timeout_seconds,
     };
 
+    rayon::ThreadPoolBuilder::new()
+        .num_threads(cli.threads)
+        .build_global()
+        .unwrap();
+
     for c in cli.connect {
         if netopts.connect_info[c.0] != NetworkPartyInfo::Listen {
             println!(

+ 11 - 12
oram/src/oram.rs

@@ -10,6 +10,7 @@ use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
 use ff::PrimeField;
 use itertools::{izip, Itertools};
 use rand::thread_rng;
+use rayon::prelude::*;
 use std::collections::VecDeque;
 use std::iter::repeat;
 use std::marker::PhantomData;
@@ -235,7 +236,7 @@ where
     MPDPF: MultiPointDpf<Value = F>,
     MPDPF::Key: Serializable,
     SPDPF: SinglePointDpf<Value = F>,
-    SPDPF::Key: Serializable,
+    SPDPF::Key: Serializable + Sync,
 {
     pub fn new(party_id: usize, log_db_size: u32) -> Self {
         assert!(party_id < 3);
@@ -476,11 +477,10 @@ where
 
         // Compute memory index tags
         for lpk_prev in new_lpks_prev {
-            let mut memory_index_tags_prev = Vec::with_capacity(self.memory_size);
-            memory_index_tags_prev
-                .extend((0..self.memory_size).map(|j| {
-                    LegendrePrf::eval_to_uint::<u128>(&lpk_prev, F::from_u128(j as u128))
-                }));
+            let memory_index_tags_prev: Vec<_> = (0..self.memory_size)
+                .into_par_iter()
+                .map(|j| LegendrePrf::eval_to_uint::<u128>(&lpk_prev, F::from_u128(j as u128)))
+                .collect();
             let memory_index_tags_prev_sorted: Vec<_> = memory_index_tags_prev
                 .iter()
                 .copied()
@@ -502,11 +502,10 @@ where
         let t_after_receiving_lpks_next = Instant::now();
 
         for lpk_next in new_lpks_next {
-            let mut memory_index_tags_next = Vec::with_capacity(self.memory_size);
-            memory_index_tags_next
-                .extend((0..self.memory_size).map(|j| {
-                    LegendrePrf::eval_to_uint::<u128>(&lpk_next, F::from_u128(j as u128))
-                }));
+            let memory_index_tags_next: Vec<_> = (0..self.memory_size)
+                .into_par_iter()
+                .map(|j| LegendrePrf::eval_to_uint::<u128>(&lpk_next, F::from_u128(j as u128)))
+                .collect();
             let memory_index_tags_next_with_index_sorted: Vec<_> = memory_index_tags_next
                 .iter()
                 .copied()
@@ -893,7 +892,7 @@ where
     MPDPF: MultiPointDpf<Value = F>,
     MPDPF::Key: Serializable,
     SPDPF: SinglePointDpf<Value = F>,
-    SPDPF::Key: Serializable,
+    SPDPF::Key: Serializable + Sync,
 {
     fn get_party_id(&self) -> usize {
         self.party_id

+ 31 - 30
oram/src/stash.rs

@@ -9,6 +9,7 @@ use communicator::{AbstractCommunicator, Fut, Serializable};
 use dpf::spdpf::SinglePointDpf;
 use ff::PrimeField;
 use rand::thread_rng;
+use rayon::prelude::*;
 use std::marker::PhantomData;
 use std::time::{Duration, Instant};
 use utils::field::LegendreSymbol;
@@ -135,7 +136,7 @@ impl<F, SPDPF> StashProtocol<F, SPDPF>
 where
     F: PrimeField + LegendreSymbol + Serializable,
     SPDPF: SinglePointDpf<Value = F>,
-    SPDPF::Key: Serializable,
+    SPDPF::Key: Serializable + Sync,
 {
     pub fn new(party_id: usize, stash_size: usize) -> Self {
         assert!(party_id < 3);
@@ -218,7 +219,6 @@ where
             r
         });
 
-        // panic!("not implemented");
         self.state = State::AwaitingRead;
         Ok(runtimes)
     }
@@ -311,18 +311,16 @@ where
                 // 3. Compute shares of <flag>, <loc>, i.e., if the address is present in the stash and if
                 //    so, where it is
                 {
-                    let mut flag_share = F::ZERO;
-                    let mut location_share = F::ZERO;
-                    let mut j_as_field_element = F::ZERO;
-                    for j in 0..self.address_tag_list.len() {
-                        let dpf_value_j = SPDPF::evaluate_at(
-                            &dpf_key_i,
-                            self.address_tag_list[j] ^ address_tag_mask,
-                        );
-                        flag_share += dpf_value_j;
-                        location_share += j_as_field_element * dpf_value_j;
-                        j_as_field_element += F::ONE;
-                    }
+                    let (flag_share, location_share) = self
+                        .address_tag_list
+                        .par_iter()
+                        .enumerate()
+                        .map(|(j, tag_j)| {
+                            let dpf_value_j =
+                                SPDPF::evaluate_at(&dpf_key_i, tag_j ^ address_tag_mask);
+                            (dpf_value_j, F::from_u128(j as u128) * dpf_value_j)
+                        })
+                        .reduce(|| (F::ZERO, F::ZERO), |(a, b), (c, d)| (a + c, b + d));
                     let t_after_compute_flag_loc = Instant::now();
                     (
                         flag_share,
@@ -399,15 +397,15 @@ where
             let dpf_key_prev = fut_prev.get()?;
             let dpf_key_next = fut_next.get()?;
             let t_after_dpf_key_distr = Instant::now();
-            let mut value_share = F::ZERO;
-            for j in 0..self.access_counter {
-                let index_prev = ((j as u16 + r_prev) & bit_mask) as u64;
-                let index_next = ((j as u16 + r_next) & bit_mask) as u64;
-                value_share +=
-                    SPDPF::evaluate_at(&dpf_key_prev, index_prev) * self.stash_values_share[j];
-                value_share +=
-                    SPDPF::evaluate_at(&dpf_key_next, index_next) * stash_values_share_prev[j];
-            }
+            let value_share: F = (0..self.access_counter)
+                .into_par_iter()
+                .map(|j| {
+                    let index_prev = ((j as u16 + r_prev) & bit_mask) as u64;
+                    let index_next = ((j as u16 + r_next) & bit_mask) as u64;
+                    SPDPF::evaluate_at(&dpf_key_prev, index_prev) * self.stash_values_share[j]
+                        + SPDPF::evaluate_at(&dpf_key_next, index_next) * stash_values_share_prev[j]
+                })
+                .sum();
             (
                 value_share,
                 t_after_convert_to_replicated,
@@ -565,12 +563,15 @@ where
             let dpf_key_prev = fut_prev.get()?;
             let dpf_key_next = fut_next.get()?;
             let t_after_dpf_key_distr = Instant::now();
-            for j in 0..=self.access_counter {
-                let index_prev = ((j as u16).wrapping_add(r_prev) & bit_mask) as u64;
-                let index_next = ((j as u16).wrapping_add(r_next) & bit_mask) as u64;
-                self.stash_values_share[j] += SPDPF::evaluate_at(&dpf_key_prev, index_prev);
-                self.stash_values_share[j] += SPDPF::evaluate_at(&dpf_key_next, index_next);
-            }
+            self.stash_values_share
+                .par_iter_mut()
+                .enumerate()
+                .for_each(|(j, svs_j)| {
+                    let index_prev = ((j as u16).wrapping_add(r_prev) & bit_mask) as u64;
+                    let index_next = ((j as u16).wrapping_add(r_next) & bit_mask) as u64;
+                    *svs_j += SPDPF::evaluate_at(&dpf_key_prev, index_prev)
+                        + SPDPF::evaluate_at(&dpf_key_next, index_next);
+                });
             (t_after_masked_index, t_after_dpf_key_distr)
         };
         let t_after_dpf_eval = Instant::now();
@@ -619,7 +620,7 @@ impl<F, SPDPF> Stash<F> for StashProtocol<F, SPDPF>
 where
     F: PrimeField + LegendreSymbol + Serializable,
     SPDPF: SinglePointDpf<Value = F>,
-    SPDPF::Key: Serializable,
+    SPDPF::Key: Serializable + Sync,
 {
     fn get_party_id(&self) -> usize {
         self.party_id