Browse Source

Multithreaded db encryption

Ian Goldberg 1 year ago
parent
commit
4126f187ef
1 changed files with 45 additions and 15 deletions
  1. 45 15
      src/main.rs

+ 45 - 15
src/main.rs

@@ -27,6 +27,8 @@ use curve25519_dalek::ristretto::RistrettoBasepointTable;
 use curve25519_dalek::ristretto::RistrettoPoint;
 use curve25519_dalek::scalar::Scalar;
 
+use crossbeam::thread;
+
 use spiral_rs::client::*;
 use spiral_rs::params::*;
 use spiral_rs::server::*;
@@ -63,23 +65,51 @@ fn xor16(outar: &mut Block, inar: &[u8; 16]) {
 // as the XOR of r of the provided keys, one from each pair, according
 // to the bits of the element number.  Outputs a byte vector containing
 // the encrypted database.
-fn encdb_xor_keys(db: &[DbEntry], keys: &[[u8; 16]], r: usize, blind: DbEntry) -> Vec<u8> {
+fn encdb_xor_keys(
+    db: &[DbEntry],
+    keys: &[[u8; 16]],
+    r: usize,
+    blind: DbEntry,
+    num_threads: usize,
+) -> Vec<u8> {
     let num_records: usize = 1 << r;
     let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
-    for j in 0..num_records {
-        let mut key = Block::from([0u8; 16]);
-        for i in 0..r {
-            let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
-            xor16(&mut key, &keys[2 * i + bit]);
+    ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
+    thread::scope(|s| {
+        let mut record_thread_start = 0usize;
+        let records_per_thread_base = num_records / num_threads;
+        let records_per_thread_extra = num_records % num_threads;
+        let mut retslice = ret.as_mut_slice();
+        for thr in 0..num_threads {
+            let records_this_thread =
+                records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
+            let record_thread_end = record_thread_start + records_this_thread;
+            let (thread_ret, retslice_) =
+                retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
+            retslice = retslice_;
+            s.spawn(move |_| {
+                let mut offset = 0usize;
+                for j in record_thread_start..record_thread_end {
+                    let mut key = Block::from([0u8; 16]);
+                    for i in 0..r {
+                        let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
+                        xor16(&mut key, &keys[2 * i + bit]);
+                    }
+                    let aes = Aes128Enc::new(&key);
+                    let mut block = Block::from([0u8; 16]);
+                    block[0..8].copy_from_slice(&j.to_le_bytes());
+                    aes.encrypt_block(&mut block);
+                    let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
+                    let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
+                    thread_ret[offset..offset + mem::size_of::<DbEntry>()]
+                        .copy_from_slice(&encelem.to_le_bytes());
+                    offset += mem::size_of::<DbEntry>();
+                }
+            });
+            record_thread_start = record_thread_end;
         }
-        let aes = Aes128Enc::new(&key);
-        let mut block = Block::from([0u8; 16]);
-        block[0..8].copy_from_slice(&j.to_le_bytes());
-        aes.encrypt_block(&mut block);
-        let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
-        let encelem = (db[j].wrapping_add(blind)) ^ aeskeystream;
-        ret.extend(encelem.to_le_bytes());
-    }
+    })
+    .unwrap();
     ret
 }
 
@@ -309,7 +339,7 @@ fn main() {
     let blind: DbEntry = 20;
     let encdb_start = Instant::now();
     db.rotate_right(idx_offset);
-    let encdb = encdb_xor_keys(&db, &dbkeys, r, blind);
+    let encdb = encdb_xor_keys(&db, &dbkeys, r, blind, num_threads);
     let encdb_us = encdb_start.elapsed().as_micros();
     println!("Server encrypt database {} µs", encdb_us);