Browse Source

oram: add preprocessing for select

Lennart Braun 1 year ago
parent
commit
8ddcb8717a
3 changed files with 242 additions and 68 deletions
  1. 45 6
      oram/src/oram.rs
  2. 174 58
      oram/src/select.rs
  3. 23 4
      oram/src/stash.rs

+ 45 - 6
oram/src/oram.rs

@@ -68,6 +68,7 @@ pub enum ProtocolStep {
     PreprocessStash,
     PreprocessDOPrf,
     PreprocessPOt,
+    PreprocessSelect,
     AccessStashRead,
     AccessAddressSelection,
     AccessDatabaseRead,
@@ -90,7 +91,7 @@ pub enum ProtocolStep {
 
 #[derive(Debug, Default, Clone, Copy)]
 pub struct Runtimes {
-    durations: [Duration; 28],
+    durations: [Duration; 29],
     stash_runtimes: StashRuntimes,
 }
 
@@ -215,6 +216,7 @@ where
     preprocessed_memory_index_tags_prev_sorted: VecDeque<Vec<u128>>,
     preprocessed_memory_index_tags_next_sorted: VecDeque<Vec<u128>>,
     preprocessed_stash: VecDeque<StashProtocol<F, SPDPF>>,
+    preprocessed_select: VecDeque<SelectProtocol<F>>,
     preprocessed_doprf: VecDeque<JointDOPrf<F>>,
     preprocessed_pot: VecDeque<JointPOTParties<F, FisherYatesPermutation>>,
     preprocessed_pot_expands: VecDeque<Vec<F>>,
@@ -227,6 +229,7 @@ where
     is_initialized: bool,
     address_tags_read: Vec<u128>,
     stash: Option<StashProtocol<F, SPDPF>>,
+    select_party: Option<SelectProtocol<F>>,
     joint_doprf: Option<JointDOPrf<F>>,
     legendre_prf_key_next: Option<LegendrePrfKey<F>>,
     legendre_prf_key_prev: Option<LegendrePrfKey<F>>,
@@ -267,6 +270,7 @@ where
             preprocessed_memory_index_tags_prev_sorted: Default::default(),
             preprocessed_memory_index_tags_next_sorted: Default::default(),
             preprocessed_stash: Default::default(),
+            preprocessed_select: Default::default(),
             preprocessed_doprf: Default::default(),
             preprocessed_pot: Default::default(),
             preprocessed_pot_expands: Default::default(),
@@ -279,6 +283,7 @@ where
             is_initialized: false,
             address_tags_read: Default::default(),
             stash: None,
+            select_party: None,
             joint_doprf: None,
             legendre_prf_key_next: None,
             legendre_prf_key_prev: None,
@@ -482,6 +487,11 @@ where
             .reserve(number_epochs);
         self.preprocessed_memory_index_tags_mine_sorted
             .reserve(number_epochs);
+        self.preprocessed_stash.reserve(number_epochs);
+        self.preprocessed_select.reserve(number_epochs);
+        self.preprocessed_doprf.reserve(number_epochs);
+        self.preprocessed_pot.reserve(number_epochs);
+        self.preprocessed_pot_expands.reserve(number_epochs);
 
         let t_start = Instant::now();
 
@@ -603,6 +613,19 @@ where
 
         let t_after_preprocess_pot = Instant::now();
 
+        self.preprocessed_select
+            .extend((0..number_epochs).map(|_| SelectProtocol::default()));
+        for select in self
+            .preprocessed_select
+            .iter_mut()
+            .skip(already_preprocessed)
+        {
+            select.init(comm)?;
+            select.preprocess(comm, 2 * self.stash_size)?;
+        }
+
+        let t_after_preprocess_select = Instant::now();
+
         self.number_preprocessed_epochs += number_epochs;
 
         debug_assert_eq!(
@@ -646,6 +669,10 @@ where
             self.preprocessed_pot_expands.len(),
             self.number_preprocessed_epochs
         );
+        debug_assert_eq!(
+            self.preprocessed_select.len(),
+            self.number_preprocessed_epochs
+        );
 
         let runtimes = runtimes.map(|mut r| {
             r.record(
@@ -688,6 +715,10 @@ where
                 ProtocolStep::PreprocessPOt,
                 t_after_preprocess_pot - t_after_init_doprf,
             );
+            r.record(
+                ProtocolStep::PreprocessSelect,
+                t_after_preprocess_select - t_after_preprocess_pot,
+            );
             r
         });
 
@@ -729,10 +760,14 @@ where
         debug_assert!(self.joint_doprf.is_some());
 
         // c) p-OT
-        self.joint_pot = Some(self.preprocessed_pot.pop_front().unwrap());
+        self.joint_pot = self.preprocessed_pot.pop_front();
+        debug_assert!(self.joint_pot.is_some());
+
+        // d) select
+        self.select_party = self.preprocessed_select.pop_front();
         debug_assert!(self.joint_pot.is_some());
 
-        // d) Retrieve preprocessed index tags
+        // e) Retrieve preprocessed index tags
         self.memory_index_tags_prev = self
             .preprocessed_memory_index_tags_prev
             .pop_front()
@@ -859,7 +894,7 @@ where
             PARTY_1 => F::from_u128((1 << self.log_db_size) + self.get_access_counter() as u128),
             _ => F::ZERO,
         };
-        let db_address_share = SelectProtocol::select(
+        let db_address_share = self.select_party.as_mut().unwrap().select(
             comm,
             stash_state.flag,
             dummy_address_share,
@@ -884,8 +919,12 @@ where
         let t_after_stash_write = Instant::now();
 
         // 5. Select the right value to return
-        let read_value =
-            SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
+        let read_value = self.select_party.as_mut().unwrap().select(
+            comm,
+            stash_state.flag,
+            stash_state.value,
+            db_value_share,
+        )?;
         let t_after_value_selection = Instant::now();
 
         // 6. If the stash is full, write the value back into the database

+ 174 - 58
oram/src/select.rs

@@ -1,16 +1,28 @@
 use crate::common::Error;
 use communicator::{AbstractCommunicator, Fut, Serializable};
 use ff::Field;
-use rand::thread_rng;
+use itertools::izip;
+use rand::{thread_rng, Rng, SeedableRng};
+use rand_chacha::ChaChaRng;
+use std::collections::VecDeque;
 
 /// Select between two shared value <a>, <b> based on a shared condition bit <c>:
 /// Output <w> <- if <c> then <a> else <b>.
 pub trait Select<F> {
+    fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error>;
+
+    fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error>;
+
     fn select<C: AbstractCommunicator>(
+        &mut self,
         comm: &mut C,
+        c_share: F,
+        a_share: F,
         b_share: F,
-        x_share: F,
-        y_share: F,
     ) -> Result<F, Error>;
 }
 
@@ -26,13 +38,92 @@ fn other_compute_party(my_id: usize) -> usize {
     }
 }
 
-pub struct SelectProtocol {}
+#[derive(Default)]
+pub struct SelectProtocol<F> {
+    shared_prg_1: Option<ChaChaRng>,
+    shared_prg_2: Option<ChaChaRng>,
+    shared_prg_3: Option<ChaChaRng>,
+    is_initialized: bool,
+    num_preprocessed_invocations: usize,
+    preprocessed_mt_x: VecDeque<F>,
+    preprocessed_mt_y: VecDeque<F>,
+    preprocessed_mt_z: VecDeque<F>,
+    preprocessed_c_1_2: VecDeque<F>,
+    preprocessed_amb_1_2: VecDeque<F>,
+}
 
-impl<F> Select<F> for SelectProtocol
+impl<F> Select<F> for SelectProtocol<F>
 where
     F: Field + Serializable,
 {
+    fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        if comm.get_my_id() == PARTY_1 {
+            self.shared_prg_2 = Some(ChaChaRng::from_seed(thread_rng().gen()));
+            comm.send(PARTY_2, self.shared_prg_2.as_ref().unwrap().get_seed())?;
+            self.shared_prg_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
+            comm.send(PARTY_3, self.shared_prg_3.as_ref().unwrap().get_seed())?;
+        } else {
+            let fut_seed = comm.receive(PARTY_1)?;
+            self.shared_prg_1 = Some(ChaChaRng::from_seed(fut_seed.get()?));
+        }
+        self.is_initialized = true;
+        Ok(())
+    }
+
+    fn preprocess<C: AbstractCommunicator>(&mut self, comm: &mut C, n: usize) -> Result<(), Error> {
+        assert!(self.is_initialized);
+
+        let my_id = comm.get_my_id();
+
+        if my_id == PARTY_1 {
+            let x2s: Vec<F> = (0..n)
+                .map(|_| F::random(self.shared_prg_2.as_mut().unwrap()))
+                .collect();
+            let y2s: Vec<F> = (0..n)
+                .map(|_| F::random(self.shared_prg_2.as_mut().unwrap()))
+                .collect();
+            let z2s: Vec<F> = (0..n)
+                .map(|_| F::random(self.shared_prg_2.as_mut().unwrap()))
+                .collect();
+            let x3s: Vec<F> = (0..n)
+                .map(|_| F::random(self.shared_prg_3.as_mut().unwrap()))
+                .collect();
+            let y3s: Vec<F> = (0..n)
+                .map(|_| F::random(self.shared_prg_3.as_mut().unwrap()))
+                .collect();
+            let z3s: Vec<F> = (0..n)
+                .map(|_| F::random(self.shared_prg_3.as_mut().unwrap()))
+                .collect();
+
+            let z1s = izip!(x2s, y2s, z2s, x3s, y3s, z3s)
+                .map(|(x_2, y_2, z_2, x_3, y_3, z_3)| (x_2 + x_3) * (y_2 + y_3) - z_2 - z_3);
+            self.preprocessed_mt_z.extend(z1s);
+
+            self.preprocessed_c_1_2
+                .extend((0..n).map(|_| F::random(self.shared_prg_2.as_mut().unwrap())));
+            self.preprocessed_amb_1_2
+                .extend((0..n).map(|_| F::random(self.shared_prg_2.as_mut().unwrap())));
+        } else {
+            self.preprocessed_mt_x
+                .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
+            self.preprocessed_mt_y
+                .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
+            self.preprocessed_mt_z
+                .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
+            if my_id == PARTY_2 {
+                self.preprocessed_c_1_2
+                    .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
+                self.preprocessed_amb_1_2
+                    .extend((0..n).map(|_| F::random(self.shared_prg_1.as_mut().unwrap())));
+            }
+        }
+
+        self.num_preprocessed_invocations += n;
+        Ok(())
+    }
+
     fn select<C: AbstractCommunicator>(
+        &mut self,
         comm: &mut C,
         c_share: F,
         a_share: F,
@@ -40,46 +131,45 @@ where
     ) -> Result<F, Error> {
         let my_id = comm.get_my_id();
 
-        let output = b_share
-            + if my_id == PARTY_1 {
-                let mut rng = thread_rng();
-                // create multiplication triple
-                let x_2 = F::random(&mut rng);
-                let x_3 = F::random(&mut rng);
-                let y_2 = F::random(&mut rng);
-                let y_3 = F::random(&mut rng);
-                let z_2 = F::random(&mut rng);
-                let z_3 = F::random(&mut rng);
-                let z_1 = (x_2 + x_3) * (y_2 + y_3) - z_2 - z_3;
-                debug_assert_eq!((x_2 + x_3) * (y_2 + y_3), z_1 + z_2 + z_3);
-                let c_1_2 = F::random(&mut rng);
-                let amb_1_2 = F::random(&mut rng);
-                let c_1_3 = c_share - c_1_2;
-                let amb_1_3 = (a_share - b_share) - amb_1_2;
-
-                comm.send(PARTY_2, (x_2, y_2, z_2, c_1_2, amb_1_2))?;
-                comm.send(PARTY_3, (x_3, y_3, z_3, c_1_3, amb_1_3))?;
-
-                z_1
+        // if further preprocessing is needed, do it now
+        if self.num_preprocessed_invocations == 0 {
+            self.preprocess(comm, 1)?;
+        }
+        self.num_preprocessed_invocations -= 1;
+
+        if my_id == PARTY_1 {
+            let c_1_2 = self.preprocessed_c_1_2.pop_front().unwrap();
+            let amb_1_2 = self.preprocessed_amb_1_2.pop_front().unwrap();
+            comm.send(PARTY_3, (c_share - c_1_2, (a_share - b_share) - amb_1_2))?;
+            let z = self.preprocessed_mt_z.pop_front().unwrap();
+            Ok(b_share + z)
+        } else {
+            let (c_1_i, amb_1_i) = if my_id == PARTY_2 {
+                (
+                    self.preprocessed_c_1_2.pop_front().unwrap(),
+                    self.preprocessed_amb_1_2.pop_front().unwrap(),
+                )
             } else {
-                let fut_xzy = comm.receive::<(F, F, F, F, F)>(PARTY_1)?;
-                let fut_de = comm.receive::<(F, F)>(other_compute_party(my_id))?;
-                let (x_i, y_i, mut z_i, c_1_i, amb_1_i) = fut_xzy.get()?;
-                let d_i = (c_share + c_1_i) - x_i;
-                let e_i = (a_share - b_share + amb_1_i) - y_i;
-                comm.send(other_compute_party(my_id), (d_i, e_i))?;
-                let (d_j, e_j) = fut_de.get()?;
-                let (d, e) = (d_i + d_j, e_i + e_j);
-
-                z_i += e * (c_share + c_1_i) + d * (a_share - b_share + amb_1_i);
-                if my_id == PARTY_2 {
-                    z_i -= d * e;
-                }
-
-                z_i
+                let fut_1 = comm.receive::<(F, F)>(PARTY_1)?;
+                fut_1.get()?
             };
+            let fut_de = comm.receive::<(F, F)>(other_compute_party(my_id))?;
+            let x_i = self.preprocessed_mt_x.pop_front().unwrap();
+            let y_i = self.preprocessed_mt_y.pop_front().unwrap();
+            let mut z_i = self.preprocessed_mt_z.pop_front().unwrap();
+            let d_i = (c_share + c_1_i) - x_i;
+            let e_i = (a_share - b_share + amb_1_i) - y_i;
+            comm.send(other_compute_party(my_id), (d_i, e_i))?;
+            let (d_j, e_j) = fut_de.get()?;
+            let (d, e) = (d_i + d_j, e_i + e_j);
 
-        Ok(output)
+            z_i += e * (c_share + c_1_i) + d * (a_share - b_share + amb_1_i);
+            if my_id == PARTY_2 {
+                z_i -= d * e;
+            }
+
+            Ok(b_share + z_i)
+        }
     }
 }
 
@@ -90,23 +180,41 @@ mod tests {
     use std::thread;
     use utils::field::Fp;
 
-    fn run_select<Proto: Select<F>, F>(
+    fn run_init<Proto: Select<F> + Send + 'static, F>(
+        mut comm: impl AbstractCommunicator + Send + 'static,
+        mut proto: Proto,
+    ) -> thread::JoinHandle<(impl AbstractCommunicator, Proto)>
+    where
+        F: Field + Serializable,
+    {
+        thread::spawn(move || {
+            proto.init(&mut comm).unwrap();
+            (comm, proto)
+        })
+    }
+
+    fn run_select<Proto: Select<F> + Send + 'static, F>(
         mut comm: impl AbstractCommunicator + Send + 'static,
+        mut proto: Proto,
         c_share: F,
         a_share: F,
         b_share: F,
-    ) -> thread::JoinHandle<(impl AbstractCommunicator, F)>
+    ) -> thread::JoinHandle<(impl AbstractCommunicator, Proto, F)>
     where
         F: Field + Serializable,
     {
         thread::spawn(move || {
-            let result = Proto::select(&mut comm, c_share, a_share, b_share);
-            (comm, result.unwrap())
+            let result = proto.select(&mut comm, c_share, a_share, b_share);
+            (comm, proto, result.unwrap())
         })
     }
 
     #[test]
     fn test_select() {
+        let proto_1 = SelectProtocol::<Fp>::default();
+        let proto_2 = SelectProtocol::<Fp>::default();
+        let proto_3 = SelectProtocol::<Fp>::default();
+
         let (comm_3, comm_2, comm_1) = {
             let mut comms = make_unix_communicators(3);
             (
@@ -115,6 +223,14 @@ mod tests {
                 comms.pop().unwrap(),
             )
         };
+
+        let h1 = run_init(comm_1, proto_1);
+        let h2 = run_init(comm_2, proto_2);
+        let h3 = run_init(comm_3, proto_3);
+        let (comm_1, proto_1) = h1.join().unwrap();
+        let (comm_2, proto_2) = h2.join().unwrap();
+        let (comm_3, proto_3) = h3.join().unwrap();
+
         let mut rng = thread_rng();
 
         let (a_1, a_2, a_3) = (
@@ -135,23 +251,23 @@ mod tests {
         let c1_1 = Fp::ONE - c_2 - c_3;
 
         // check for <c> = <0>
-        let h1 = run_select::<SelectProtocol, _>(comm_1, c0_1, a_1, b_1);
-        let h2 = run_select::<SelectProtocol, _>(comm_2, c_2, a_2, b_2);
-        let h3 = run_select::<SelectProtocol, _>(comm_3, c_3, a_3, b_3);
-        let (comm_1, x_1) = h1.join().unwrap();
-        let (comm_2, x_2) = h2.join().unwrap();
-        let (comm_3, x_3) = h3.join().unwrap();
+        let h1 = run_select(comm_1, proto_1, c0_1, a_1, b_1);
+        let h2 = run_select(comm_2, proto_2, c_2, a_2, b_2);
+        let h3 = run_select(comm_3, proto_3, c_3, a_3, b_3);
+        let (comm_1, proto_1, x_1) = h1.join().unwrap();
+        let (comm_2, proto_2, x_2) = h2.join().unwrap();
+        let (comm_3, proto_3, x_3) = h3.join().unwrap();
 
         assert_eq!(c0_1 + c_2 + c_3, Fp::ZERO);
         assert_eq!(x_1 + x_2 + x_3, b);
 
         // check for <c> = <1>
-        let h1 = run_select::<SelectProtocol, _>(comm_1, c1_1, a_1, b_1);
-        let h2 = run_select::<SelectProtocol, _>(comm_2, c_2, a_2, b_2);
-        let h3 = run_select::<SelectProtocol, _>(comm_3, c_3, a_3, b_3);
-        let (_, y_1) = h1.join().unwrap();
-        let (_, y_2) = h2.join().unwrap();
-        let (_, y_3) = h3.join().unwrap();
+        let h1 = run_select(comm_1, proto_1, c1_1, a_1, b_1);
+        let h2 = run_select(comm_2, proto_2, c_2, a_2, b_2);
+        let h3 = run_select(comm_3, proto_3, c_3, a_3, b_3);
+        let (_, _, y_1) = h1.join().unwrap();
+        let (_, _, y_2) = h2.join().unwrap();
+        let (_, _, y_3) = h3.join().unwrap();
 
         assert_eq!(c1_1 + c_2 + c_3, Fp::ONE);
         assert_eq!(y_1 + y_2 + y_3, a);

+ 23 - 4
oram/src/stash.rs

@@ -123,6 +123,7 @@ where
     stash_values_share: Vec<F>,
     stash_old_values_share: Vec<F>,
     address_tag_list: Vec<u64>,
+    select_party: Option<SelectProtocol<F>>,
     doprf_party_1: Option<DOPrfParty1<F>>,
     doprf_party_2: Option<DOPrfParty2<F>>,
     doprf_party_3: Option<DOPrfParty3<F>>,
@@ -156,6 +157,7 @@ where
             } else {
                 Vec::with_capacity(stash_size)
             },
+            select_party: None,
             doprf_party_1: None,
             doprf_party_2: None,
             doprf_party_3: None,
@@ -213,6 +215,14 @@ where
             _ => panic!("invalid party id"),
         }
 
+        // run Select initialiation and preprocessing
+        {
+            let mut select_party = SelectProtocol::default();
+            select_party.init(comm)?;
+            select_party.preprocess(comm, 3 * self.stash_size)?;
+            self.select_party = Some(select_party);
+        }
+
         let t_end = Instant::now();
         let runtimes = runtimes.map(|mut r| {
             r.record(ProtocolStep::Init, t_end - t_start);
@@ -341,7 +351,12 @@ where
             } else {
                 F::ZERO
             };
-            SelectProtocol::select(comm, flag_share, location_share, access_counter_share)?
+            self.select_party.as_mut().unwrap().select(
+                comm,
+                flag_share,
+                location_share,
+                access_counter_share,
+            )?
         };
 
         let t_after_location_share = Instant::now();
@@ -518,10 +533,14 @@ where
         let t_after_store_triple = Instant::now();
 
         // 3. Update stash
-        let previous_value_share =
-            SelectProtocol::select(comm, stash_state.flag, stash_state.value, db_value_share)?;
+        let previous_value_share = self.select_party.as_mut().unwrap().select(
+            comm,
+            stash_state.flag,
+            stash_state.value,
+            db_value_share,
+        )?;
         let t_after_select_previous_value = Instant::now();
-        let value_share = SelectProtocol::select(
+        let value_share = self.select_party.as_mut().unwrap().select(
             comm,
             instruction.operation,
             instruction.value - previous_value_share,