Bläddra i källkod

doprf: add variant with masked output

Lennart Braun 2 år sedan
förälder
incheckning
997a29630d
2 ändrade filer med 705 tillägg och 4 borttagningar
  1. 88 1
      oram/benches/doprf.rs
  2. 617 3
      oram/src/doprf.rs

+ 88 - 1
oram/benches/doprf.rs

@@ -2,6 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
 use ff::Field;
 use ff::Field;
 use oram::doprf::LegendrePrf;
 use oram::doprf::LegendrePrf;
 use oram::doprf::{DOPrfParty1, DOPrfParty2, DOPrfParty3};
 use oram::doprf::{DOPrfParty1, DOPrfParty2, DOPrfParty3};
+use oram::doprf::{MaskedDOPrfParty1, MaskedDOPrfParty2, MaskedDOPrfParty3};
 use rand::thread_rng;
 use rand::thread_rng;
 use utils::field::Fp;
 use utils::field::Fp;
 
 
@@ -105,9 +106,95 @@ pub fn bench_doprf(c: &mut Criterion) {
     group.finish();
     group.finish();
 }
 }
 
 
+pub fn bench_masked_doprf(c: &mut Criterion) {
+    let output_bitsize = 128;
+    let mut group = c.benchmark_group("MaskedDOPrf");
+
+    let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
+    let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
+    let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
+
+    group.bench_function("init", |b| {
+        b.iter(|| {
+            party_1.reset();
+            party_2.reset();
+            party_3.reset();
+            let (msg_1_2, msg_1_3) = party_1.init_round_0();
+            let (msg_2_1, msg_2_3) = party_2.init_round_0();
+            let (msg_3_1, msg_3_2) = party_3.init_round_0();
+            party_1.init_round_1(msg_2_1, msg_3_1);
+            party_2.init_round_1(msg_1_2, msg_3_2);
+            party_3.init_round_1(msg_1_3, msg_2_3);
+        });
+    });
+
+    {
+        party_1.reset();
+        party_2.reset();
+        party_3.reset();
+        let (msg_1_2, msg_1_3) = party_1.init_round_0();
+        let (msg_2_1, msg_2_3) = party_2.init_round_0();
+        let (msg_3_1, msg_3_2) = party_3.init_round_0();
+        party_1.init_round_1(msg_2_1, msg_3_1);
+        party_2.init_round_1(msg_1_2, msg_3_2);
+        party_3.init_round_1(msg_1_3, msg_2_3);
+    }
+
+    for log_num_evaluations in LOG_NUM_EVALUATIONS {
+        group.bench_with_input(
+            BenchmarkId::new("preprocess", log_num_evaluations),
+            &log_num_evaluations,
+            |b, &log_num_evaluations| {
+                let num = 1 << log_num_evaluations;
+                b.iter(|| {
+                    party_1.reset_preprocessing();
+                    party_2.reset_preprocessing();
+                    party_3.reset_preprocessing();
+                    let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
+                    let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
+                    let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
+                    party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
+                    party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
+                    party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+                });
+            },
+        );
+    }
+
+    for log_num_evaluations in LOG_NUM_EVALUATIONS {
+        group.bench_with_input(
+            BenchmarkId::new("preprocess+eval", log_num_evaluations),
+            &log_num_evaluations,
+            |b, &log_num_evaluations| {
+                let num = 1 << log_num_evaluations;
+                let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
+                let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
+                let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
+                b.iter(|| {
+                    let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
+                    let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
+                    let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
+                    party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
+                    party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
+                    party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+
+                    let (_, msg_1_3) = party_1.eval_round_0(num, &shares_1);
+                    let (_, msg_2_3) = party_2.eval_round_0(num, &shares_2);
+                    let (msg_3_1, _) = party_3.eval_round_1(num, &shares_3, &msg_1_3, &msg_2_3);
+                    let _masked_output = party_1.eval_round_2(num, &shares_1, (), msg_3_1);
+                    let _mask2 = party_2.eval_get_output(num);
+                    let _mask3 = party_3.eval_get_output(num);
+                });
+            },
+        );
+    }
+
+    group.finish();
+}
+
 criterion_group!(
 criterion_group!(
     name = benches;
     name = benches;
     config = Criterion::default().sample_size(10);
     config = Criterion::default().sample_size(10);
-    targets = bench_legendre_prf, bench_doprf
+    targets = bench_legendre_prf, bench_doprf, bench_masked_doprf
 );
 );
 criterion_main!(benches);
 criterion_main!(benches);

+ 617 - 3
oram/src/doprf.rs

@@ -1,7 +1,8 @@
-use bitvec::vec::BitVec;
+use bitvec::{slice::BitSlice, vec::BitVec};
 use core::marker::PhantomData;
 use core::marker::PhantomData;
 use itertools::izip;
 use itertools::izip;
-use rand::thread_rng;
+use rand::{thread_rng, Rng, RngCore, SeedableRng};
+use rand_chacha::ChaChaRng;
 use std::iter::repeat;
 use std::iter::repeat;
 use utils::field::{FromLimbs, FromPrf, LegendreSymbol, Modulus128};
 use utils::field::{FromLimbs, FromPrf, LegendreSymbol, Modulus128};
 
 
@@ -86,6 +87,7 @@ where
     F: LegendreSymbol + FromPrf,
     F: LegendreSymbol + FromPrf,
 {
 {
     pub fn new(output_bitsize: usize) -> Self {
     pub fn new(output_bitsize: usize) -> Self {
+        assert!(output_bitsize > 0);
         Self {
         Self {
             _phantom: PhantomData,
             _phantom: PhantomData,
             output_bitsize,
             output_bitsize,
@@ -213,6 +215,7 @@ where
     F: LegendreSymbol + FromPrf,
     F: LegendreSymbol + FromPrf,
 {
 {
     pub fn new(output_bitsize: usize) -> Self {
     pub fn new(output_bitsize: usize) -> Self {
+        assert!(output_bitsize > 0);
         Self {
         Self {
             _phantom: PhantomData,
             _phantom: PhantomData,
             output_bitsize,
             output_bitsize,
@@ -326,6 +329,7 @@ where
     F: LegendreSymbol + FromPrf + FromLimbs + Modulus128,
     F: LegendreSymbol + FromPrf + FromLimbs + Modulus128,
 {
 {
     pub fn new(output_bitsize: usize) -> Self {
     pub fn new(output_bitsize: usize) -> Self {
+        assert!(output_bitsize > 0);
         Self {
         Self {
             _phantom: PhantomData,
             _phantom: PhantomData,
             output_bitsize,
             output_bitsize,
@@ -456,7 +460,7 @@ where
                 let mut bv = BitVec::with_capacity(self.output_bitsize);
                 let mut bv = BitVec::with_capacity(self.output_bitsize);
                 for &x in chunk.iter() {
                 for &x in chunk.iter() {
                     let ls = F::legendre_symbol(x);
                     let ls = F::legendre_symbol(x);
-                    assert!(ls != F::ZERO, "unlikely");
+                    debug_assert!(ls != F::ZERO, "unlikely");
                     bv.push(ls == F::ONE);
                     bv.push(ls == F::ONE);
                 }
                 }
                 bv
                 bv
@@ -467,6 +471,494 @@ where
     }
     }
 }
 }
 
 
+pub struct MaskedDOPrfParty1<F: LegendreSymbol + FromPrf> {
+    _phantom: PhantomData<F>,
+    output_bitsize: usize,
+    shared_prf_1_2: Option<SharedPrf<F>>,
+    shared_prf_1_3: Option<SharedPrf<F>>,
+    legendre_prf_key: Option<LegendrePrfKey<F>>,
+    is_initialized: bool,
+    num_preprocessed_invocations: usize,
+    preprocessed_rerand_m1: Vec<F>,
+    preprocessed_mt_a: Vec<F>,
+    preprocessed_mt_c1: Vec<F>,
+    preprocessed_mult_e: Vec<F>,
+    mult_d: Vec<F>,
+}
+
+impl<F> MaskedDOPrfParty1<F>
+where
+    F: LegendreSymbol + FromPrf,
+{
+    pub fn new(output_bitsize: usize) -> Self {
+        assert!(output_bitsize > 0);
+        Self {
+            _phantom: PhantomData,
+            output_bitsize,
+            shared_prf_1_2: None,
+            shared_prf_1_3: None,
+            legendre_prf_key: None,
+            is_initialized: false,
+            num_preprocessed_invocations: 0,
+            preprocessed_rerand_m1: Default::default(),
+            preprocessed_mt_a: Default::default(),
+            preprocessed_mt_c1: Default::default(),
+            preprocessed_mult_e: Default::default(),
+            mult_d: Default::default(),
+        }
+    }
+
+    pub fn reset(&mut self) {
+        *self = Self::new(self.output_bitsize)
+    }
+
+    pub fn reset_preprocessing(&mut self) {
+        self.num_preprocessed_invocations = 0;
+        self.preprocessed_rerand_m1 = Default::default();
+        self.preprocessed_mt_a = Default::default();
+        self.preprocessed_mt_c1 = Default::default();
+        self.preprocessed_mult_e = Default::default();
+    }
+
+    pub fn init_round_0(&mut self) -> (F::PrfKey, ()) {
+        assert!(!self.is_initialized);
+        // sample and share a PRF key with Party 2
+        self.shared_prf_1_2 = Some(SharedPrf::key_gen());
+        (self.shared_prf_1_2.as_ref().unwrap().get_key(), ())
+    }
+
+    pub fn init_round_1(&mut self, _: (), shared_prf_key_1_3: F::PrfKey) {
+        assert!(!self.is_initialized);
+        // receive shared PRF key from Party 3
+        self.shared_prf_1_3 = Some(SharedPrf::from_key(shared_prf_key_1_3));
+        // generate Legendre PRF key
+        self.legendre_prf_key = Some(LegendrePrf::key_gen(self.output_bitsize));
+        self.is_initialized = true;
+    }
+
+    pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
+        assert!(self.is_initialized);
+        self.legendre_prf_key.as_ref().unwrap().clone()
+    }
+
+    pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
+        assert!(self.is_initialized);
+        let n = num * self.output_bitsize;
+        self.preprocessed_rerand_m1
+            .extend((0..num).map(|_| self.shared_prf_1_2.as_mut().unwrap().eval()));
+        self.preprocessed_mt_a
+            .extend((0..n).map(|_| self.shared_prf_1_2.as_mut().unwrap().eval()));
+        self.preprocessed_mt_c1
+            .extend((0..n).map(|_| self.shared_prf_1_2.as_mut().unwrap().eval()));
+        self.preprocessed_mult_e
+            .extend((0..n).map(|_| self.shared_prf_1_2.as_mut().unwrap().eval()));
+        ((), ())
+    }
+
+    pub fn preprocess_round_1(&mut self, num: usize, _: (), _: ()) {
+        assert!(self.is_initialized);
+        self.num_preprocessed_invocations += num;
+    }
+
+    pub fn get_num_preprocessed_invocations(&self) -> usize {
+        self.num_preprocessed_invocations
+    }
+
+    pub fn get_preprocessed_data(&self) -> (&[F], &[F], &[F], &[F]) {
+        (
+            &self.preprocessed_rerand_m1,
+            &self.preprocessed_mt_a,
+            &self.preprocessed_mt_c1,
+            &self.preprocessed_mult_e,
+        )
+    }
+
+    pub fn check_preprocessing(&self) {
+        let num = self.num_preprocessed_invocations;
+        let n = num * self.output_bitsize;
+        assert_eq!(self.preprocessed_rerand_m1.len(), num);
+        assert_eq!(self.preprocessed_mt_a.len(), n);
+        assert_eq!(self.preprocessed_mt_c1.len(), n);
+        assert_eq!(self.preprocessed_mult_e.len(), n);
+    }
+
+    pub fn eval_round_0(&mut self, num: usize, shares1: &[F]) -> ((), Vec<F>) {
+        assert!(num <= self.num_preprocessed_invocations);
+        assert_eq!(shares1.len(), num);
+        let n = num * self.output_bitsize;
+        let k = &self.legendre_prf_key.as_ref().unwrap().keys;
+        self.mult_d = izip!(
+            k.iter().cycle(),
+            shares1
+                .iter()
+                .flat_map(|s1| repeat(s1).take(self.output_bitsize)),
+            self.preprocessed_rerand_m1
+                .iter()
+                .take(num)
+                .flat_map(|m1| repeat(m1).take(self.output_bitsize)),
+            self.preprocessed_mt_a.drain(0..n),
+        )
+        .map(|(&k_i, &s1_i, m1_i, a_i)| k_i + s1_i + m1_i - a_i)
+        .collect();
+        assert_eq!(self.mult_d.len(), n);
+        ((), self.mult_d.clone())
+    }
+
+    pub fn eval_round_2(
+        &mut self,
+        num: usize,
+        shares1: &[F],
+        _: (),
+        output_shares_z3: Vec<F>,
+    ) -> Vec<BitVec> {
+        assert!(num <= self.num_preprocessed_invocations);
+        let n = num * self.output_bitsize;
+        assert_eq!(shares1.len(), num);
+        assert_eq!(output_shares_z3.len(), n);
+        let k = &self.legendre_prf_key.as_ref().unwrap().keys;
+        let lprf_inputs: Vec<F> = izip!(
+            k.iter().cycle(),
+            shares1
+                .iter()
+                .flat_map(|s1| repeat(s1).take(self.output_bitsize)),
+            self.preprocessed_rerand_m1
+                .drain(0..num)
+                .flat_map(|m1| repeat(m1).take(self.output_bitsize)),
+            self.preprocessed_mult_e.drain(0..n),
+            self.mult_d.drain(..),
+            self.preprocessed_mt_c1.drain(0..n),
+            output_shares_z3.iter(),
+        )
+        .map(|(&k_j, &s1_i, m1_i, e_ij, d_ij, c1_ij, &z3_ij)| {
+            e_ij * (k_j + s1_i + m1_i) + c1_ij + z3_ij - d_ij * e_ij
+        })
+        .collect();
+        assert_eq!(lprf_inputs.len(), n);
+        let output: Vec<BitVec> = lprf_inputs
+            .chunks_exact(self.output_bitsize)
+            .map(|chunk| {
+                let mut bv = BitVec::with_capacity(self.output_bitsize);
+                for &x in chunk.iter() {
+                    let ls = F::legendre_symbol(x);
+                    debug_assert!(ls != F::ZERO, "unlikely");
+                    bv.push(ls == F::ONE);
+                }
+                bv
+            })
+            .collect();
+        self.num_preprocessed_invocations -= num;
+        output
+    }
+}
+
+pub struct MaskedDOPrfParty2<F: LegendreSymbol + FromPrf> {
+    _phantom: PhantomData<F>,
+    output_bitsize: usize,
+    shared_prf_1_2: Option<SharedPrf<F>>,
+    shared_prf_2_3: Option<SharedPrf<F>>,
+    shared_prg_2_3: Option<ChaChaRng>,
+    is_initialized: bool,
+    num_preprocessed_invocations: usize,
+    preprocessed_rerand_m2: Vec<F>,
+    preprocessed_r: BitVec,
+}
+
+impl<F> MaskedDOPrfParty2<F>
+where
+    F: LegendreSymbol + FromPrf,
+{
+    pub fn new(output_bitsize: usize) -> Self {
+        assert!(output_bitsize > 0);
+        Self {
+            _phantom: PhantomData,
+            output_bitsize,
+            shared_prf_1_2: None,
+            shared_prf_2_3: None,
+            shared_prg_2_3: None,
+            is_initialized: false,
+            num_preprocessed_invocations: 0,
+            preprocessed_rerand_m2: Default::default(),
+            preprocessed_r: Default::default(),
+        }
+    }
+
+    pub fn reset(&mut self) {
+        *self = Self::new(self.output_bitsize)
+    }
+
+    pub fn reset_preprocessing(&mut self) {
+        self.num_preprocessed_invocations = 0;
+        self.preprocessed_rerand_m2 = Default::default();
+    }
+
+    pub fn init_round_0(&mut self) -> ((), (F::PrfKey, <ChaChaRng as SeedableRng>::Seed)) {
+        assert!(!self.is_initialized);
+        self.shared_prf_2_3 = Some(SharedPrf::key_gen());
+        self.shared_prg_2_3 = Some(ChaChaRng::from_seed(thread_rng().gen()));
+        (
+            (),
+            (
+                self.shared_prf_2_3.as_ref().unwrap().get_key(),
+                self.shared_prg_2_3.as_ref().unwrap().get_seed(),
+            ),
+        )
+    }
+
+    pub fn init_round_1(&mut self, shared_prf_key_1_2: F::PrfKey, _: ()) {
+        assert!(!self.is_initialized);
+        // receive shared PRF key from Party 1
+        self.shared_prf_1_2 = Some(SharedPrf::from_key(shared_prf_key_1_2));
+        self.is_initialized = true;
+    }
+
+    pub fn preprocess_round_0(&mut self, num: usize) -> ((), Vec<F>) {
+        assert!(self.is_initialized);
+        let n = num * self.output_bitsize;
+
+        let mut preprocessed_t: Vec<_> = (0..n)
+            .map(|_| self.shared_prf_2_3.as_mut().unwrap().eval().square())
+            .collect();
+        debug_assert!(!preprocessed_t.contains(&F::ZERO));
+        {
+            let mut random_bytes = vec![0u8; (n + 7) / 8];
+            self.shared_prg_2_3
+                .as_mut()
+                .unwrap()
+                .fill_bytes(&mut random_bytes);
+            let new_r_slice = BitSlice::<u8>::from_slice(&random_bytes);
+            self.preprocessed_r.extend(&new_r_slice[..n]);
+            for (i, r_i) in new_r_slice.iter().by_vals().take(n).enumerate() {
+                if r_i {
+                    preprocessed_t[i] *= F::get_non_random_qnr();
+                }
+            }
+        }
+        self.preprocessed_rerand_m2
+            .extend((0..num).map(|_| -self.shared_prf_1_2.as_mut().unwrap().eval()));
+        let preprocessed_mt_a: Vec<F> = (0..n)
+            .map(|_| self.shared_prf_1_2.as_mut().unwrap().eval())
+            .collect();
+        let preprocessed_mt_c1: Vec<F> = (0..n)
+            .map(|_| self.shared_prf_1_2.as_mut().unwrap().eval())
+            .collect();
+        let preprocessed_mult_e: Vec<F> = (0..n)
+            .map(|_| self.shared_prf_1_2.as_mut().unwrap().eval())
+            .collect();
+        let preprocessed_c3: Vec<F> = izip!(
+            preprocessed_t.iter(),
+            preprocessed_mult_e.iter(),
+            preprocessed_mt_a.iter(),
+            preprocessed_mt_c1.iter(),
+        )
+        .map(|(&t, &e, &a, &c1)| a * (t - e) - c1)
+        .collect();
+        self.num_preprocessed_invocations += num;
+        ((), preprocessed_c3)
+    }
+
+    pub fn preprocess_round_1(&mut self, _: usize, _: (), _: ()) {
+        assert!(self.is_initialized);
+    }
+
+    pub fn get_num_preprocessed_invocations(&self) -> usize {
+        self.num_preprocessed_invocations
+    }
+
+    pub fn get_preprocessed_data(&self) -> (&BitSlice, &[F]) {
+        (&self.preprocessed_r, &self.preprocessed_rerand_m2)
+    }
+
+    pub fn check_preprocessing(&self) {
+        let num = self.num_preprocessed_invocations;
+        assert_eq!(self.preprocessed_rerand_m2.len(), num);
+    }
+
+    pub fn eval_round_0(&mut self, num: usize, shares2: &[F]) -> ((), Vec<F>) {
+        assert!(num <= self.num_preprocessed_invocations);
+        assert_eq!(shares2.len(), num);
+        let masked_shares2: Vec<F> =
+            izip!(shares2.iter(), self.preprocessed_rerand_m2.drain(0..num),)
+                .map(|(&s2i, m2i)| s2i + m2i)
+                .collect();
+        assert_eq!(masked_shares2.len(), num);
+        ((), masked_shares2)
+    }
+
+    pub fn eval_get_output(&mut self, num: usize) -> Vec<BitVec> {
+        assert!(num <= self.num_preprocessed_invocations);
+        let n = num * self.output_bitsize;
+        let mut output = Vec::with_capacity(num);
+        for chunk in self
+            .preprocessed_r
+            .chunks_exact(self.output_bitsize)
+            .take(num)
+        {
+            output.push(chunk.to_bitvec());
+        }
+        let (_, last_r) = self.preprocessed_r.split_at(n);
+        self.preprocessed_r = last_r.to_bitvec();
+        self.num_preprocessed_invocations -= num;
+        output
+    }
+}
+
+pub struct MaskedDOPrfParty3<F: LegendreSymbol + FromPrf> {
+    _phantom: PhantomData<F>,
+    output_bitsize: usize,
+    shared_prf_1_3: Option<SharedPrf<F>>,
+    shared_prf_2_3: Option<SharedPrf<F>>,
+    shared_prg_2_3: Option<ChaChaRng>,
+    is_initialized: bool,
+    num_preprocessed_invocations: usize,
+    preprocessed_r: BitVec,
+    preprocessed_t: Vec<F>,
+    preprocessed_mt_c3: Vec<F>,
+}
+
+impl<F> MaskedDOPrfParty3<F>
+where
+    F: LegendreSymbol + FromPrf + FromLimbs + Modulus128,
+{
+    pub fn new(output_bitsize: usize) -> Self {
+        assert!(output_bitsize > 0);
+        Self {
+            _phantom: PhantomData,
+            output_bitsize,
+            shared_prf_1_3: None,
+            shared_prf_2_3: None,
+            shared_prg_2_3: None,
+            is_initialized: false,
+            num_preprocessed_invocations: 0,
+            preprocessed_r: Default::default(),
+            preprocessed_t: Default::default(),
+            preprocessed_mt_c3: Default::default(),
+        }
+    }
+
+    pub fn reset(&mut self) {
+        *self = Self::new(self.output_bitsize)
+    }
+
+    pub fn reset_preprocessing(&mut self) {
+        self.num_preprocessed_invocations = 0;
+        self.preprocessed_t = Default::default();
+        self.preprocessed_mt_c3 = Default::default();
+    }
+
+    pub fn init_round_0(&mut self) -> (F::PrfKey, ()) {
+        assert!(!self.is_initialized);
+        self.shared_prf_1_3 = Some(SharedPrf::key_gen());
+        (self.shared_prf_1_3.as_ref().unwrap().get_key(), ())
+    }
+
+    pub fn init_round_1(
+        &mut self,
+        _: (),
+        (shared_prf_key_2_3, shared_prg_seed_2_3): (F::PrfKey, <ChaChaRng as SeedableRng>::Seed),
+    ) {
+        self.shared_prf_2_3 = Some(SharedPrf::from_key(shared_prf_key_2_3));
+        self.shared_prg_2_3 = Some(ChaChaRng::from_seed(shared_prg_seed_2_3));
+        self.is_initialized = true;
+    }
+
+    pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
+        assert!(self.is_initialized);
+        let n = num * self.output_bitsize;
+        let start_index = self.num_preprocessed_invocations * self.output_bitsize;
+
+        self.preprocessed_t
+            .extend((0..n).map(|_| self.shared_prf_2_3.as_mut().unwrap().eval().square()));
+        debug_assert!(!self.preprocessed_t[start_index..].contains(&F::ZERO));
+        {
+            let mut random_bytes = vec![0u8; (n + 7) / 8];
+            self.shared_prg_2_3
+                .as_mut()
+                .unwrap()
+                .fill_bytes(&mut random_bytes);
+            let new_r_slice = BitSlice::<u8>::from_slice(&random_bytes);
+            self.preprocessed_r.extend(&new_r_slice[..n]);
+            for (i, r_i) in new_r_slice.iter().by_vals().take(n).enumerate() {
+                if r_i {
+                    self.preprocessed_t[start_index + i] *= F::get_non_random_qnr();
+                }
+            }
+        }
+        ((), ())
+    }
+
+    pub fn preprocess_round_1(&mut self, num: usize, _: (), preprocessed_mt_c3: Vec<F>) {
+        assert!(self.is_initialized);
+        let n = num * self.output_bitsize;
+        assert_eq!(preprocessed_mt_c3.len(), n);
+        self.preprocessed_mt_c3.extend(preprocessed_mt_c3);
+        self.num_preprocessed_invocations += num;
+    }
+
+    pub fn get_num_preprocessed_invocations(&self) -> usize {
+        self.num_preprocessed_invocations
+    }
+
+    pub fn get_preprocessed_data(&self) -> (&BitSlice, &[F], &[F]) {
+        (
+            &self.preprocessed_r,
+            &self.preprocessed_t,
+            &self.preprocessed_mt_c3,
+        )
+    }
+
+    pub fn check_preprocessing(&self) {
+        let num = self.num_preprocessed_invocations;
+        let n = num * self.output_bitsize;
+        assert_eq!(self.preprocessed_t.len(), n);
+        assert_eq!(self.preprocessed_mt_c3.len(), n);
+    }
+
+    pub fn eval_round_1(
+        &mut self,
+        num: usize,
+        shares3: &[F],
+        mult_d: &[F],
+        masked_shares2: &[F],
+    ) -> (Vec<F>, ()) {
+        assert!(num <= self.num_preprocessed_invocations);
+        let n = num * self.output_bitsize;
+        assert_eq!(shares3.len(), num);
+        assert_eq!(masked_shares2.len(), num);
+        assert_eq!(mult_d.len(), n);
+        let output_shares_z3: Vec<F> = izip!(
+            shares3
+                .iter()
+                .flat_map(|s1i| repeat(s1i).take(self.output_bitsize)),
+            masked_shares2
+                .iter()
+                .flat_map(|ms2i| repeat(ms2i).take(self.output_bitsize)),
+            self.preprocessed_t.drain(0..n),
+            self.preprocessed_mt_c3.drain(0..n),
+            mult_d,
+        )
+        .map(|(&s3_i, &ms2_i, t_ij, c3_ij, &d_ij)| t_ij * (s3_i + ms2_i) + d_ij * t_ij + c3_ij)
+        .collect();
+        (output_shares_z3, ())
+    }
+
+    pub fn eval_get_output(&mut self, num: usize) -> Vec<BitVec> {
+        assert!(num <= self.num_preprocessed_invocations);
+        let n = num * self.output_bitsize;
+        let mut output = Vec::with_capacity(num);
+        for chunk in self
+            .preprocessed_r
+            .chunks_exact(self.output_bitsize)
+            .take(num)
+        {
+            output.push(chunk.to_bitvec());
+        }
+        let (_, last_r) = self.preprocessed_r.split_at(n);
+        self.preprocessed_r = last_r.to_bitvec();
+        self.num_preprocessed_invocations -= num;
+        output
+    }
+}
+
 #[cfg(test)]
 #[cfg(test)]
 mod tests {
 mod tests {
     use super::*;
     use super::*;
@@ -592,4 +1084,126 @@ mod tests {
             assert_eq!(output[i], output_i);
             assert_eq!(output[i], output_i);
         }
         }
     }
     }
+
+    #[test]
+    fn test_masked_doprf() {
+        let output_bitsize = 42;
+
+        let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
+        let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
+        let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
+
+        let (msg_1_2, msg_1_3) = party_1.init_round_0();
+        let (msg_2_1, msg_2_3) = party_2.init_round_0();
+        let (msg_3_1, msg_3_2) = party_3.init_round_0();
+        party_1.init_round_1(msg_2_1, msg_3_1);
+        party_2.init_round_1(msg_1_2, msg_3_2);
+        party_3.init_round_1(msg_1_3, msg_2_3);
+
+        // preprocess num invocations
+        let num = 20;
+
+        let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
+        let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
+        let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
+        party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
+        party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
+        party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+
+        assert_eq!(party_1.get_num_preprocessed_invocations(), num);
+        assert_eq!(party_2.get_num_preprocessed_invocations(), num);
+        assert_eq!(party_3.get_num_preprocessed_invocations(), num);
+
+        party_1.check_preprocessing();
+        party_2.check_preprocessing();
+        party_3.check_preprocessing();
+
+        // preprocess another n invocations
+        let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
+        let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
+        let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
+        party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
+        party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
+        party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+
+        let num = 2 * num;
+
+        assert_eq!(party_1.get_num_preprocessed_invocations(), num);
+        assert_eq!(party_2.get_num_preprocessed_invocations(), num);
+        assert_eq!(party_3.get_num_preprocessed_invocations(), num);
+
+        party_1.check_preprocessing();
+        party_2.check_preprocessing();
+        party_3.check_preprocessing();
+
+        // verify preprocessed data
+        {
+            let n = num * output_bitsize;
+            let (rerand_m1, mt_a, mt_c1, mult_e) = party_1.get_preprocessed_data();
+            let (r2, rerand_m2) = party_2.get_preprocessed_data();
+            let (r3, ts, mt_c3) = party_3.get_preprocessed_data();
+
+            assert_eq!(r2.len(), n);
+            assert_eq!(r2, r3);
+            assert_eq!(ts.len(), n);
+            assert!(r2.iter().by_vals().zip(ts.iter()).all(|(r_i, &t_i)| {
+                if r_i {
+                    Fp::legendre_symbol(t_i) == -Fp::ONE
+                } else {
+                    Fp::legendre_symbol(t_i) == Fp::ONE
+                }
+            }));
+
+            assert_eq!(rerand_m1.len(), num);
+            assert_eq!(rerand_m2.len(), num);
+            assert!(izip!(rerand_m1.iter(), rerand_m2.iter()).all(|(&m1, &m2)| m1 + m2 == Fp::ZERO));
+
+            let mt_b: Vec<Fp> = ts.iter().zip(mult_e.iter()).map(|(&t, &e)| t - e).collect();
+            assert_eq!(mult_e.len(), n);
+
+            assert_eq!(mt_a.len(), n);
+            assert_eq!(mt_b.len(), n);
+            assert_eq!(mt_c1.len(), n);
+            assert_eq!(mt_c3.len(), n);
+            let mut triple_it = izip!(mt_a.iter(), mt_b.iter(), mt_c1.iter(), mt_c3.iter());
+            assert_eq!(triple_it.clone().count(), n);
+            assert!(triple_it.all(|(&a, &b, &c1, &c3)| a * b == c1 + c3));
+        }
+
+        // perform n evaluations
+        let num = 15;
+
+        let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
+        let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
+        let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
+
+        let (_, msg_1_3) = party_1.eval_round_0(num, &shares_1);
+        let (_, msg_2_3) = party_2.eval_round_0(num, &shares_2);
+        let (msg_3_1, _) = party_3.eval_round_1(num, &shares_3, &msg_1_3, &msg_2_3);
+        let masked_output = party_1.eval_round_2(num, &shares_1, (), msg_3_1);
+        let mask2 = party_2.eval_get_output(num);
+        let mask3 = party_3.eval_get_output(num);
+
+        assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
+        assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
+        assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
+        party_1.check_preprocessing();
+        party_2.check_preprocessing();
+        party_3.check_preprocessing();
+
+        assert_eq!(masked_output.len(), num);
+        assert!(masked_output.iter().all(|bv| bv.len() == output_bitsize));
+        assert_eq!(mask2.len(), num);
+        assert_eq!(mask2, mask3);
+        assert!(mask2.iter().all(|bv| bv.len() == output_bitsize));
+
+        // check that the output matches the non-distributed version
+        let legendre_prf_key = party_1.get_legendre_prf_key();
+        for i in 0..num {
+            let input_i = shares_1[i] + shares_2[i] + shares_3[i];
+            let expected_output_i = LegendrePrf::<Fp>::eval(&legendre_prf_key, input_i);
+            let output_i = masked_output[i].clone() ^ &mask2[i];
+            assert_eq!(output_i, expected_output_i);
+        }
+    }
 }
 }