server.rs 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. use criterion::measurement::WallTime;
  2. use criterion::BenchmarkGroup;
  3. use criterion::{black_box, criterion_group, criterion_main, Criterion};
  4. use pprof::criterion::{Output, PProfProfiler};
  5. use rand::Rng;
  6. use spiral_rs::aligned_memory::AlignedMemory64;
  7. use spiral_rs::client::*;
  8. use spiral_rs::params::*;
  9. use spiral_rs::poly::*;
  10. use spiral_rs::server::*;
  11. use spiral_rs::util::*;
  12. use std::time::Duration;
  13. pub fn generate_random_incorrect_db(params: &Params) -> AlignedMemory64 {
  14. let instances = params.instances;
  15. let trials = params.n * params.n;
  16. let dim0 = 1 << params.db_dim_1;
  17. let num_per = 1 << params.db_dim_2;
  18. let num_items = dim0 * num_per;
  19. let db_size_words = instances * trials * num_items * params.poly_len;
  20. let mut out = AlignedMemory64::new(db_size_words);
  21. let out_mut_slice = out.as_mut_slice();
  22. let mut rng = get_seeded_rng();
  23. for i in 0..out_mut_slice.len() {
  24. out_mut_slice[i] = rng.gen();
  25. }
  26. out
  27. }
  28. fn test_full_processing(group: &mut BenchmarkGroup<WallTime>) {
  29. // let names = ["server_processing_20_256", "server_processing_16_100000"];
  30. // let cfgs = [CFG_20_256, CFG_16_100000];
  31. let names = ["server_processing_16_100000"];
  32. let cfgs = [CFG_16_100000];
  33. for i in 0..names.len() {
  34. let name = names[i];
  35. let cfg = cfgs[i];
  36. let params = params_from_json(&cfg.replace("'", "\""));
  37. let mut seeded_rng = get_seeded_rng();
  38. let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
  39. let mut client = Client::init(&params, &mut seeded_rng);
  40. let public_params = client.generate_keys();
  41. let query = client.generate_query(target_idx);
  42. println!("Generating database...");
  43. let db = generate_random_incorrect_db(&params);
  44. println!("Done generating database.");
  45. group.bench_function(name, |b| {
  46. b.iter(|| {
  47. black_box(process_query(
  48. black_box(&params),
  49. black_box(&public_params),
  50. black_box(&query),
  51. black_box(db.as_slice()),
  52. ));
  53. });
  54. });
  55. }
  56. }
  57. fn criterion_benchmark(c: &mut Criterion) {
  58. let mut group = c.benchmark_group("server");
  59. group
  60. .sample_size(10)
  61. .measurement_time(Duration::from_secs(30));
  62. let params = get_expansion_testing_params();
  63. let v_neg1 = params.get_v_neg1();
  64. let mut seeded_rng = get_seeded_rng();
  65. let mut client = Client::init(&params, &mut seeded_rng);
  66. let public_params = client.generate_keys();
  67. let mut v = Vec::new();
  68. for _ in 0..params.poly_len {
  69. v.push(PolyMatrixNTT::zero(&params, 2, 1));
  70. }
  71. let scale_k = params.modulus / params.pt_modulus;
  72. let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
  73. sigma.data[7] = scale_k;
  74. v[0] = client.encrypt_matrix_reg(&sigma.ntt());
  75. let v_w_left = public_params.v_expansion_left.unwrap();
  76. let v_w_right = public_params.v_expansion_right.unwrap();
  77. // note: the benchmark on AVX2 is 545ms for the c++ impl
  78. group.bench_function("coefficient_expansion", |b| {
  79. b.iter(|| {
  80. coefficient_expansion(
  81. black_box(&mut v),
  82. black_box(params.g()),
  83. black_box(params.stop_round()),
  84. black_box(&params),
  85. black_box(&v_w_left),
  86. black_box(&v_w_right),
  87. black_box(&v_neg1),
  88. black_box(params.t_gsw * params.db_dim_2),
  89. )
  90. });
  91. });
  92. let mut seeded_rng = get_seeded_rng();
  93. let trials = params.n * params.n;
  94. let dim0 = 1 << params.db_dim_1;
  95. let num_per = 1 << params.db_dim_2;
  96. let num_items = dim0 * num_per;
  97. let db_size_words = trials * num_items * params.poly_len;
  98. let mut db = vec![0u64; db_size_words];
  99. for i in 0..db_size_words {
  100. db[i] = seeded_rng.gen();
  101. }
  102. let v_reg_sz = dim0 * 2 * params.poly_len;
  103. let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
  104. for i in 0..v_reg_sz {
  105. v_reg_reoriented[i] = seeded_rng.gen();
  106. }
  107. let mut out = Vec::with_capacity(num_per);
  108. for _ in 0..dim0 {
  109. out.push(PolyMatrixNTT::zero(&params, 2, 1));
  110. }
  111. // note: the benchmark on AVX2 is 45ms for the c++ impl
  112. group.bench_function("first_dimension_processing", |b| {
  113. b.iter(|| {
  114. multiply_reg_by_database(
  115. black_box(&mut out),
  116. black_box(db.as_slice()),
  117. black_box(v_reg_reoriented.as_slice()),
  118. black_box(&params),
  119. black_box(dim0),
  120. black_box(num_per),
  121. )
  122. });
  123. });
  124. // full server processing benchmark
  125. test_full_processing(&mut group);
  126. group.finish();
  127. }
  128. // criterion_group!(benches, criterion_benchmark);
  129. criterion_group! {
  130. name = benches;
  131. config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
  132. targets = criterion_benchmark
  133. }
  134. criterion_main!(benches);