avadapal 1 rok pred
rodič
commit
0741414928
1 zmenil súbory, kde vykonal 33 pridanie a 19 odobranie
  1. 33 19
      2p-preprocessing/preprocessing.cpp

+ 33 - 19
2p-preprocessing/preprocessing.cpp

@@ -426,7 +426,17 @@ int main(int argc, char *argv[])
 	const size_t n_threads = atoi(argv[2]);
 	const size_t number_of_sockets = 5 * n_threads;
 	const size_t expo = atoi(argv[3]);
+
+	const size_t maxRAM = atoi(argv[4]);
+
 	const size_t db_nitems = 1ULL << expo;
+ size_t RAM_needed = 0;
+ RAM_needed = n_threads *  9 * ((sizeof(__m128i) * db_nitems));
+ //std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
+ size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
+ //std::cout << "n_batches = " << n_batches << std::endl;
+ size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
+ //std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
 
 	std::vector<socket_t> socketsPb;
 	for (size_t j = 0; j < number_of_sockets + 1; ++j)
@@ -480,14 +490,14 @@ bool party;
 #endif
 
  
-	__m128i *final_correction_word = (__m128i *)std::aligned_alloc(sizeof(__m256i), n_threads * sizeof(__m128i));
+	__m128i *final_correction_word = (__m128i *)std::aligned_alloc(sizeof(__m256i), thread_per_batch * sizeof(__m128i));
 
 	AES_KEY aeskey;
 
-	__m128i **output = (__m128i **)malloc(sizeof(__m128i *) * n_threads);
-	int8_t **flags = (int8_t **)malloc(sizeof(uint8_t *) * n_threads);
+	__m128i **output = (__m128i **)malloc(sizeof(__m128i *) * thread_per_batch);
+	int8_t **flags = (int8_t **)malloc(sizeof(uint8_t *) * thread_per_batch);
 
-	for (size_t j = 0; j < n_threads; ++j)
+	for (size_t j = 0; j < thread_per_batch; ++j)
 	{
 		output[j] = (__m128i *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(__m128i));
 		flags[j] = (int8_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(uint8_t));
@@ -499,31 +509,35 @@ bool party;
 	const size_t depth = std::ceil(std::log2(db_nitems));
 	const size_t nbits = std::ceil(std::log2(db_nitems));
 	const size_t nodes_in_interval = db_nitems - 1;
+	auto start = std::chrono::steady_clock::now();
 
-	boost::asio::thread_pool pool(n_threads);
 
-	//#ifdef VERBOSE
+#ifdef VERBOSE
 		printf("n_threads = %zu\n\n", n_threads);
-	//#endif
+#endif
  
-	auto start = std::chrono::steady_clock::now();
 
-	uint8_t **target_share_read = new uint8_t *[n_threads];
 
-	generate_random_targets(target_share_read, n_threads, party, expo);
 
-	for (size_t j = 0; j < n_threads; ++j)
-	{
-		boost::asio::post(pool, std::bind(evalfull_mpc, std::ref(nodes_per_leaf), std::ref(depth), std::ref(nbits), std::ref(nodes_in_interval),
-										  std::ref(aeskey), target_share_read[j], std::ref(socketsPb), 0, db_nitems - 1, output[j],
-										  flags[j], std::ref(final_correction_word[j]), party, 5 * j));
-	}
 
-	pool.join();
+ for(size_t iters = 0; iters < n_batches; ++iters)
+{
+   uint8_t **target_share_read = new uint8_t *[thread_per_batch];
+   generate_random_targets(target_share_read, thread_per_batch, party, expo);
+   boost::asio::thread_pool pool(thread_per_batch);
+   for (size_t j = 0; j < thread_per_batch; ++j)
+   {
+    boost::asio::post(pool, std::bind(evalfull_mpc, std::ref(nodes_per_leaf), std::ref(depth), std::ref(nbits), std::ref(nodes_in_interval),
+              std::ref(aeskey), target_share_read[j], std::ref(socketsPb), 0, db_nitems - 1, output[j],
+              flags[j], std::ref(final_correction_word[j]), party, 5 * j));
+   }
+
+   pool.join();
 
 
-	convert_shares(output, flags, n_threads, db_nitems, final_correction_word, socketsPb[0], party);
-	auto end = std::chrono::steady_clock::now();
+   convert_shares(output, flags, thread_per_batch, db_nitems, final_correction_word, socketsPb[0], party);
+}
+ auto end = std::chrono::steady_clock::now();
 	std::chrono::duration<double> elapsed_seconds = end - start;
 	std::cout << "WallClockTime: " << elapsed_seconds.count() << "s" << std::endl<< std::endl;
  std::cout << "CommunicationCost: " << communication_cost/1024 << " KiB" << std::endl;