Browse Source

doprf: add methods for simpler protocol usage

Lennart Braun 2 years ago
parent
commit
cf7d87f1b8
1 changed files with 229 additions and 0 deletions
  1. 229 0
      oram/src/doprf.rs

+ 229 - 0
oram/src/doprf.rs

@@ -1,4 +1,6 @@
+use crate::common::Error;
 use bitvec;
+use communicator::{AbstractCommunicator, Fut, Serializable};
 use core::marker::PhantomData;
 use itertools::izip;
 use rand::{thread_rng, Rng, RngCore, SeedableRng};
@@ -112,6 +114,14 @@ where
         self.is_initialized = true;
     }
 
+    pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        let fut_3_1 = comm.receive_previous()?;
+        let (msg_1_2, _) = self.init_round_0();
+        comm.send_next(msg_1_2)?;
+        self.init_round_1((), fut_3_1.get()?);
+        Ok(())
+    }
+
     pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
         assert!(self.legendre_prf_key.is_some());
         self.legendre_prf_key.as_ref().unwrap().clone()
@@ -133,6 +143,20 @@ where
         self.num_preprocessed_invocations += num;
     }
 
+    pub fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        let fut_2_1 = comm.receive_next()?;
+        self.preprocess_round_0(num);
+        self.preprocess_round_1(num, fut_2_1.get()?, ());
+        Ok(())
+    }
+
     pub fn get_num_preprocessed_invocations(&self) -> usize {
         self.num_preprocessed_invocations
     }
@@ -183,6 +207,23 @@ where
         self.num_preprocessed_invocations -= num;
         ((), output_shares_z1)
     }
+
+    pub fn eval<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares1: &[F],
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        assert_eq!(shares1.len(), num);
+        let fut_2_1 = comm.receive_next::<Vec<_>>()?;
+        let fut_3_1 = comm.receive_previous::<Vec<_>>()?;
+        let (_, msg_1_3) = self.eval_round_1(num, shares1, &fut_2_1.get()?, &fut_3_1.get()?);
+        comm.send_previous(msg_1_3)?;
+        Ok(())
+    }
 }
 
 pub struct DOPrfParty2<F: LegendreSymbol> {
@@ -234,6 +275,14 @@ where
         self.is_initialized = true;
     }
 
+    pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        let fut_1_2 = comm.receive_previous()?;
+        let (_, msg_2_3) = self.init_round_0();
+        comm.send_next(msg_2_3)?;
+        self.init_round_1(fut_1_2.get()?, ());
+        Ok(())
+    }
+
     pub fn preprocess_round_0(&mut self, num: usize) -> (Vec<F>, ()) {
         assert!(self.is_initialized);
         let n = num * self.output_bitsize;
@@ -270,6 +319,20 @@ where
         assert!(self.is_initialized);
     }
 
+    pub fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        let (msg_2_1, _) = self.preprocess_round_0(num);
+        comm.send_previous(msg_2_1)?;
+        self.preprocess_round_1(num, (), ());
+        Ok(())
+    }
+
     pub fn get_num_preprocessed_invocations(&self) -> usize {
         self.num_preprocessed_invocations
     }
@@ -293,6 +356,21 @@ where
         self.num_preprocessed_invocations -= num;
         (masked_shares2, ())
     }
+
+    pub fn eval<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares2: &[F],
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        assert_eq!(shares2.len(), num);
+        let (msg_2_1, _) = self.eval_round_0(1, shares2);
+        comm.send_previous(msg_2_1)?;
+        Ok(())
+    }
 }
 
 pub struct DOPrfParty3<F: LegendreSymbol> {
@@ -354,6 +432,14 @@ where
         self.is_initialized = true;
     }
 
+    pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        let fut_2_3 = comm.receive_previous()?;
+        let (msg_3_1, _) = self.init_round_0();
+        comm.send_next(msg_3_1)?;
+        self.init_round_1((), fut_2_3.get()?);
+        Ok(())
+    }
+
     pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
         assert!(self.is_initialized);
         let n = num * self.output_bitsize;
@@ -374,6 +460,19 @@ where
         self.num_preprocessed_invocations += num;
     }
 
+    pub fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        _comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        self.preprocess_round_0(num);
+        self.preprocess_round_1(num, (), ());
+        Ok(())
+    }
+
     pub fn get_num_preprocessed_invocations(&self) -> usize {
         self.num_preprocessed_invocations
     }
@@ -454,6 +553,23 @@ where
         self.num_preprocessed_invocations -= num;
         output
     }
+
+    pub fn eval<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares3: &[F],
+    ) -> Result<Vec<BitVec>, Error>
+    where
+        F: Serializable,
+    {
+        assert_eq!(shares3.len(), num);
+        let fut_1_3 = comm.receive_next()?;
+        let (msg_3_1, _) = self.eval_round_0(num, shares3);
+        comm.send_next(msg_3_1)?;
+        let output = self.eval_round_2(num, shares3, fut_1_3.get()?, ());
+        Ok(output)
+    }
 }
 
 pub struct MaskedDOPrfParty1<F: LegendreSymbol> {
@@ -529,6 +645,14 @@ where
         self.is_initialized = true;
     }
 
+    pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        let fut_3_1 = comm.receive_previous()?;
+        let (msg_1_2, _) = self.init_round_0();
+        comm.send_next(msg_1_2)?;
+        self.init_round_1((), fut_3_1.get()?);
+        Ok(())
+    }
+
     pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
         assert!(self.is_initialized);
         self.legendre_prf_key.as_ref().unwrap().clone()
@@ -553,6 +677,16 @@ where
         self.num_preprocessed_invocations += num;
     }
 
+    pub fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        _comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error> {
+        self.preprocess_round_0(num);
+        self.preprocess_round_1(num, (), ());
+        Ok(())
+    }
+
     pub fn get_num_preprocessed_invocations(&self) -> usize {
         self.num_preprocessed_invocations
     }
@@ -642,6 +776,23 @@ where
         self.num_preprocessed_invocations -= num;
         output
     }
+
+    pub fn eval<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares1: &[F],
+    ) -> Result<Vec<BitVec>, Error>
+    where
+        F: Serializable,
+    {
+        assert_eq!(shares1.len(), num);
+        let fut_3_1 = comm.receive_previous()?;
+        let (_, msg_1_3) = self.eval_round_0(num, shares1);
+        comm.send_previous(msg_1_3)?;
+        let output = self.eval_round_2(1, shares1, (), fut_3_1.get()?);
+        Ok(output)
+    }
 }
 
 pub struct MaskedDOPrfParty2<F: LegendreSymbol> {
@@ -695,6 +846,14 @@ where
         self.is_initialized = true;
     }
 
+    pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        let fut_1_2 = comm.receive_previous()?;
+        let (_, msg_2_3) = self.init_round_0();
+        comm.send_next(msg_2_3)?;
+        self.init_round_1(fut_1_2.get()?, ());
+        Ok(())
+    }
+
     pub fn preprocess_round_0(&mut self, num: usize) -> ((), Vec<F>) {
         assert!(self.is_initialized);
         let n = num * self.output_bitsize;
@@ -744,6 +903,20 @@ where
         assert!(self.is_initialized);
     }
 
+    pub fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        let (_, msg_2_3) = self.preprocess_round_0(num);
+        comm.send_next(msg_2_3)?;
+        self.preprocess_round_1(num, (), ());
+        Ok(())
+    }
+
     pub fn get_num_preprocessed_invocations(&self) -> usize {
         self.num_preprocessed_invocations
     }
@@ -784,6 +957,22 @@ where
         self.num_preprocessed_invocations -= num;
         output
     }
+
+    pub fn eval<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares2: &[F],
+    ) -> Result<Vec<BitVec>, Error>
+    where
+        F: Serializable,
+    {
+        assert_eq!(shares2.len(), num);
+        let (_, msg_2_3) = self.eval_round_0(num, shares2);
+        comm.send_next(msg_2_3)?;
+        let output = self.eval_get_output(num);
+        Ok(output)
+    }
 }
 
 pub struct MaskedDOPrfParty3<F: LegendreSymbol> {
@@ -838,6 +1027,14 @@ where
         self.is_initialized = true;
     }
 
+    pub fn init<C: AbstractCommunicator>(&mut self, comm: &mut C) -> Result<(), Error> {
+        let fut_2_3 = comm.receive_previous()?;
+        let (msg_3_1, _) = self.init_round_0();
+        comm.send_next(msg_3_1)?;
+        self.init_round_1((), fut_2_3.get()?);
+        Ok(())
+    }
+
     pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
         assert!(self.is_initialized);
         let n = num * self.output_bitsize;
@@ -871,6 +1068,20 @@ where
         self.num_preprocessed_invocations += num;
     }
 
+    pub fn preprocess<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+    ) -> Result<(), Error>
+    where
+        F: Serializable,
+    {
+        let fut_2_3 = comm.receive_previous()?;
+        self.preprocess_round_0(num);
+        self.preprocess_round_1(num, (), fut_2_3.get()?);
+        Ok(())
+    }
+
     pub fn get_num_preprocessed_invocations(&self) -> usize {
         self.num_preprocessed_invocations
     }
@@ -934,6 +1145,24 @@ where
         self.num_preprocessed_invocations -= num;
         output
     }
+
+    pub fn eval<C: AbstractCommunicator>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares3: &[F],
+    ) -> Result<Vec<BitVec>, Error>
+    where
+        F: Serializable,
+    {
+        assert_eq!(shares3.len(), num);
+        let fut_1_3 = comm.receive_next::<Vec<_>>()?;
+        let fut_2_3 = comm.receive_previous::<Vec<_>>()?;
+        let (msg_3_1, _) = self.eval_round_1(1, shares3, &fut_1_3.get()?, &fut_2_3.get()?);
+        comm.send_next(msg_3_1)?;
+        let output = self.eval_get_output(num);
+        Ok(output)
+    }
 }
 
 #[cfg(test)]