lib.rs 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. mod aligned_memory_mt;
  2. pub mod client;
  3. mod ot;
  4. mod params;
  5. pub mod server;
  6. mod spiral_mt;
  7. use aes::cipher::{BlockEncrypt, KeyInit};
  8. use aes::Aes128Enc;
  9. use aes::Block;
  10. use std::mem;
  11. use serde::{Deserialize, Serialize};
  12. use std::os::raw::c_uchar;
  13. use rayon::scope;
  14. use rayon::ThreadPoolBuilder;
  15. use serde_with::serde_as;
  16. use spiral_rs::params::*;
  17. use crate::ot::{otkey_init, xor16};
  18. use crate::spiral_mt::*;
  19. pub type DbEntry = u64;
  20. // Encrypt a database of 2^r elements, where each element is a DbEntry,
  21. // using the 2*r provided keys (r pairs of keys). Also rotate the
  22. // database by rot positions, and add the provided blinding factor to
  23. // each element before encryption (the same blinding factor for all
  24. // elements). Each element is encrypted in AES counter mode, with the
  25. // counter being the element number and the key computed as the XOR of r
  26. // of the provided keys, one from each pair, according to the bits of
  27. // the element number. Outputs a byte vector containing the encrypted
  28. // database.
  29. fn db_encrypt(
  30. db: &[DbEntry],
  31. keys: &[[u8; 16]],
  32. r: usize,
  33. rot: usize,
  34. blind: DbEntry,
  35. num_threads: usize,
  36. ) -> Vec<u8> {
  37. let num_records: usize = 1 << r;
  38. let num_record_mask: usize = num_records - 1;
  39. let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
  40. ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
  41. scope(|s| {
  42. let mut record_thread_start = 0usize;
  43. let records_per_thread_base = num_records / num_threads;
  44. let records_per_thread_extra = num_records % num_threads;
  45. let mut retslice = ret.as_mut_slice();
  46. for thr in 0..num_threads {
  47. let records_this_thread =
  48. records_per_thread_base + if thr < records_per_thread_extra { 1 } else { 0 };
  49. let record_thread_end = record_thread_start + records_this_thread;
  50. let (thread_ret, retslice_) =
  51. retslice.split_at_mut(records_this_thread * mem::size_of::<DbEntry>());
  52. retslice = retslice_;
  53. s.spawn(move |_| {
  54. let mut offset = 0usize;
  55. for j in record_thread_start..record_thread_end {
  56. let rec = (j + rot) & num_record_mask;
  57. let mut key = Block::from([0u8; 16]);
  58. for i in 0..r {
  59. let bit = if (j & (1 << i)) == 0 { 0 } else { 1 };
  60. xor16(&mut key, &keys[2 * i + bit]);
  61. }
  62. let aes = Aes128Enc::new(&key);
  63. let mut block = Block::from([0u8; 16]);
  64. block[0..8].copy_from_slice(&j.to_le_bytes());
  65. aes.encrypt_block(&mut block);
  66. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  67. let encelem = (db[rec].wrapping_add(blind)) ^ aeskeystream;
  68. thread_ret[offset..offset + mem::size_of::<DbEntry>()]
  69. .copy_from_slice(&encelem.to_le_bytes());
  70. offset += mem::size_of::<DbEntry>();
  71. }
  72. });
  73. record_thread_start = record_thread_end;
  74. }
  75. });
  76. ret
  77. }
  78. // Having received the key for element q with r parallel 1-out-of-2 OTs,
  79. // and having received the encrypted element with (non-symmetric) PIR,
  80. // use the key to decrypt the element.
  81. fn dbentry_decrypt(key: &Block, q: usize, encelement: DbEntry) -> DbEntry {
  82. let aes = Aes128Enc::new(key);
  83. let mut block = Block::from([0u8; 16]);
  84. block[0..8].copy_from_slice(&q.to_le_bytes());
  85. aes.encrypt_block(&mut block);
  86. let aeskeystream = DbEntry::from_le_bytes(block[0..8].try_into().unwrap());
  87. encelement ^ aeskeystream
  88. }
  89. // Things that are only done once total, not once for each SPIR
  90. pub fn init(num_threads: usize) {
  91. otkey_init();
  92. // Initialize the thread pool
  93. ThreadPoolBuilder::new()
  94. .num_threads(num_threads)
  95. .build_global()
  96. .unwrap();
  97. }
  98. pub fn print_params_summary(params: &Params) {
  99. let db_elem_size = params.item_size();
  100. let total_size = params.num_items() * db_elem_size;
  101. println!(
  102. "Using a {} x {} byte database ({} bytes total)",
  103. params.num_items(),
  104. db_elem_size,
  105. total_size
  106. );
  107. }
  108. // The message format for a single preprocess query
  109. #[derive(Serialize, Deserialize)]
  110. struct PreProcSingleMsg {
  111. ot_query: Vec<[u8; 32]>,
  112. spc_query: Vec<u8>,
  113. }
  114. // The message format for a single preprocess response
  115. #[serde_as]
  116. #[derive(Serialize, Deserialize)]
  117. struct PreProcSingleRespMsg {
  118. #[serde_as(as = "Vec<[_; 64]>")]
  119. ot_resp: Vec<[u8; 64]>,
  120. }
  121. #[no_mangle]
  122. pub extern "C" fn spir_init(num_threads: u32) {
  123. init(num_threads as usize);
  124. }
  125. #[repr(C)]
  126. pub struct VecData {
  127. data: *const c_uchar,
  128. len: usize,
  129. cap: usize,
  130. }
  131. #[repr(C)]
  132. pub struct VecMutData {
  133. data: *mut c_uchar,
  134. len: usize,
  135. cap: usize,
  136. }
  137. pub fn to_vecdata(v: Vec<u8>) -> VecData {
  138. let vecdata = VecData {
  139. data: v.as_ptr(),
  140. len: v.len(),
  141. cap: v.capacity(),
  142. };
  143. std::mem::forget(v);
  144. vecdata
  145. }
  146. #[no_mangle]
  147. pub extern "C" fn spir_vecdata_free(vecdata: VecMutData) {
  148. unsafe { Vec::from_raw_parts(vecdata.data, vecdata.len, vecdata.cap) };
  149. }