소스 검색

p-ot: log_domain_size -> domain_size, and usize for indices

Lennart Braun 2 년 전
부모
커밋
29bb5c3479
1개의 변경된 파일31개의 추가작업 그리고 31개의 파일을 삭제
  1. 31 31
      oram/src/p_ot.rs

+ 31 - 31
oram/src/p_ot.rs

@@ -5,7 +5,7 @@ use utils::permutation::Permutation;
 
 pub struct POTKeyParty<F: FromPrf, Perm> {
     /// log of the database size
-    log_domain_size: u32,
+    domain_size: usize,
     /// if init was run
     is_initialized: bool,
     /// PRF key of the Index Party
@@ -18,9 +18,9 @@ pub struct POTKeyParty<F: FromPrf, Perm> {
 }
 
 impl<F: Field + FromPrf, Perm: Permutation> POTKeyParty<F, Perm> {
-    pub fn new(log_domain_size: u32) -> Self {
+    pub fn new(domain_size: usize) -> Self {
         Self {
-            log_domain_size,
+            domain_size,
             is_initialized: false,
             prf_key_i: None,
             prf_key_r: None,
@@ -34,14 +34,14 @@ impl<F: Field + FromPrf, Perm: Permutation> POTKeyParty<F, Perm> {
     }
 
     pub fn reset(&mut self) {
-        *self = Self::new(self.log_domain_size);
+        *self = Self::new(self.domain_size);
     }
 
     pub fn init(&mut self) -> ((F::PrfKey, Perm::Key), F::PrfKey) {
         assert!(!self.is_initialized);
         self.prf_key_i = Some(F::prf_key_gen());
         self.prf_key_r = Some(F::prf_key_gen());
-        let permutation_key = Perm::sample(self.log_domain_size);
+        let permutation_key = Perm::sample(self.domain_size);
         self.permutation = Some(Perm::from_key(permutation_key));
         self.is_initialized = true;
         (
@@ -52,11 +52,11 @@ impl<F: Field + FromPrf, Perm: Permutation> POTKeyParty<F, Perm> {
 
     pub fn expand(&self) -> Vec<F> {
         assert!(self.is_initialized);
-        let n = 1 << self.log_domain_size;
-        (0..n)
+        (0..self.domain_size)
             .map(|x| {
-                let pi_x = self.permutation.as_ref().unwrap().permute(x) as u64;
-                F::prf(&self.prf_key_i.unwrap(), pi_x) + F::prf(&self.prf_key_r.unwrap(), pi_x)
+                let pi_x = self.permutation.as_ref().unwrap().permute(x);
+                F::prf(&self.prf_key_i.unwrap(), pi_x as u64)
+                    + F::prf(&self.prf_key_r.unwrap(), pi_x as u64)
             })
             .collect()
     }
@@ -64,7 +64,7 @@ impl<F: Field + FromPrf, Perm: Permutation> POTKeyParty<F, Perm> {
 
 pub struct POTIndexParty<F: FromPrf, Perm> {
     /// log of the database size
-    log_domain_size: u32,
+    domain_size: usize,
     /// if init was run
     is_initialized: bool,
     /// PRF key of the Index Party
@@ -75,9 +75,9 @@ pub struct POTIndexParty<F: FromPrf, Perm> {
 }
 
 impl<F: Field + FromPrf, Perm: Permutation> POTIndexParty<F, Perm> {
-    pub fn new(log_domain_size: u32) -> Self {
+    pub fn new(domain_size: usize) -> Self {
         Self {
-            log_domain_size,
+            domain_size,
             is_initialized: false,
             prf_key_i: None,
             permutation: None,
@@ -90,7 +90,7 @@ impl<F: Field + FromPrf, Perm: Permutation> POTIndexParty<F, Perm> {
     }
 
     pub fn reset(&mut self) {
-        *self = Self::new(self.log_domain_size);
+        *self = Self::new(self.domain_size);
     }
 
     pub fn init(&mut self, prf_key_i: F::PrfKey, permutation_key: Perm::Key) {
@@ -100,16 +100,16 @@ impl<F: Field + FromPrf, Perm: Permutation> POTIndexParty<F, Perm> {
         self.is_initialized = true;
     }
 
-    pub fn access(&self, index: u64) -> (u64, F) {
-        assert!(index < (1 << self.log_domain_size));
-        let pi_x = self.permutation.as_ref().unwrap().permute(index as usize) as u64;
-        (pi_x, F::prf(&self.prf_key_i.unwrap(), pi_x))
+    pub fn access(&self, index: usize) -> (usize, F) {
+        assert!(index < self.domain_size);
+        let pi_x = self.permutation.as_ref().unwrap().permute(index);
+        (pi_x, F::prf(&self.prf_key_i.unwrap(), pi_x as u64))
     }
 }
 
 pub struct POTReceiverParty<F: FromPrf> {
     /// log of the database size
-    log_domain_size: u32,
+    domain_size: usize,
     /// if init was run
     is_initialized: bool,
     /// PRF key of the Receiver Party
@@ -118,9 +118,9 @@ pub struct POTReceiverParty<F: FromPrf> {
 }
 
 impl<F: Field + FromPrf> POTReceiverParty<F> {
-    pub fn new(log_domain_size: u32) -> Self {
+    pub fn new(domain_size: usize) -> Self {
         Self {
-            log_domain_size,
+            domain_size,
             is_initialized: false,
             prf_key_r: None,
             _phantom: PhantomData,
@@ -132,7 +132,7 @@ impl<F: Field + FromPrf> POTReceiverParty<F> {
     }
 
     pub fn reset(&mut self) {
-        *self = Self::new(self.log_domain_size);
+        *self = Self::new(self.domain_size);
     }
 
     pub fn init(&mut self, prf_key_r: F::PrfKey) {
@@ -141,9 +141,9 @@ impl<F: Field + FromPrf> POTReceiverParty<F> {
         self.is_initialized = true;
     }
 
-    pub fn access(&self, permuted_index: u64, output_share: F) -> F {
-        assert!(permuted_index < (1 << self.log_domain_size));
-        F::prf(&self.prf_key_r.unwrap(), permuted_index) + output_share
+    pub fn access(&self, permuted_index: usize, output_share: F) -> F {
+        assert!(permuted_index < self.domain_size);
+        F::prf(&self.prf_key_r.unwrap(), permuted_index as u64) + output_share
     }
 }
 
@@ -158,12 +158,12 @@ mod tests {
         F: Field + FromPrf,
         Perm: Permutation,
     {
-        let n = 1 << log_domain_size;
+        let domain_size = 1 << log_domain_size;
 
         // creation
-        let mut key_party = POTKeyParty::<F, Perm>::new(log_domain_size);
-        let mut index_party = POTIndexParty::<F, Perm>::new(log_domain_size);
-        let mut receiver_party = POTReceiverParty::<F>::new(log_domain_size);
+        let mut key_party = POTKeyParty::<F, Perm>::new(domain_size);
+        let mut index_party = POTIndexParty::<F, Perm>::new(domain_size);
+        let mut receiver_party = POTReceiverParty::<F>::new(domain_size);
         assert!(!key_party.is_initialized());
         assert!(!index_party.is_initialized());
         assert!(!receiver_party.is_initialized());
@@ -178,13 +178,13 @@ mod tests {
 
         // expand to the key party's output
         let output_k = key_party.expand();
-        assert_eq!(output_k.len(), n);
+        assert_eq!(output_k.len(), domain_size);
 
         // access each index and verify consistency with key party's output
-        for i in 0..(n as u64) {
+        for i in 0..domain_size {
             let msg_to_receiver_party = index_party.access(i);
             let output = receiver_party.access(msg_to_receiver_party.0, msg_to_receiver_party.1);
-            assert_eq!(output, output_k[i as usize]);
+            assert_eq!(output, output_k[i]);
         }
     }