Browse Source

dpf: fix clippy warnings

Lennart Braun 1 year ago
parent
commit
20c1f8ce9c
2 changed files with 29 additions and 21 deletions
  1. 4 4
      dpf/src/mpdpf.rs
  2. 25 17
      dpf/src/spdpf.rs

+ 4 - 4
dpf/src/mpdpf.rs

@@ -30,7 +30,7 @@ pub trait MultiPointDpf {
     fn evaluate_at(&self, key: &Self::Key, index: u64) -> Self::Value;
     fn evaluate_domain(&self, key: &Self::Key) -> Vec<Self::Value> {
         (0..key.get_domain_size())
-            .map(|x| self.evaluate_at(&key, x as u64))
+            .map(|x| self.evaluate_at(key, x as u64))
             .collect()
     }
 }
@@ -167,7 +167,7 @@ where
         } else {
             (" ", "")
         };
-        write!(f, "SmartMpDpfKey<SPDPF, H>{{{}", newline)?;
+        write!(f, "SmartMpDpfKey<SPDPF, H>{{{newline}")?;
         write!(
             f,
             "{}party_id: {:?},{}",
@@ -184,9 +184,9 @@ where
             indentation, self.number_points, newline
         )?;
         if f.alternate() {
-            write!(f, "    spdpf_keys:\n")?;
+            writeln!(f, "    spdpf_keys:")?;
             for (i, k) in self.spdpf_keys.iter().enumerate() {
-                write!(f, "        spdpf_keys[{}]: {:?}\n", i, k)?;
+                writeln!(f, "        spdpf_keys[{i}]: {k:?}")?;
             }
         } else {
             write!(f, " spdpf_keys: {:?},", self.spdpf_keys)?;

+ 25 - 17
dpf/src/spdpf.rs

@@ -139,7 +139,7 @@ where
 {
     const FIXED_KEY_AES_KEY: [u8; 16] =
         0xdead_beef_1337_4247_dead_beef_1337_4247_u128.to_le_bytes();
-    const HASH_KEY: u128 = 0xc000ffee_c0ffffee_c0ffeeee_c00ffeee_u128;
+    const HASH_KEY: u128 = 0xc000_ffee_c0ff_ffee_c0ff_eeee_c00f_feee_u128;
 }
 
 impl<V> SinglePointDpf for HalfTreeSpDpf<V>
@@ -195,10 +195,12 @@ where
         let mut st_1 = st_0 ^ delta;
         let party_seeds = (st_0, st_1);
 
-        for i in 0..(tree_height - 1) as usize {
-            let cw_i = hash(st_0) ^ hash(st_1) ^ (1 - alpha_bits[i] as u128) * delta;
-            st_0 = hash(st_0) ^ alpha_bits[i] as u128 * (st_0) ^ (st_0 & 1) * cw_i;
-            st_1 = hash(st_1) ^ alpha_bits[i] as u128 * (st_1) ^ (st_1 & 1) * cw_i;
+        debug_assert_eq!(alpha_bits.len(), tree_height);
+
+        for alpha_i in alpha_bits.iter().copied().take(tree_height - 1) {
+            let cw_i = hash(st_0) ^ hash(st_1) ^ ((1 - alpha_i as u128) * delta);
+            st_0 = hash(st_0) ^ (alpha_i as u128 * st_0) ^ ((st_0 & 1) * cw_i);
+            st_1 = hash(st_1) ^ (alpha_i as u128 * st_1) ^ ((st_1 & 1) * cw_i);
             correction_words.push(cw_i);
         }
 
@@ -212,8 +214,8 @@ where
             ((high_low[0][1] ^ high_low[1][1] ^ a_n as u128) & LOW_MASK) != 0,
         ];
 
-        st_0 = high_low[0][a_n as usize] ^ (st_0 & 1) * (hcw | lcw[a_n as usize] as u128);
-        st_1 = high_low[1][a_n as usize] ^ (st_1 & 1) * (hcw | lcw[a_n as usize] as u128);
+        st_0 = high_low[0][a_n as usize] ^ ((st_0 & 1) * (hcw | lcw[a_n as usize] as u128));
+        st_1 = high_low[1][a_n as usize] ^ ((st_1 & 1) * (hcw | lcw[a_n as usize] as u128));
         let correction_word_np1: V = match (st_0 & 1).wrapping_sub(st_1 & 1) {
             u128::MAX => convert(st_0 >> 1) - convert(st_1 >> 1) - beta,
             0 => V::zero(),
@@ -259,14 +261,20 @@ where
         let tree_height = (key.domain_size as f64).log2().ceil() as usize;
         let index_bits: Vec<bool> = bit_decompose(index, tree_height);
 
+        debug_assert_eq!(index_bits.len(), tree_height);
+
         let mut st_b = key.party_seed;
-        for i in 0..tree_height - 1 {
-            st_b =
-                hash(st_b) ^ (index_bits[i] as u128 * st_b) ^ (st_b & 1) * key.correction_words[i];
+        for (index_bit_i, correction_word_i) in index_bits
+            .iter()
+            .copied()
+            .zip(key.correction_words.iter())
+            .take(tree_height - 1)
+        {
+            st_b = hash(st_b) ^ (index_bit_i as u128 * st_b) ^ ((st_b & 1) * correction_word_i);
         }
         let x_n = index_bits[tree_height - 1];
         let high_low_b_xn = hash(st_b ^ x_n as u128);
-        st_b = high_low_b_xn ^ (st_b & 1) * (key.hcw | key.lcw[x_n as usize] as u128);
+        st_b = high_low_b_xn ^ ((st_b & 1) * (key.hcw | key.lcw[x_n as usize] as u128));
 
         let value = convert(st_b >> 1)
             + if st_b & 1 == 0 {
@@ -309,8 +317,8 @@ where
                 for j in (0..(last_index >> (tree_height - i)) + 1).rev() {
                     // for j in (0..(1 << i)).rev() {
                     let st = seeds[j];
-                    let st_0 = hash(st) ^ (st & 1) * key.correction_words[i];
-                    let st_1 = hash(st) ^ st ^ (st & 1) * key.correction_words[i];
+                    let st_0 = hash(st) ^ ((st & 1) * key.correction_words[i]);
+                    let st_1 = hash(st) ^ st ^ ((st & 1) * key.correction_words[i]);
                     seeds[2 * j] = st_0;
                     seeds[2 * j + 1] = st_1;
                 }
@@ -322,19 +330,19 @@ where
             // handle the last expansion separately, since we might not need both outputs
             let j = last_index >> 1;
             let st = seeds[j];
-            let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
+            let st_0 = hash(st) ^ ((st & 1) * (key.hcw | key.lcw[0] as u128));
             seeds[2 * j] = st_0;
             // check if we need both outputs
             if key.domain_size & 1 == 0 {
-                let st_1 = hash(st ^ 1) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
+                let st_1 = hash(st ^ 1) ^ ((st & 1) * (key.hcw | key.lcw[1] as u128));
                 seeds[2 * j + 1] = st_1;
             }
 
             // handle the other expansions as usual
             for j in (0..(last_index >> 1)).rev() {
                 let st = seeds[j];
-                let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
-                let st_1 = hash(st ^ 1) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
+                let st_0 = hash(st) ^ ((st & 1) * (key.hcw | key.lcw[0] as u128));
+                let st_1 = hash(st ^ 1) ^ ((st & 1) * (key.hcw | key.lcw[1] as u128));
                 seeds[2 * j] = st_0;
                 seeds[2 * j + 1] = st_1;
             }