oram.rs 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525
  1. //! Implementation of the main distributed oblivious RAM protocol.
  2. use crate::common::{Error, InstructionShare};
  3. use crate::doprf::{JointDOPrf, LegendrePrf, LegendrePrfKey};
  4. use crate::p_ot::JointPOTParties;
  5. use crate::select::{Select, SelectProtocol};
  6. use crate::stash::{
  7. ProtocolStep as StashProtocolStep, Runtimes as StashRuntimes, Stash, StashProtocol,
  8. };
  9. use communicator::{AbstractCommunicator, Fut, Serializable};
  10. use dpf::{mpdpf::MultiPointDpf, spdpf::SinglePointDpf};
  11. use ff::PrimeField;
  12. use itertools::Itertools;
  13. use rand::thread_rng;
  14. use rayon::prelude::*;
  15. use std::collections::VecDeque;
  16. use std::iter::repeat;
  17. use std::marker::PhantomData;
  18. use std::time::{Duration, Instant};
  19. use strum::IntoEnumIterator;
  20. use utils::field::{FromPrf, LegendreSymbol};
  21. use utils::permutation::FisherYatesPermutation;
  22. /// Specification of the DORAM interface.
  23. pub trait DistributedOram<F>
  24. where
  25. F: PrimeField,
  26. {
  27. /// Get the current parties ID.
  28. fn get_party_id(&self) -> usize;
  29. /// Get the database size.
  30. fn get_db_size(&self) -> usize;
  31. /// Run the initialization protocol using the given database share.
  32. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error>;
  33. /// Run the preprocessing protocol for the given `number_epochs`.
  34. fn preprocess<C: AbstractCommunicator>(
  35. &mut self,
  36. comm: &mut C,
  37. number_epochs: usize,
  38. ) -> Result<(), Error>;
  39. /// Run the access protocol for the given shared instruction.
  40. fn access<C: AbstractCommunicator>(
  41. &mut self,
  42. comm: &mut C,
  43. instruction: InstructionShare<F>,
  44. ) -> Result<F, Error>;
  45. /// Get the share of the database.
  46. ///
  47. /// If `rerandomize_shares` is true, perform extra rerandomization.
  48. fn get_db<C: AbstractCommunicator>(
  49. &mut self,
  50. comm: &mut C,
  51. rerandomize_shares: bool,
  52. ) -> Result<Vec<F>, Error>;
  53. }
  54. const PARTY_1: usize = 0;
  55. // const PARTY_2: usize = 1;
  56. // const PARTY_3: usize = 2;
  57. fn compute_oram_prf_output_bitsize(memory_size: usize) -> usize {
  58. (usize::BITS - memory_size.leading_zeros()) as usize + 40
  59. }
  60. /// Steps of the DORAM protocol.
  61. #[allow(missing_docs)]
  62. #[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter, strum_macros::Display)]
  63. pub enum ProtocolStep {
  64. Preprocess = 0,
  65. PreprocessLPRFKeyGenPrev,
  66. PreprocessLPRFEvalSortPrev,
  67. PreprocessLPRFKeyRecvNext,
  68. PreprocessLPRFEvalSortNext,
  69. PreprocessMpDpdfPrecomp,
  70. PreprocessRecvTagsMine,
  71. PreprocessStash,
  72. PreprocessDOPrf,
  73. PreprocessPOt,
  74. PreprocessSelect,
  75. Access,
  76. AccessStashRead,
  77. AccessAddressSelection,
  78. AccessDatabaseRead,
  79. AccessStashWrite,
  80. AccessValueSelection,
  81. AccessRefresh,
  82. DbReadAddressTag,
  83. DbReadGarbledIndex,
  84. DbReadPotAccess,
  85. DbWriteMpDpfKeyExchange,
  86. DbWriteMpDpfEvaluations,
  87. DbWriteUpdateMemory,
  88. RefreshJitPreprocess,
  89. RefreshResetFuncs,
  90. RefreshGetPreproc,
  91. RefreshSorting,
  92. RefreshPOtExpandMasking,
  93. RefreshReceivingShare,
  94. }
  95. /// Collection of accumulated runtimes for the protocol steps.
  96. #[derive(Debug, Default, Clone, Copy)]
  97. pub struct Runtimes {
  98. durations: [Duration; 30],
  99. stash_runtimes: StashRuntimes,
  100. }
  101. impl Runtimes {
  102. /// Add another duration to the accumulated runtimes for a protocol step.
  103. #[inline(always)]
  104. pub fn record(&mut self, id: ProtocolStep, duration: Duration) {
  105. self.durations[id as usize] += duration;
  106. }
  107. /// Get a copy of the recorded runtimes of the stash protocol,
  108. pub fn get_stash_runtimes(&self) -> StashRuntimes {
  109. self.stash_runtimes
  110. }
  111. /// Set the recorded runtimes of the stash protocol,
  112. pub fn set_stash_runtimes(&mut self, stash_runtimes: StashRuntimes) {
  113. self.stash_runtimes = stash_runtimes;
  114. }
  115. /// Get the accumulated durations for a protocol step.
  116. pub fn get(&self, id: ProtocolStep) -> Duration {
  117. self.durations[id as usize]
  118. }
  119. /// Pretty-print the recorded runtimes amortized over `num_accesses`.
  120. pub fn print(&self, party_id: usize, num_accesses: usize) {
  121. println!("==================== Party {party_id} ====================");
  122. println!("- times per access over {num_accesses} accesses in total");
  123. println!(
  124. "{:30} {:7.3} ms",
  125. ProtocolStep::Preprocess,
  126. self.get(ProtocolStep::Preprocess).as_secs_f64() * 1000.0 / num_accesses as f64
  127. );
  128. for step in ProtocolStep::iter()
  129. .filter(|x| x.to_string().starts_with("Preprocess") && *x != ProtocolStep::Preprocess)
  130. {
  131. println!(
  132. " {:26} {:7.3} ms",
  133. step,
  134. self.get(step).as_secs_f64() * 1000.0 / num_accesses as f64
  135. );
  136. }
  137. for step in ProtocolStep::iter().filter(|x| x.to_string().starts_with("Access")) {
  138. println!(
  139. "{:30} {:7.3} ms",
  140. step,
  141. self.get(step).as_secs_f64() * 1000.0 / num_accesses as f64
  142. );
  143. match step {
  144. ProtocolStep::AccessDatabaseRead => {
  145. for step in ProtocolStep::iter().filter(|x| x.to_string().starts_with("DbRead"))
  146. {
  147. println!(
  148. " {:26} {:7.3} ms",
  149. step,
  150. self.get(step).as_secs_f64() * 1000.0 / num_accesses as f64
  151. );
  152. }
  153. }
  154. ProtocolStep::AccessRefresh => {
  155. for step in ProtocolStep::iter().filter(|x| {
  156. x.to_string().starts_with("DbWrite") || x.to_string().starts_with("Refresh")
  157. }) {
  158. println!(
  159. " {:26} {:7.3} ms",
  160. step,
  161. self.get(step).as_secs_f64() * 1000.0 / num_accesses as f64
  162. );
  163. }
  164. }
  165. ProtocolStep::AccessStashRead => {
  166. for step in
  167. StashProtocolStep::iter().filter(|x| x.to_string().starts_with("Read"))
  168. {
  169. println!(
  170. " {:26} {:7.3} ms",
  171. step,
  172. self.stash_runtimes.get(step).as_secs_f64() * 1000.0
  173. / num_accesses as f64
  174. );
  175. }
  176. }
  177. ProtocolStep::AccessStashWrite => {
  178. for step in
  179. StashProtocolStep::iter().filter(|x| x.to_string().starts_with("Write"))
  180. {
  181. println!(
  182. " {:26} {:7.3} ms",
  183. step,
  184. self.stash_runtimes.get(step).as_secs_f64() * 1000.0
  185. / num_accesses as f64
  186. );
  187. }
  188. }
  189. _ => {}
  190. }
  191. }
  192. println!("==================================================");
  193. }
  194. }
  195. /// Implementation of the DORAM protocol.
  196. pub struct DistributedOramProtocol<F, MPDPF, SPDPF>
  197. where
  198. F: FromPrf + LegendreSymbol + Serializable,
  199. F::PrfKey: Serializable,
  200. MPDPF: MultiPointDpf<Value = F>,
  201. MPDPF::Key: Serializable,
  202. SPDPF: SinglePointDpf<Value = F>,
  203. SPDPF::Key: Serializable,
  204. {
  205. party_id: usize,
  206. db_size: usize,
  207. stash_size: usize,
  208. memory_size: usize,
  209. memory_share: Vec<F>,
  210. prf_output_bitsize: usize,
  211. number_preprocessed_epochs: usize,
  212. preprocessed_legendre_prf_key_next: VecDeque<LegendrePrfKey<F>>,
  213. preprocessed_legendre_prf_key_prev: VecDeque<LegendrePrfKey<F>>,
  214. preprocessed_memory_index_tags_prev: VecDeque<Vec<u128>>,
  215. preprocessed_memory_index_tags_next: VecDeque<Vec<u128>>,
  216. preprocessed_memory_index_tags_mine_sorted: VecDeque<Vec<u128>>,
  217. preprocessed_memory_index_tags_prev_sorted: VecDeque<Vec<u128>>,
  218. preprocessed_memory_index_tags_next_sorted: VecDeque<Vec<u128>>,
  219. preprocessed_stash: VecDeque<StashProtocol<F, SPDPF>>,
  220. preprocessed_select: VecDeque<SelectProtocol<F>>,
  221. preprocessed_doprf: VecDeque<JointDOPrf<F>>,
  222. preprocessed_pot: VecDeque<JointPOTParties<F, FisherYatesPermutation>>,
  223. preprocessed_pot_expands: VecDeque<Vec<F>>,
  224. memory_index_tags_prev: Vec<u128>,
  225. memory_index_tags_next: Vec<u128>,
  226. memory_index_tags_prev_sorted: Vec<u128>,
  227. memory_index_tags_next_sorted: Vec<u128>,
  228. memory_index_tags_mine_sorted: Vec<u128>,
  229. garbled_memory_share: Vec<F>,
  230. is_initialized: bool,
  231. address_tags_read: Vec<u128>,
  232. stash: Option<StashProtocol<F, SPDPF>>,
  233. select_party: Option<SelectProtocol<F>>,
  234. joint_doprf: Option<JointDOPrf<F>>,
  235. legendre_prf_key_next: Option<LegendrePrfKey<F>>,
  236. legendre_prf_key_prev: Option<LegendrePrfKey<F>>,
  237. joint_pot: Option<JointPOTParties<F, FisherYatesPermutation>>,
  238. mpdpf: MPDPF,
  239. _phantom: PhantomData<MPDPF>,
  240. }
  241. impl<F, MPDPF, SPDPF> DistributedOramProtocol<F, MPDPF, SPDPF>
  242. where
  243. F: FromPrf + LegendreSymbol + Serializable,
  244. F::PrfKey: Serializable + Sync,
  245. MPDPF: MultiPointDpf<Value = F> + Sync,
  246. MPDPF::Key: Serializable,
  247. SPDPF: SinglePointDpf<Value = F> + Sync,
  248. SPDPF::Key: Serializable + Sync,
  249. {
  250. /// Create a new instance.
  251. pub fn new(party_id: usize, db_size: usize) -> Self {
  252. assert!(party_id < 3);
  253. let stash_size = (db_size as f64).sqrt().round() as usize;
  254. let memory_size = db_size + stash_size;
  255. let prf_output_bitsize = compute_oram_prf_output_bitsize(memory_size);
  256. Self {
  257. party_id,
  258. db_size,
  259. stash_size,
  260. memory_size,
  261. memory_share: Default::default(),
  262. number_preprocessed_epochs: 0,
  263. prf_output_bitsize,
  264. preprocessed_legendre_prf_key_next: Default::default(),
  265. preprocessed_legendre_prf_key_prev: Default::default(),
  266. preprocessed_memory_index_tags_prev: Default::default(),
  267. preprocessed_memory_index_tags_next: Default::default(),
  268. preprocessed_memory_index_tags_mine_sorted: Default::default(),
  269. preprocessed_memory_index_tags_prev_sorted: Default::default(),
  270. preprocessed_memory_index_tags_next_sorted: Default::default(),
  271. preprocessed_stash: Default::default(),
  272. preprocessed_select: Default::default(),
  273. preprocessed_doprf: Default::default(),
  274. preprocessed_pot: Default::default(),
  275. preprocessed_pot_expands: Default::default(),
  276. memory_index_tags_prev: Default::default(),
  277. memory_index_tags_next: Default::default(),
  278. memory_index_tags_prev_sorted: Default::default(),
  279. memory_index_tags_next_sorted: Default::default(),
  280. memory_index_tags_mine_sorted: Default::default(),
  281. garbled_memory_share: Default::default(),
  282. is_initialized: false,
  283. address_tags_read: Default::default(),
  284. stash: None,
  285. select_party: None,
  286. joint_doprf: None,
  287. legendre_prf_key_next: None,
  288. legendre_prf_key_prev: None,
  289. joint_pot: None,
  290. mpdpf: MPDPF::new(memory_size, stash_size),
  291. _phantom: PhantomData,
  292. }
  293. }
  294. /// Get the current access counter.
  295. pub fn get_access_counter(&self) -> usize {
  296. self.stash.as_ref().unwrap().get_access_counter()
  297. }
  298. /// Get a reference to the stash protocol instance.
  299. pub fn get_stash(&self) -> &StashProtocol<F, SPDPF> {
  300. self.stash.as_ref().unwrap()
  301. }
  302. /// Return the size of the stash.
  303. pub fn get_stash_size(&self) -> usize {
  304. self.stash_size
  305. }
  306. fn pos_prev(&self, tag: u128) -> usize {
  307. debug_assert_eq!(self.memory_index_tags_prev_sorted.len(), self.memory_size);
  308. self.memory_index_tags_prev_sorted
  309. .binary_search(&tag)
  310. .expect("tag not found")
  311. }
  312. fn pos_next(&self, tag: u128) -> usize {
  313. debug_assert_eq!(self.memory_index_tags_next_sorted.len(), self.memory_size);
  314. self.memory_index_tags_next_sorted
  315. .binary_search(&tag)
  316. .expect("tag not found")
  317. }
  318. fn pos_mine(&self, tag: u128) -> usize {
  319. debug_assert_eq!(self.memory_index_tags_mine_sorted.len(), self.memory_size);
  320. self.memory_index_tags_mine_sorted
  321. .binary_search(&tag)
  322. .expect("tag not found")
  323. }
  324. fn read_from_database<C: AbstractCommunicator>(
  325. &mut self,
  326. comm: &mut C,
  327. address_share: F,
  328. runtimes: Option<Runtimes>,
  329. ) -> Result<(F, Option<Runtimes>), Error> {
  330. let mut value_share = F::ZERO;
  331. let t_start = Instant::now();
  332. // 1. Compute address tag
  333. let address_tag: u128 = self
  334. .joint_doprf
  335. .as_mut()
  336. .unwrap()
  337. .eval_to_uint(comm, &[address_share])?[0];
  338. // 2. Update tags read list
  339. self.address_tags_read.push(address_tag);
  340. let t_after_address_tag = Instant::now();
  341. // 3. Compute index in garbled memory and retrieve share
  342. let garbled_index = self.pos_mine(address_tag);
  343. value_share += self.garbled_memory_share[garbled_index];
  344. let t_after_index_computation = Instant::now();
  345. // 4. Run p-OT.Access
  346. value_share -= self
  347. .joint_pot
  348. .as_ref()
  349. .unwrap()
  350. .access(comm, garbled_index)?;
  351. let t_after_pot_access = Instant::now();
  352. let runtimes = runtimes.map(|mut r| {
  353. r.record(
  354. ProtocolStep::DbReadAddressTag,
  355. t_after_address_tag - t_start,
  356. );
  357. r.record(
  358. ProtocolStep::DbReadAddressTag,
  359. t_after_index_computation - t_after_address_tag,
  360. );
  361. r.record(
  362. ProtocolStep::DbReadAddressTag,
  363. t_after_pot_access - t_after_index_computation,
  364. );
  365. r
  366. });
  367. Ok((value_share, runtimes))
  368. }
  369. fn update_database_from_stash<C: AbstractCommunicator>(
  370. &mut self,
  371. comm: &mut C,
  372. runtimes: Option<Runtimes>,
  373. ) -> Result<Option<Runtimes>, Error> {
  374. let t_start = Instant::now();
  375. let fut_dpf_key_from_prev = comm.receive_previous()?;
  376. let fut_dpf_key_from_next = comm.receive_next()?;
  377. let (_, stash_values_share, stash_old_values_share) =
  378. self.stash.as_ref().unwrap().get_stash_share();
  379. assert_eq!(stash_values_share.len(), self.get_access_counter());
  380. assert_eq!(stash_old_values_share.len(), self.get_access_counter());
  381. assert_eq!(self.address_tags_read.len(), self.get_access_counter());
  382. let mut points: Vec<_> = self
  383. .address_tags_read
  384. .par_iter()
  385. .copied()
  386. .map(|tag| self.pos_mine(tag) as u64)
  387. .collect();
  388. let values: Vec<_> = stash_values_share
  389. .par_iter()
  390. .copied()
  391. .zip(stash_old_values_share.par_iter().copied())
  392. .map(|(val, old_val)| val - old_val)
  393. .collect();
  394. self.address_tags_read.truncate(0);
  395. // sort point, value pairs
  396. let (points, values): (Vec<u64>, Vec<F>) = {
  397. let mut indices: Vec<usize> = (0..points.len()).collect();
  398. indices.par_sort_unstable_by_key(|&i| points[i]);
  399. points.par_sort();
  400. let new_values = indices.par_iter().map(|&i| values[i]).collect();
  401. (points, new_values)
  402. };
  403. let (dpf_key_prev, dpf_key_next) = self.mpdpf.generate_keys(&points, &values);
  404. comm.send_previous(dpf_key_prev)?;
  405. comm.send_next(dpf_key_next)?;
  406. let dpf_key_from_prev = fut_dpf_key_from_prev.get()?;
  407. let dpf_key_from_next = fut_dpf_key_from_next.get()?;
  408. let t_after_mpdpf_key_exchange = Instant::now();
  409. let new_memory_share_from_prev = self.mpdpf.evaluate_domain(&dpf_key_from_prev);
  410. let new_memory_share_from_next = self.mpdpf.evaluate_domain(&dpf_key_from_next);
  411. let t_after_mpdpf_evaluations = Instant::now();
  412. {
  413. let mut memory_share = Vec::new();
  414. std::mem::swap(&mut self.memory_share, &mut memory_share);
  415. memory_share
  416. .par_iter_mut()
  417. .enumerate()
  418. .for_each(|(j, mem_cell)| {
  419. *mem_cell += new_memory_share_from_prev
  420. [self.pos_prev(self.memory_index_tags_prev[j])]
  421. + new_memory_share_from_next[self.pos_next(self.memory_index_tags_next[j])];
  422. });
  423. std::mem::swap(&mut self.memory_share, &mut memory_share);
  424. }
  425. let t_after_memory_update = Instant::now();
  426. let runtimes = runtimes.map(|mut r| {
  427. r.record(
  428. ProtocolStep::DbWriteMpDpfKeyExchange,
  429. t_after_mpdpf_key_exchange - t_start,
  430. );
  431. r.record(
  432. ProtocolStep::DbWriteMpDpfEvaluations,
  433. t_after_mpdpf_evaluations - t_after_mpdpf_key_exchange,
  434. );
  435. r.record(
  436. ProtocolStep::DbWriteUpdateMemory,
  437. t_after_memory_update - t_after_mpdpf_evaluations,
  438. );
  439. r
  440. });
  441. Ok(runtimes)
  442. }
  443. /// Run the preprocessing protocol and collect runtime data.
  444. pub fn preprocess_with_runtimes<C: AbstractCommunicator>(
  445. &mut self,
  446. comm: &mut C,
  447. number_epochs: usize,
  448. runtimes: Option<Runtimes>,
  449. ) -> Result<Option<Runtimes>, Error> {
  450. let already_preprocessed = self.number_preprocessed_epochs;
  451. // Reserve some space
  452. self.preprocessed_legendre_prf_key_prev
  453. .reserve(number_epochs);
  454. self.preprocessed_legendre_prf_key_next
  455. .reserve(number_epochs);
  456. self.preprocessed_memory_index_tags_prev
  457. .reserve(number_epochs);
  458. self.preprocessed_memory_index_tags_next
  459. .reserve(number_epochs);
  460. self.preprocessed_memory_index_tags_prev_sorted
  461. .reserve(number_epochs);
  462. self.preprocessed_memory_index_tags_next_sorted
  463. .reserve(number_epochs);
  464. self.preprocessed_memory_index_tags_mine_sorted
  465. .reserve(number_epochs);
  466. self.preprocessed_stash.reserve(number_epochs);
  467. self.preprocessed_select.reserve(number_epochs);
  468. self.preprocessed_doprf.reserve(number_epochs);
  469. self.preprocessed_pot.reserve(number_epochs);
  470. self.preprocessed_pot_expands.reserve(number_epochs);
  471. let t_start = Instant::now();
  472. // Generate Legendre PRF keys
  473. let fut_lpks_next = comm.receive_previous::<Vec<LegendrePrfKey<F>>>()?;
  474. let fut_tags_mine_sorted = comm.receive_previous::<Vec<Vec<u128>>>()?;
  475. self.preprocessed_legendre_prf_key_prev
  476. .extend((0..number_epochs).map(|_| LegendrePrf::key_gen(self.prf_output_bitsize)));
  477. let new_lpks_prev =
  478. &self.preprocessed_legendre_prf_key_prev.make_contiguous()[already_preprocessed..];
  479. comm.send_slice_next(new_lpks_prev.as_ref())?;
  480. let t_after_gen_lpks_prev = Instant::now();
  481. // Compute memory index tags
  482. for lpk_prev in new_lpks_prev {
  483. let memory_index_tags_prev: Vec<_> = (0..self.memory_size)
  484. .into_par_iter()
  485. .map(|j| LegendrePrf::eval_to_uint::<u128>(lpk_prev, F::from_u128(j as u128)))
  486. .collect();
  487. let mut memory_index_tags_prev_sorted = memory_index_tags_prev.clone();
  488. memory_index_tags_prev_sorted.par_sort_unstable();
  489. self.preprocessed_memory_index_tags_prev
  490. .push_back(memory_index_tags_prev);
  491. self.preprocessed_memory_index_tags_prev_sorted
  492. .push_back(memory_index_tags_prev_sorted);
  493. }
  494. let t_after_computing_index_tags_prev = Instant::now();
  495. self.preprocessed_legendre_prf_key_next
  496. .extend(fut_lpks_next.get()?.into_iter());
  497. let new_lpks_next =
  498. &self.preprocessed_legendre_prf_key_next.make_contiguous()[already_preprocessed..];
  499. let t_after_receiving_lpks_next = Instant::now();
  500. for lpk_next in new_lpks_next {
  501. let memory_index_tags_next: Vec<_> = (0..self.memory_size)
  502. .into_par_iter()
  503. .map(|j| LegendrePrf::eval_to_uint::<u128>(lpk_next, F::from_u128(j as u128)))
  504. .collect();
  505. let memory_index_tags_next_with_index_sorted: Vec<_> = memory_index_tags_next
  506. .iter()
  507. .copied()
  508. .enumerate()
  509. .sorted_unstable_by_key(|(_, x)| *x)
  510. .collect();
  511. self.preprocessed_memory_index_tags_next
  512. .push_back(memory_index_tags_next);
  513. self.preprocessed_memory_index_tags_next_sorted.push_back(
  514. memory_index_tags_next_with_index_sorted
  515. .par_iter()
  516. .map(|(_, x)| *x)
  517. .collect(),
  518. );
  519. }
  520. comm.send_next(
  521. self.preprocessed_memory_index_tags_next_sorted
  522. .make_contiguous()[already_preprocessed..]
  523. .to_vec(),
  524. )?;
  525. let t_after_computing_index_tags_next = Instant::now();
  526. self.mpdpf.precompute();
  527. let t_after_mpdpf_precomp = Instant::now();
  528. self.preprocessed_memory_index_tags_mine_sorted
  529. .extend(fut_tags_mine_sorted.get()?);
  530. let t_after_receiving_index_tags_mine = Instant::now();
  531. // Initialize Stash instances
  532. self.preprocessed_stash
  533. .extend((0..number_epochs).map(|_| StashProtocol::new(self.party_id, self.stash_size)));
  534. for stash in self
  535. .preprocessed_stash
  536. .iter_mut()
  537. .skip(already_preprocessed)
  538. {
  539. stash.init(comm)?;
  540. }
  541. let t_after_init_stash = Instant::now();
  542. // Initialize DOPRF instances
  543. self.preprocessed_doprf
  544. .extend((0..number_epochs).map(|_| JointDOPrf::new(self.prf_output_bitsize)));
  545. for (doprf, lpk_prev) in self
  546. .preprocessed_doprf
  547. .iter_mut()
  548. .skip(already_preprocessed)
  549. .zip(
  550. self.preprocessed_legendre_prf_key_prev
  551. .iter()
  552. .skip(already_preprocessed),
  553. )
  554. {
  555. doprf.set_legendre_prf_key_prev(lpk_prev.clone());
  556. doprf.init(comm)?;
  557. doprf.preprocess(comm, self.stash_size)?;
  558. }
  559. let t_after_init_doprf = Instant::now();
  560. // Precompute p-OTs and expand the mask
  561. self.preprocessed_pot
  562. .extend((0..number_epochs).map(|_| JointPOTParties::new(self.memory_size)));
  563. for pot in self.preprocessed_pot.iter_mut().skip(already_preprocessed) {
  564. pot.init(comm)?;
  565. }
  566. self.preprocessed_pot_expands.extend(
  567. self.preprocessed_pot.make_contiguous()[already_preprocessed..]
  568. .iter()
  569. .map(|pot| pot.expand()),
  570. );
  571. let t_after_preprocess_pot = Instant::now();
  572. self.preprocessed_select
  573. .extend((0..number_epochs).map(|_| SelectProtocol::default()));
  574. for select in self
  575. .preprocessed_select
  576. .iter_mut()
  577. .skip(already_preprocessed)
  578. {
  579. select.init(comm)?;
  580. select.preprocess(comm, 2 * self.stash_size)?;
  581. }
  582. let t_after_preprocess_select = Instant::now();
  583. self.number_preprocessed_epochs += number_epochs;
  584. debug_assert_eq!(
  585. self.preprocessed_legendre_prf_key_prev.len(),
  586. self.number_preprocessed_epochs
  587. );
  588. debug_assert_eq!(
  589. self.preprocessed_legendre_prf_key_next.len(),
  590. self.number_preprocessed_epochs
  591. );
  592. debug_assert_eq!(
  593. self.preprocessed_memory_index_tags_prev.len(),
  594. self.number_preprocessed_epochs
  595. );
  596. debug_assert_eq!(
  597. self.preprocessed_memory_index_tags_prev_sorted.len(),
  598. self.number_preprocessed_epochs
  599. );
  600. debug_assert_eq!(
  601. self.preprocessed_memory_index_tags_next.len(),
  602. self.number_preprocessed_epochs
  603. );
  604. debug_assert_eq!(
  605. self.preprocessed_memory_index_tags_next_sorted.len(),
  606. self.number_preprocessed_epochs
  607. );
  608. debug_assert_eq!(
  609. self.preprocessed_memory_index_tags_mine_sorted.len(),
  610. self.number_preprocessed_epochs
  611. );
  612. debug_assert_eq!(
  613. self.preprocessed_stash.len(),
  614. self.number_preprocessed_epochs
  615. );
  616. debug_assert_eq!(
  617. self.preprocessed_doprf.len(),
  618. self.number_preprocessed_epochs
  619. );
  620. debug_assert_eq!(self.preprocessed_pot.len(), self.number_preprocessed_epochs);
  621. debug_assert_eq!(
  622. self.preprocessed_pot_expands.len(),
  623. self.number_preprocessed_epochs
  624. );
  625. debug_assert_eq!(
  626. self.preprocessed_select.len(),
  627. self.number_preprocessed_epochs
  628. );
  629. let runtimes = runtimes.map(|mut r| {
  630. r.record(
  631. ProtocolStep::PreprocessLPRFKeyGenPrev,
  632. t_after_gen_lpks_prev - t_start,
  633. );
  634. r.record(
  635. ProtocolStep::PreprocessLPRFEvalSortPrev,
  636. t_after_computing_index_tags_prev - t_after_gen_lpks_prev,
  637. );
  638. r.record(
  639. ProtocolStep::PreprocessLPRFKeyRecvNext,
  640. t_after_receiving_lpks_next - t_after_computing_index_tags_prev,
  641. );
  642. r.record(
  643. ProtocolStep::PreprocessLPRFEvalSortNext,
  644. t_after_computing_index_tags_next - t_after_receiving_lpks_next,
  645. );
  646. r.record(
  647. ProtocolStep::PreprocessMpDpdfPrecomp,
  648. t_after_mpdpf_precomp - t_after_computing_index_tags_next,
  649. );
  650. r.record(
  651. ProtocolStep::PreprocessRecvTagsMine,
  652. t_after_receiving_index_tags_mine - t_after_mpdpf_precomp,
  653. );
  654. r.record(
  655. ProtocolStep::PreprocessStash,
  656. t_after_init_stash - t_after_receiving_index_tags_mine,
  657. );
  658. r.record(
  659. ProtocolStep::PreprocessDOPrf,
  660. t_after_init_doprf - t_after_init_stash,
  661. );
  662. r.record(
  663. ProtocolStep::PreprocessPOt,
  664. t_after_preprocess_pot - t_after_init_doprf,
  665. );
  666. r.record(
  667. ProtocolStep::PreprocessSelect,
  668. t_after_preprocess_select - t_after_preprocess_pot,
  669. );
  670. r.record(
  671. ProtocolStep::Preprocess,
  672. t_after_preprocess_select - t_start,
  673. );
  674. r
  675. });
  676. Ok(runtimes)
  677. }
  678. /// Run the refresh protocol at the end of an epoch.
  679. fn refresh<C: AbstractCommunicator>(
  680. &mut self,
  681. comm: &mut C,
  682. runtimes: Option<Runtimes>,
  683. ) -> Result<Option<Runtimes>, Error> {
  684. let t_start = Instant::now();
  685. // 0. Do preprocessing if not already done
  686. let runtimes = if self.number_preprocessed_epochs == 0 {
  687. self.preprocess_with_runtimes(comm, 1, runtimes)?
  688. } else {
  689. runtimes
  690. };
  691. let t_after_jit_preprocessing = Instant::now();
  692. // 1. Expect to receive garbled memory share
  693. let fut_garbled_memory_share = comm.receive_previous::<Vec<F>>()?;
  694. // 2. Get fresh (initialized) instances of the functionalities
  695. // a) Stash
  696. self.stash = self.preprocessed_stash.pop_front();
  697. debug_assert!(self.stash.is_some());
  698. // b) DOPRF
  699. self.legendre_prf_key_prev = self.preprocessed_legendre_prf_key_prev.pop_front();
  700. self.legendre_prf_key_next = self.preprocessed_legendre_prf_key_next.pop_front();
  701. self.joint_doprf = self.preprocessed_doprf.pop_front();
  702. debug_assert!(self.legendre_prf_key_prev.is_some());
  703. debug_assert!(self.legendre_prf_key_next.is_some());
  704. debug_assert!(self.joint_doprf.is_some());
  705. // c) p-OT
  706. self.joint_pot = self.preprocessed_pot.pop_front();
  707. debug_assert!(self.joint_pot.is_some());
  708. // d) select
  709. self.select_party = self.preprocessed_select.pop_front();
  710. debug_assert!(self.joint_pot.is_some());
  711. // e) Retrieve preprocessed index tags
  712. self.memory_index_tags_prev = self
  713. .preprocessed_memory_index_tags_prev
  714. .pop_front()
  715. .unwrap();
  716. self.memory_index_tags_prev_sorted = self
  717. .preprocessed_memory_index_tags_prev_sorted
  718. .pop_front()
  719. .unwrap();
  720. self.memory_index_tags_next = self
  721. .preprocessed_memory_index_tags_next
  722. .pop_front()
  723. .unwrap();
  724. self.memory_index_tags_next_sorted = self
  725. .preprocessed_memory_index_tags_next_sorted
  726. .pop_front()
  727. .unwrap();
  728. self.memory_index_tags_mine_sorted = self
  729. .preprocessed_memory_index_tags_mine_sorted
  730. .pop_front()
  731. .unwrap();
  732. debug_assert!(
  733. self.memory_index_tags_prev_sorted
  734. .windows(2)
  735. .all(|w| w[0] < w[1]),
  736. "index tags not sorted or colliding"
  737. );
  738. debug_assert!(
  739. self.memory_index_tags_next_sorted
  740. .windows(2)
  741. .all(|w| w[0] < w[1]),
  742. "index tags not sorted or colliding"
  743. );
  744. debug_assert!(
  745. self.memory_index_tags_mine_sorted
  746. .windows(2)
  747. .all(|w| w[0] < w[1]),
  748. "index tags not sorted or colliding"
  749. );
  750. let t_after_get_preprocessed_data = Instant::now();
  751. // 2.) Garble the memory share for the next party
  752. let mut garbled_memory_share_next: Vec<_> = self
  753. .memory_share
  754. .iter()
  755. .copied()
  756. .zip(self.memory_index_tags_next.iter().copied())
  757. .sorted_unstable_by_key(|(_, i)| *i)
  758. .map(|(x, _)| x)
  759. .collect();
  760. let t_after_sort = Instant::now();
  761. // the memory_index_tags_{prev,next} now define the pos_{prev,next} maps
  762. // - pos_(i-1)(tag) -> index of tag in mem_idx_tags_prev
  763. // - pos_(i+1)(tag) -> index of tag in mem_idx_tags_next
  764. let mask = self.preprocessed_pot_expands.pop_front().unwrap();
  765. self.memory_index_tags_next_sorted
  766. .par_iter()
  767. .zip(garbled_memory_share_next.par_iter_mut())
  768. .for_each(|(&tag, val)| {
  769. *val += mask[self.pos_next(tag)];
  770. });
  771. comm.send_next(garbled_memory_share_next)?;
  772. let t_after_pot_expand = Instant::now();
  773. self.garbled_memory_share = fut_garbled_memory_share.get()?;
  774. // the garbled_memory_share now defines the pos_mine map:
  775. // - pos_i(tag) -> index of tag in garbled_memory_share
  776. let t_after_receiving = Instant::now();
  777. // account that we used one set of preprocessing material
  778. self.number_preprocessed_epochs -= 1;
  779. let runtimes = runtimes.map(|mut r| {
  780. r.record(
  781. ProtocolStep::RefreshJitPreprocess,
  782. t_after_jit_preprocessing - t_start,
  783. );
  784. r.record(
  785. ProtocolStep::RefreshGetPreproc,
  786. t_after_get_preprocessed_data - t_after_jit_preprocessing,
  787. );
  788. r.record(
  789. ProtocolStep::RefreshSorting,
  790. t_after_sort - t_after_get_preprocessed_data,
  791. );
  792. r.record(
  793. ProtocolStep::RefreshPOtExpandMasking,
  794. t_after_pot_expand - t_after_sort,
  795. );
  796. r.record(
  797. ProtocolStep::RefreshReceivingShare,
  798. t_after_receiving - t_after_pot_expand,
  799. );
  800. r
  801. });
  802. Ok(runtimes)
  803. }
  804. /// Run the access protocol and collect runtime data.
  805. pub fn access_with_runtimes<C: AbstractCommunicator>(
  806. &mut self,
  807. comm: &mut C,
  808. instruction: InstructionShare<F>,
  809. runtimes: Option<Runtimes>,
  810. ) -> Result<(F, Option<Runtimes>), Error> {
  811. assert!(self.is_initialized);
  812. // 1. Read from the stash
  813. let t_start = Instant::now();
  814. let (stash_state, stash_runtimes) = self.stash.as_mut().unwrap().read_with_runtimes(
  815. comm,
  816. instruction,
  817. runtimes.map(|r| r.get_stash_runtimes()),
  818. )?;
  819. let t_after_stash_read = Instant::now();
  820. // 2. If the value was found in a stash, we read from the dummy address
  821. let dummy_address_share = match self.party_id {
  822. PARTY_1 => F::from_u128((self.db_size + self.get_access_counter()) as u128),
  823. _ => F::ZERO,
  824. };
  825. let db_address_share = self.select_party.as_mut().unwrap().select(
  826. comm,
  827. stash_state.flag,
  828. dummy_address_share,
  829. instruction.address,
  830. )?;
  831. let t_after_address_selection = Instant::now();
  832. // 3. Read a (dummy or real) value from the database
  833. let (db_value_share, runtimes) =
  834. self.read_from_database(comm, db_address_share, runtimes)?;
  835. let t_after_db_read = Instant::now();
  836. // 4. Write the read value into the stash
  837. let stash_runtime = self.stash.as_mut().unwrap().write_with_runtimes(
  838. comm,
  839. instruction,
  840. stash_state,
  841. db_address_share,
  842. db_value_share,
  843. stash_runtimes,
  844. )?;
  845. let t_after_stash_write = Instant::now();
  846. // 5. Select the right value to return
  847. let read_value = self.select_party.as_mut().unwrap().select(
  848. comm,
  849. stash_state.flag,
  850. stash_state.value,
  851. db_value_share,
  852. )?;
  853. let t_after_value_selection = Instant::now();
  854. // 6. If the stash is full, write the value back into the database
  855. let runtimes = if self.get_access_counter() == self.stash_size {
  856. let runtimes = self.update_database_from_stash(comm, runtimes)?;
  857. self.refresh(comm, runtimes)?
  858. } else {
  859. runtimes
  860. };
  861. let t_after_refresh = Instant::now();
  862. let runtimes = runtimes.map(|mut r| {
  863. r.set_stash_runtimes(stash_runtime.unwrap());
  864. r.record(ProtocolStep::AccessStashRead, t_after_stash_read - t_start);
  865. r.record(
  866. ProtocolStep::AccessAddressSelection,
  867. t_after_address_selection - t_after_stash_read,
  868. );
  869. r.record(
  870. ProtocolStep::AccessDatabaseRead,
  871. t_after_db_read - t_after_address_selection,
  872. );
  873. r.record(
  874. ProtocolStep::AccessStashWrite,
  875. t_after_stash_write - t_after_db_read,
  876. );
  877. r.record(
  878. ProtocolStep::AccessValueSelection,
  879. t_after_value_selection - t_after_stash_write,
  880. );
  881. r.record(
  882. ProtocolStep::AccessRefresh,
  883. t_after_refresh - t_after_value_selection,
  884. );
  885. r.record(ProtocolStep::Access, t_after_refresh - t_start);
  886. r
  887. });
  888. Ok((read_value, runtimes))
  889. }
  890. }
  891. impl<F, MPDPF, SPDPF> DistributedOram<F> for DistributedOramProtocol<F, MPDPF, SPDPF>
  892. where
  893. F: FromPrf + LegendreSymbol + Serializable,
  894. F::PrfKey: Serializable + Sync,
  895. MPDPF: MultiPointDpf<Value = F> + Sync,
  896. MPDPF::Key: Serializable,
  897. SPDPF: SinglePointDpf<Value = F> + Sync,
  898. SPDPF::Key: Serializable + Sync,
  899. {
  900. fn get_party_id(&self) -> usize {
  901. self.party_id
  902. }
  903. fn get_db_size(&self) -> usize {
  904. self.db_size
  905. }
  906. fn init<C: AbstractCommunicator>(&mut self, comm: &mut C, db_share: &[F]) -> Result<(), Error> {
  907. assert_eq!(db_share.len(), self.db_size);
  908. // 1. Initialize memory share with given db share and pad with dummy values
  909. self.memory_share = Vec::with_capacity(self.memory_size);
  910. self.memory_share.extend_from_slice(db_share);
  911. self.memory_share
  912. .extend(repeat(F::ZERO).take(self.stash_size));
  913. // 2. Run the refresh protocol to initialize everything.
  914. self.refresh(comm, None)?;
  915. self.is_initialized = true;
  916. Ok(())
  917. }
  918. fn preprocess<C: AbstractCommunicator>(
  919. &mut self,
  920. comm: &mut C,
  921. number_epochs: usize,
  922. ) -> Result<(), Error> {
  923. self.preprocess_with_runtimes(comm, number_epochs, None)
  924. .map(|_| ())
  925. }
  926. fn access<C: AbstractCommunicator>(
  927. &mut self,
  928. comm: &mut C,
  929. instruction: InstructionShare<F>,
  930. ) -> Result<F, Error> {
  931. self.access_with_runtimes(comm, instruction, None)
  932. .map(|x| x.0)
  933. }
  934. fn get_db<C: AbstractCommunicator>(
  935. &mut self,
  936. comm: &mut C,
  937. rerandomize_shares: bool,
  938. ) -> Result<Vec<F>, Error> {
  939. assert!(self.is_initialized);
  940. if self.get_access_counter() > 0 {
  941. self.refresh(comm, None)?;
  942. }
  943. if rerandomize_shares {
  944. let fut = comm.receive_previous()?;
  945. let mut rng = thread_rng();
  946. let mask: Vec<_> = (0..self.db_size).map(|_| F::random(&mut rng)).collect();
  947. let mut masked_share: Vec<_> = self.memory_share[0..self.db_size]
  948. .iter()
  949. .zip(mask.iter())
  950. .map(|(&x, &m)| x + m)
  951. .collect();
  952. comm.send_next(mask)?;
  953. let mask_prev: Vec<F> = fut.get()?;
  954. masked_share
  955. .iter_mut()
  956. .zip(mask_prev.iter())
  957. .for_each(|(x, &mp)| *x -= mp);
  958. Ok(masked_share)
  959. } else {
  960. Ok(self.memory_share[0..self.db_size].to_vec())
  961. }
  962. }
  963. }
  964. #[cfg(test)]
  965. mod tests {
  966. use super::*;
  967. use crate::common::Operation;
  968. use communicator::unix::make_unix_communicators;
  969. use dpf::mpdpf::DummyMpDpf;
  970. use dpf::spdpf::DummySpDpf;
  971. use ff::Field;
  972. use itertools::izip;
  973. use std::thread;
  974. use utils::field::Fp;
  975. const PARTY_1: usize = 0;
  976. const PARTY_2: usize = 1;
  977. const PARTY_3: usize = 2;
  978. fn run_init<F, C, P>(
  979. mut doram_party: P,
  980. mut comm: C,
  981. db_share: &[F],
  982. ) -> thread::JoinHandle<(P, C)>
  983. where
  984. F: PrimeField,
  985. C: AbstractCommunicator + Send + 'static,
  986. P: DistributedOram<F> + Send + 'static,
  987. {
  988. let db_share = db_share.to_vec();
  989. thread::Builder::new()
  990. .name(format!("Party {}", doram_party.get_party_id()))
  991. .spawn(move || {
  992. doram_party.init(&mut comm, &db_share).unwrap();
  993. (doram_party, comm)
  994. })
  995. .unwrap()
  996. }
  997. fn run_access<F, C, P>(
  998. mut doram_party: P,
  999. mut comm: C,
  1000. instruction: InstructionShare<F>,
  1001. ) -> thread::JoinHandle<(P, C, F)>
  1002. where
  1003. F: PrimeField,
  1004. C: AbstractCommunicator + Send + 'static,
  1005. P: DistributedOram<F> + Send + 'static,
  1006. {
  1007. thread::Builder::new()
  1008. .name(format!("Party {}", doram_party.get_party_id()))
  1009. .spawn(move || {
  1010. let output = doram_party.access(&mut comm, instruction).unwrap();
  1011. (doram_party, comm, output)
  1012. })
  1013. .unwrap()
  1014. }
  1015. fn run_get_db<F, C, P>(mut doram_party: P, mut comm: C) -> thread::JoinHandle<(P, C, Vec<F>)>
  1016. where
  1017. F: PrimeField,
  1018. C: AbstractCommunicator + Send + 'static,
  1019. P: DistributedOram<F> + Send + 'static,
  1020. {
  1021. thread::Builder::new()
  1022. .name(format!("Party {}", doram_party.get_party_id()))
  1023. .spawn(move || {
  1024. let output = doram_party.get_db(&mut comm, false).unwrap();
  1025. (doram_party, comm, output)
  1026. })
  1027. .unwrap()
  1028. }
  1029. fn mk_read(address: u128, value: u128) -> InstructionShare<Fp> {
  1030. InstructionShare {
  1031. operation: Operation::Read.encode(),
  1032. address: Fp::from_u128(address),
  1033. value: Fp::from_u128(value),
  1034. }
  1035. }
  1036. fn mk_write(address: u128, value: u128) -> InstructionShare<Fp> {
  1037. InstructionShare {
  1038. operation: Operation::Write.encode(),
  1039. address: Fp::from_u128(address),
  1040. value: Fp::from_u128(value),
  1041. }
  1042. }
  1043. const INST_ZERO_SHARE: InstructionShare<Fp> = InstructionShare {
  1044. operation: Fp::ZERO,
  1045. address: Fp::ZERO,
  1046. value: Fp::ZERO,
  1047. };
  1048. type SPDPF = DummySpDpf<Fp>;
  1049. type MPDPF = DummyMpDpf<Fp>;
  1050. fn setup(
  1051. db_size: usize,
  1052. ) -> (
  1053. (
  1054. impl DistributedOram<Fp>,
  1055. impl DistributedOram<Fp>,
  1056. impl DistributedOram<Fp>,
  1057. ),
  1058. (
  1059. impl AbstractCommunicator,
  1060. impl AbstractCommunicator,
  1061. impl AbstractCommunicator,
  1062. ),
  1063. ) {
  1064. let party_1 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_1, db_size);
  1065. let party_2 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_2, db_size);
  1066. let party_3 = DistributedOramProtocol::<Fp, MPDPF, SPDPF>::new(PARTY_3, db_size);
  1067. assert_eq!(party_1.get_party_id(), PARTY_1);
  1068. assert_eq!(party_2.get_party_id(), PARTY_2);
  1069. assert_eq!(party_3.get_party_id(), PARTY_3);
  1070. assert_eq!(party_1.get_db_size(), db_size);
  1071. assert_eq!(party_2.get_db_size(), db_size);
  1072. assert_eq!(party_3.get_db_size(), db_size);
  1073. let (comm_3, comm_2, comm_1) = {
  1074. let mut comms = make_unix_communicators(3);
  1075. (
  1076. comms.pop().unwrap(),
  1077. comms.pop().unwrap(),
  1078. comms.pop().unwrap(),
  1079. )
  1080. };
  1081. // Initialize DB with zeros
  1082. let db_share_1: Vec<_> = repeat(Fp::ZERO).take(db_size).collect();
  1083. let db_share_2: Vec<_> = repeat(Fp::ZERO).take(db_size).collect();
  1084. let db_share_3: Vec<_> = repeat(Fp::ZERO).take(db_size).collect();
  1085. let h1 = run_init(party_1, comm_1, &db_share_1);
  1086. let h2 = run_init(party_2, comm_2, &db_share_2);
  1087. let h3 = run_init(party_3, comm_3, &db_share_3);
  1088. let (party_1, comm_1) = h1.join().unwrap();
  1089. let (party_2, comm_2) = h2.join().unwrap();
  1090. let (party_3, comm_3) = h3.join().unwrap();
  1091. ((party_1, party_2, party_3), (comm_1, comm_2, comm_3))
  1092. }
  1093. #[test]
  1094. fn test_oram_even_exp() {
  1095. let db_size = 1 << 4;
  1096. let stash_size = (db_size as f64).sqrt().round() as usize;
  1097. let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
  1098. setup(db_size);
  1099. let number_cycles = 8;
  1100. let instructions = [
  1101. mk_write(12, 18),
  1102. mk_read(12, 899),
  1103. mk_write(13, 457),
  1104. mk_write(0, 77),
  1105. mk_write(13, 515),
  1106. mk_write(15, 421),
  1107. mk_write(13, 895),
  1108. mk_write(4, 941),
  1109. mk_write(1, 358),
  1110. mk_read(9, 894),
  1111. mk_read(7, 678),
  1112. mk_write(3, 110),
  1113. mk_read(15, 691),
  1114. mk_read(13, 335),
  1115. mk_write(9, 286),
  1116. mk_read(13, 217),
  1117. mk_write(10, 167),
  1118. mk_read(3, 909),
  1119. mk_write(2, 949),
  1120. mk_read(14, 245),
  1121. mk_write(3, 334),
  1122. mk_write(0, 378),
  1123. mk_write(2, 129),
  1124. mk_write(5, 191),
  1125. mk_write(15, 662),
  1126. mk_write(4, 724),
  1127. mk_write(1, 190),
  1128. mk_write(6, 887),
  1129. mk_write(9, 271),
  1130. mk_read(12, 666),
  1131. mk_write(0, 57),
  1132. mk_write(2, 185),
  1133. ];
  1134. let expected_values = [
  1135. Fp::from_u128(0),
  1136. Fp::from_u128(18),
  1137. Fp::from_u128(0),
  1138. Fp::from_u128(0),
  1139. Fp::from_u128(457),
  1140. Fp::from_u128(0),
  1141. Fp::from_u128(515),
  1142. Fp::from_u128(0),
  1143. Fp::from_u128(0),
  1144. Fp::from_u128(0),
  1145. Fp::from_u128(0),
  1146. Fp::from_u128(0),
  1147. Fp::from_u128(421),
  1148. Fp::from_u128(895),
  1149. Fp::from_u128(0),
  1150. Fp::from_u128(895),
  1151. Fp::from_u128(0),
  1152. Fp::from_u128(110),
  1153. Fp::from_u128(0),
  1154. Fp::from_u128(0),
  1155. Fp::from_u128(110),
  1156. Fp::from_u128(77),
  1157. Fp::from_u128(949),
  1158. Fp::from_u128(0),
  1159. Fp::from_u128(421),
  1160. Fp::from_u128(941),
  1161. Fp::from_u128(358),
  1162. Fp::from_u128(0),
  1163. Fp::from_u128(286),
  1164. Fp::from_u128(18),
  1165. Fp::from_u128(378),
  1166. Fp::from_u128(129),
  1167. ];
  1168. let expected_db_contents = [
  1169. [77, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 457, 0, 0],
  1170. [77, 0, 0, 0, 941, 0, 0, 0, 0, 0, 0, 0, 18, 895, 0, 421],
  1171. [77, 358, 0, 110, 941, 0, 0, 0, 0, 0, 0, 0, 18, 895, 0, 421],
  1172. [77, 358, 0, 110, 941, 0, 0, 0, 0, 286, 0, 0, 18, 895, 0, 421],
  1173. [
  1174. 77, 358, 949, 110, 941, 0, 0, 0, 0, 286, 167, 0, 18, 895, 0, 421,
  1175. ],
  1176. [
  1177. 378, 358, 129, 334, 941, 191, 0, 0, 0, 286, 167, 0, 18, 895, 0, 421,
  1178. ],
  1179. [
  1180. 378, 190, 129, 334, 724, 191, 887, 0, 0, 286, 167, 0, 18, 895, 0, 662,
  1181. ],
  1182. [
  1183. 57, 190, 185, 334, 724, 191, 887, 0, 0, 271, 167, 0, 18, 895, 0, 662,
  1184. ],
  1185. ];
  1186. for i in 0..number_cycles {
  1187. for j in 0..stash_size {
  1188. let inst = instructions[i * stash_size + j];
  1189. let expected_value = expected_values[i * stash_size + j];
  1190. let h1 = run_access(party_1, comm_1, inst);
  1191. let h2 = run_access(party_2, comm_2, INST_ZERO_SHARE);
  1192. let h3 = run_access(party_3, comm_3, INST_ZERO_SHARE);
  1193. let (p1, c1, value_1) = h1.join().unwrap();
  1194. let (p2, c2, value_2) = h2.join().unwrap();
  1195. let (p3, c3, value_3) = h3.join().unwrap();
  1196. (party_1, party_2, party_3) = (p1, p2, p3);
  1197. (comm_1, comm_2, comm_3) = (c1, c2, c3);
  1198. assert_eq!(value_1 + value_2 + value_3, expected_value);
  1199. }
  1200. let h1 = run_get_db(party_1, comm_1);
  1201. let h2 = run_get_db(party_2, comm_2);
  1202. let h3 = run_get_db(party_3, comm_3);
  1203. let (p1, c1, db_share_1) = h1.join().unwrap();
  1204. let (p2, c2, db_share_2) = h2.join().unwrap();
  1205. let (p3, c3, db_share_3) = h3.join().unwrap();
  1206. (party_1, party_2, party_3) = (p1, p2, p3);
  1207. (comm_1, comm_2, comm_3) = (c1, c2, c3);
  1208. let db: Vec<_> = izip!(db_share_1.iter(), db_share_2.iter(), db_share_3.iter())
  1209. .map(|(&x, &y, &z)| x + y + z)
  1210. .collect();
  1211. for k in 0..db_size {
  1212. assert_eq!(db[k], Fp::from_u128(expected_db_contents[i][k]));
  1213. }
  1214. }
  1215. }
  1216. #[test]
  1217. fn test_oram_odd_exp() {
  1218. let db_size = 1 << 5;
  1219. let stash_size = (db_size as f64).sqrt().round() as usize;
  1220. let ((mut party_1, mut party_2, mut party_3), (mut comm_1, mut comm_2, mut comm_3)) =
  1221. setup(db_size);
  1222. let number_cycles = 8;
  1223. let instructions = [
  1224. mk_write(26, 64),
  1225. mk_read(4, 141),
  1226. mk_write(25, 701),
  1227. mk_write(29, 927),
  1228. mk_read(28, 132),
  1229. mk_write(30, 990),
  1230. mk_write(23, 167),
  1231. mk_write(31, 347),
  1232. mk_write(26, 1020),
  1233. mk_write(20, 893),
  1234. mk_read(26, 805),
  1235. mk_write(3, 949),
  1236. mk_read(10, 195),
  1237. mk_write(29, 767),
  1238. mk_read(28, 107),
  1239. mk_write(30, 426),
  1240. mk_write(22, 605),
  1241. mk_write(0, 171),
  1242. mk_write(4, 210),
  1243. mk_read(12, 737),
  1244. mk_write(19, 977),
  1245. mk_read(16, 143),
  1246. mk_write(29, 775),
  1247. mk_read(28, 34),
  1248. mk_write(27, 95),
  1249. mk_write(30, 130),
  1250. mk_read(8, 89),
  1251. mk_read(23, 132),
  1252. mk_read(21, 12),
  1253. mk_read(4, 675),
  1254. mk_write(28, 225),
  1255. mk_write(5, 978),
  1256. mk_write(2, 833),
  1257. mk_write(1, 456),
  1258. mk_write(17, 921),
  1259. mk_read(26, 293),
  1260. mk_write(5, 474),
  1261. mk_write(7, 981),
  1262. mk_read(19, 189),
  1263. mk_write(1, 248),
  1264. mk_read(27, 573),
  1265. mk_read(17, 142),
  1266. mk_read(29, 945),
  1267. mk_read(16, 902),
  1268. mk_write(16, 799),
  1269. mk_read(28, 864),
  1270. mk_write(6, 986),
  1271. mk_read(2, 201),
  1272. ];
  1273. let expected_values = [
  1274. Fp::from_u128(0),
  1275. Fp::from_u128(0),
  1276. Fp::from_u128(0),
  1277. Fp::from_u128(0),
  1278. Fp::from_u128(0),
  1279. Fp::from_u128(0),
  1280. Fp::from_u128(0),
  1281. Fp::from_u128(0),
  1282. Fp::from_u128(64),
  1283. Fp::from_u128(0),
  1284. Fp::from_u128(1020),
  1285. Fp::from_u128(0),
  1286. Fp::from_u128(0),
  1287. Fp::from_u128(927),
  1288. Fp::from_u128(0),
  1289. Fp::from_u128(990),
  1290. Fp::from_u128(0),
  1291. Fp::from_u128(0),
  1292. Fp::from_u128(0),
  1293. Fp::from_u128(0),
  1294. Fp::from_u128(0),
  1295. Fp::from_u128(0),
  1296. Fp::from_u128(767),
  1297. Fp::from_u128(0),
  1298. Fp::from_u128(0),
  1299. Fp::from_u128(426),
  1300. Fp::from_u128(0),
  1301. Fp::from_u128(167),
  1302. Fp::from_u128(0),
  1303. Fp::from_u128(210),
  1304. Fp::from_u128(0),
  1305. Fp::from_u128(0),
  1306. Fp::from_u128(0),
  1307. Fp::from_u128(0),
  1308. Fp::from_u128(0),
  1309. Fp::from_u128(1020),
  1310. Fp::from_u128(978),
  1311. Fp::from_u128(0),
  1312. Fp::from_u128(977),
  1313. Fp::from_u128(456),
  1314. Fp::from_u128(95),
  1315. Fp::from_u128(921),
  1316. Fp::from_u128(775),
  1317. Fp::from_u128(0),
  1318. Fp::from_u128(0),
  1319. Fp::from_u128(225),
  1320. Fp::from_u128(0),
  1321. Fp::from_u128(833),
  1322. ];
  1323. let expected_db_contents = [
  1324. [
  1325. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 701, 64,
  1326. 0, 0, 927, 990, 0,
  1327. ],
  1328. [
  1329. 0, 0, 0, 949, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, 0, 0, 167, 0,
  1330. 701, 1020, 0, 0, 927, 990, 347,
  1331. ],
  1332. [
  1333. 171, 0, 0, 949, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, 0, 605, 167,
  1334. 0, 701, 1020, 0, 0, 767, 426, 347,
  1335. ],
  1336. [
  1337. 171, 0, 0, 949, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 977, 893, 0, 605,
  1338. 167, 0, 701, 1020, 0, 0, 775, 426, 347,
  1339. ],
  1340. [
  1341. 171, 0, 0, 949, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 977, 893, 0, 605,
  1342. 167, 0, 701, 1020, 95, 0, 775, 130, 347,
  1343. ],
  1344. [
  1345. 171, 456, 833, 949, 210, 978, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 921, 0, 977, 893, 0,
  1346. 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
  1347. ],
  1348. [
  1349. 171, 248, 833, 949, 210, 474, 0, 981, 0, 0, 0, 0, 0, 0, 0, 0, 0, 921, 0, 977, 893,
  1350. 0, 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
  1351. ],
  1352. [
  1353. 171, 248, 833, 949, 210, 474, 986, 981, 0, 0, 0, 0, 0, 0, 0, 0, 799, 921, 0, 977,
  1354. 893, 0, 605, 167, 0, 701, 1020, 95, 225, 775, 130, 347,
  1355. ],
  1356. ];
  1357. for i in 0..number_cycles {
  1358. for j in 0..stash_size {
  1359. let inst = instructions[i * stash_size + j];
  1360. let expected_value = expected_values[i * stash_size + j];
  1361. let h1 = run_access(party_1, comm_1, inst);
  1362. let h2 = run_access(party_2, comm_2, INST_ZERO_SHARE);
  1363. let h3 = run_access(party_3, comm_3, INST_ZERO_SHARE);
  1364. let (p1, c1, value_1) = h1.join().unwrap();
  1365. let (p2, c2, value_2) = h2.join().unwrap();
  1366. let (p3, c3, value_3) = h3.join().unwrap();
  1367. (party_1, party_2, party_3) = (p1, p2, p3);
  1368. (comm_1, comm_2, comm_3) = (c1, c2, c3);
  1369. assert_eq!(value_1 + value_2 + value_3, expected_value);
  1370. }
  1371. let h1 = run_get_db(party_1, comm_1);
  1372. let h2 = run_get_db(party_2, comm_2);
  1373. let h3 = run_get_db(party_3, comm_3);
  1374. let (p1, c1, db_share_1) = h1.join().unwrap();
  1375. let (p2, c2, db_share_2) = h2.join().unwrap();
  1376. let (p3, c3, db_share_3) = h3.join().unwrap();
  1377. (party_1, party_2, party_3) = (p1, p2, p3);
  1378. (comm_1, comm_2, comm_3) = (c1, c2, c3);
  1379. let db: Vec<_> = izip!(db_share_1.iter(), db_share_2.iter(), db_share_3.iter())
  1380. .map(|(&x, &y, &z)| x + y + z)
  1381. .collect();
  1382. for k in 0..db_size {
  1383. assert_eq!(db[k], Fp::from_u128(expected_db_contents[i][k]));
  1384. }
  1385. }
  1386. }
  1387. }