Browse Source

Make the preprocessing communication cost accounting thread-safe

Ian Goldberg 1 year ago
parent
commit
0bb2a5291b
2 changed files with 19 additions and 8 deletions
  1. 5 5
      preprocessing/dpfgen.h
  2. 14 3
      preprocessing/preprocessing.cpp

+ 5 - 5
preprocessing/dpfgen.h

@@ -46,7 +46,7 @@ struct BlindsCW
 
  
 
-void compute_CW(cw_construction computecw_array, size_t ind, size_t layer,tcp::socket& sout,  __m128i L, __m128i R, uint8_t bit, __m128i & CW, uint8_t &cwt_L, uint8_t &cwt_R)
+void compute_CW(cw_construction computecw_array, size_t ind, size_t layer,tcp::socket& sout,  __m128i L, __m128i R, uint8_t bit, __m128i & CW, uint8_t &cwt_L, uint8_t &cwt_R, size_t &thread_communication_cost)
 {
 
  	reconstructioncw cwsent, cwrecv;
@@ -66,7 +66,7 @@ void compute_CW(cw_construction computecw_array, size_t ind, size_t layer,tcp::s
 
 	//exchange blinded shares for OSWAP.
   boost::asio::write(sout, boost::asio::buffer(&blinds_sent, sizeof(BlindsCW)));
-  communication_cost += sizeof(BlindsCW);
+  thread_communication_cost += sizeof(BlindsCW);
 
 	boost::asio::read(sout, boost::asio::buffer(&blinds_recv, sizeof(BlindsCW)));
 	
@@ -90,7 +90,7 @@ void compute_CW(cw_construction computecw_array, size_t ind, size_t layer,tcp::s
 
  boost::asio::write(sout, boost::asio::buffer(&cwsent, sizeof(cwsent)));
  boost::asio::read(sout, boost::asio::buffer(&cwrecv, sizeof(cwrecv)));
- communication_cost += sizeof(cwsent);
+ thread_communication_cost += sizeof(cwsent);
  cwrecv.cw ^= cwsent.cw;
  cwrecv.cwbit[0] ^= (cwsent.cwbit[0] ^ 1);
  cwrecv.cwbit[1] ^= (cwsent.cwbit[1]);
@@ -162,7 +162,7 @@ static inline void traverse(const prgkey_t & prgkey, const node_t & seed,	node_t
 inline void create_dpfs (bool reading,  size_t db_nitems, const AES_KEY& prgkey,  
                          uint8_t target_share[64], std::vector<socket_t>& socketsPb, std::vector<socket_t>& socketsP2, const size_t from, const size_t to, __m128i * output, int8_t * _t, __m128i& final_correction_word,  
 						                   cw_construction computecw_array, dpfP2 * dpf_instance, 
-                         bool party, size_t socket_no, size_t ind = 0)
+                         bool party, size_t socket_no, size_t ind, size_t &thread_communication_cost)
 { 
 	const size_t bits_per_leaf = std::is_same<leaf_t, bool>::value ? 1 : sizeof(leaf_t) * CHAR_BIT;
 	const bool  is_packed = (sizeof(leaf_t) < sizeof(node_t));
@@ -245,7 +245,7 @@ inline void create_dpfs (bool reading,  size_t db_nitems, const AES_KEY& prgkey,
   		uint8_t cwt_L, cwt_R;
 			
 			// Computes the correction word using OSWAP
-			compute_CW(computecw_array, ind, layer,  socketsPb[socket_no],  L,  R, target_share[layer], CW[layer], cwt_L,  cwt_R);
+			compute_CW(computecw_array, ind, layer,  socketsPb[socket_no],  L,  R, target_share[layer], CW[layer], cwt_L,  cwt_R, thread_communication_cost);
 			
 			#ifdef DEBUG
 				if(ind == 0) 

+ 14 - 3
preprocessing/preprocessing.cpp

@@ -129,6 +129,7 @@ int main(int argc, char * argv[])
 
       
 
+     size_t *thread_communication_costs = new size_t[thread_per_batch];
      for(size_t iter = 0; iter < n_batches; ++iter)
      { 
         if (n_batches > 1) {
@@ -137,11 +138,21 @@ int main(int argc, char * argv[])
         boost::asio::thread_pool pool(thread_per_batch);
         for(size_t j = 0; j < thread_per_batch; ++j)
         {
-         boost::asio::post(pool,  std::bind(create_dpfs, reading,  db_nitems,	std::ref(aeskey),  target_share_read[j],  std::ref(socketsPb), std::ref(socketsP2), 0, db_nitems-1, 
-                                               output[j],  flags[j], std::ref(final_correction_word[j]), computecw_array, std::ref(dpf_instance),  party, 5 * j, j));	 	  
+	  thread_communication_costs[j] = 0; 
+	  boost::asio::post(pool,
+	    std::bind(create_dpfs, reading,  db_nitems, std::ref(aeskey),
+		target_share_read[j], std::ref(socketsPb), std::ref(socketsP2),
+		0, db_nitems-1, output[j],  flags[j],
+		std::ref(final_correction_word[j]), computecw_array,
+		std::ref(dpf_instance), party, 5 * j, j,
+		std::ref(thread_communication_costs[j])));
         }    
-        pool.join();  
+        pool.join();
+        for(size_t j = 0; j < thread_per_batch; ++j) {
+	  communication_cost += thread_communication_costs[j];
+	}
      }
+     delete[] thread_communication_costs;
       
      boost::asio::write(socketsP2[0], boost::asio::buffer(dpf_instance, n_threads * sizeof(dpfP2))); // do this in parallel.
      communication_cost += (n_threads * sizeof(dpfP2));