lib.rs 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. // We really want points to be capital letters and scalars to be
  2. // lowercase letters
  3. #![allow(non_snake_case)]
  4. mod aligned_memory_mt;
  5. pub mod client;
  6. mod params;
  7. pub mod server;
  8. mod spiral_mt;
  9. use aes::cipher::{BlockEncrypt, KeyInit};
  10. use aes::Aes128Enc;
  11. use aes::Block;
  12. use std::env;
  13. use std::mem;
  14. use std::time::Instant;
  15. use subtle::Choice;
  16. use subtle::ConditionallySelectable;
  17. use rand::RngCore;
  18. use sha2::Digest;
  19. use sha2::Sha256;
  20. use sha2::Sha512;
  21. use std::os::raw::c_uchar;
  22. use curve25519_dalek::constants as dalek_constants;
  23. use curve25519_dalek::ristretto::CompressedRistretto;
  24. use curve25519_dalek::ristretto::RistrettoBasepointTable;
  25. use curve25519_dalek::ristretto::RistrettoPoint;
  26. use curve25519_dalek::scalar::Scalar;
  27. use rayon::scope;
  28. use rayon::ThreadPoolBuilder;
  29. use spiral_rs::client::*;
  30. use spiral_rs::params::*;
  31. use spiral_rs::server::*;
  32. use crate::spiral_mt::*;
  33. use lazy_static::lazy_static;
  34. pub type DbEntry = u64;
  35. // Generators of the Ristretto group (the standard B and another one C,
  36. // for which the DL relationship is unknown), and their precomputed
  37. // multiplication tables. Used for the Oblivious Transfer protocol
  38. lazy_static! {
  39. pub static ref OT_B: RistrettoPoint = dalek_constants::RISTRETTO_BASEPOINT_POINT;
  40. pub static ref OT_C: RistrettoPoint =
  41. RistrettoPoint::hash_from_bytes::<Sha512>(b"OT Generator C");
  42. pub static ref OT_B_TABLE: RistrettoBasepointTable = dalek_constants::RISTRETTO_BASEPOINT_TABLE;
  43. pub static ref OT_C_TABLE: RistrettoBasepointTable = RistrettoBasepointTable::create(&OT_C);
  44. }
  45. // XOR a 16-byte slice into a Block (which will be used as an AES key)
  46. fn xor16(outar: &mut Block, inar: &[u8; 16]) {
  47. for i in 0..16 {
  48. outar[i] ^= inar[i];
  49. }
  50. }
  51. // Encrypt a database of 2^r elements, where each element is a DbEntry,
  52. // using the 2*r provided keys (r pairs of keys). Also rotate the
  53. // database by rot positions, and add the provided blinding factor to
  54. // each element before encryption (the same blinding factor for all
  55. // elements). Each element is encrypted in AES counter mode, with the
  56. // counter being the element number and the key computed as the XOR of r
  57. // of the provided keys, one from each pair, according to the bits of
  58. // the element number. Outputs a byte vector containing the encrypted
  59. // database.
  60. pub fn encdb_xor_keys(
  61. db: &[DbEntry],
  62. keys: &[[u8; 16]],
  63. r: usize,
  64. rot: usize,
  65. blind: DbEntry,
  66. num_threads: usize,
  67. ) -> Vec<u8> {
  68. let num_records: usize = 1 << r;
  69. let num_record_mask: usize = num_records - 1;
  70. let negrot = num_records - rot;
  71. let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
  72. ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
  73. scope(|s| {
  74. let mut record_thread_start = 0usize;
  75. let records_per_thread_base = num_records / num_threads;
  76. let records_per_thread_extra = num_records % num_threads;
  77. let mut retslice = ret.as_mut_slice();
  78. for thr in 0..num_threads {
  79. let records_this_thread =
  80. records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
  81. let record_thread_end = record_thread_start + records_this_thread;
  82. let (thread_ret, retslice_) =
  83. retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
  84. retslice = retslice_;
  85. s.spawn(move |_| {
  86. let mut offset = 0usize;
  87. for j in record_thread_start..record_thread_end {
  88. let rec = (j + negrot) & num_record_mask;
  89. let mut key = Block::from([0u8; 16]);
  90. for i in 0..r {
  91. let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
  92. xor16(&mut key, &keys[2 * i + bit]);
  93. }
  94. let aes = Aes128Enc::new(&key);
  95. let mut block = Block::from([0u8; 16]);
  96. block[0..8].copy_from_slice(&j.to_le_bytes());
  97. aes.encrypt_block(&mut block);
  98. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  99. let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream;
  100. thread_ret[offset..offset + mem::size_of::<DbEntry>()]
  101. .copy_from_slice(&encelem.to_le_bytes());
  102. offset += mem::size_of::<DbEntry>();
  103. }
  104. });
  105. record_thread_start = record_thread_end;
  106. }
  107. });
  108. ret
  109. }
  110. // Generate the keys for encrypting the database
  111. pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
  112. let mut keys: Vec<[u8; 16]> = Vec::new();
  113. let mut rng = rand::thread_rng();
  114. for _ in 0..2 * r {
  115. let mut k: [u8; 16] = [0; 16];
  116. rng.fill_bytes(&mut k);
  117. keys.push(k);
  118. }
  119. keys
  120. }
  121. // 1-out-of-2 Oblivious Transfer (OT)
  122. fn ot12_request(sel: Choice) -> ((Choice, Scalar), [u8; 32]) {
  123. let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  124. let C: &RistrettoPoint = &OT_C;
  125. let mut rng = rand07::thread_rng();
  126. let x = Scalar::random(&mut rng);
  127. let xB = &x * Btable;
  128. let CmxB = C - xB;
  129. let P = RistrettoPoint::conditional_select(&xB, &CmxB, sel);
  130. ((sel, x), P.compress().to_bytes())
  131. }
  132. fn ot12_serve(query: &[u8; 32], m0: &[u8; 16], m1: &[u8; 16]) -> [u8; 64] {
  133. let Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  134. let Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
  135. let mut rng = rand07::thread_rng();
  136. let y = Scalar::random(&mut rng);
  137. let yB = &y * Btable;
  138. let yC = &y * Ctable;
  139. let P = CompressedRistretto::from_slice(query).decompress().unwrap();
  140. let yP0 = y * P;
  141. let yP1 = yC - yP0;
  142. let mut HyP0 = Sha256::digest(yP0.compress().as_bytes());
  143. for i in 0..16 {
  144. HyP0[i] ^= m0[i];
  145. }
  146. let mut HyP1 = Sha256::digest(yP1.compress().as_bytes());
  147. for i in 0..16 {
  148. HyP1[i] ^= m1[i];
  149. }
  150. let mut ret = [0u8; 64];
  151. ret[0..32].copy_from_slice(yB.compress().as_bytes());
  152. ret[32..48].copy_from_slice(&HyP0[0..16]);
  153. ret[48..64].copy_from_slice(&HyP1[0..16]);
  154. ret
  155. }
  156. fn ot12_receive(state: (Choice, Scalar), response: &[u8; 64]) -> [u8; 16] {
  157. let yB = CompressedRistretto::from_slice(&response[0..32])
  158. .decompress()
  159. .unwrap();
  160. let yP = state.1 * yB;
  161. let mut HyP = Sha256::digest(yP.compress().as_bytes());
  162. for i in 0..16 {
  163. HyP[i] ^= u8::conditional_select(&response[32 + i], &response[48 + i], state.0);
  164. }
  165. HyP[0..16].try_into().unwrap()
  166. }
  167. // Obliviously fetch the key for element q of the database (which has
  168. // 2^r elements total). Each bit of q is used in a 1-out-of-2 OT to get
  169. // one of the keys in each of the r pairs of keys on the server side.
  170. // The resulting r keys are XORed together.
  171. pub fn otkey_request(q: usize, r: usize) -> (Vec<(Choice, Scalar)>, Vec<[u8; 32]>) {
  172. let mut state: Vec<(Choice, Scalar)> = Vec::with_capacity(r);
  173. let mut query: Vec<[u8; 32]> = Vec::with_capacity(r);
  174. for i in 0..r {
  175. let bit = ((q >> i) & 1) as u8;
  176. let (si, qi) = ot12_request(bit.into());
  177. state.push(si);
  178. query.push(qi);
  179. }
  180. (state, query)
  181. }
  182. pub fn otkey_serve(query: Vec<[u8; 32]>, keys: &Vec<[u8; 16]>) -> Vec<[u8; 64]> {
  183. let r = query.len();
  184. assert!(keys.len() == 2 * r);
  185. let mut response: Vec<[u8; 64]> = Vec::with_capacity(r);
  186. for i in 0..r {
  187. response.push(ot12_serve(&query[i], &keys[2 * i], &keys[2 * i + 1]));
  188. }
  189. response
  190. }
  191. pub fn otkey_receive(state: Vec<(Choice, Scalar)>, response: &Vec<[u8; 64]>) -> Block {
  192. let r = state.len();
  193. assert!(response.len() == r);
  194. let mut key = Block::from([0u8; 16]);
  195. for i in 0..r {
  196. xor16(&mut key, &ot12_receive(state[i], &response[i]));
  197. }
  198. key
  199. }
  200. // Having received the key for element q with r parallel 1-out-of-2 OTs,
  201. // and having received the encrypted element with (non-symmetric) PIR,
  202. // use the key to decrypt the element.
  203. pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
  204. let aes = Aes128Enc::new(key);
  205. let mut block = Block::from([0u8; 16]);
  206. block[0..8].copy_from_slice(&q.to_le_bytes());
  207. aes.encrypt_block(&mut block);
  208. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  209. encelement ^ aeskeystream
  210. }
  211. // Things that are only done once total, not once for each SPIR
  212. pub fn init(num_threads: usize) {
  213. // Resolve the lazy statics
  214. let _B: &RistrettoPoint = &OT_B;
  215. let _Btable: &RistrettoBasepointTable = &OT_B_TABLE;
  216. let _C: &RistrettoPoint = &OT_C;
  217. let _Ctable: &RistrettoBasepointTable = &OT_C_TABLE;
  218. // Initialize the thread pool
  219. ThreadPoolBuilder::new()
  220. .num_threads(num_threads)
  221. .build_global()
  222. .unwrap();
  223. }
  224. pub fn print_params_summary(params: &Params) {
  225. let db_elem_size = params.item_size();
  226. let total_size = params.num_items() * db_elem_size;
  227. println!(
  228. "Using a {} x {} byte database ({} bytes total)",
  229. params.num_items(),
  230. db_elem_size,
  231. total_size
  232. );
  233. }
  234. #[no_mangle]
  235. pub extern "C" fn spir_init(num_threads: u32) {
  236. init(num_threads as usize);
  237. }
  238. #[repr(C)]
  239. pub struct VecData {
  240. data: *const c_uchar,
  241. len: usize,
  242. cap: usize,
  243. }
  244. #[repr(C)]
  245. pub struct VecMutData {
  246. data: *mut c_uchar,
  247. len: usize,
  248. cap: usize,
  249. }
  250. #[no_mangle]
  251. pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
  252. unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };
  253. }