|
@@ -426,7 +426,17 @@ int main(int argc, char *argv[])
|
|
|
const size_t n_threads = atoi(argv[2]);
|
|
|
const size_t number_of_sockets = 5 * n_threads;
|
|
|
const size_t expo = atoi(argv[3]);
|
|
|
+
|
|
|
+ const size_t maxRAM = atoi(argv[4]);
|
|
|
+
|
|
|
const size_t db_nitems = 1ULL << expo;
|
|
|
+ size_t RAM_needed = 0;
|
|
|
+ RAM_needed = n_threads * 9 * ((sizeof(__m128i) * db_nitems));
|
|
|
+ //std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
|
|
|
+ size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
|
|
|
+ //std::cout << "n_batches = " << n_batches << std::endl;
|
|
|
+ size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
|
|
|
+ //std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
|
|
|
|
|
|
std::vector<socket_t> socketsPb;
|
|
|
for (size_t j = 0; j < number_of_sockets + 1; ++j)
|
|
@@ -480,14 +490,14 @@ bool party;
|
|
|
#endif
|
|
|
|
|
|
|
|
|
- __m128i *final_correction_word = (__m128i *)std::aligned_alloc(sizeof(__m256i), n_threads * sizeof(__m128i));
|
|
|
+ __m128i *final_correction_word = (__m128i *)std::aligned_alloc(sizeof(__m256i), thread_per_batch * sizeof(__m128i));
|
|
|
|
|
|
AES_KEY aeskey;
|
|
|
|
|
|
- __m128i **output = (__m128i **)malloc(sizeof(__m128i *) * n_threads);
|
|
|
- int8_t **flags = (int8_t **)malloc(sizeof(uint8_t *) * n_threads);
|
|
|
+ __m128i **output = (__m128i **)malloc(sizeof(__m128i *) * thread_per_batch);
|
|
|
+ int8_t **flags = (int8_t **)malloc(sizeof(uint8_t *) * thread_per_batch);
|
|
|
|
|
|
- for (size_t j = 0; j < n_threads; ++j)
|
|
|
+ for (size_t j = 0; j < thread_per_batch; ++j)
|
|
|
{
|
|
|
output[j] = (__m128i *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(__m128i));
|
|
|
flags[j] = (int8_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(uint8_t));
|
|
@@ -499,31 +509,35 @@ bool party;
|
|
|
const size_t depth = std::ceil(std::log2(db_nitems));
|
|
|
const size_t nbits = std::ceil(std::log2(db_nitems));
|
|
|
const size_t nodes_in_interval = db_nitems - 1;
|
|
|
+ auto start = std::chrono::steady_clock::now();
|
|
|
|
|
|
- boost::asio::thread_pool pool(n_threads);
|
|
|
|
|
|
- //#ifdef VERBOSE
|
|
|
+#ifdef VERBOSE
|
|
|
printf("n_threads = %zu\n\n", n_threads);
|
|
|
- //#endif
|
|
|
+#endif
|
|
|
|
|
|
- auto start = std::chrono::steady_clock::now();
|
|
|
|
|
|
- uint8_t **target_share_read = new uint8_t *[n_threads];
|
|
|
|
|
|
- generate_random_targets(target_share_read, n_threads, party, expo);
|
|
|
|
|
|
- for (size_t j = 0; j < n_threads; ++j)
|
|
|
- {
|
|
|
- boost::asio::post(pool, std::bind(evalfull_mpc, std::ref(nodes_per_leaf), std::ref(depth), std::ref(nbits), std::ref(nodes_in_interval),
|
|
|
- std::ref(aeskey), target_share_read[j], std::ref(socketsPb), 0, db_nitems - 1, output[j],
|
|
|
- flags[j], std::ref(final_correction_word[j]), party, 5 * j));
|
|
|
- }
|
|
|
|
|
|
- pool.join();
|
|
|
+ for(size_t iters = 0; iters < n_batches; ++iters)
|
|
|
+{
|
|
|
+ uint8_t **target_share_read = new uint8_t *[thread_per_batch];
|
|
|
+ generate_random_targets(target_share_read, thread_per_batch, party, expo);
|
|
|
+ boost::asio::thread_pool pool(thread_per_batch);
|
|
|
+ for (size_t j = 0; j < thread_per_batch; ++j)
|
|
|
+ {
|
|
|
+ boost::asio::post(pool, std::bind(evalfull_mpc, std::ref(nodes_per_leaf), std::ref(depth), std::ref(nbits), std::ref(nodes_in_interval),
|
|
|
+ std::ref(aeskey), target_share_read[j], std::ref(socketsPb), 0, db_nitems - 1, output[j],
|
|
|
+ flags[j], std::ref(final_correction_word[j]), party, 5 * j));
|
|
|
+ }
|
|
|
+
|
|
|
+ pool.join();
|
|
|
|
|
|
|
|
|
- convert_shares(output, flags, n_threads, db_nitems, final_correction_word, socketsPb[0], party);
|
|
|
- auto end = std::chrono::steady_clock::now();
|
|
|
+ convert_shares(output, flags, thread_per_batch, db_nitems, final_correction_word, socketsPb[0], party);
|
|
|
+}
|
|
|
+ auto end = std::chrono::steady_clock::now();
|
|
|
std::chrono::duration<double> elapsed_seconds = end - start;
|
|
|
std::cout << "WallClockTime: " << elapsed_seconds.count() << "s" << std::endl<< std::endl;
|
|
|
std::cout << "CommunicationCost: " << communication_cost/1024 << " KiB" << std::endl;
|