permutation.rs 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. //! Functionality for random permutations.
  2. use bincode;
  3. use rand::{thread_rng, Rng, SeedableRng};
  4. use rand_chacha::ChaCha20Rng;
  5. /// Trait that models a random permutation.
  6. pub trait Permutation {
  7. /// Key type that defines the permutation.
  8. type Key: Copy;
  9. /// Sample a random key for a permutation with the given domain size.
  10. fn sample(domain_size: usize) -> Self::Key;
  11. /// Instantiate a permutation object from a given key.
  12. fn from_key(key: Self::Key) -> Self;
  13. /// Get the key for this permutation instance.
  14. fn get_key(&self) -> Self::Key;
  15. /// Return the domain size of this permutation.
  16. fn get_domain_size(&self) -> usize;
  17. /// Apply the permutation to index `x`.
  18. fn permute(&self, x: usize) -> usize;
  19. }
  20. /// Key type for a [`FisherYatesPermutation`].
  21. #[derive(Clone, Copy, Debug, PartialEq, Eq, bincode::Encode, bincode::Decode)]
  22. pub struct FisherYatesPermutationKey {
  23. domain_size: usize,
  24. prg_seed: [u8; 32],
  25. }
  26. /// Random permutation based on a Fisher-Yates shuffle of [0, N) with a seeded PRG.
  27. #[derive(Clone, Debug)]
  28. pub struct FisherYatesPermutation {
  29. key: FisherYatesPermutationKey,
  30. permuted_vector: Vec<usize>,
  31. }
  32. impl Permutation for FisherYatesPermutation {
  33. type Key = FisherYatesPermutationKey;
  34. fn sample(domain_size: usize) -> Self::Key {
  35. Self::Key {
  36. domain_size,
  37. prg_seed: thread_rng().gen(),
  38. }
  39. }
  40. fn from_key(key: Self::Key) -> Self {
  41. // rng seeded by the key
  42. let mut rng = ChaCha20Rng::from_seed(key.prg_seed);
  43. // size of the domain
  44. let n = key.domain_size;
  45. // vector to store permutation explicitly
  46. let mut permuted_vector: Vec<usize> = (0..n).collect();
  47. // run Fisher-Yates
  48. for i in (1..n).rev() {
  49. let j: usize = rng.gen_range(0..=i);
  50. permuted_vector.swap(j, i);
  51. }
  52. Self {
  53. key,
  54. permuted_vector,
  55. }
  56. }
  57. fn get_key(&self) -> Self::Key {
  58. self.key
  59. }
  60. fn get_domain_size(&self) -> usize {
  61. self.key.domain_size
  62. }
  63. fn permute(&self, x: usize) -> usize {
  64. assert!(x < self.permuted_vector.len());
  65. self.permuted_vector[x]
  66. }
  67. }
  68. #[cfg(test)]
  69. mod tests {
  70. use super::*;
  71. fn test_permutation<Perm: Permutation>(log_domain_size: u32) {
  72. let n: usize = 1 << log_domain_size;
  73. let key = Perm::sample(n);
  74. let perm = Perm::from_key(key);
  75. let mut buffer = vec![0usize; n];
  76. for i in 0..n {
  77. buffer[i] = perm.permute(i);
  78. }
  79. buffer.sort();
  80. for i in 0..n {
  81. assert_eq!(buffer[i], i);
  82. }
  83. }
  84. #[test]
  85. fn test_all_permutations() {
  86. let log_domain_size = 10;
  87. test_permutation::<FisherYatesPermutation>(log_domain_size);
  88. }
  89. #[test]
  90. fn test_serialization() {
  91. for _ in 0..100 {
  92. let log_domain_size = thread_rng().gen_range(1..30);
  93. let key = FisherYatesPermutation::sample(log_domain_size);
  94. let bytes = bincode::encode_to_vec(key, bincode::config::standard()).unwrap();
  95. let (new_key, bytes_read): (FisherYatesPermutationKey, usize) =
  96. bincode::decode_from_slice(&bytes, bincode::config::standard()).unwrap();
  97. assert_eq!(bytes_read, bytes.len());
  98. assert_eq!(new_key, key);
  99. }
  100. }
  101. }