Browse Source

doprf: update tests

Lennart Braun 2 years ago
parent
commit
e6cd84972f
1 changed files with 213 additions and 48 deletions
  1. 213 48
      oram/src/doprf.rs

+ 213 - 48
oram/src/doprf.rs

@@ -942,30 +942,66 @@ mod tests {
     use ff::Field;
     use utils::field::Fp;
 
-    #[test]
-    fn test_doprf() {
-        let output_bitsize = 42;
-
-        let mut party_1 = DOPrfParty1::<Fp>::new(output_bitsize);
-        let mut party_2 = DOPrfParty2::<Fp>::new(output_bitsize);
-        let mut party_3 = DOPrfParty3::<Fp>::new(output_bitsize);
-
+    fn doprf_init(
+        party_1: &mut DOPrfParty1<Fp>,
+        party_2: &mut DOPrfParty2<Fp>,
+        party_3: &mut DOPrfParty3<Fp>,
+    ) {
         let (msg_1_2, msg_1_3) = party_1.init_round_0();
         let (msg_2_1, msg_2_3) = party_2.init_round_0();
         let (msg_3_1, msg_3_2) = party_3.init_round_0();
         party_1.init_round_1(msg_2_1, msg_3_1);
         party_2.init_round_1(msg_1_2, msg_3_2);
         party_3.init_round_1(msg_1_3, msg_2_3);
+    }
 
-        // preprocess num invocations
-        let num = 20;
-
+    fn doprf_preprocess(
+        party_1: &mut DOPrfParty1<Fp>,
+        party_2: &mut DOPrfParty2<Fp>,
+        party_3: &mut DOPrfParty3<Fp>,
+        num: usize,
+    ) {
         let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
         let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
         let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
         party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
         party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
         party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+    }
+
+    fn doprf_eval(
+        party_1: &mut DOPrfParty1<Fp>,
+        party_2: &mut DOPrfParty2<Fp>,
+        party_3: &mut DOPrfParty3<Fp>,
+        shares_1: &[Fp],
+        shares_2: &[Fp],
+        shares_3: &[Fp],
+        num: usize,
+    ) -> Vec<BitVec> {
+        assert_eq!(shares_1.len(), num);
+        assert_eq!(shares_2.len(), num);
+        assert_eq!(shares_3.len(), num);
+
+        let (msg_2_1, msg_2_3) = party_2.eval_round_0(num, &shares_2);
+        let (msg_3_1, _) = party_3.eval_round_0(num, &shares_3);
+        let (_, msg_1_3) = party_1.eval_round_1(num, &shares_1, &msg_2_1, &msg_3_1);
+        let output = party_3.eval_round_2(num, &shares_3, msg_1_3, msg_2_3);
+        output
+    }
+
+    #[test]
+    fn test_doprf() {
+        let output_bitsize = 42;
+
+        let mut party_1 = DOPrfParty1::<Fp>::new(output_bitsize);
+        let mut party_2 = DOPrfParty2::<Fp>::new(output_bitsize);
+        let mut party_3 = DOPrfParty3::<Fp>::new(output_bitsize);
+
+        doprf_init(&mut party_1, &mut party_2, &mut party_3);
+
+        // preprocess num invocations
+        let num = 20;
+        doprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
 
         assert_eq!(party_1.get_num_preprocessed_invocations(), num);
         assert_eq!(party_2.get_num_preprocessed_invocations(), num);
@@ -976,12 +1012,7 @@ mod tests {
         party_3.check_preprocessing();
 
         // preprocess another n invocations
-        let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
-        let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
-        let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
-        party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
-        party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
-        party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+        doprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
 
         let num = 2 * num;
 
@@ -1013,9 +1044,6 @@ mod tests {
                 .map(|(&s, &d)| s - d)
                 .collect();
             assert_eq!(mult_d.len(), n);
-            // assert!(
-            //     izip!(squares.iter(), mt_a.iter(), mult_d.iter()).all(|(&s, &a, &d)| d == s - a)
-            // );
 
             assert_eq!(mt_a.len(), n);
             assert_eq!(mt_b.len(), num);
@@ -1038,10 +1066,15 @@ mod tests {
         let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
         let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
 
-        let (msg_2_1, msg_2_3) = party_2.eval_round_0(num, &shares_2);
-        let (msg_3_1, _) = party_3.eval_round_0(num, &shares_3);
-        let (_, msg_1_3) = party_1.eval_round_1(num, &shares_1, &msg_2_1, &msg_3_1);
-        let output = party_3.eval_round_2(num, &shares_3, msg_1_3, msg_2_3);
+        let output = doprf_eval(
+            &mut party_1,
+            &mut party_2,
+            &mut party_3,
+            &shares_1,
+            &shares_2,
+            &shares_3,
+            num,
+        );
 
         assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
         assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
@@ -1062,30 +1095,68 @@ mod tests {
         }
     }
 
-    #[test]
-    fn test_masked_doprf() {
-        let output_bitsize = 42;
-
-        let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
-        let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
-        let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
-
+    fn mdoprf_init(
+        party_1: &mut MaskedDOPrfParty1<Fp>,
+        party_2: &mut MaskedDOPrfParty2<Fp>,
+        party_3: &mut MaskedDOPrfParty3<Fp>,
+    ) {
         let (msg_1_2, msg_1_3) = party_1.init_round_0();
         let (msg_2_1, msg_2_3) = party_2.init_round_0();
         let (msg_3_1, msg_3_2) = party_3.init_round_0();
         party_1.init_round_1(msg_2_1, msg_3_1);
         party_2.init_round_1(msg_1_2, msg_3_2);
         party_3.init_round_1(msg_1_3, msg_2_3);
+    }
 
-        // preprocess num invocations
-        let num = 20;
-
+    fn mdoprf_preprocess(
+        party_1: &mut MaskedDOPrfParty1<Fp>,
+        party_2: &mut MaskedDOPrfParty2<Fp>,
+        party_3: &mut MaskedDOPrfParty3<Fp>,
+        num: usize,
+    ) {
         let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
         let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
         let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
         party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
         party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
         party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+    }
+
+    fn mdoprf_eval(
+        party_1: &mut MaskedDOPrfParty1<Fp>,
+        party_2: &mut MaskedDOPrfParty2<Fp>,
+        party_3: &mut MaskedDOPrfParty3<Fp>,
+        shares_1: &[Fp],
+        shares_2: &[Fp],
+        shares_3: &[Fp],
+        num: usize,
+    ) -> (Vec<BitVec>, Vec<BitVec>, Vec<BitVec>) {
+        assert_eq!(shares_1.len(), num);
+        assert_eq!(shares_2.len(), num);
+        assert_eq!(shares_3.len(), num);
+
+        let (_, msg_1_3) = party_1.eval_round_0(num, &shares_1);
+        let (_, msg_2_3) = party_2.eval_round_0(num, &shares_2);
+        let (msg_3_1, _) = party_3.eval_round_1(num, &shares_3, &msg_1_3, &msg_2_3);
+        let masked_output = party_1.eval_round_2(num, &shares_1, (), msg_3_1);
+        let mask2 = party_2.eval_get_output(num);
+        let mask3 = party_3.eval_get_output(num);
+        (masked_output, mask2, mask3)
+    }
+
+    #[test]
+    fn test_masked_doprf() {
+        let output_bitsize = 42;
+
+        let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
+        let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
+        let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
+
+        mdoprf_init(&mut party_1, &mut party_2, &mut party_3);
+
+        // preprocess num invocations
+        let num = 20;
+        mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
 
         assert_eq!(party_1.get_num_preprocessed_invocations(), num);
         assert_eq!(party_2.get_num_preprocessed_invocations(), num);
@@ -1096,12 +1167,7 @@ mod tests {
         party_3.check_preprocessing();
 
         // preprocess another n invocations
-        let (msg_1_2, msg_1_3) = party_1.preprocess_round_0(num);
-        let (msg_2_1, msg_2_3) = party_2.preprocess_round_0(num);
-        let (msg_3_1, msg_3_2) = party_3.preprocess_round_0(num);
-        party_1.preprocess_round_1(num, msg_2_1, msg_3_1);
-        party_2.preprocess_round_1(num, msg_1_2, msg_3_2);
-        party_3.preprocess_round_1(num, msg_1_3, msg_2_3);
+        mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
 
         let num = 2 * num;
 
@@ -1153,13 +1219,15 @@ mod tests {
         let shares_1: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
         let shares_2: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
         let shares_3: Vec<Fp> = (0..num).map(|_| Fp::random(thread_rng())).collect();
-
-        let (_, msg_1_3) = party_1.eval_round_0(num, &shares_1);
-        let (_, msg_2_3) = party_2.eval_round_0(num, &shares_2);
-        let (msg_3_1, _) = party_3.eval_round_1(num, &shares_3, &msg_1_3, &msg_2_3);
-        let masked_output = party_1.eval_round_2(num, &shares_1, (), msg_3_1);
-        let mask2 = party_2.eval_get_output(num);
-        let mask3 = party_3.eval_get_output(num);
+        let (masked_output, mask2, mask3) = mdoprf_eval(
+            &mut party_1,
+            &mut party_2,
+            &mut party_3,
+            &shares_1,
+            &shares_2,
+            &shares_3,
+            num,
+        );
 
         assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
         assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
@@ -1182,5 +1250,102 @@ mod tests {
             let output_i = masked_output[i].clone() ^ &mask2[i];
             assert_eq!(output_i, expected_output_i);
         }
+
+        // preprocess another n invocations
+        mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, num);
+
+        // perform another n evaluations on the same inputs
+        let num = 15;
+
+        let (new_masked_output, new_mask2, new_mask3) = mdoprf_eval(
+            &mut party_1,
+            &mut party_2,
+            &mut party_3,
+            &shares_1,
+            &shares_2,
+            &shares_3,
+            num,
+        );
+
+        assert_eq!(party_1.get_num_preprocessed_invocations(), 25);
+        assert_eq!(party_2.get_num_preprocessed_invocations(), 25);
+        assert_eq!(party_3.get_num_preprocessed_invocations(), 25);
+        party_1.check_preprocessing();
+        party_2.check_preprocessing();
+        party_3.check_preprocessing();
+
+        assert_eq!(new_masked_output.len(), num);
+        assert!(new_masked_output
+            .iter()
+            .all(|bv| bv.len() == output_bitsize));
+        assert_eq!(new_mask2.len(), num);
+        assert_eq!(new_mask2, new_mask3);
+        assert!(new_mask2.iter().all(|bv| bv.len() == output_bitsize));
+
+        // check that the new output matches the previous one
+        for i in 0..num {
+            let expected_output_i = masked_output[i].clone() ^ &mask2[i];
+            let output_i = new_masked_output[i].clone() ^ &new_mask2[i];
+            assert_eq!(output_i, expected_output_i);
+        }
+    }
+
+    #[test]
+    fn test_masked_doprf_single() {
+        let output_bitsize = 42;
+
+        let mut party_1 = MaskedDOPrfParty1::<Fp>::new(output_bitsize);
+        let mut party_2 = MaskedDOPrfParty2::<Fp>::new(output_bitsize);
+        let mut party_3 = MaskedDOPrfParty3::<Fp>::new(output_bitsize);
+
+        mdoprf_init(&mut party_1, &mut party_2, &mut party_3);
+
+        let share_1 = Fp::random(thread_rng());
+        let share_2 = Fp::random(thread_rng());
+        let share_3 = Fp::random(thread_rng());
+
+        mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
+        let (masked_output_1, mask2_1, mask3_1) = mdoprf_eval(
+            &mut party_1,
+            &mut party_2,
+            &mut party_3,
+            &[share_1],
+            &[share_2],
+            &[share_3],
+            1,
+        );
+        mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
+        let (masked_output_2, mask2_2, mask3_2) = mdoprf_eval(
+            &mut party_1,
+            &mut party_2,
+            &mut party_3,
+            &[share_1],
+            &[share_2],
+            &[share_3],
+            1,
+        );
+        mdoprf_preprocess(&mut party_1, &mut party_2, &mut party_3, 1);
+        let (masked_output_3, mask2_3, mask3_3) = mdoprf_eval(
+            &mut party_1,
+            &mut party_2,
+            &mut party_3,
+            &[share_1],
+            &[share_2],
+            &[share_3],
+            1,
+        );
+
+        assert_eq!(mask2_1, mask3_1);
+        assert_eq!(mask2_2, mask3_2);
+        assert_eq!(mask2_3, mask3_3);
+        let plain_output = masked_output_1[0].clone() ^ mask2_1[0].clone();
+        assert_eq!(
+            masked_output_2[0].clone() ^ mask2_2[0].clone(),
+            plain_output
+        );
+        assert_eq!(
+            masked_output_3[0].clone() ^ mask2_3[0].clone(),
+            plain_output
+        );
     }
 }