Browse Source

cuckoo: fix clippy warnings

Lennart Braun 1 year ago
parent
commit
e7a9c23294
4 changed files with 28 additions and 33 deletions
  1. 1 7
      cuckoo/src/bin/params.rs
  2. 14 13
      cuckoo/src/cuckoo.rs
  3. 10 10
      dpf/src/mpdpf.rs
  4. 3 3
      dpf/src/spdpf.rs

+ 1 - 7
cuckoo/src/bin/params.rs

@@ -16,13 +16,7 @@ fn main() {
         let avg_bucket_size = buckets.iter().map(|b| b.len()).sum::<usize>() / number_buckets;
         let number_empty_buckets = buckets.iter().map(|b| b.len()).filter(|&l| l == 0).count();
         println!(
-            "{:6}  {:6}  {:6}  {:8}  {:8}  {:11}",
-            log_domain_size,
-            log_number_inputs,
-            number_buckets,
-            max_bucket_size,
-            avg_bucket_size,
-            number_empty_buckets
+            "{log_domain_size:6}  {log_number_inputs:6}  {number_buckets:6}  {max_bucket_size:8}  {avg_bucket_size:8}  {number_empty_buckets:11}"
         );
     }
 }

+ 14 - 13
cuckoo/src/cuckoo.rs

@@ -97,9 +97,8 @@ where
     pub fn from_seed(number_inputs: usize, seed: [u8; 32]) -> Self {
         let number_buckets = Self::compute_number_buckets(number_inputs);
         let mut rng = ChaCha12Rng::from_seed(seed);
-        let hash_function_descriptions = array::from_fn(|_| {
-            H::from_seed(number_buckets.try_into().unwrap(), rng.gen()).to_description()
-        });
+        let hash_function_descriptions =
+            array::from_fn(|_| H::from_seed(number_buckets, rng.gen()).to_description());
 
         Parameters::<H, Value> {
             number_inputs,
@@ -112,7 +111,7 @@ where
     pub fn sample(number_inputs: usize) -> Self {
         let number_buckets = Self::compute_number_buckets(number_inputs);
         let hash_function_descriptions =
-            array::from_fn(|_| H::sample(number_buckets.try_into().unwrap()).to_description());
+            array::from_fn(|_| H::sample(number_buckets).to_description());
 
         Parameters::<H, Value> {
             number_inputs,
@@ -209,10 +208,11 @@ where
         hashes: &[Vec<Value>; NUMBER_HASH_FUNCTIONS],
     ) -> Vec<Vec<u64>> {
         debug_assert!(hashes.iter().all(|v| v.len() as u64 == domain_size));
-        let mut hash_table = vec![Vec::new(); self.parameters.number_buckets as usize];
+        debug_assert_eq!(hashes.len(), NUMBER_HASH_FUNCTIONS);
+        let mut hash_table = vec![Vec::new(); self.parameters.number_buckets];
         for x in 0..domain_size {
-            for hash_function_index in 0..NUMBER_HASH_FUNCTIONS {
-                let h = hashes[hash_function_index][x as usize];
+            for hash_function_values in hashes.iter() {
+                let h = hash_function_values[x as usize];
                 hash_table[H::hash_value_as_usize(h)].push(x);
             }
         }
@@ -226,11 +226,12 @@ where
 
     /// Hash the given items into buckets all three hash functions
     pub fn hash_items_into_buckets(&self, items: &[u64]) -> Vec<Vec<u64>> {
-        let mut hash_table = vec![Vec::new(); self.parameters.number_buckets as usize];
+        let mut hash_table = vec![Vec::new(); self.parameters.number_buckets];
         let hashes = self.hash_items(items);
+        debug_assert_eq!(hashes.len(), NUMBER_HASH_FUNCTIONS);
         for (i, &x) in items.iter().enumerate() {
-            for hash_function_index in 0..NUMBER_HASH_FUNCTIONS {
-                let h = hashes[hash_function_index][i as usize];
+            for hash_function_values in hashes.iter() {
+                let h = hash_function_values[i];
                 hash_table[H::hash_value_as_usize(h)].push(x);
             }
         }
@@ -238,7 +239,7 @@ where
     }
 
     /// Compute a vector of the sizes of all buckets
-    pub fn compute_bucket_sizes(hash_table: &Vec<Vec<u64>>) -> Vec<usize> {
+    pub fn compute_bucket_sizes(hash_table: &[Vec<u64>]) -> Vec<usize> {
         hash_table.iter().map(|v| v.len()).collect()
     }
 
@@ -248,7 +249,7 @@ where
     /// is placed into buckets using three hash functions.
     pub fn compute_pos_lookup_table(
         domain_size: u64,
-        hash_table: &Vec<Vec<u64>>,
+        hash_table: &[Vec<u64>],
     ) -> Vec<[(usize, usize); 3]> {
         let mut lookup_table = vec![[(usize::MAX, usize::MAX); 3]; domain_size as usize];
         for (bucket_i, bucket) in hash_table.iter().enumerate() {
@@ -266,7 +267,7 @@ where
     }
 
     /// Use the lookup table for the position map
-    pub fn pos_lookup(lookup_table: &Vec<[(usize, usize); 3]>, bucket_i: usize, item: u64) -> u64 {
+    pub fn pos_lookup(lookup_table: &[[(usize, usize); 3]], bucket_i: usize, item: u64) -> u64 {
         for k in 0..NUMBER_HASH_FUNCTIONS {
             if lookup_table[item as usize][k].0 == bucket_i {
                 return lookup_table[item as usize][k].1 as u64;

+ 10 - 10
dpf/src/mpdpf.rs

@@ -117,15 +117,15 @@ where
                 party_id: 0,
                 domain_size: self.domain_size,
                 number_points: self.number_points,
-                alphas: alphas.iter().copied().collect(),
-                betas: betas.iter().copied().collect(),
+                alphas: alphas.to_vec(),
+                betas: betas.to_vec(),
             },
             DummyMpDpfKey {
                 party_id: 1,
                 domain_size: self.domain_size,
                 number_points: self.number_points,
-                alphas: alphas.iter().copied().collect(),
-                betas: betas.iter().copied().collect(),
+                alphas: alphas.to_vec(),
+                betas: betas.to_vec(),
             },
         )
     }
@@ -212,7 +212,7 @@ where
             domain_size: self.domain_size,
             number_points: self.number_points,
             spdpf_keys: self.spdpf_keys.clone(),
-            cuckoo_parameters: self.cuckoo_parameters.clone(),
+            cuckoo_parameters: self.cuckoo_parameters,
         }
     }
 }
@@ -436,14 +436,14 @@ where
                 domain_size: self.domain_size,
                 number_points,
                 spdpf_keys: keys_0,
-                cuckoo_parameters: cuckoo_parameters.clone(),
+                cuckoo_parameters: *cuckoo_parameters,
             },
             SmartMpDpfKey::<SPDPF, H> {
                 party_id: 1,
                 domain_size: self.domain_size,
                 number_points,
                 spdpf_keys: keys_1,
-                cuckoo_parameters: cuckoo_parameters.clone(),
+                cuckoo_parameters: *cuckoo_parameters,
             },
         )
     }
@@ -471,7 +471,7 @@ where
             debug_assert!(key.spdpf_keys[hash].is_some());
             let sp_key = key.spdpf_keys[hash].as_ref().unwrap();
             debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
-            SPDPF::evaluate_at(&sp_key, pos(hash, index))
+            SPDPF::evaluate_at(sp_key, pos(hash, index))
         };
 
         // prevent adding the same term multiple times when we have collisions
@@ -493,7 +493,7 @@ where
             debug_assert!(key.spdpf_keys[hash].is_some());
             let sp_key = key.spdpf_keys[hash].as_ref().unwrap();
             debug_assert_eq!(simple_htable[hash][pos(hash, index) as usize], index);
-            output += SPDPF::evaluate_at(&sp_key, pos(hash, index));
+            output += SPDPF::evaluate_at(sp_key, pos(hash, index));
         }
 
         output
@@ -531,7 +531,7 @@ where
             .map(|sp_key_opt| {
                 sp_key_opt
                     .as_ref()
-                    .map_or(vec![], |sp_key| SPDPF::evaluate_domain(&sp_key))
+                    .map_or(vec![], |sp_key| SPDPF::evaluate_domain(sp_key))
             })
             .collect();
 

+ 3 - 3
dpf/src/spdpf.rs

@@ -21,7 +21,7 @@ pub trait SinglePointDpf {
     fn evaluate_at(key: &Self::Key, index: u64) -> Self::Value;
     fn evaluate_domain(key: &Self::Key) -> Vec<Self::Value> {
         (0..key.get_domain_size())
-            .map(|x| Self::evaluate_at(&key, x as u64))
+            .map(|x| Self::evaluate_at(key, x as u64))
             .collect()
     }
 }
@@ -326,7 +326,7 @@ where
             seeds[2 * j] = st_0;
             // check if we need both outputs
             if key.domain_size & 1 == 0 {
-                let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
+                let st_1 = hash(st ^ 1) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
                 seeds[2 * j + 1] = st_1;
             }
 
@@ -334,7 +334,7 @@ where
             for j in (0..(last_index >> 1)).rev() {
                 let st = seeds[j];
                 let st_0 = hash(st) ^ (st & 1) * (key.hcw | key.lcw[0] as u128);
-                let st_1 = hash(st ^ 1 as u128) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
+                let st_1 = hash(st ^ 1) ^ (st & 1) * (key.hcw | key.lcw[1] as u128);
                 seeds[2 * j] = st_0;
                 seeds[2 * j + 1] = st_1;
             }