Browse Source

oram: implement preprocessing

Lennart Braun 2 years ago
parent
commit
962c0ba635
2 changed files with 348 additions and 68 deletions
  1. 32 3
      oram/examples/bench_doram.rs
  2. 316 65
      oram/src/oram.rs

+ 32 - 3
oram/examples/bench_doram.rs

@@ -24,6 +24,9 @@ struct Cli {
     /// Log2 of the database size, must be even
     #[arg(long, short = 's', value_parser = parse_log_db_size)]
     pub log_db_size: u32,
+    /// Use preprocessing
+    #[arg(long)]
+    pub preprocess: bool,
     /// Which address to listen on for incoming connections
     #[arg(long, short = 'l')]
     pub listen_host: String,
@@ -141,6 +144,19 @@ fn main() {
 
     let mut runtimes = Runtimes::default();
 
+    let d_preprocess = if cli.preprocess {
+        let t_start = Instant::now();
+
+        runtimes = doram
+            .preprocess_with_runtimes(&mut comm, 1, Some(runtimes))
+            .expect("preprocess failed")
+            .unwrap();
+
+        t_start.elapsed()
+    } else {
+        Default::default()
+    };
+
     let t_start = Instant::now();
     for (_i, inst) in instructions.iter().enumerate() {
         // println!("executing instruction #{i}: {inst:?}");
@@ -152,10 +168,23 @@ fn main() {
     }
     let d_accesses = Instant::now() - t_start;
 
-    println!("time init: {:.3} s", d_init.as_secs_f64());
-    println!("time accesses: {:.3} s", d_accesses.as_secs_f64());
+    println!("time init:        {:7.3} s", d_init.as_secs_f64());
+    println!("time preprocess:  {:7.3} s", d_preprocess.as_secs_f64());
+    println!(
+        "   per accesses: {:7.3} s",
+        d_preprocess.as_secs_f64() / stash_size as f64
+    );
+    println!(
+        "time accesses:    {:7.3} s{}",
+        d_accesses.as_secs_f64(),
+        if cli.preprocess {
+            "  (online only)"
+        } else {
+            ""
+        }
+    );
     println!(
-        "time per accesses: {:.3} s",
+        "   per accesses: {:7.3} s",
         d_accesses.as_secs_f64() / stash_size as f64
     );
 

+ 316 - 65
oram/src/oram.rs

@@ -10,6 +10,7 @@ use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
 use ff::PrimeField;
 use itertools::{izip, Itertools};
 use rand::thread_rng;
+use std::collections::VecDeque;
 use std::iter::repeat;
 use std::marker::PhantomData;
 use std::time::{Duration, Instant};
@@ -27,6 +28,12 @@ where
 
     fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error>;
 
+    fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        number_epochs: usize,
+    ) -> Result<(), Error>;
+
     fn access<C: AbstractCommunicator>(
         &mut self,
         comm: &mut C,
@@ -50,7 +57,13 @@ fn compute_oram_prf_output_bitsize(memory_size: usize) -> usize {
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter, strum_macros::Display)]
 pub enum ProtocolStep {
-    AccessStashRead = 0,
+    Preprocess = 0,
+    PreprocessLPRFKeyGenPrev,
+    PreprocessLPRFEvalSortPrev,
+    PreprocessLPRFKeyRecvNext,
+    PreprocessLPRFEvalSortNext,
+    PreprocessRecvTagsMine,
+    AccessStashRead,
     AccessAddressSelection,
     AccessDatabaseRead,
     AccessStashWrite,
@@ -62,11 +75,12 @@ pub enum ProtocolStep {
     DbWriteMpDpfKeyExchange,
     DbWriteMpDpfEvaluations,
     DbWriteUpdateMemory,
+    RefreshJitPreprocess,
     RefreshResetFuncs,
     RefreshInitStash,
     RefreshInitDOPrf,
     RefreshInitPOt,
-    RefreshPrfEvaluations,
+    RefreshGetPreproc,
     RefreshSorting,
     RefreshPOtExpandMasking,
     RefreshReceivingShare,
@@ -74,7 +88,7 @@ pub enum ProtocolStep {
 
 #[derive(Debug, Default, Clone, Copy)]
 pub struct Runtimes {
-    durations: [Duration; 20],
+    durations: [Duration; 27],
     stash_runtimes: StashRuntimes,
 }
 
@@ -97,11 +111,28 @@ impl Runtimes {
     }
 
     pub fn print(&self, party_id: usize, num_accesses: usize) {
-        println!("=============== Party {} ===============", party_id);
+        println!(
+            "==================== Party {} ====================",
+            party_id
+        );
         println!("- times per access over {num_accesses} accesses in total");
+        println!(
+            "{:30}    {:.5} s",
+            ProtocolStep::Preprocess,
+            self.get(ProtocolStep::Preprocess).as_secs_f64() / num_accesses as f64
+        );
+        for step in ProtocolStep::iter()
+            .filter(|x| x.to_string().starts_with("Preprocess") && *x != ProtocolStep::Preprocess)
+        {
+            println!(
+                "    {:26}    {:.5} s",
+                step,
+                self.get(step).as_secs_f64() / num_accesses as f64
+            );
+        }
         for step in ProtocolStep::iter().filter(|x| x.to_string().starts_with("Access")) {
             println!(
-                "{:28}    {:.5} s",
+                "{:30}    {:.5} s",
                 step,
                 self.get(step).as_secs_f64() / num_accesses as f64
             );
@@ -110,7 +141,7 @@ impl Runtimes {
                     for step in ProtocolStep::iter().filter(|x| x.to_string().starts_with("DbRead"))
                     {
                         println!(
-                            "    {:24}    {:.5} s",
+                            "    {:26}    {:.5} s",
                             step,
                             self.get(step).as_secs_f64() / num_accesses as f64
                         );
@@ -121,7 +152,7 @@ impl Runtimes {
                         x.to_string().starts_with("DbWrite") || x.to_string().starts_with("Refresh")
                     }) {
                         println!(
-                            "    {:24}    {:.5} s",
+                            "    {:26}    {:.5} s",
                             step,
                             self.get(step).as_secs_f64() / num_accesses as f64
                         );
@@ -132,7 +163,7 @@ impl Runtimes {
                         StashProtocolStep::iter().filter(|x| x.to_string().starts_with("Read"))
                     {
                         println!(
-                            "    {:24}    {:.5} s",
+                            "    {:26}    {:.5} s",
                             step,
                             self.stash_runtimes.get(step).as_secs_f64() / num_accesses as f64
                         );
@@ -143,7 +174,7 @@ impl Runtimes {
                         StashProtocolStep::iter().filter(|x| x.to_string().starts_with("Write"))
                     {
                         println!(
-                            "    {:24}    {:.5} s",
+                            "    {:26}    {:.5} s",
                             step,
                             self.stash_runtimes.get(step).as_secs_f64() / num_accesses as f64
                         );
@@ -151,10 +182,8 @@ impl Runtimes {
                 }
                 _ => {}
             }
-            // if step.to_string().starts_with("") {
-            // }
         }
-        println!("========================================");
+        println!("==================================================");
     }
 }
 
@@ -172,11 +201,21 @@ where
     stash_size: usize,
     memory_size: usize,
     memory_share: Vec<F>,
+    prf_output_bitsize: usize,
+    number_preprocessed_epochs: usize,
+    preprocessed_legendre_prf_key_next: VecDeque<LegendrePrfKey<F>>,
+    preprocessed_legendre_prf_key_prev: VecDeque<LegendrePrfKey<F>>,
+    preprocessed_memory_index_tags_prev: VecDeque<Vec<u128>>,
+    preprocessed_memory_index_tags_next: VecDeque<Vec<u128>>,
+    preprocessed_memory_index_tags_mine_sorted: VecDeque<Vec<u128>>,
+    preprocessed_memory_index_tags_prev_sorted: VecDeque<Vec<u128>>,
+    preprocessed_memory_index_tags_next_sorted: VecDeque<Vec<u128>>,
     memory_index_tags_prev: Vec<u128>,
     memory_index_tags_next: Vec<u128>,
     memory_index_tags_prev_sorted: Vec<u128>,
     memory_index_tags_next_sorted: Vec<u128>,
-    garbled_memory_share: Vec<(u128, F)>,
+    memory_index_tags_mine_sorted: Vec<u128>,
+    garbled_memory_share: Vec<F>,
     is_initialized: bool,
     address_tags_read: Vec<u128>,
     stash: StashProtocol<F, SPDPF>,
@@ -211,10 +250,20 @@ where
             stash_size,
             memory_size,
             memory_share: Default::default(),
+            number_preprocessed_epochs: 0,
+            prf_output_bitsize,
+            preprocessed_legendre_prf_key_next: Default::default(),
+            preprocessed_legendre_prf_key_prev: Default::default(),
+            preprocessed_memory_index_tags_prev: Default::default(),
+            preprocessed_memory_index_tags_next: Default::default(),
+            preprocessed_memory_index_tags_mine_sorted: Default::default(),
+            preprocessed_memory_index_tags_prev_sorted: Default::default(),
+            preprocessed_memory_index_tags_next_sorted: Default::default(),
             memory_index_tags_prev: Default::default(),
             memory_index_tags_next: Default::default(),
             memory_index_tags_prev_sorted: Default::default(),
             memory_index_tags_next_sorted: Default::default(),
+            memory_index_tags_mine_sorted: Default::default(),
             garbled_memory_share: Default::default(),
             is_initialized: false,
             address_tags_read: Default::default(),
@@ -252,9 +301,9 @@ where
     }
 
     fn pos_mine(&self, tag: u128) -> usize {
-        debug_assert_eq!(self.garbled_memory_share.len(), self.memory_size);
-        self.garbled_memory_share
-            .binary_search_by_key(&tag, |x| x.0)
+        debug_assert_eq!(self.memory_index_tags_mine_sorted.len(), self.memory_size);
+        self.memory_index_tags_mine_sorted
+            .binary_search(&tag)
             .expect("tag not found")
     }
 
@@ -278,7 +327,7 @@ where
 
         // 3. Compute index in garbled memory and retrieve share
         let garbled_index = self.pos_mine(address_tag);
-        value_share += self.garbled_memory_share[garbled_index].1;
+        value_share += self.garbled_memory_share[garbled_index];
 
         let t_after_index_computation = Instant::now();
 
@@ -388,6 +437,166 @@ where
         Ok(runtimes)
     }
 
+    pub fn preprocess_with_runtimes<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        number_epochs: usize,
+        runtimes: Option<Runtimes>,
+    ) -> Result<Option<Runtimes>, Error> {
+        let already_preprocessed = self.number_preprocessed_epochs;
+
+        // Reserve some space
+        self.preprocessed_legendre_prf_key_prev
+            .reserve(number_epochs);
+        self.preprocessed_legendre_prf_key_next
+            .reserve(number_epochs);
+        self.preprocessed_memory_index_tags_prev
+            .reserve(number_epochs);
+        self.preprocessed_memory_index_tags_next
+            .reserve(number_epochs);
+        self.preprocessed_memory_index_tags_prev_sorted
+            .reserve(number_epochs);
+        self.preprocessed_memory_index_tags_next_sorted
+            .reserve(number_epochs);
+        self.preprocessed_memory_index_tags_mine_sorted
+            .reserve(number_epochs);
+
+        let t_start = Instant::now();
+
+        // Generate Legendre PRF keys
+        let fut_lpks_next = comm.receive_previous::<Vec<LegendrePrfKey<F>>>()?;
+        let fut_tags_mine_sorted = comm.receive_previous::<Vec<Vec<u128>>>()?;
+        self.preprocessed_legendre_prf_key_prev
+            .extend((0..number_epochs).map(|_| LegendrePrf::key_gen(self.prf_output_bitsize)));
+        let new_lpks_prev =
+            &self.preprocessed_legendre_prf_key_prev.make_contiguous()[already_preprocessed..];
+        comm.send_next(new_lpks_prev.to_vec())?;
+
+        let t_after_gen_lpks_prev = Instant::now();
+
+        // Compute memory index tags
+        for lpk_prev in new_lpks_prev {
+            let mut memory_index_tags_prev = Vec::with_capacity(self.memory_size);
+            memory_index_tags_prev
+                .extend((0..self.memory_size).map(|j| {
+                    LegendrePrf::eval_to_uint::<u128>(&lpk_prev, F::from_u128(j as u128))
+                }));
+            let memory_index_tags_prev_sorted: Vec<_> = memory_index_tags_prev
+                .iter()
+                .copied()
+                .sorted_unstable()
+                .collect();
+            self.preprocessed_memory_index_tags_prev
+                .push_back(memory_index_tags_prev);
+            self.preprocessed_memory_index_tags_prev_sorted
+                .push_back(memory_index_tags_prev_sorted);
+        }
+
+        let t_after_computing_index_tags_prev = Instant::now();
+
+        self.preprocessed_legendre_prf_key_next
+            .extend(fut_lpks_next.get()?.into_iter());
+        let new_lpks_next =
+            &self.preprocessed_legendre_prf_key_next.make_contiguous()[already_preprocessed..];
+
+        let t_after_receiving_lpks_next = Instant::now();
+
+        for lpk_next in new_lpks_next {
+            let mut memory_index_tags_next = Vec::with_capacity(self.memory_size);
+            memory_index_tags_next
+                .extend((0..self.memory_size).map(|j| {
+                    LegendrePrf::eval_to_uint::<u128>(&lpk_next, F::from_u128(j as u128))
+                }));
+            let memory_index_tags_next_with_index_sorted: Vec<_> = memory_index_tags_next
+                .iter()
+                .copied()
+                .enumerate()
+                .sorted_unstable_by_key(|(_, x)| *x)
+                .collect();
+            self.preprocessed_memory_index_tags_next
+                .push_back(memory_index_tags_next);
+            self.preprocessed_memory_index_tags_next_sorted.push_back(
+                memory_index_tags_next_with_index_sorted
+                    .iter()
+                    .map(|(_, x)| *x)
+                    .collect(),
+            );
+        }
+        comm.send_next(
+            self.preprocessed_memory_index_tags_next_sorted
+                .make_contiguous()[already_preprocessed..]
+                .to_vec(),
+        )?;
+
+        let t_after_computing_index_tags_next = Instant::now();
+
+        self.preprocessed_memory_index_tags_mine_sorted
+            .extend(fut_tags_mine_sorted.get()?);
+
+        let t_after_receiving_index_tags_mine = Instant::now();
+
+        self.number_preprocessed_epochs += number_epochs;
+
+        debug_assert_eq!(
+            self.preprocessed_legendre_prf_key_prev.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_legendre_prf_key_next.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_memory_index_tags_prev.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_memory_index_tags_prev_sorted.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_memory_index_tags_next.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_memory_index_tags_next_sorted.len(),
+            self.number_preprocessed_epochs
+        );
+        debug_assert_eq!(
+            self.preprocessed_memory_index_tags_mine_sorted.len(),
+            self.number_preprocessed_epochs
+        );
+
+        let runtimes = runtimes.map(|mut r| {
+            r.record(
+                ProtocolStep::Preprocess,
+                t_after_receiving_index_tags_mine - t_start,
+            );
+            r.record(
+                ProtocolStep::PreprocessLPRFKeyGenPrev,
+                t_after_gen_lpks_prev - t_start,
+            );
+            r.record(
+                ProtocolStep::PreprocessLPRFEvalSortPrev,
+                t_after_computing_index_tags_prev - t_after_gen_lpks_prev,
+            );
+            r.record(
+                ProtocolStep::PreprocessLPRFKeyRecvNext,
+                t_after_receiving_lpks_next - t_after_computing_index_tags_prev,
+            );
+            r.record(
+                ProtocolStep::PreprocessLPRFEvalSortNext,
+                t_after_computing_index_tags_next - t_after_receiving_lpks_next,
+            );
+            r.record(
+                ProtocolStep::PreprocessRecvTagsMine,
+                t_after_receiving_index_tags_mine - t_after_computing_index_tags_next,
+            );
+            r
+        });
+
+        Ok(runtimes)
+    }
+
     fn refresh<C: AbstractCommunicator>(
         &mut self,
         comm: &mut C,
@@ -395,6 +604,16 @@ where
     ) -> Result<Option<Runtimes>, Error> {
         let t_start = Instant::now();
 
+        // -1. Do preprocessing if not already done
+
+        let runtimes = if self.number_preprocessed_epochs == 0 {
+            self.preprocess_with_runtimes(comm, 1, runtimes)?
+        } else {
+            runtimes
+        };
+
+        let t_after_jit_preprocessing = Instant::now();
+
         // 0. Reset the functionalities
         self.stash.reset();
         self.joint_doprf.reset();
@@ -412,11 +631,15 @@ where
         // 2. Run r-DB init protocol
         // a) Initialize DOPRF
         {
+            let legendre_prf_key_prev =
+                self.preprocessed_legendre_prf_key_prev.pop_front().unwrap();
+            let legendre_prf_key_next =
+                self.preprocessed_legendre_prf_key_next.pop_front().unwrap();
+            self.joint_doprf
+                .set_legendre_prf_key_prev(legendre_prf_key_prev.clone());
             self.joint_doprf.init(comm)?;
-            let fut_lpk_next = comm.receive_previous::<LegendrePrfKey<F>>()?;
-            comm.send_next(self.joint_doprf.get_legendre_prf_key_prev())?;
-            self.legendre_prf_key_prev = Some(self.joint_doprf.get_legendre_prf_key_prev());
-            self.legendre_prf_key_next = Some(fut_lpk_next.get()?);
+            self.legendre_prf_key_prev = Some(legendre_prf_key_prev);
+            self.legendre_prf_key_next = Some(legendre_prf_key_next);
 
             // preprocessing for stash_size number evaluations
             self.joint_doprf.preprocess(comm, self.stash_size)?;
@@ -448,56 +671,60 @@ where
 
         let t_after_init_pot = Instant::now();
 
-        // c) Compute index tags and garble the memory share for the next party
-        let fut_garbled_memory_share = comm.receive_previous()?;
-        self.memory_index_tags_prev = Vec::with_capacity(self.memory_size);
-        self.memory_index_tags_prev
-            .extend((0..self.memory_size).map(|j| {
-                LegendrePrf::eval_to_uint::<u128>(
-                    &self.legendre_prf_key_prev.as_ref().unwrap(),
-                    F::from_u128(j as u128),
-                )
-            }));
-        let mut garbled_memory_share_next: Vec<_> = self
-            .memory_share
-            .iter()
-            .copied()
-            .enumerate()
-            .map(|(j, x)| {
-                (
-                    LegendrePrf::eval_to_uint::<u128>(
-                        &self.legendre_prf_key_next.as_ref().unwrap(),
-                        F::from_u128(j as u128),
-                    ),
-                    x,
-                )
-            })
-            .collect();
-
-        let t_after_prf = Instant::now();
+        // c) Expect to receive garbled memory share
+        let fut_garbled_memory_share = comm.receive_previous::<Vec<F>>()?;
 
+        // d) Retrieve preprocessed index tags
+        self.memory_index_tags_prev = self
+            .preprocessed_memory_index_tags_prev
+            .pop_front()
+            .unwrap();
         self.memory_index_tags_prev_sorted = self
-            .memory_index_tags_prev
-            .iter()
-            .copied()
-            .sorted_unstable()
-            .collect();
+            .preprocessed_memory_index_tags_prev_sorted
+            .pop_front()
+            .unwrap();
+        self.memory_index_tags_next = self
+            .preprocessed_memory_index_tags_next
+            .pop_front()
+            .unwrap();
+        self.memory_index_tags_next_sorted = self
+            .preprocessed_memory_index_tags_next_sorted
+            .pop_front()
+            .unwrap();
+        self.memory_index_tags_mine_sorted = self
+            .preprocessed_memory_index_tags_mine_sorted
+            .pop_front()
+            .unwrap();
         debug_assert!(
             self.memory_index_tags_prev_sorted
                 .windows(2)
                 .all(|w| w[0] < w[1]),
             "index tags not sorted or colliding"
         );
-        self.memory_index_tags_next = garbled_memory_share_next.iter().map(|x| x.0).collect();
-        garbled_memory_share_next.sort_unstable_by_key(|x| x.0);
-        self.memory_index_tags_next_sorted =
-            garbled_memory_share_next.iter().map(|x| x.0).collect();
         debug_assert!(
             self.memory_index_tags_next_sorted
                 .windows(2)
                 .all(|w| w[0] < w[1]),
             "index tags not sorted or colliding"
         );
+        debug_assert!(
+            self.memory_index_tags_mine_sorted
+                .windows(2)
+                .all(|w| w[0] < w[1]),
+            "index tags not sorted or colliding"
+        );
+
+        let t_after_get_preprocessed_data = Instant::now();
+
+        // e) Garble the memory share for the next party
+        let mut garbled_memory_share_next: Vec<_> = self
+            .memory_share
+            .iter()
+            .copied()
+            .zip(self.memory_index_tags_next.iter().copied())
+            .sorted_unstable_by_key(|(_, i)| *i)
+            .map(|(x, _)| x)
+            .collect();
 
         let t_after_sort = Instant::now();
 
@@ -506,10 +733,12 @@ where
         // - pos_(i+1)(tag) -> index of tag in mem_idx_tags_next
 
         let mask = self.pot_key_party.expand();
-        for j in 0..self.memory_size {
-            let (tag, val) = garbled_memory_share_next[j];
-            let masked_val = val + mask[self.pos_next(tag)];
-            garbled_memory_share_next[j] = (tag, masked_val);
+        for (&tag, val) in self
+            .memory_index_tags_next_sorted
+            .iter()
+            .zip(garbled_memory_share_next.iter_mut())
+        {
+            *val += mask[self.pos_next(tag)];
         }
         comm.send_next(garbled_memory_share_next)?;
 
@@ -521,8 +750,18 @@ where
 
         let t_after_receiving = Instant::now();
 
+        // account that we used one set of preprocessing material
+        self.number_preprocessed_epochs -= 1;
+
         let runtimes = runtimes.map(|mut r| {
-            r.record(ProtocolStep::RefreshResetFuncs, t_after_reset - t_start);
+            r.record(
+                ProtocolStep::RefreshJitPreprocess,
+                t_after_jit_preprocessing - t_start,
+            );
+            r.record(
+                ProtocolStep::RefreshResetFuncs,
+                t_after_reset - t_after_jit_preprocessing,
+            );
             r.record(
                 ProtocolStep::RefreshInitStash,
                 t_after_init_stash - t_after_reset,
@@ -536,10 +775,13 @@ where
                 t_after_init_pot - t_after_init_doprf,
             );
             r.record(
-                ProtocolStep::RefreshPrfEvaluations,
-                t_after_prf - t_after_init_pot,
+                ProtocolStep::RefreshGetPreproc,
+                t_after_get_preprocessed_data - t_after_init_pot,
+            );
+            r.record(
+                ProtocolStep::RefreshSorting,
+                t_after_sort - t_after_get_preprocessed_data,
             );
-            r.record(ProtocolStep::RefreshSorting, t_after_sort - t_after_prf);
             r.record(
                 ProtocolStep::RefreshPOtExpandMasking,
                 t_after_pot_expand - t_after_sort,
@@ -678,6 +920,15 @@ where
         Ok(())
     }
 
+    fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        number_epochs: usize,
+    ) -> Result<(), Error> {
+        self.preprocess_with_runtimes(comm, number_epochs, None)
+            .map(|_| ())
+    }
+
     fn access<C: AbstractCommunicator>(
         &mut self,
         comm: &mut C,