Преглед изворни кода

utils: make FY permuation key serializable

Lennart Braun пре 2 година
родитељ
комит
618c3cfa01
1 измењених фајлова са 42 додато и 1 уклоњено
  1. 42 1
      utils/src/permutation.rs

+ 42 - 1
utils/src/permutation.rs

@@ -1,3 +1,4 @@
+use communicator::{traits::Serializable, Error as CommError};
 use rand::{thread_rng, Rng, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 
@@ -14,12 +15,41 @@ pub trait Permutation {
     // fn permuted_vector() -> Vec<usize>;
 }
 
-#[derive(Clone, Copy, Debug)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 pub struct FisherYatesPermutationKey {
     log_domain_size: u32,
     prg_seed: [u8; 32],
 }
 
+impl Serializable for FisherYatesPermutationKey {
+    fn bytes_required() -> usize {
+        36
+    }
+
+    fn into_bytes(&self, buf: &mut [u8]) -> Result<(), CommError> {
+        if buf.len() != Self::bytes_required() {
+            return Err(CommError::SerializationError(
+                "buffer has wrong size".to_owned(),
+            ));
+        }
+        buf[..4].copy_from_slice(&self.log_domain_size.to_be_bytes());
+        buf[4..36].copy_from_slice(&self.prg_seed);
+        Ok(())
+    }
+
+    fn from_bytes(buf: &[u8]) -> Result<Self, CommError> {
+        if buf.len() != Self::bytes_required() {
+            return Err(CommError::DeserializationError(
+                "buffer has wrong size".to_owned(),
+            ));
+        }
+        Ok(Self {
+            log_domain_size: u32::from_be_bytes(buf[..4].try_into().unwrap()),
+            prg_seed: buf[4..36].try_into().unwrap(),
+        })
+    }
+}
+
 /// Random permutation based on a Fisher-Yates shuffle of [0, N) with a seeded PRG.
 #[derive(Clone, Debug)]
 pub struct FisherYatesPermutation {
@@ -92,4 +122,15 @@ mod tests {
         let log_domain_size = 10;
         test_permutation::<FisherYatesPermutation>(log_domain_size);
     }
+
+    #[test]
+    fn test_serialization() {
+        for _ in 0..100 {
+            let log_domain_size = thread_rng().gen_range(1..30);
+            let key = FisherYatesPermutation::sample(log_domain_size);
+            let bytes = key.to_bytes();
+            assert_eq!(bytes.len(), FisherYatesPermutationKey::bytes_required());
+            assert_eq!(FisherYatesPermutationKey::from_bytes(&bytes).unwrap(), key);
+        }
+    }
 }