Browse Source

Correctly compute the batch sizes and numbers for preprocessing

Ian Goldberg 1 year ago
parent
commit
e8db098e6d
3 changed files with 48 additions and 39 deletions
  1. 15 7
      2p-preprocessing/preprocessing.cpp
  2. 17 16
      preprocessing/p2preprocessing.cpp
  3. 16 16
      preprocessing/preprocessing.cpp

+ 15 - 7
2p-preprocessing/preprocessing.cpp

@@ -430,13 +430,21 @@ int main(int argc, char *argv[])
 	const size_t maxRAM = atoi(argv[4]);
 
 	const size_t db_nitems = 1ULL << expo;
- size_t RAM_needed = 0;
- RAM_needed = n_threads * 35 * db_nitems;
- std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
- size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
- std::cout << "n_batches = " << n_batches << std::endl;
- size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
- std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
+
+      size_t RAM_needed_per_thread = 164 * db_nitems;
+      std::cout << "RAM needed = " << n_threads*RAM_needed_per_thread << " bytes = " << n_threads*RAM_needed_per_thread/1073741824 << " GiB" << std::endl;
+      std::cout << "RAM needed per thread = " << RAM_needed_per_thread << " bytes = " << (RAM_needed_per_thread>>30) << " GiB" << std::endl;
+      size_t thread_per_batch = std::floor(double(maxRAM<<30)/RAM_needed_per_thread);
+      if (thread_per_batch > n_threads) {
+	thread_per_batch = n_threads;
+      }
+      std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
+      if (thread_per_batch < 1) {
+       std::cout << "You need more RAM" << std::endl;
+       exit(0);
+      }
+      size_t n_batches = std::ceil(double(n_threads)/thread_per_batch);
+      std::cout << "n_batches = " << n_batches << std::endl;
 
 	std::vector<socket_t> socketsPb;
 	for (size_t j = 0; j < number_of_sockets + 1; ++j)

+ 17 - 16
preprocessing/p2preprocessing.cpp

@@ -81,24 +81,25 @@ int main(int argc, char* argv[])
  const std::string host1 = (argc < 3) ? "127.0.0.1" : argv[2];  
  const size_t n_threads = atoi(argv[3]);
  const size_t number_of_sockets = 5 * n_threads;
- const size_t db_nitems = 1ULL << atoi(argv[4]);
+ const size_t depth = atoi(argv[4]);
+ const size_t db_nitems = 1ULL << depth;
  const size_t maxRAM = atoi(argv[5]);
 
- std::cout << "maxRAM = "  << maxRAM << std::endl;
- size_t RAM_needed = 0;
- RAM_needed = n_threads *  164 * db_nitems;
- std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
- size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
- std::cout << "n_batches = " << n_batches << std::endl;
- size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
- std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
-   if(n_batches > n_threads)
-   {
-    std::cout << "You need more RAM" << std::endl;
-    exit(0);
-   }
- const size_t depth = std::ceil(std::log2(db_nitems));
-    
+      size_t RAM_needed_per_thread = 164 * db_nitems;
+      std::cout << "RAM needed = " << n_threads*RAM_needed_per_thread << " bytes = " << n_threads*RAM_needed_per_thread/1073741824 << " GiB" << std::endl;
+      std::cout << "RAM needed per thread = " << RAM_needed_per_thread << " bytes = " << (RAM_needed_per_thread>>30) << " GiB" << std::endl;
+      size_t thread_per_batch = std::floor(double(maxRAM<<30)/RAM_needed_per_thread);
+      if (thread_per_batch > n_threads) {
+	thread_per_batch = n_threads;
+      }
+      std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
+      if (thread_per_batch < 1) {
+       std::cout << "You need more RAM" << std::endl;
+       exit(0);
+      }
+      size_t n_batches = std::ceil(double(n_threads)/thread_per_batch);
+      std::cout << "n_batches = " << n_batches << std::endl;
+
  std::vector<int> ports2_0;
  for(size_t j = 0; j < number_of_sockets; ++j) 
  {

+ 16 - 16
preprocessing/preprocessing.cpp

@@ -52,6 +52,7 @@ int main(int argc, char * argv[])
    const std::string host2 = (argc < 3) ? "127.0.0.1" : argv[2];
    const size_t n_threads = atoi(argv[3]);
    const size_t expo = atoi(argv[4]);
+   const size_t db_nitems = 1ULL << expo;
    const size_t maxRAM = atoi(argv[5]);
    //std::cout << "n_threads = " << n_threads << std::endl;
  
@@ -67,22 +68,21 @@ int main(int argc, char * argv[])
     /* The function make_connections appears in network.h */
    make_connections(party, host1, host2,  io_context, socketsPb, socketsP2, ports,  ports2_1, ports2_0, number_of_sockets);
  
-   const size_t db_nitems = 1ULL << atoi(argv[4]);
-      //std::cout << "maxRAM = "  << maxRAM << std::endl;
-      size_t RAM_needed = 0;
-      RAM_needed = n_threads * 164 * db_nitems;
-      //std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
-       size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
-      //std::cout << "n_batches = " << n_batches << std::endl;
-      size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
-        //std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
-
-   if(n_batches > n_threads)
-   {
-    std::cout << "You need more RAM" << std::endl;
-    exit(0);
-   }
-    
+      size_t RAM_needed_per_thread = 164 * db_nitems;
+      std::cout << "RAM needed = " << n_threads*RAM_needed_per_thread << " bytes = " << n_threads*RAM_needed_per_thread/1073741824 << " GiB" << std::endl;
+      std::cout << "RAM needed per thread = " << RAM_needed_per_thread << " bytes = " << (RAM_needed_per_thread>>30) << " GiB" << std::endl;
+      size_t thread_per_batch = std::floor(double(maxRAM<<30)/RAM_needed_per_thread);
+      if (thread_per_batch > n_threads) {
+	thread_per_batch = n_threads;
+      }
+      std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
+      if (thread_per_batch < 1) {
+       std::cout << "You need more RAM" << std::endl;
+       exit(0);
+      }
+      size_t n_batches = std::ceil(double(n_threads)/thread_per_batch);
+      std::cout << "n_batches = " << n_batches << std::endl;
+
    uint8_t ** target_share_read = new uint8_t*[thread_per_batch];
 
    generate_random_targets(target_share_read,  thread_per_batch, party, expo);