Quellcode durchsuchen

oram: preprocess init of functionalities

Lennart Braun vor 2 Jahren
Ursprung
Commit
d0518c3892
1 geänderte Dateien mit 131 neuen und 72 gelöschten Zeilen
  1. 131 72
      oram/src/oram.rs

+ 131 - 72
oram/src/oram.rs

@@ -49,8 +49,8 @@ where
 }
 
 const PARTY_1: usize = 0;
-const PARTY_2: usize = 1;
-const PARTY_3: usize = 2;
+// const PARTY_2: usize = 1;
+// const PARTY_3: usize = 2;
 
 fn compute_oram_prf_output_bitsize(memory_size: usize) -> usize {
     (usize::BITS - memory_size.leading_zeros()) as usize + 40
@@ -65,6 +65,9 @@ pub enum ProtocolStep {
     PreprocessLPRFEvalSortNext,
     PreprocessMpDpdfPrecomp,
     PreprocessRecvTagsMine,
+    PreprocessStash,
+    PreprocessDOPrf,
+    PreprocessPOt,
     AccessStashRead,
     AccessAddressSelection,
     AccessDatabaseRead,
@@ -79,9 +82,6 @@ pub enum ProtocolStep {
     DbWriteUpdateMemory,
     RefreshJitPreprocess,
     RefreshResetFuncs,
-    RefreshInitStash,
-    RefreshInitDOPrf,
-    RefreshInitPOt,
     RefreshGetPreproc,
     RefreshSorting,
     RefreshPOtExpandMasking,
@@ -214,6 +214,10 @@ where
     preprocessed_memory_index_tags_mine_sorted: VecDeque<Vec<u128>>,
     preprocessed_memory_index_tags_prev_sorted: VecDeque<Vec<u128>>,
     preprocessed_memory_index_tags_next_sorted: VecDeque<Vec<u128>>,
+    preprocessed_stash: VecDeque<StashProtocol<F, SPDPF>>,
+    preprocessed_doprf: VecDeque<JointDOPrf<F>>,
+    preprocessed_pot: VecDeque<JointPOTParties<F, FisherYatesPermutation>>,
+    preprocessed_pot_expands: VecDeque<Vec<F>>,
     memory_index_tags_prev: Vec<u128>,
     memory_index_tags_next: Vec<u128>,
     memory_index_tags_prev_sorted: Vec<u128>,
@@ -222,11 +226,11 @@ where
     garbled_memory_share: Vec<F>,
     is_initialized: bool,
     address_tags_read: Vec<u128>,
-    stash: StashProtocol<F, SPDPF>,
-    joint_doprf: JointDOPrf<F>,
+    stash: Option<StashProtocol<F, SPDPF>>,
+    joint_doprf: Option<JointDOPrf<F>>,
     legendre_prf_key_next: Option<LegendrePrfKey<F>>,
     legendre_prf_key_prev: Option<LegendrePrfKey<F>>,
-    joint_pot: JointPOTParties<F, FisherYatesPermutation>,
+    joint_pot: Option<JointPOTParties<F, FisherYatesPermutation>>,
     mpdpf: MPDPF,
     _phantom: PhantomData<MPDPF>,
 }
@@ -262,6 +266,10 @@ where
             preprocessed_memory_index_tags_mine_sorted: Default::default(),
             preprocessed_memory_index_tags_prev_sorted: Default::default(),
             preprocessed_memory_index_tags_next_sorted: Default::default(),
+            preprocessed_stash: Default::default(),
+            preprocessed_doprf: Default::default(),
+            preprocessed_pot: Default::default(),
+            preprocessed_pot_expands: Default::default(),
             memory_index_tags_prev: Default::default(),
             memory_index_tags_next: Default::default(),
             memory_index_tags_prev_sorted: Default::default(),
@@ -270,22 +278,22 @@ where
             garbled_memory_share: Default::default(),
             is_initialized: false,
             address_tags_read: Default::default(),
-            stash: StashProtocol::new(party_id, stash_size),
-            joint_doprf: JointDOPrf::new(prf_output_bitsize),
+            stash: None,
+            joint_doprf: None,
             legendre_prf_key_next: None,
             legendre_prf_key_prev: None,
-            joint_pot: JointPOTParties::new(memory_size),
+            joint_pot: None,
             mpdpf: MPDPF::new(memory_size, stash_size),
             _phantom: PhantomData,
         }
     }
 
     pub fn get_access_counter(&self) -> usize {
-        self.stash.get_access_counter()
+        self.stash.as_ref().unwrap().get_access_counter()
     }
 
     pub fn get_stash(&self) -> &StashProtocol<F, SPDPF> {
-        &self.stash
+        self.stash.as_ref().unwrap()
     }
 
     fn pos_prev(&self, tag: u128) -> usize {
@@ -320,7 +328,11 @@ where
         let t_start = Instant::now();
 
         // 1. Compute address tag
-        let address_tag: u128 = self.joint_doprf.eval_to_uint(comm, &[address_share])?[0];
+        let address_tag: u128 = self
+            .joint_doprf
+            .as_mut()
+            .unwrap()
+            .eval_to_uint(comm, &[address_share])?[0];
 
         // 2. Update tags read list
         self.address_tags_read.push(address_tag);
@@ -334,7 +346,11 @@ where
         let t_after_index_computation = Instant::now();
 
         // 4. Run p-OT.Access
-        value_share -= self.joint_pot.access(comm, garbled_index)?;
+        value_share -= self
+            .joint_pot
+            .as_ref()
+            .unwrap()
+            .access(comm, garbled_index)?;
 
         let t_after_pot_access = Instant::now();
 
@@ -367,7 +383,8 @@ where
         let fut_dpf_key_from_prev = comm.receive_previous()?;
         let fut_dpf_key_from_next = comm.receive_next()?;
 
-        let (_, stash_values_share, stash_old_values_share) = self.stash.get_stash_share();
+        let (_, stash_values_share, stash_old_values_share) =
+            self.stash.as_ref().unwrap().get_stash_share();
         assert_eq!(stash_values_share.len(), self.get_access_counter());
         assert_eq!(stash_old_values_share.len(), self.get_access_counter());
         assert_eq!(self.address_tags_read.len(), self.get_access_counter());
@@ -539,6 +556,53 @@ where
 
         let t_after_receiving_index_tags_mine = Instant::now();
 
+        // Initialize Stash instances
+        self.preprocessed_stash
+            .extend((0..number_epochs).map(|_| StashProtocol::new(self.party_id, self.stash_size)));
+        for stash in self
+            .preprocessed_stash
+            .iter_mut()
+            .skip(already_preprocessed)
+        {
+            stash.init(comm)?;
+        }
+
+        let t_after_init_stash = Instant::now();
+
+        // Initialize DOPRF instances
+        self.preprocessed_doprf
+            .extend((0..number_epochs).map(|_| JointDOPrf::new(self.prf_output_bitsize)));
+        for (doprf, lpk_prev) in self
+            .preprocessed_doprf
+            .iter_mut()
+            .skip(already_preprocessed)
+            .zip(
+                self.preprocessed_legendre_prf_key_prev
+                    .iter()
+                    .skip(already_preprocessed),
+            )
+        {
+            doprf.set_legendre_prf_key_prev(lpk_prev.clone());
+            doprf.init(comm)?;
+            doprf.preprocess(comm, self.stash_size)?;
+        }
+
+        let t_after_init_doprf = Instant::now();
+
+        // Precompute p-OTs and expand the mask
+        self.preprocessed_pot
+            .extend((0..number_epochs).map(|_| JointPOTParties::new(self.memory_size)));
+        for pot in self.preprocessed_pot.iter_mut().skip(already_preprocessed) {
+            pot.init(comm)?;
+        }
+        self.preprocessed_pot_expands.extend(
+            self.preprocessed_pot.make_contiguous()[already_preprocessed..]
+                .iter()
+                .map(|pot| pot.expand()),
+        );
+
+        let t_after_preprocess_pot = Instant::now();
+
         self.number_preprocessed_epochs += number_epochs;
 
         debug_assert_eq!(
@@ -569,6 +633,19 @@ where
             self.preprocessed_memory_index_tags_mine_sorted.len(),
             self.number_preprocessed_epochs
         );
+        debug_assert_eq!(
+            self.preprocessed_stash.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_doprf.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(self.preprocessed_pot.len(), self.number_preprocessed_epochs);
+        debug_assert_eq!(
+            self.preprocessed_pot_expands.len(),
+            self.number_preprocessed_epochs
+        );
 
         let runtimes = runtimes.map(|mut r| {
             r.record(
@@ -599,6 +676,18 @@ where
                 ProtocolStep::PreprocessRecvTagsMine,
                 t_after_receiving_index_tags_mine - t_after_mpdpf_precomp,
             );
+            r.record(
+                ProtocolStep::PreprocessStash,
+                t_after_init_stash - t_after_receiving_index_tags_mine,
+            );
+            r.record(
+                ProtocolStep::PreprocessDOPrf,
+                t_after_init_doprf - t_after_init_stash,
+            );
+            r.record(
+                ProtocolStep::PreprocessPOt,
+                t_after_preprocess_pot - t_after_init_doprf,
+            );
             r
         });
 
@@ -612,7 +701,7 @@ where
     ) -> Result<Option<Runtimes>, Error> {
         let t_start = Instant::now();
 
-        // -1. Do preprocessing if not already done
+        // 0. Do preprocessing if not already done
 
         let runtimes = if self.number_preprocessed_epochs == 0 {
             self.preprocess_with_runtimes(comm, 1, runtimes)?
@@ -622,44 +711,26 @@ where
 
         let t_after_jit_preprocessing = Instant::now();
 
-        // 0. Reset the functionalities
-        self.stash.reset();
-        self.joint_doprf.reset();
-        self.joint_pot.reset();
-
-        let t_after_reset = Instant::now();
-
-        // 1. Initialize the stash
-        self.stash.init(comm)?;
-
-        let t_after_init_stash = Instant::now();
-
-        // 2. Run r-DB init protocol
-        // a) Initialize DOPRF
-        {
-            let legendre_prf_key_prev =
-                self.preprocessed_legendre_prf_key_prev.pop_front().unwrap();
-            let legendre_prf_key_next =
-                self.preprocessed_legendre_prf_key_next.pop_front().unwrap();
-            self.joint_doprf
-                .set_legendre_prf_key_prev(legendre_prf_key_prev.clone());
-            self.joint_doprf.init(comm)?;
-            self.legendre_prf_key_prev = Some(legendre_prf_key_prev);
-            self.legendre_prf_key_next = Some(legendre_prf_key_next);
-
-            // preprocessing for stash_size number evaluations
-            self.joint_doprf.preprocess(comm, self.stash_size)?;
-        }
+        // 1. Expect to receive garbled memory share
+        let fut_garbled_memory_share = comm.receive_previous::<Vec<F>>()?;
 
-        let t_after_init_doprf = Instant::now();
+        // 2. Get fresh (initialized) instances of the functionalities
 
-        // b) Initialize p-OT
-        self.joint_pot.init(comm)?;
+        // a) Stash
+        self.stash = self.preprocessed_stash.pop_front();
+        debug_assert!(self.stash.is_some());
 
-        let t_after_init_pot = Instant::now();
+        // b) DOPRF
+        self.legendre_prf_key_prev = self.preprocessed_legendre_prf_key_prev.pop_front();
+        self.legendre_prf_key_next = self.preprocessed_legendre_prf_key_next.pop_front();
+        self.joint_doprf = self.preprocessed_doprf.pop_front();
+        debug_assert!(self.legendre_prf_key_prev.is_some());
+        debug_assert!(self.legendre_prf_key_next.is_some());
+        debug_assert!(self.joint_doprf.is_some());
 
-        // c) Expect to receive garbled memory share
-        let fut_garbled_memory_share = comm.receive_previous::<Vec<F>>()?;
+        // c) p-OT
+        self.joint_pot = Some(self.preprocessed_pot.pop_front().unwrap());
+        debug_assert!(self.joint_pot.is_some());
 
         // d) Retrieve preprocessed index tags
         self.memory_index_tags_prev = self
@@ -703,7 +774,7 @@ where
 
         let t_after_get_preprocessed_data = Instant::now();
 
-        // e) Garble the memory share for the next party
+        // 2.) Garble the memory share for the next party
         let mut garbled_memory_share_next: Vec<_> = self
             .memory_share
             .iter()
@@ -719,7 +790,7 @@ where
         // - pos_(i-1)(tag) -> index of tag in mem_idx_tags_prev
         // - pos_(i+1)(tag) -> index of tag in mem_idx_tags_next
 
-        let mask = self.joint_pot.expand();
+        let mask = self.preprocessed_pot_expands.pop_front().unwrap();
         self.memory_index_tags_next_sorted
             .par_iter()
             .zip(garbled_memory_share_next.par_iter_mut())
@@ -744,25 +815,9 @@ where
                 ProtocolStep::RefreshJitPreprocess,
                 t_after_jit_preprocessing - t_start,
             );
-            r.record(
-                ProtocolStep::RefreshResetFuncs,
-                t_after_reset - t_after_jit_preprocessing,
-            );
-            r.record(
-                ProtocolStep::RefreshInitStash,
-                t_after_init_stash - t_after_reset,
-            );
-            r.record(
-                ProtocolStep::RefreshInitDOPrf,
-                t_after_init_doprf - t_after_init_stash,
-            );
-            r.record(
-                ProtocolStep::RefreshInitPOt,
-                t_after_init_pot - t_after_init_doprf,
-            );
             r.record(
                 ProtocolStep::RefreshGetPreproc,
-                t_after_get_preprocessed_data - t_after_init_pot,
+                t_after_get_preprocessed_data - t_after_jit_preprocessing,
             );
             r.record(
                 ProtocolStep::RefreshSorting,
@@ -792,7 +847,7 @@ where
 
         // 1. Read from the stash
         let t_start = Instant::now();
-        let (stash_state, stash_runtimes) = self.stash.read_with_runtimes(
+        let (stash_state, stash_runtimes) = self.stash.as_mut().unwrap().read_with_runtimes(
             comm,
             instruction,
             runtimes.map(|r| r.get_stash_runtimes()),
@@ -818,7 +873,7 @@ where
         let t_after_db_read = Instant::now();
 
         // 4. Write the read value into the stash
-        let stash_runtime = self.stash.write_with_runtimes(
+        let stash_runtime = self.stash.as_mut().unwrap().write_with_runtimes(
             comm,
             instruction,
             stash_state,
@@ -971,6 +1026,10 @@ mod tests {
     use std::thread;
     use utils::field::Fp;
 
+    const PARTY_1: usize = 0;
+    const PARTY_2: usize = 1;
+    const PARTY_3: usize = 2;
+
     fn run_init<F, C, P>(
         mut doram_party: P,
         mut comm: C,