Browse Source

stash: minor fixes and cleanup

Lennart Braun 2 years ago
parent
commit
fe92c054fc
1 changed files with 3 additions and 41 deletions
  1. 3 41
      oram/src/stash.rs

+ 3 - 41
oram/src/stash.rs

@@ -5,17 +5,13 @@ use crate::doprf::{
 };
 use crate::mask_index::{MaskIndex, MaskIndexProtocol};
 use crate::select::{Select, SelectProtocol};
-use bitvec;
 use communicator::{AbstractCommunicator, Fut, Serializable};
 use dpf::spdpf::SinglePointDpf;
 use ff::PrimeField;
 use rand::thread_rng;
-use std::io::Read;
 use std::marker::PhantomData;
 use utils::field::LegendreSymbol;
 
-type BitVec = bitvec::vec::BitVec<u8>;
-
 #[derive(Clone, Copy, Debug, Default)]
 pub struct StashEntryShare<F: PrimeField> {
     pub address: F,
@@ -136,7 +132,7 @@ where
 {
     // a) mask and reconstruct the stash index <loc>
     let index_bits = {
-        let bits = (access_counter as f64).log2().ceil() as u32;
+        let bits = usize::BITS - access_counter.leading_zeros();
         if bits > 0 {
             bits
         } else {
@@ -148,14 +144,6 @@ where
     let (masked_loc, r_prev, r_next) =
         MaskIndexProtocol::mask_index(comm, index_bits, location_share)?;
 
-    eprintln!(
-        "Party {}: masked index = {}, r_prev = {}, r_next = {}    ({} bits)",
-        comm.get_my_id(),
-        masked_loc,
-        r_prev,
-        r_next,
-        index_bits
-    );
     // b) use DPFs to read the stash value
     let fut_prev = comm.receive_previous::<SPDPF::Key>()?;
     let fut_next = comm.receive_next::<SPDPF::Key>()?;
@@ -168,8 +156,8 @@ where
     let dpf_key_prev = fut_prev.get()?;
     let dpf_key_next = fut_next.get()?;
     for j in 0..=access_counter {
-        let index_prev = ((j as u16 + r_prev) & bit_mask) as u64;
-        let index_next = ((j as u16 + r_next) & bit_mask) as u64;
+        let index_prev = ((j as u16).wrapping_add(r_prev) & bit_mask) as u64;
+        let index_next = ((j as u16).wrapping_add(r_next) & bit_mask) as u64;
         stash_values_share_mine[j] += SPDPF::evaluate_at(&dpf_key_prev, index_prev);
         stash_values_share_mine[j] += SPDPF::evaluate_at(&dpf_key_next, index_next);
     }
@@ -632,13 +620,6 @@ mod tests {
             Fp::ZERO
         );
 
-        eprintln!(
-            "stash_state = {{ flag = {:?}, location = {:?}, value = {:?} }}",
-            state_1.flag + state_2.flag + state_3.flag,
-            state_1.location + state_2.location + state_3.location,
-            state_1.value + state_2.value + state_3.value
-        );
-
         let h1 = run_write(
             party_1,
             comm_1,
@@ -677,9 +658,6 @@ mod tests {
             let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
             let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
             let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_1, st_vals_1, st_old_vals_1);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_2, st_vals_2, st_old_vals_2);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_3, st_vals_3, st_old_vals_3);
             assert_eq!(st_adrs_1.len(), num_accesses);
             assert_eq!(st_vals_1.len(), num_accesses);
             assert_eq!(st_old_vals_1.len(), num_accesses);
@@ -764,9 +742,6 @@ mod tests {
             let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
             let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
             let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_1, st_vals_1, st_old_vals_1);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_2, st_vals_2, st_old_vals_2);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_3, st_vals_3, st_old_vals_3);
             assert_eq!(st_adrs_1.len(), num_accesses);
             assert_eq!(st_vals_1.len(), num_accesses);
             assert_eq!(st_old_vals_1.len(), num_accesses);
@@ -820,13 +795,6 @@ mod tests {
             Fp::from_u128(old_value)
         );
 
-        eprintln!(
-            "stash_state = {{ flag = {:?}, location = {:?}, value = {:?} }}",
-            state_1.flag + state_2.flag + state_3.flag,
-            state_1.location + state_2.location + state_3.location,
-            state_1.value + state_2.value + state_3.value
-        );
-
         let h1 = run_write(
             party_1,
             comm_1,
@@ -868,9 +836,6 @@ mod tests {
             let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
             let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
             let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_1, st_vals_1, st_old_vals_1);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_2, st_vals_2, st_old_vals_2);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_3, st_vals_3, st_old_vals_3);
             assert_eq!(st_adrs_1.len(), num_accesses);
             assert_eq!(st_vals_1.len(), num_accesses);
             assert_eq!(st_old_vals_1.len(), num_accesses);
@@ -955,9 +920,6 @@ mod tests {
             let (st_adrs_1, st_vals_1, st_old_vals_1) = party_1.get_stash_share();
             let (st_adrs_2, st_vals_2, st_old_vals_2) = party_2.get_stash_share();
             let (st_adrs_3, st_vals_3, st_old_vals_3) = party_3.get_stash_share();
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_1, st_vals_1, st_old_vals_1);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_2, st_vals_2, st_old_vals_2);
-            eprintln!("{:?}    {:?}    {:?}", st_adrs_3, st_vals_3, st_old_vals_3);
             assert_eq!(st_adrs_1.len(), num_accesses);
             assert_eq!(st_vals_1.len(), num_accesses);
             assert_eq!(st_old_vals_1.len(), num_accesses);