lib.rs 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. // We really want points to be capital letters and scalars to be
  2. // lowercase letters
  3. #![allow(non_snake_case)]
  4. mod aligned_memory_mt;
  5. pub mod client;
  6. mod ot;
  7. mod params;
  8. pub mod server;
  9. mod spiral_mt;
  10. use aes::cipher::{BlockEncrypt, KeyInit};
  11. use aes::Aes128Enc;
  12. use aes::Block;
  13. use std::env;
  14. use std::mem;
  15. use std::time::Instant;
  16. use rand::RngCore;
  17. use std::os::raw::c_uchar;
  18. use rayon::scope;
  19. use rayon::ThreadPoolBuilder;
  20. use spiral_rs::client::*;
  21. use spiral_rs::params::*;
  22. use spiral_rs::server::*;
  23. use crate::spiral_mt::*;
  24. use crate::ot::{otkey_init, xor16};
  25. pub type DbEntry = u64;
  26. // Encrypt a database of 2^r elements, where each element is a DbEntry,
  27. // using the 2*r provided keys (r pairs of keys). Also rotate the
  28. // database by rot positions, and add the provided blinding factor to
  29. // each element before encryption (the same blinding factor for all
  30. // elements). Each element is encrypted in AES counter mode, with the
  31. // counter being the element number and the key computed as the XOR of r
  32. // of the provided keys, one from each pair, according to the bits of
  33. // the element number. Outputs a byte vector containing the encrypted
  34. // database.
  35. pub fn encdb_xor_keys(
  36. db: &[DbEntry],
  37. keys: &[[u8; 16]],
  38. r: usize,
  39. rot: usize,
  40. blind: DbEntry,
  41. num_threads: usize,
  42. ) -> Vec<u8> {
  43. let num_records: usize = 1 << r;
  44. let num_record_mask: usize = num_records - 1;
  45. let negrot = num_records - rot;
  46. let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
  47. ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
  48. scope(|s| {
  49. let mut record_thread_start = 0usize;
  50. let records_per_thread_base = num_records / num_threads;
  51. let records_per_thread_extra = num_records % num_threads;
  52. let mut retslice = ret.as_mut_slice();
  53. for thr in 0..num_threads {
  54. let records_this_thread =
  55. records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
  56. let record_thread_end = record_thread_start + records_this_thread;
  57. let (thread_ret, retslice_) =
  58. retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
  59. retslice = retslice_;
  60. s.spawn(move |_| {
  61. let mut offset = 0usize;
  62. for j in record_thread_start..record_thread_end {
  63. let rec = (j + negrot) & num_record_mask;
  64. let mut key = Block::from([0u8; 16]);
  65. for i in 0..r {
  66. let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
  67. xor16(&mut key, &keys[2 * i + bit]);
  68. }
  69. let aes = Aes128Enc::new(&key);
  70. let mut block = Block::from([0u8; 16]);
  71. block[0..8].copy_from_slice(&j.to_le_bytes());
  72. aes.encrypt_block(&mut block);
  73. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  74. let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream;
  75. thread_ret[offset..offset + mem::size_of::<DbEntry>()]
  76. .copy_from_slice(&encelem.to_le_bytes());
  77. offset += mem::size_of::<DbEntry>();
  78. }
  79. });
  80. record_thread_start = record_thread_end;
  81. }
  82. });
  83. ret
  84. }
  85. // Generate the keys for encrypting the database
  86. pub fn gen_db_enc_keys(r: usize) -> Vec<[u8; 16]> {
  87. let mut keys: Vec<[u8; 16]> = Vec::new();
  88. let mut rng = rand::thread_rng();
  89. for _ in 0..2 * r {
  90. let mut k: [u8; 16] = [0; 16];
  91. rng.fill_bytes(&mut k);
  92. keys.push(k);
  93. }
  94. keys
  95. }
  96. // Having received the key for element q with r parallel 1-out-of-2 OTs,
  97. // and having received the encrypted element with (non-symmetric) PIR,
  98. // use the key to decrypt the element.
  99. pub fn otkey_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
  100. let aes = Aes128Enc::new(key);
  101. let mut block = Block::from([0u8; 16]);
  102. block[0..8].copy_from_slice(&q.to_le_bytes());
  103. aes.encrypt_block(&mut block);
  104. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  105. encelement ^ aeskeystream
  106. }
  107. // Things that are only done once total, not once for each SPIR
  108. pub fn init(num_threads: usize) {
  109. otkey_init();
  110. // Initialize the thread pool
  111. ThreadPoolBuilder::new()
  112. .num_threads(num_threads)
  113. .build_global()
  114. .unwrap();
  115. }
  116. pub fn print_params_summary(params: &Params) {
  117. let db_elem_size = params.item_size();
  118. let total_size = params.num_items() * db_elem_size;
  119. println!(
  120. "Using a {} x {} byte database ({} bytes total)",
  121. params.num_items(),
  122. db_elem_size,
  123. total_size
  124. );
  125. }
  126. #[no_mangle]
  127. pub extern "C" fn spir_init(num_threads: u32) {
  128. init(num_threads as usize);
  129. }
  130. #[repr(C)]
  131. pub struct VecData {
  132. data: *const c_uchar,
  133. len: usize,
  134. cap: usize,
  135. }
  136. #[repr(C)]
  137. pub struct VecMutData {
  138. data: *mut c_uchar,
  139. len: usize,
  140. cap: usize,
  141. }
  142. #[no_mangle]
  143. pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
  144. unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };
  145. }