Browse Source

doprf: add eval_to_uint

Lennart Braun 2 years ago
parent
commit
05ffc542e8
2 changed files with 83 additions and 31 deletions
  1. 66 0
      oram/src/doprf.rs
  2. 17 31
      oram/src/stash.rs

+ 66 - 0
oram/src/doprf.rs

@@ -64,6 +64,20 @@ impl<F: LegendreSymbol> LegendrePrf<F> {
     }
 }
 
+fn to_uint<T: Unsigned>(vs: impl IntoIterator<Item = impl IntoIterator<Item = bool>>) -> Vec<T> {
+    vs.into_iter()
+        .map(|v| {
+            let mut output = T::ZERO;
+            for (i, b) in v.into_iter().enumerate() {
+                if b {
+                    output |= T::ONE << i;
+                }
+            }
+            output
+        })
+        .collect()
+}
+
 type SharedSeed = [u8; 32];
 
 pub struct DOPrfParty1<F: LegendreSymbol> {
@@ -587,6 +601,19 @@ where
         let output = self.eval_round_2(num, shares3, fut_1_3.get()?, ());
         Ok(output)
     }
+
+    pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares3: &[F],
+    ) -> Result<Vec<T>, Error>
+    where
+        F: Serializable,
+    {
+        assert!(self.output_bitsize <= T::BITS as usize);
+        Ok(to_uint(self.eval(comm, num, shares3)?))
+    }
 }
 
 pub struct MaskedDOPrfParty1<F: LegendreSymbol> {
@@ -810,6 +837,19 @@ where
         let output = self.eval_round_2(1, shares1, (), fut_3_1.get()?);
         Ok(output)
     }
+
+    pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares1: &[F],
+    ) -> Result<Vec<T>, Error>
+    where
+        F: Serializable,
+    {
+        assert!(self.output_bitsize <= T::BITS as usize);
+        Ok(to_uint(self.eval(comm, num, shares1)?))
+    }
 }
 
 pub struct MaskedDOPrfParty2<F: LegendreSymbol> {
@@ -990,6 +1030,19 @@ where
         let output = self.eval_get_output(num);
         Ok(output)
     }
+
+    pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares2: &[F],
+    ) -> Result<Vec<T>, Error>
+    where
+        F: Serializable,
+    {
+        assert!(self.output_bitsize <= T::BITS as usize);
+        Ok(to_uint(self.eval(comm, num, shares2)?))
+    }
 }
 
 pub struct MaskedDOPrfParty3<F: LegendreSymbol> {
@@ -1180,6 +1233,19 @@ where
         let output = self.eval_get_output(num);
         Ok(output)
     }
+
+    pub fn eval_to_uint<C: AbstractCommunicator, T: Unsigned>(
+        &mut self,
+        comm: &mut C,
+        num: usize,
+        shares3: &[F],
+    ) -> Result<Vec<T>, Error>
+    where
+        F: Serializable,
+    {
+        assert!(self.output_bitsize <= T::BITS as usize);
+        Ok(to_uint(self.eval(comm, num, shares3)?))
+    }
 }
 
 #[cfg(test)]

+ 17 - 31
oram/src/stash.rs

@@ -75,15 +75,6 @@ fn compute_stash_prf_output_bitsize(stash_size: usize) -> usize {
     (usize::BITS - stash_size.leading_zeros()) as usize + 40
 }
 
-fn bits_to_u64(mut bits: BitVec) -> u64 {
-    assert!(bits.len() <= 64);
-    let mut bytes = [0u8; 8];
-    bits.force_align(); // important! otherwise, the first bit in the bitvec might not be the lsb
-                        // of the first byte in the underlying byte vector
-    bits.read(&mut bytes).unwrap();
-    u64::from_le_bytes(bytes)
-}
-
 fn stash_read_value<C, F, SPDPF>(
     comm: &mut C,
     access_counter: usize,
@@ -324,12 +315,11 @@ where
         let (flag_share, location_share) = match self.party_id {
             PARTY_1 => {
                 // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
-                let masked_address_tag = {
+                let masked_address_tag: u64 = {
                     let mdoprf_p1 = self.masked_doprf_party_1.as_mut().unwrap();
                     // for now do preprocessing on the fly
                     mdoprf_p1.preprocess(comm, 1)?;
-                    let mut masked_tag = mdoprf_p1.eval(comm, 1, &[instruction.address])?;
-                    bits_to_u64(masked_tag.pop().unwrap())
+                    mdoprf_p1.eval_to_uint(comm, 1, &[instruction.address])?[0]
                 };
 
                 // 2. Create and send DPF keys for the function f(x) = if x = y { 1 } else { 0 }
@@ -347,23 +337,20 @@ where
             }
             PARTY_2 | PARTY_3 => {
                 // 1. Compute tag y := PRF(k, <I.adr>) such that P1 obtains y + r and P2, P3 obtain the mask r.
-                let address_tag_mask = {
-                    let mut mask = match self.party_id {
-                        PARTY_2 => {
-                            let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
-                            // for now do preprocessing on the fly
-                            mdoprf_p2.preprocess(comm, 1)?;
-                            mdoprf_p2.eval(comm, 1, &[instruction.address])?
-                        }
-                        PARTY_3 => {
-                            let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
-                            // for now do preprocessing on the fly
-                            mdoprf_p3.preprocess(comm, 1)?;
-                            mdoprf_p3.eval(comm, 1, &[instruction.address])?
-                        }
-                        _ => panic!("invalid party id"),
-                    };
-                    bits_to_u64(mask.pop().unwrap())
+                let address_tag_mask: u64 = match self.party_id {
+                    PARTY_2 => {
+                        let mdoprf_p2 = self.masked_doprf_party_2.as_mut().unwrap();
+                        // for now do preprocessing on the fly
+                        mdoprf_p2.preprocess(comm, 1)?;
+                        mdoprf_p2.eval_to_uint(comm, 1, &[instruction.address])?[0]
+                    }
+                    PARTY_3 => {
+                        let mdoprf_p3 = self.masked_doprf_party_3.as_mut().unwrap();
+                        // for now do preprocessing on the fly
+                        mdoprf_p3.preprocess(comm, 1)?;
+                        mdoprf_p3.eval_to_uint(comm, 1, &[instruction.address])?[0]
+                    }
+                    _ => panic!("invalid party id"),
                 };
 
                 // 2. Receive DPF key for the function f(x) = if x = y { 1 } else { 0 }
@@ -468,8 +455,7 @@ where
                     let doprf_p3 = self.doprf_party_3.as_mut().unwrap();
                     // for now do preprocessing on the fly
                     doprf_p3.preprocess(comm, 1)?;
-                    let mut tag = doprf_p3.eval(comm, 1, &[db_address_share])?;
-                    let tag = bits_to_u64(tag.pop().unwrap());
+                    let tag = doprf_p3.eval_to_uint(comm, 1, &[db_address_share])?[0];
                     comm.send(PARTY_2, tag)?;
                     tag
                 };