Browse Source

Move the oblivious transfer code into its own module

Ian Goldberg 1 year ago
parent
commit
2bd0149ea8
4 changed files with 135 additions and 128 deletions
  1. 0 2
      src/client.rs
  2. 3 124
      src/lib.rs
  3. 132 0
      src/ot.rs
  4. 0 2
      src/server.rs

+ 0 - 2
src/client.rs

@@ -41,13 +41,11 @@ impl Client {
 
             // Wait for commands
             loop {
-                println!("Client waiting");
                 match incoming_cmd_recv.recv() {
                     Err(_) => break,
                     _ => panic!("Received something unexpected"),
                 }
             }
-            println!("Client ending");
         });
         let pub_params = match outgoing_resp.recv() {
             Ok(Response::PubParams(x)) => x,

+ 3 - 124
src/lib.rs

@@ -4,6 +4,7 @@
 
 mod aligned_memory_mt;
 pub mod client;
+mod ot;
 mod params;
 pub mod server;
 mod spiral_mt;
@@ -14,23 +15,11 @@ use aes::Block;
 use std::env;
 use std::mem;
 use std::time::Instant;
-use subtle::Choice;
-use subtle::ConditionallySelectable;
 
 use rand::RngCore;
 
-use sha2::Digest;
-use sha2::Sha256;
-use sha2::Sha512;
-
 use std::os::raw::c_uchar;
 
-use curve25519_dalek::constants as dalek_constants;
-use curve25519_dalek::ristretto::CompressedRistretto;
-use curve25519_dalek::ristretto::RistrettoBasepointTable;
-use curve25519_dalek::ristretto::RistrettoPoint;
-use curve25519_dalek::scalar::Scalar;
-
 use rayon::scope;
 use rayon::ThreadPoolBuilder;
 
@@ -39,29 +28,10 @@ use spiral_rs::params::*;
 use spiral_rs::server::*;
 
 use crate::spiral_mt::*;
-
-use lazy_static::lazy_static;
+use crate::ot::{otkey_init, xor16};
 
 pub type DbEntry = u64;
 
-// Generators of the Ristretto group (the standard B and another one C,
-// for which the DL relationship is unknown), and their precomputed
-// multiplication tables.  Used for the Oblivious Transfer protocol
-lazy_static! {
-    pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
-    pub static ref OT_C: RistrettoPoint =
-        RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
-    pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
-    pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
-}
-
-// XOR a 16-byte slice into a Block (which will be used as an AES key)
-fn xor16(outar: &mut Block, inar: &[u8; 16]) {
-    for i in 0..16 {
-        outar[i] ^= inar[i];
-    }
-}
-
 // Encrypt a database of 2^r elements, where each element is a DbEntry,
 // using the 2*r provided keys (r pairs of keys).  Also rotate the
 // database by rot positions, and add the provided blinding factor to
@@ -135,93 +105,6 @@ pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
     keys
 }
 
-// 1-out-of-2 Oblivious Transfer (OT)
-
-fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
-    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let C: &RistrettoPoint = &OT_C;
-    let mut rng = rand07::thread_rng();
-    let x = Scalar::random(&mut rng);
-    let xB = &x * Btable;
-    let CmxB = C - xB;
-    let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
-    ((sel, x), P.compress().to_bytes())
-}
-
-fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
-    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
-    let mut rng = rand07::thread_rng();
-    let y = Scalar::random(&mut rng);
-    let yB = &y * Btable;
-    let yC = &y * Ctable;
-    let P = CompressedRistretto::from_slice(query).decompress().unwrap();
-    let yP0 = y * P;
-    let yP1 = yC - yP0;
-    let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
-    for i in 0..16 {
-        HyP0[i] ^= m0[i];
-    }
-    let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
-    for i in 0..16 {
-        HyP1[i] ^= m1[i];
-    }
-    let mut ret = [0u8; 64];
-    ret[0..32].copy_from_slice(yB.compress().as_bytes());
-    ret[32..48].copy_from_slice(&HyP0[0..16]);
-    ret[48..64].copy_from_slice(&HyP1[0..16]);
-    ret
-}
-
-fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
-    let yB = CompressedRistretto::from_slice(&response[0..32])
-        .decompress()
-        .unwrap();
-    let yP = state.1 * yB;
-    let mut HyP = Sha256::digest(yP.compress().as_bytes());
-    for i in 0..16 {
-        HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
-    }
-    HyP[0..16].try_into().unwrap()
-}
-
-// Obliviously fetch the key for element q of the database (which has
-// 2^r elements total).  Each bit of q is used in a 1-out-of-2 OT to get
-// one of the keys in each of the r pairs of keys on the server side.
-// The resulting r keys are XORed together.
-
-pub fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
-    let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
-    let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
-    for i in 0..r {
-        let bit = ((q >> i) & 1) as u8;
-        let (si, qi) = ot12_request(bit.into());
-        state.push(si);
-        query.push(qi);
-    }
-    (state, query)
-}
-
-pub fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
-    let r = query.len();
-    assert!(keys.len() == 2 * r);
-    let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
-    for i in 0..r {
-        response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
-    }
-    response
-}
-
-pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
-    let r = state.len();
-    assert!(response.len() == r);
-    let mut key = Block::from([0u8; 16]);
-    for i in 0..r {
-        xor16(&mut key, &ot12_receive(state[i], &response[i]));
-    }
-    key
-}
-
 // Having received the key for element q with r parallel 1-out-of-2 OTs,
 // and having received the encrypted element with (non-symmetric) PIR,
 // use the key to decrypt the element.
@@ -236,11 +119,7 @@ pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
 
 // Things that are only done once total, not once for each SPIR
 pub fn init(num_threads: usize) {
-    // Resolve the lazy statics
-    let _B: &RistrettoPoint = &OT_B;
-    let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
-    let _C: &RistrettoPoint = &OT_C;
-    let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+    otkey_init();
 
     // Initialize the thread pool
     ThreadPoolBuilder::new()

+ 132 - 0
src/ot.rs

@@ -0,0 +1,132 @@
+// Oblivious transfer
+
+use subtle::Choice;
+use subtle::ConditionallySelectable;
+
+use aes::Block;
+
+use sha2::Digest;
+use sha2::Sha256;
+use sha2::Sha512;
+
+use curve25519_dalek::constants as dalek_constants;
+use curve25519_dalek::ristretto::CompressedRistretto;
+use curve25519_dalek::ristretto::RistrettoBasepointTable;
+use curve25519_dalek::ristretto::RistrettoPoint;
+use curve25519_dalek::scalar::Scalar;
+
+use lazy_static::lazy_static;
+
+// Generators of the Ristretto group (the standard B and another one C,
+// for which the DL relationship is unknown), and their precomputed
+// multiplication tables.  Used for the Oblivious Transfer protocol
+lazy_static! {
+    pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
+    pub static ref OT_C: RistrettoPoint =
+        RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
+    pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
+    pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
+}
+
+// 1-out-of-2 Oblivious Transfer (OT)
+
+fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
+    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let C: &RistrettoPoint = &OT_C;
+    let mut rng = rand07::thread_rng();
+    let x = Scalar::random(&mut rng);
+    let xB = &x * Btable;
+    let CmxB = C - xB;
+    let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
+    ((sel, x), P.compress().to_bytes())
+}
+
+fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
+    let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+    let mut rng = rand07::thread_rng();
+    let y = Scalar::random(&mut rng);
+    let yB = &y * Btable;
+    let yC = &y * Ctable;
+    let P = CompressedRistretto::from_slice(query).decompress().unwrap();
+    let yP0 = y * P;
+    let yP1 = yC - yP0;
+    let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
+    for i in 0..16 {
+        HyP0[i] ^= m0[i];
+    }
+    let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
+    for i in 0..16 {
+        HyP1[i] ^= m1[i];
+    }
+    let mut ret = [0u8; 64];
+    ret[0..32].copy_from_slice(yB.compress().as_bytes());
+    ret[32..48].copy_from_slice(&HyP0[0..16]);
+    ret[48..64].copy_from_slice(&HyP1[0..16]);
+    ret
+}
+
+fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
+    let yB = CompressedRistretto::from_slice(&response[0..32])
+        .decompress()
+        .unwrap();
+    let yP = state.1 * yB;
+    let mut HyP = Sha256::digest(yP.compress().as_bytes());
+    for i in 0..16 {
+        HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
+    }
+    HyP[0..16].try_into().unwrap()
+}
+
+// Obliviously fetch the key for element q of the database (which has
+// 2^r elements total).  Each bit of q is used in a 1-out-of-2 OT to get
+// one of the keys in each of the r pairs of keys on the server side.
+// The resulting r keys are XORed together.
+
+pub fn otkey_init() {
+    // Resolve the lazy statics
+    let _B: &RistrettoPoint = &OT_B;
+    let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
+    let _C: &RistrettoPoint = &OT_C;
+    let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
+}
+
+pub fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
+    let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
+    let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
+    for i in 0..r {
+        let bit = ((q >> i) & 1) as u8;
+        let (si, qi) = ot12_request(bit.into());
+        state.push(si);
+        query.push(qi);
+    }
+    (state, query)
+}
+
+pub fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
+    let r = query.len();
+    assert!(keys.len() == 2 * r);
+    let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
+    for i in 0..r {
+        response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
+    }
+    response
+}
+
+// XOR a 16-byte slice into a Block (which will be used as an AES key)
+pub fn xor16(outar: &mut Block, inar: &[u8; 16]) {
+    for i in 0..16 {
+        outar[i] ^= inar[i];
+    }
+}
+
+pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
+    let r = state.len();
+    assert!(response.len() == r);
+    let mut key = Block::from([0u8; 16]);
+    for i in 0..r {
+        xor16(&mut key, &ot12_receive(state[i], &response[i]));
+    }
+    key
+}
+

+ 0 - 2
src/server.rs

@@ -31,13 +31,11 @@ impl Server {
             let spiral_params = params::get_spiral_params(r);
             let pub_params = PublicParameters::deserialize(&spiral_params, &pub_params);
             loop {
-                println!("Waiting");
                 match incoming_cmd_recv.recv() {
                     Err(_) => break,
                     _ => panic!("Received something unexpected"),
                 }
             }
-            println!("Ending");
         });
         Server {
             r,