소스 검색

oram: add rerandomize output option

Lennart Braun 2 년 전
부모
커밋
779df9ddeb
1개의 변경된 파일34개의 추가작업 그리고 4개의 파일을 삭제
  1. 34 4
      oram/src/oram.rs

+ 34 - 4
oram/src/oram.rs

@@ -7,6 +7,7 @@ use communicator::{AbstractCommunicator, Fut, Serializable};
 use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
 use ff::PrimeField;
 use itertools::{izip, Itertools};
+use rand::thread_rng;
 use std::iter::repeat;
 use std::marker::PhantomData;
 use utils::field::{FromPrf, LegendreSymbol};
@@ -28,7 +29,11 @@ where
         instruction: InstructionShare<F>,
     ) -> Result<F, Error>;
 
-    fn get_db<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<Vec<F>, Error>;
+    fn get_db<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        rerandomize_shares: bool,
+    ) -> Result<Vec<F>, Error>;
 }
 
 const PARTY_1: usize = 0;
@@ -459,13 +464,38 @@ where
         Ok(read_value)
     }
 
-    fn get_db<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<Vec<F>, Error> {
+    fn get_db<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        rerandomize_shares: bool,
+    ) -> Result<Vec<F>, Error> {
         assert!(self.is_initialized);
 
         if self.get_access_counter() > 0 {
             self.refresh(comm)?;
         }
-        return Ok(self.memory_share[0..1 << self.log_db_size].to_vec());
+
+        if rerandomize_shares {
+            let fut = comm.receive_previous()?;
+            let mut rng = thread_rng();
+            let mask: Vec<_> = (0..1 << self.log_db_size)
+                .map(|_| F::random(&mut rng))
+                .collect();
+            let mut masked_share: Vec<_> = self.memory_share[0..1 << self.log_db_size]
+                .iter()
+                .zip(mask.iter())
+                .map(|(&x, &m)| x + m)
+                .collect();
+            comm.send_next(mask)?;
+            let mask_prev: Vec<F> = fut.get()?;
+            masked_share
+                .iter_mut()
+                .zip(mask_prev.iter())
+                .for_each(|(x, &mp)| *x -= mp);
+            Ok(masked_share)
+        } else {
+            Ok(self.memory_share[0..1 << self.log_db_size].to_vec())
+        }
     }
 }
 
@@ -528,7 +558,7 @@ mod tests {
         thread::Builder::new()
             .name(format!("Party {}", doram_party.get_party_id()))
             .spawn(move || {
-                let output = doram_party.get_db(&mut comm).unwrap();
+                let output = doram_party.get_db(&mut comm, false).unwrap();
                 (doram_party, comm, output)
             })
             .unwrap()