Browse Source

batching the preprocessing

avadapal 1 year ago
parent
commit
e548f53816
3 changed files with 109 additions and 65 deletions
  1. 3 3
      preprocessing/network.h
  2. 37 20
      preprocessing/p2preprocessing.cpp
  3. 69 42
      preprocessing/preprocessing.cpp

+ 3 - 3
preprocessing/network.h

@@ -33,21 +33,21 @@ void make_connections(bool& party, const std::string host1,
 
 	for(size_t j = 0; j < number_of_sockets; ++j) 
 	{
-		int port = 6000;
+		int port = 8000;
 		ports.push_back(port + j);
 	}
 	
 
 	for(size_t j = 0; j < number_of_sockets; ++j) 
 	{
-		int port = 20000;
+		int port = 22000;
 		ports2_0.push_back(port + j);
 	}
 
 
 	for(size_t j = 0; j < number_of_sockets; ++j) 
 	{
-		int port = 40000;
+		int port = 42000;
 		ports2_1.push_back(port + j);
 	}
 

+ 37 - 20
preprocessing/p2preprocessing.cpp

@@ -82,20 +82,34 @@ int main(int argc, char* argv[])
  const size_t n_threads = atoi(argv[3]);
  const size_t number_of_sockets = 5 * n_threads;
  const size_t db_nitems = 1ULL << atoi(argv[4]);
- 
+ const size_t maxRAM = atoi(argv[5]);
+
+ std::cout << "maxRAM = "  << maxRAM << std::endl;
+ size_t RAM_needed = 0;
+ RAM_needed = n_threads *  9 * ((sizeof(__m128i) * db_nitems));
+ std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
+ size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
+ std::cout << "n_batches = " << n_batches << std::endl;
+ size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
+ std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
+   if(n_batches > n_threads)
+   {
+    std::cout << "You need more RAM" << std::endl;
+    exit(0);
+   }
  const size_t depth = std::ceil(std::log2(db_nitems));
     
  std::vector<int> ports2_0;
  for(size_t j = 0; j < number_of_sockets; ++j) 
  {
-   int port = 20000;
+   int port = 22000;
    ports2_0.push_back(port + j);
  }
 
  std::vector<int> ports2_1;
  for(size_t j = 0; j < number_of_sockets; ++j) 
  {
-   int port = 40000;
+   int port = 42000;
    ports2_1.push_back(port + j);
  }
 
@@ -119,7 +133,8 @@ int main(int argc, char* argv[])
 
  pool2.join();
 
- boost::asio::thread_pool pool(n_threads);  
+
+
 
 	__m128i ** output0 = (__m128i ** ) malloc(sizeof(__m128i *) * n_threads);
 	int8_t  ** flags0  = (int8_t ** ) malloc(sizeof(uint8_t *) * n_threads);
@@ -179,34 +194,36 @@ int main(int argc, char* argv[])
     // }
 
 
-
  
+  boost::asio::thread_pool pool(thread_per_batch);  
+
   boost::asio::write(sockets0[0], boost::asio::buffer(&computecw0_array,  sizeof(computecw0_array)));
   boost::asio::write(sockets1[0], boost::asio::buffer(&computecw1_array,  sizeof(computecw1_array)));
 
   boost::asio::read(sockets0[0], boost::asio::buffer(dpf_instance0, n_threads * sizeof(dpfP2)));
   boost::asio::read(sockets1[0], boost::asio::buffer(dpf_instance1, n_threads * sizeof(dpfP2))); 
-  for(size_t jj = 0; jj < 16; ++jj)
-  {
-    for(size_t j = 0; j < n_threads; ++j)
+
+
+    for(size_t j = 0; j < thread_per_batch; ++j)
     {
      boost::asio::post(pool, std::bind(mpc_gen,  std::ref(depth), std::ref(aeskey), db_nitems, n_threads,  std::ref(sockets0), std::ref(sockets1), 
                                      output0, flags0,  output1, flags1,  std::ref(dpf_instance0), std::ref(dpf_instance1), j, 5 * j));    
     }  
-  }
-  pool.join();
 
+  pool.join();
+ 
 
 
 
-  boost::asio::thread_pool pool3(n_threads); 
+  boost::asio::thread_pool pool3(thread_per_batch); 
   
- int64_t ** leaves0    = (int64_t ** ) malloc(sizeof(int64_t *) * n_threads);
- int64_t ** leafbits0  = (int64_t ** ) malloc(sizeof(int64_t *) * n_threads); 
- int64_t ** leaves1    = (int64_t ** ) malloc(sizeof(int64_t *) * n_threads);
- int64_t ** leafbits1  = (int64_t ** ) malloc(sizeof(int64_t *) * n_threads); 
+ int64_t ** leaves0    = (int64_t ** ) malloc(sizeof(int64_t *) * thread_per_batch);
+ int64_t ** leafbits0  = (int64_t ** ) malloc(sizeof(int64_t *) * thread_per_batch); 
+ int64_t ** leaves1    = (int64_t ** ) malloc(sizeof(int64_t *) * thread_per_batch);
+ int64_t ** leafbits1  = (int64_t ** ) malloc(sizeof(int64_t *) * thread_per_batch); 
+
 
- for(size_t j = 0; j < n_threads; ++j)
+ for(size_t j = 0; j < thread_per_batch; ++j)
  {
   leaves0[j]    = (int64_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(int64_t));
   leafbits0[j]  = (int64_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(int64_t));
@@ -215,7 +232,7 @@ int main(int argc, char* argv[])
  }
 
  /* The function convert_sharesP2 appears in share-conversion.h */
-  for(size_t j = 0; j < n_threads; ++j)
+  for(size_t j = 0; j < thread_per_batch; ++j)
   {
    boost::asio::post(pool3, std::bind(convert_sharesP2, db_nitems,  output0, flags0,  output1, flags1, leaves0, leafbits0, leaves1, leafbits1,  std::ref(sockets0), std::ref(sockets1), j, j));    
   }
@@ -223,8 +240,8 @@ int main(int argc, char* argv[])
   pool3.join(); 
   
   /* The function P2_xor_to_additive appears in share-conversion.h */
-  boost::asio::thread_pool pool4(n_threads); 
-  for(size_t j = 0; j < n_threads; ++j)
+  boost::asio::thread_pool pool4(thread_per_batch); 
+  for(size_t j = 0; j < thread_per_batch; ++j)
   {
    boost::asio::post(pool4,  std::bind(P2_xor_to_additive, std::ref(sockets0[j]), std::ref(sockets1[j]), j));
   }
@@ -233,7 +250,7 @@ int main(int argc, char* argv[])
 
 
 
-  for(size_t i = 0; i < n_threads; ++i)
+  for(size_t i = 0; i < thread_per_batch; ++i)
   {
    P2_write_evalfull_outs_into_a_file(false, i, db_nitems,  flags0[i], 	leaves0[i]);
    P2_write_evalfull_outs_into_a_file(true,  i, db_nitems,  flags1[i], 	leaves1[i]);

+ 69 - 42
preprocessing/preprocessing.cpp

@@ -53,7 +53,11 @@ int main(int argc, char * argv[])
    const size_t n_threads = atoi(argv[3]);
    const size_t expo = atoi(argv[4]);
    const size_t op = atoi(argv[5]);
+   const size_t maxRAM = atoi(argv[6]);
+   std::cout << "n_threads = " << n_threads << std::endl;
  
+
+
    const size_t number_of_sockets = 5 * n_threads;
 
    std::vector<socket_t> socketsPb, socketsP2;
@@ -65,28 +69,43 @@ int main(int argc, char * argv[])
    make_connections(party, host1, host2,  io_context, socketsPb, socketsP2, ports,  ports2_1, ports2_0, number_of_sockets);
  
    const size_t db_nitems = 1ULL << atoi(argv[4]);
-
-
+      std::cout << "maxRAM = "  << maxRAM << std::endl;
+      size_t RAM_needed = 0;
+      RAM_needed = n_threads *  9 * ((sizeof(__m128i) * db_nitems));
+      std::cout << "RAM needed = " << RAM_needed << " bytes = " << RAM_needed/1073741824 << " GiB" << std::endl;
+       size_t n_batches = std::ceil(double(RAM_needed)/(1073741824 * maxRAM));
+      std::cout << "n_batches = " << n_batches << std::endl;
+      size_t thread_per_batch = std::ceil(double(n_threads)/n_batches);
+        std::cout << "thread_per_batch = " << thread_per_batch << std::endl;
+
+   if(n_batches > n_threads)
+   {
+    std::cout << "You need more RAM" << std::endl;
+    exit(0);
+   }
     
-   uint8_t ** target_share_read = new uint8_t*[n_threads];
+   uint8_t ** target_share_read = new uint8_t*[thread_per_batch];
 
-   generate_random_targets(target_share_read,  n_threads, party, expo);
+   generate_random_targets(target_share_read,  thread_per_batch, party, expo);
    
    AES_KEY aeskey;
    
-   __m128i * final_correction_word = (__m128i *) std::aligned_alloc(sizeof(__m256i), n_threads * sizeof(__m128i));
-   __m128i ** output = (__m128i ** ) malloc(sizeof(__m128i *) * n_threads);
+   auto start = std::chrono::steady_clock::now(); 
+   
+
+   __m128i * final_correction_word = (__m128i *) std::aligned_alloc(sizeof(__m256i), thread_per_batch * sizeof(__m128i));
+   __m128i ** output = (__m128i ** ) malloc(sizeof(__m128i *) * thread_per_batch);
     
-   int8_t  ** flags  = (int8_t ** ) malloc(sizeof(uint8_t *) * n_threads);
+   int8_t  ** flags  = (int8_t ** ) malloc(sizeof(uint8_t *) * thread_per_batch);
      
-   for(size_t j = 0; j < n_threads; ++j)
+   for(size_t j = 0; j < thread_per_batch; ++j)
    {
      output[j] = (__m128i *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(__m128i));
      flags[j]  = (int8_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(uint8_t));
    }
      
-  
-   boost::asio::thread_pool pool_share_conversion(n_threads);
+ 
+   boost::asio::thread_pool pool_share_conversion(thread_per_batch);
     
 
     
@@ -99,7 +118,7 @@ int main(int argc, char * argv[])
 
 
   cw_construction computecw_array;
-  auto start = std::chrono::steady_clock::now(); 
+
  
      boost::asio::read(socketsP2[0], boost::asio::buffer(&computecw_array, sizeof(computecw_array)));
      communication_cost += sizeof(computecw_array);
@@ -109,34 +128,44 @@ int main(int argc, char * argv[])
 
       /* The function create_dpfs appears in dpf-gen.h*/
       bool reading = true;
-      boost::asio::thread_pool pool(n_threads);
 
-      for(size_t j = 0; j < n_threads; ++j)
-      {
-       boost::asio::post(pool,  std::bind(create_dpfs, reading,  db_nitems,	std::ref(aeskey),  target_share_read[j],  std::ref(socketsPb), std::ref(socketsP2), 0, db_nitems-1, 
-                                             output[j],  flags[j], std::ref(final_correction_word[j]), computecw_array, std::ref(dpf_instance),  party, 5 * j, j));	 	  
-      }    
-      pool.join();  
       
-      bool interleaved = false;
+
+     for(size_t iter = 0; iter < n_batches; ++iter)
+     { 
+        boost::asio::thread_pool pool(thread_per_batch);
+        for(size_t j = 0; j < thread_per_batch; ++j)
+        {
+         boost::asio::post(pool,  std::bind(create_dpfs, reading,  db_nitems,	std::ref(aeskey),  target_share_read[j],  std::ref(socketsPb), std::ref(socketsP2), 0, db_nitems-1, 
+                                               output[j],  flags[j], std::ref(final_correction_word[j]), computecw_array, std::ref(dpf_instance),  party, 5 * j, j));	 	  
+        }    
+        pool.join();  
+     }
+      
+     bool interleaved = false;
       
-      if(op == 1) interleaved = true;
+     if(op == 1) interleaved = true;
       
-      if(interleaved)
-      {
-       boost::asio::thread_pool pool2(n_threads);
-       for(size_t j = 0; j < n_threads; ++j)
-       {
+     if(interleaved)
+     {
+      for(size_t iter = 0; iter < n_batches; ++iter)
+      { 
+       std::cout << "iter = " << iter << std::endl;
+        boost::asio::thread_pool pool2(thread_per_batch);
+        for(size_t j = 0; j < thread_per_batch; ++j)
+        {
          boost::asio::post(pool2,  std::bind(create_dpfs, reading,  db_nitems,	std::ref    (aeskey),  target_share_read[j],  std::ref(socketsPb), std::ref(socketsP2), 0, db_nitems-1, 
                                              output[j],  flags[j], std::ref(final_correction_word[j]), computecw_array, std::ref(dpf_instance),  party, 5 * j, j));	 	  
+        }
+         pool2.join();  
        }
-       pool2.join();  
       }
+     
      boost::asio::write(socketsP2[0], boost::asio::buffer(dpf_instance, n_threads * sizeof(dpfP2))); // do this in parallel.
      communication_cost += (n_threads * sizeof(dpfP2));
  
    #ifdef DEBUG
-   
+
     for(size_t j = 0; j < n_threads; ++j)
     {
       std::cout << "n_threads = " << j << std::endl;
@@ -167,9 +196,9 @@ int main(int argc, char * argv[])
      leaves is a additive shares of the outputs (leaves of the DPF)
      leafbits is the additive shares of flag bits of the DPFs
     */
-   int64_t ** leaves = (int64_t ** ) malloc(sizeof(int64_t *) * n_threads);
-   int64_t ** leafbits  = (int64_t ** ) malloc(sizeof(int64_t *) * n_threads); 
-   for(size_t j = 0; j < n_threads; ++j)
+   int64_t ** leaves = (int64_t ** ) malloc(sizeof(int64_t *) * thread_per_batch);
+   int64_t ** leafbits  = (int64_t ** ) malloc(sizeof(int64_t *) * thread_per_batch); 
+   for(size_t j = 0; j < thread_per_batch; ++j)
    {
     leaves[j] = (int64_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(int64_t));
     leafbits[j]  = (int64_t *)std::aligned_alloc(sizeof(node_t), db_nitems * sizeof(int64_t));
@@ -178,7 +207,7 @@ int main(int argc, char * argv[])
 
 
     /* The function convert_shares appears in share-conversion.h */
-   for(size_t j = 0; j < n_threads; ++j)
+   for(size_t j = 0; j < thread_per_batch; ++j)
    {
      boost::asio::post(pool_share_conversion,  std::bind(convert_shares, j, output, flags, n_threads, db_nitems, final_correction_word, 	leaves, leafbits, 
                                                           std::ref(socketsPb), std::ref(socketsP2), party));	 	
@@ -186,30 +215,28 @@ int main(int argc, char * argv[])
     
     pool_share_conversion.join();
 
-    boost::asio::thread_pool pool_xor_to_additive(n_threads); 
+    boost::asio::thread_pool pool_xor_to_additive(thread_per_batch); 
 
     std::array<int64_t, 128> additve_shares; 
-    for(size_t j = 0; j < n_threads; ++j)
+    for(size_t j = 0; j < thread_per_batch; ++j)
     {
      boost::asio::post(pool_xor_to_additive, std::bind(xor_to_additive, party, target_share_read[j], std::ref(socketsPb[j]), std::ref(socketsP2[j]), expo, std::ref(additve_shares[j])));
     }
 
     pool_xor_to_additive.join();
     
-    auto end = std::chrono::steady_clock::now();
-    std::chrono::duration<double> elapsed_seconds = end-start;
-    //std::cout << "time to generate and evaluate " << n_threads << " dpfs of size 2^" << atoi(argv[4]) << " is: " << elapsed_seconds.count() << "s\n";
-    std::cout << "WallClockTime: "  << elapsed_seconds.count() << std::endl;
-    start = std::chrono::steady_clock::now();
+ 
     
-    for(size_t i = 0; i < n_threads; ++i)
+    for(size_t i = 0; i < thread_per_batch; ++i)
     {
      write_evalfull_outs_into_a_file(party, i, db_nitems, flags[i],  leaves[i], final_correction_word[i], additve_shares[i]); 
     }
-    
-    end = std::chrono::steady_clock::now();
-    elapsed_seconds = end-start;
 
+    auto end = std::chrono::steady_clock::now();
+    std::chrono::duration<double> elapsed_seconds = end-start;
+    //std::cout << "time to generate and evaluate " << n_threads << " dpfs of size 2^" << atoi(argv[4]) << " is: " << elapsed_seconds.count() << "s\n";
+    std::cout << "WallClockTime: "  << elapsed_seconds.count() << std::endl;
+ 
     // std::cout << "elapsed_ FIO = " << elapsed_seconds.count() << std::endl;
 
     std::cout << "CommunicationCost: " << communication_cost/1024 << " KiB" << std::endl;