Browse Source

utils: add test for permutation

Lennart Braun 2 years ago
parent
commit
67815f3320
1 changed files with 32 additions and 2 deletions
  1. 32 2
      utils/src/permutation.rs

+ 32 - 2
utils/src/permutation.rs

@@ -1,6 +1,7 @@
 use rand::{thread_rng, Rng, SeedableRng};
 use rand::{thread_rng, Rng, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use rand_chacha::ChaCha20Rng;
 
 
+/// Trait that models a random permutation.
 pub trait Permutation {
 pub trait Permutation {
     type Key: Copy;
     type Key: Copy;
 
 
@@ -19,6 +20,7 @@ pub struct FisherYatesPermutationKey {
     prg_seed: [u8; 32],
     prg_seed: [u8; 32],
 }
 }
 
 
+/// Random permutation based on a Fisher-Yates shuffle of [0, N) with a seeded PRG.
 #[derive(Clone, Debug)]
 #[derive(Clone, Debug)]
 pub struct FisherYatesPermutation {
 pub struct FisherYatesPermutation {
     key: FisherYatesPermutationKey,
     key: FisherYatesPermutationKey,
@@ -36,10 +38,13 @@ impl Permutation for FisherYatesPermutation {
     }
     }
 
 
     fn from_key(key: Self::Key) -> Self {
     fn from_key(key: Self::Key) -> Self {
+        // rng seeded by the key
         let mut rng = ChaCha20Rng::from_seed(key.prg_seed);
         let mut rng = ChaCha20Rng::from_seed(key.prg_seed);
-        let mut permuted_vector: Vec<usize> = (0..(1 << key.log_domain_size)).collect();
-        // To shuffle an array a of n elements (indices 0..n-1):
+        // size of the domain
         let n = 1 << key.log_domain_size;
         let n = 1 << key.log_domain_size;
+        // vector to store permutation explicitly
+        let mut permuted_vector: Vec<usize> = (0..n).collect();
+        // run Fisher-Yates
         for i in (1..n).rev() {
         for i in (1..n).rev() {
             let j: usize = rng.gen_range(0..=i);
             let j: usize = rng.gen_range(0..=i);
             permuted_vector.swap(j, i);
             permuted_vector.swap(j, i);
@@ -63,3 +68,28 @@ impl Permutation for FisherYatesPermutation {
         self.permuted_vector[x]
         self.permuted_vector[x]
     }
     }
 }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    fn test_permutation<Perm: Permutation>(log_domain_size: u32) {
+        let n: usize = 1 << log_domain_size;
+        let key = Perm::sample(log_domain_size);
+        let perm = Perm::from_key(key);
+        let mut buffer = vec![0usize; n];
+        for i in 0..n {
+            buffer[i] = perm.permute(i);
+        }
+        buffer.sort();
+        for i in 0..n {
+            assert_eq!(buffer[i], i);
+        }
+    }
+
+    #[test]
+    fn test_all_permutations() {
+        let log_domain_size = 10;
+        test_permutation::<FisherYatesPermutation>(log_domain_size);
+    }
+}