Ver código fonte

doprf: only P1 has key, output bitstring

Lennart Braun 2 anos atrás
pai
commit
c0f521357c
3 arquivos alterados com 74 adições e 95 exclusões
  1. 1 0
      oram/Cargo.toml
  2. 7 5
      oram/benches/doprf.rs
  3. 66 90
      oram/src/doprf.rs

+ 1 - 0
oram/Cargo.toml

@@ -7,6 +7,7 @@ edition = "2021"
 
 [dependencies]
 utils = { path = "../utils" }
+bitvec = "1.0.1"
 ff = "0.13.0"
 itertools = "0.10.5"
 num-bigint = "0.4.3"

+ 7 - 5
oram/benches/doprf.rs

@@ -6,12 +6,13 @@ use rand::thread_rng;
 use utils::field::Fp;
 
 pub fn bench_legendre_prf(c: &mut Criterion) {
+    let output_bitsize = 128;
     let mut group = c.benchmark_group("LegendrePrf");
     group.bench_function("keygen", |b| {
-        b.iter(|| black_box(LegendrePrf::<Fp>::key_gen()))
+        b.iter(|| black_box(LegendrePrf::<Fp>::key_gen(output_bitsize)))
     });
     group.bench_function("eval", |b| {
-        let key = LegendrePrf::<Fp>::key_gen();
+        let key = LegendrePrf::<Fp>::key_gen(output_bitsize);
         let x = Fp::random(thread_rng());
         b.iter(|| black_box(LegendrePrf::<Fp>::eval(&key, x)))
     });
@@ -21,11 +22,12 @@ pub fn bench_legendre_prf(c: &mut Criterion) {
 const LOG_NUM_EVALUATIONS: [usize; 4] = [4, 6, 8, 10];
 
 pub fn bench_doprf(c: &mut Criterion) {
+    let output_bitsize = 128;
     let mut group = c.benchmark_group("DOPrf");
 
-    let mut party_1 = DOPrfParty1::<Fp>::new();
-    let mut party_2 = DOPrfParty2::<Fp>::new();
-    let mut party_3 = DOPrfParty3::<Fp>::new();
+    let mut party_1 = DOPrfParty1::<Fp>::new(output_bitsize);
+    let mut party_2 = DOPrfParty2::<Fp>::new(output_bitsize);
+    let mut party_3 = DOPrfParty3::<Fp>::new(output_bitsize);
 
     group.bench_function("init", |b| {
         b.iter(|| {

+ 66 - 90
oram/src/doprf.rs

@@ -1,51 +1,43 @@
-use utils::field::{FromLimbs, FromPrf, LegendreSymbol, Modulus128};
+use bitvec::vec::BitVec;
 use core::marker::PhantomData;
 use itertools::izip;
-use num_bigint::BigUint;
-use num_traits::identities::Zero;
 use rand::thread_rng;
 use std::iter::repeat;
+use utils::field::{FromLimbs, FromPrf, LegendreSymbol, Modulus128};
 
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub struct LegendrePrfKey<F: LegendreSymbol> {
     pub keys: Vec<F>,
 }
 
+impl<F: LegendreSymbol> LegendrePrfKey<F> {
+    pub fn get_output_bitsize(&self) -> usize {
+        self.keys.len()
+    }
+}
+
 /// Legendre PRF: F x F -> F
 pub struct LegendrePrf<F> {
     _phantom: PhantomData<F>,
 }
 
 impl<F: LegendreSymbol> LegendrePrf<F> {
-    pub const BITS_NEEDED: usize = F::NUM_BITS as usize + 40;
-}
-
-impl<F> LegendrePrf<F>
-where
-    F: LegendreSymbol + Modulus128 + FromLimbs,
-{
-    pub fn key_gen() -> LegendrePrfKey<F> {
+    pub fn key_gen(output_bitsize: usize) -> LegendrePrfKey<F> {
         LegendrePrfKey {
-            keys: (0..Self::BITS_NEEDED)
+            keys: (0..output_bitsize)
                 .map(|_| F::random(thread_rng()))
                 .collect(),
         }
     }
 
-    pub fn eval(key: &LegendrePrfKey<F>, input: F) -> F {
-        let mut int = BigUint::zero();
-
-        for (i, &k) in key.keys.iter().enumerate() {
-            let output = F::legendre_symbol(k + input);
-            assert!(output != F::ZERO, "unlikely");
-            int.set_bit(i.try_into().unwrap(), output == F::ONE);
+    pub fn eval(key: &LegendrePrfKey<F>, input: F) -> BitVec {
+        let mut output = BitVec::with_capacity(key.keys.len());
+        for &k in key.keys.iter() {
+            let ls = F::legendre_symbol(k + input);
+            assert!(ls != F::ZERO, "unlikely");
+            output.push(ls == F::ONE);
         }
-
-        let int = int % F::MOD;
-        assert_eq!(F::NUM_BITS, 128);
-        let mut limbs = int.to_u64_digits();
-        limbs.push(0);
-        F::from_limbs(&limbs)
+        output
     }
 }
 
@@ -79,6 +71,7 @@ impl<F: FromPrf> SharedPrf<F> {
 
 pub struct DOPrfParty1<F: LegendreSymbol + FromPrf> {
     _phantom: PhantomData<F>,
+    output_bitsize: usize,
     shared_prf_1_2: Option<SharedPrf<F>>,
     shared_prf_1_3: Option<SharedPrf<F>>,
     legendre_prf_key: Option<LegendrePrfKey<F>>,
@@ -92,9 +85,10 @@ impl<F> DOPrfParty1<F>
 where
     F: LegendreSymbol + FromPrf,
 {
-    pub fn new() -> Self {
+    pub fn new(output_bitsize: usize) -> Self {
         Self {
             _phantom: PhantomData,
+            output_bitsize,
             shared_prf_1_2: None,
             shared_prf_1_3: None,
             legendre_prf_key: None,
@@ -106,7 +100,7 @@ where
     }
 
     pub fn reset(&mut self) {
-        *self = Self::new()
+        *self = Self::new(self.output_bitsize)
     }
 
     pub fn reset_preprocessing(&mut self) {
@@ -126,12 +120,8 @@ where
         assert!(!self.is_initialized);
         // receive shared PRF key from Party 3
         self.shared_prf_1_3 = Some(SharedPrf::from_key(shared_prf_key_1_3));
-        // generate Legendre PRF key (shared with party 3)
-        self.legendre_prf_key = Some(LegendrePrfKey {
-            keys: (0..LegendrePrf::<F>::BITS_NEEDED)
-                .map(|_| self.shared_prf_1_2.as_mut().unwrap().eval())
-                .collect(),
-        });
+        // generate Legendre PRF key
+        self.legendre_prf_key = Some(LegendrePrf::key_gen(self.output_bitsize));
         self.is_initialized = true;
     }
 
@@ -142,7 +132,7 @@ where
 
     pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
         assert!(self.is_initialized);
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
         self.preprocessed_squares
             .extend((0..n).map(|_| self.shared_prf_1_2.as_mut().unwrap().eval().square()));
         ((), ())
@@ -150,7 +140,7 @@ where
 
     pub fn preprocess_round_1(&mut self, num: usize, preprocessed_mt_c1: Vec<F>, _: ()) {
         assert!(self.is_initialized);
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
         assert_eq!(preprocessed_mt_c1.len(), n);
         self.preprocessed_mt_c1.extend(preprocessed_mt_c1);
         self.num_preprocessed_invocations += num;
@@ -166,7 +156,7 @@ where
 
     pub fn check_preprocessing(&self) {
         let num = self.num_preprocessed_invocations;
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
         assert_eq!(self.preprocessed_squares.len(), n);
         assert_eq!(self.preprocessed_mt_c1.len(), n);
     }
@@ -179,25 +169,25 @@ where
         mult_e: &[F],
     ) -> ((), Vec<F>) {
         assert!(num <= self.num_preprocessed_invocations);
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
         assert_eq!(shares1.len(), num);
         assert_eq!(masked_shares2.len(), num);
         assert_eq!(mult_e.len(), num);
         let k = &self.legendre_prf_key.as_ref().unwrap().keys;
-        assert_eq!(k.len(), LegendrePrf::<F>::BITS_NEEDED);
+        assert_eq!(k.len(), self.output_bitsize);
         let output_shares_z1: Vec<F> = izip!(
             shares1
                 .iter()
-                .flat_map(|s1i| repeat(s1i).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|s1i| repeat(s1i).take(self.output_bitsize)),
             masked_shares2
                 .iter()
-                .flat_map(|ms2i| repeat(ms2i).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|ms2i| repeat(ms2i).take(self.output_bitsize)),
             k.iter().cycle(),
             self.preprocessed_squares.drain(0..n),
             self.preprocessed_mt_c1.drain(0..n),
             mult_e
                 .iter()
-                .flat_map(|e| repeat(e).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|e| repeat(e).take(self.output_bitsize)),
         )
         .map(|(&s1_i, &ms2_i, &k_j, sq_ij, c1_ij, &e_ij)| {
             sq_ij * (k_j + s1_i + ms2_i) + e_ij * sq_ij + c1_ij
@@ -210,9 +200,9 @@ where
 
 pub struct DOPrfParty2<F: LegendreSymbol + FromPrf> {
     _phantom: PhantomData<F>,
+    output_bitsize: usize,
     shared_prf_1_2: Option<SharedPrf<F>>,
     shared_prf_2_3: Option<SharedPrf<F>>,
-    legendre_prf_key: Option<LegendrePrfKey<F>>,
     is_initialized: bool,
     num_preprocessed_invocations: usize,
     preprocessed_rerand_m2: Vec<F>,
@@ -222,12 +212,12 @@ impl<F> DOPrfParty2<F>
 where
     F: LegendreSymbol + FromPrf,
 {
-    pub fn new() -> Self {
+    pub fn new(output_bitsize: usize) -> Self {
         Self {
             _phantom: PhantomData,
+            output_bitsize,
             shared_prf_1_2: None,
             shared_prf_2_3: None,
-            legendre_prf_key: None,
             is_initialized: false,
             num_preprocessed_invocations: 0,
             preprocessed_rerand_m2: Default::default(),
@@ -235,7 +225,7 @@ where
     }
 
     pub fn reset(&mut self) {
-        *self = Self::new()
+        *self = Self::new(self.output_bitsize)
     }
 
     pub fn reset_preprocessing(&mut self) {
@@ -253,23 +243,12 @@ where
         assert!(!self.is_initialized);
         // receive shared PRF key from Party 1
         self.shared_prf_1_2 = Some(SharedPrf::from_key(shared_prf_key_1_2));
-        // generate Legendre PRF key (shared with party 1)
-        self.legendre_prf_key = Some(LegendrePrfKey {
-            keys: (0..F::NUM_BITS + 40)
-                .map(|_| self.shared_prf_1_2.as_mut().unwrap().eval())
-                .collect(),
-        });
         self.is_initialized = true;
     }
 
-    pub fn get_legendre_prf_key(&self) -> LegendrePrfKey<F> {
-        assert!(self.is_initialized);
-        self.legendre_prf_key.as_ref().unwrap().clone()
-    }
-
     pub fn preprocess_round_0(&mut self, num: usize) -> (Vec<F>, ()) {
         assert!(self.is_initialized);
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
 
         let preprocessed_squares: Vec<F> = (0..n)
             .map(|_| self.shared_prf_1_2.as_mut().unwrap().eval().square())
@@ -290,7 +269,7 @@ where
             preprocessed_mult_d.iter(),
             preprocessed_mt_b
                 .iter()
-                .flat_map(|b| repeat(b).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|b| repeat(b).take(self.output_bitsize)),
             preprocessed_mt_c3.iter(),
         )
         .map(|(&s, &d, &b, &c3)| (s - d) * b - c3)
@@ -330,6 +309,7 @@ where
 
 pub struct DOPrfParty3<F: LegendreSymbol + FromPrf> {
     _phantom: PhantomData<F>,
+    output_bitsize: usize,
     shared_prf_1_3: Option<SharedPrf<F>>,
     shared_prf_2_3: Option<SharedPrf<F>>,
     is_initialized: bool,
@@ -345,9 +325,10 @@ impl<F> DOPrfParty3<F>
 where
     F: LegendreSymbol + FromPrf + FromLimbs + Modulus128,
 {
-    pub fn new() -> Self {
+    pub fn new(output_bitsize: usize) -> Self {
         Self {
             _phantom: PhantomData,
+            output_bitsize,
             shared_prf_1_3: None,
             shared_prf_2_3: None,
             is_initialized: false,
@@ -361,7 +342,7 @@ where
     }
 
     pub fn reset(&mut self) {
-        *self = Self::new()
+        *self = Self::new(self.output_bitsize)
     }
 
     pub fn reset_preprocessing(&mut self) {
@@ -386,7 +367,7 @@ where
 
     pub fn preprocess_round_0(&mut self, num: usize) -> ((), ()) {
         assert!(self.is_initialized);
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
 
         self.preprocessed_rerand_m3
             .extend((0..num).map(|_| -self.shared_prf_2_3.as_mut().unwrap().eval()));
@@ -419,7 +400,7 @@ where
 
     pub fn check_preprocessing(&self) {
         let num = self.num_preprocessed_invocations;
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
         assert_eq!(self.preprocessed_rerand_m3.len(), num);
         assert_eq!(self.preprocessed_mt_b.len(), num);
         assert_eq!(self.preprocessed_mt_c3.len(), n);
@@ -445,22 +426,22 @@ where
         shares3: &[F],
         output_shares_z1: Vec<F>,
         _: (),
-    ) -> Vec<F> {
+    ) -> Vec<BitVec> {
         assert!(num <= self.num_preprocessed_invocations);
-        let n = num * LegendrePrf::<F>::BITS_NEEDED;
+        let n = num * self.output_bitsize;
         assert_eq!(shares3.len(), num);
         assert_eq!(output_shares_z1.len(), n);
         let lprf_inputs: Vec<F> = izip!(
             shares3
                 .iter()
-                .flat_map(|s3| repeat(s3).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|s3| repeat(s3).take(self.output_bitsize)),
             self.preprocessed_rerand_m3
                 .drain(0..num)
-                .flat_map(|m3| repeat(m3).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|m3| repeat(m3).take(self.output_bitsize)),
             self.preprocessed_mult_d.drain(0..n),
             self.mult_e
                 .drain(0..num)
-                .flat_map(|e| repeat(e).take(LegendrePrf::<F>::BITS_NEEDED)),
+                .flat_map(|e| repeat(e).take(self.output_bitsize)),
             self.preprocessed_mt_c3.drain(0..n),
             output_shares_z1.iter(),
         )
@@ -469,20 +450,16 @@ where
         })
         .collect();
         assert_eq!(lprf_inputs.len(), n);
-        let output: Vec<F> = lprf_inputs
-            .chunks_exact(LegendrePrf::<F>::BITS_NEEDED)
+        let output: Vec<BitVec> = lprf_inputs
+            .chunks_exact(self.output_bitsize)
             .map(|chunk| {
-                let mut int = BigUint::zero();
-                for (i, &x) in chunk.iter().enumerate() {
-                    let output = F::legendre_symbol(x);
-                    assert!(output != F::ZERO, "unlikely");
-                    int.set_bit(i.try_into().unwrap(), output == F::ONE);
+                let mut bv = BitVec::with_capacity(self.output_bitsize);
+                for &x in chunk.iter() {
+                    let ls = F::legendre_symbol(x);
+                    assert!(ls != F::ZERO, "unlikely");
+                    bv.push(ls == F::ONE);
                 }
-                let int = int % F::MOD;
-                assert_eq!(F::NUM_BITS, 128);
-                let mut limbs = int.to_u64_digits();
-                limbs.push(0);
-                F::from_limbs(&limbs)
+                bv
             })
             .collect();
         self.num_preprocessed_invocations -= num;
@@ -493,14 +470,16 @@ where
 #[cfg(test)]
 mod tests {
     use super::*;
-    use utils::field::Fp;
     use ff::Field;
+    use utils::field::Fp;
 
     #[test]
     fn test_doprf() {
-        let mut party_1 = DOPrfParty1::<Fp>::new();
-        let mut party_2 = DOPrfParty2::<Fp>::new();
-        let mut party_3 = DOPrfParty3::<Fp>::new();
+        let output_bitsize = 42;
+
+        let mut party_1 = DOPrfParty1::<Fp>::new(output_bitsize);
+        let mut party_2 = DOPrfParty2::<Fp>::new(output_bitsize);
+        let mut party_3 = DOPrfParty3::<Fp>::new(output_bitsize);
 
         let (msg_1_2, msg_1_3) = party_1.init_round_0();
         let (msg_2_1, msg_2_3) = party_2.init_round_0();
@@ -509,10 +488,6 @@ mod tests {
         party_2.init_round_1(msg_1_2, msg_3_2);
         party_3.init_round_1(msg_1_3, msg_2_3);
 
-        // check that both parties generate the same Legendre PRF key
-        let legendre_prf_key = party_1.get_legendre_prf_key();
-        assert_eq!(legendre_prf_key, party_2.get_legendre_prf_key());
-
         // preprocess num invocations
         let num = 20;
 
@@ -551,7 +526,7 @@ mod tests {
 
         // verify preprocessed data
         {
-            let n = num * LegendrePrf::<Fp>::BITS_NEEDED;
+            let n = num * output_bitsize;
             let (squares, mt_c1) = party_1.get_preprocessed_data();
             let rerand_m2 = party_2.get_preprocessed_data();
             let (rerand_m3, mt_b, mt_c3, mult_d) = party_3.get_preprocessed_data();
@@ -579,8 +554,7 @@ mod tests {
             assert_eq!(mt_c3.len(), n);
             let mut triple_it = izip!(
                 mt_a.iter(),
-                mt_b.iter()
-                    .flat_map(|b| repeat(b).take(LegendrePrf::<Fp>::BITS_NEEDED)),
+                mt_b.iter().flat_map(|b| repeat(b).take(output_bitsize)),
                 mt_c1.iter(),
                 mt_c3.iter()
             );
@@ -608,8 +582,10 @@ mod tests {
         party_3.check_preprocessing();
 
         assert_eq!(output.len(), num);
+        assert!(output.iter().all(|bv| bv.len() == output_bitsize));
 
         // check that the output matches the non-distributed version
+        let legendre_prf_key = party_1.get_legendre_prf_key();
         for i in 0..num {
             let input_i = shares_1[i] + shares_2[i] + shares_3[i];
             let output_i = LegendrePrf::<Fp>::eval(&legendre_prf_key, input_i);