Selaa lähdekoodia

Multithread RDPF creation when it makes sense

Ian Goldberg 1 vuosi sitten
vanhempi
commit
0fa58457c5
1 muutettua tiedostoa jossa 282 lisäystä ja 26 poistoa
  1. 282 26
      rdpf.cpp

+ 282 - 26
rdpf.cpp

@@ -20,6 +20,101 @@ static value_t inverse_value_t(value_t x)
     return xe;
 }
 
+#undef RDPF_MTGEN_TIMING_1
+
+#ifdef RDPF_MTGEN_TIMING_1
+// Timing tests for multithreaded generation of RDPFs
+// nthreads = 0 to not launch threads at all
+// run for num_iters iterations, output the number of millisections
+// total for all of the iterations
+//
+// Results: roughly 50 µs to launch the thread pool with 1 thread, and
+// roughly 30 additional µs for each additional thread.  Each iteration
+// of the inner loop takes about 4 to 5 ns.  This works out to around
+// level 19 where it starts being worth it to multithread, and you
+// should use at most sqrt(2^{level}/6000) threads.
+static void mtgen_timetest_1(nbits_t level, int nthreads,
+    size_t num_iters, const DPFnode *curlevel,
+    DPFnode *nextlevel, size_t &aes_ops)
+{
+    if (num_iters == 0) {
+        num_iters = 1;
+    }
+    size_t prev_aes_ops = aes_ops;
+    DPFnode L = _mm_setzero_si128();
+    DPFnode R = _mm_setzero_si128();
+    // The tweak causes us to compute something slightly different every
+    // iteration of the loop, so that the compiler doesn't notice we're
+    // doing the same thing num_iters times and optimize it away
+    DPFnode tweak = _mm_setzero_si128();
+    auto start = boost::chrono::steady_clock::now();
+    for(size_t iter=0;iter<num_iters;++iter) {
+        tweak += 1;  // This actually adds the 128-bit value whose high
+                     // and low 64-bits words are both 1, but that's
+                     // fine.
+        size_t curlevel_size = size_t(1)<<level;
+        if (nthreads == 0) {
+            size_t laes_ops = 0;
+            for(size_t i=0;i<curlevel_size;++i) {
+                DPFnode lchild, rchild;
+                prgboth(lchild, rchild, curlevel[i]^tweak, laes_ops);
+                L = (L ^ lchild);
+                R = (R ^ rchild);
+                nextlevel[2*i] = lchild;
+                nextlevel[2*i+1] = rchild;
+            }
+            aes_ops += laes_ops;
+        } else {
+            DPFnode tL[nthreads];
+            DPFnode tR[nthreads];
+            size_t taes_ops[nthreads];
+            size_t threadstart = 0;
+            size_t threadchunk = curlevel_size / nthreads;
+            size_t threadextra = curlevel_size % nthreads;
+            boost::asio::thread_pool pool(nthreads);
+            for (int t=0;t<nthreads;++t) {
+                size_t threadsize = threadchunk + (size_t(t) < threadextra);
+                size_t threadend = threadstart + threadsize;
+                boost::asio::post(pool,
+                    [t, &tL, &tR, &taes_ops, threadstart, threadend,
+                    &curlevel, &nextlevel, tweak] {
+                        DPFnode L = _mm_setzero_si128();
+                        DPFnode R = _mm_setzero_si128();
+                        size_t aes_ops = 0;
+                        for(size_t i=threadstart;i<threadend;++i) {
+                            DPFnode lchild, rchild;
+                            prgboth(lchild, rchild, curlevel[i]^tweak, aes_ops);
+                            L = (L ^ lchild);
+                            R = (R ^ rchild);
+                            nextlevel[2*i] = lchild;
+                            nextlevel[2*i+1] = rchild;
+                        }
+                        tL[t] = L;
+                        tR[t] = R;
+                        taes_ops[t] = aes_ops;
+                    });
+                threadstart = threadend;
+            }
+            pool.join();
+            for (int t=0;t<nthreads;++t) {
+                L ^= tL[t];
+                R ^= tR[t];
+                aes_ops += taes_ops[t];
+            }
+        }
+    }
+    auto elapsed =
+        boost::chrono::steady_clock::now() - start;
+    std::cout << "timetest_1 " << int(level) << " " << nthreads << " "
+        << num_iters << " " << boost::chrono::duration_cast
+        <boost::chrono::milliseconds>(elapsed) << " " <<
+        (aes_ops-prev_aes_ops) << " AES\n";
+    dump_node(L);
+    dump_node(R);
+}
+
+#endif
+
 // Construct a DPF with the given (XOR-shared) target location, and
 // of the given depth, to be used for random-access memory reads and
 // writes.  The DPF is construction collaboratively by P0 and P1,
@@ -70,15 +165,80 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
         // need to execute mpc_reconstruct_choice so that it sends
         // the AndTriples at the appropriate time.
         if (player < 2) {
-            for(size_t i=0;i<curlevel_size;++i) {
-                DPFnode lchild, rchild;
-                prgboth(lchild, rchild, curlevel[i], aes_ops);
-                L = (L ^ lchild);
-                R = (R ^ rchild);
-                if (nextlevel) {
+#ifdef RDPF_MTGEN_TIMING_1
+            if (player == 0) {
+                mtgen_timetest_1(level, 0, (1<<23)>>level, curlevel,
+                    nextlevel, aes_ops);
+                size_t niters = 2048;
+                if (level > 8) niters = (1<<20)>>level;
+                for(int t=1;t<=8;++t) {
+                    mtgen_timetest_1(level, t, niters, curlevel,
+                        nextlevel, aes_ops);
+                }
+                mtgen_timetest_1(level, 0, (1<<23)>>level, curlevel,
+                    nextlevel, aes_ops);
+            }
+#endif
+            // Using the timing results gathered above, decide whether
+            // to multithread, and if so, how many threads to use.
+            // tio.cpu_nthreads() is the maximum number we have
+            // available.
+            int max_nthreads = tio.cpu_nthreads();
+            if (max_nthreads == 1 || level < 19) {
+                // No threading
+                size_t laes_ops = 0;
+                for(size_t i=0;i<curlevel_size;++i) {
+                    DPFnode lchild, rchild;
+                    prgboth(lchild, rchild, curlevel[i], laes_ops);
+                    L = (L ^ lchild);
+                    R = (R ^ rchild);
                     nextlevel[2*i] = lchild;
                     nextlevel[2*i+1] = rchild;
                 }
+                aes_ops += laes_ops;
+            } else {
+                size_t curlevel_size = size_t(1)<<level;
+                int nthreads =
+                    int(ceil(sqrt(double(curlevel_size/6000))));
+                if (nthreads > max_nthreads) {
+                    nthreads = max_nthreads;
+                }
+                DPFnode tL[nthreads];
+                DPFnode tR[nthreads];
+                size_t taes_ops[nthreads];
+                size_t threadstart = 0;
+                size_t threadchunk = curlevel_size / nthreads;
+                size_t threadextra = curlevel_size % nthreads;
+                boost::asio::thread_pool pool(nthreads);
+                for (int t=0;t<nthreads;++t) {
+                    size_t threadsize = threadchunk + (size_t(t) < threadextra);
+                    size_t threadend = threadstart + threadsize;
+                    boost::asio::post(pool,
+                        [t, &tL, &tR, &taes_ops, threadstart, threadend,
+                        &curlevel, &nextlevel] {
+                            DPFnode L = _mm_setzero_si128();
+                            DPFnode R = _mm_setzero_si128();
+                            size_t aes_ops = 0;
+                            for(size_t i=threadstart;i<threadend;++i) {
+                                DPFnode lchild, rchild;
+                                prgboth(lchild, rchild, curlevel[i], aes_ops);
+                                L = (L ^ lchild);
+                                R = (R ^ rchild);
+                                nextlevel[2*i] = lchild;
+                                nextlevel[2*i+1] = rchild;
+                            }
+                            tL[t] = L;
+                            tR[t] = R;
+                            taes_ops[t] = aes_ops;
+                        });
+                    threadstart = threadend;
+                }
+                pool.join();
+                for (int t=0;t<nthreads;++t) {
+                    L ^= tL[t];
+                    R ^= tR[t];
+                    aes_ops += taes_ops[t];
+                }
             }
         }
         // If we're going left (bs_choice = 0), we want the correction
@@ -148,11 +308,47 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
         cfbits |= (value_t(parity_bit)<<level);
         DPFnode CWR = CW ^ lsb128_mask[parity_bit];
         if (player < 2) {
+            // The timing of each iteration of the inner loop is
+            // comparable to the above, so just use the same
+            // computations.  All of this could be tuned, of course.
+
             if (level < depth-1) {
-                for(size_t i=0;i<curlevel_size;++i) {
-                    bool flag = get_lsb(curlevel[i]);
-                    nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
-                    nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
+                // Using the timing results gathered above, decide whether
+                // to multithread, and if so, how many threads to use.
+                // tio.cpu_nthreads() is the maximum number we have
+                // available.
+                int max_nthreads = tio.cpu_nthreads();
+                if (max_nthreads == 1 || level < 19) {
+                    // No threading
+                    for(size_t i=0;i<curlevel_size;++i) {
+                        bool flag = get_lsb(curlevel[i]);
+                        nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
+                        nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
+                    }
+                } else {
+                    int nthreads =
+                        int(ceil(sqrt(double(curlevel_size/6000))));
+                    if (nthreads > max_nthreads) {
+                        nthreads = max_nthreads;
+                    }
+                    size_t threadstart = 0;
+                    size_t threadchunk = curlevel_size / nthreads;
+                    size_t threadextra = curlevel_size % nthreads;
+                    boost::asio::thread_pool pool(nthreads);
+                    for (int t=0;t<nthreads;++t) {
+                        size_t threadsize = threadchunk + (size_t(t) < threadextra);
+                        size_t threadend = threadstart + threadsize;
+                        boost::asio::post(pool, [CW, CWR, threadstart, threadend,
+                            &curlevel, &nextlevel] {
+                                for(size_t i=threadstart;i<threadend;++i) {
+                                    bool flag = get_lsb(curlevel[i]);
+                                    nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
+                                    nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
+                                }
+                        });
+                        threadstart = threadend;
+                    }
+                    pool.join();
                 }
             } else {
                 // Recall there are four potentially useful vectors that
@@ -196,23 +392,83 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
                 value_t low_sum = 0;
                 value_t high_sum = 0;
                 value_t high_xor = 0;
-                for(size_t i=0;i<curlevel_size;++i) {
-                    bool flag = get_lsb(curlevel[i]);
-                    DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
-                    DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
-                    if (save_expansion) {
-                        nextlevel[2*i] = leftchild;
-                        nextlevel[2*i+1] = rightchild;
+                // Using the timing results gathered above, decide whether
+                // to multithread, and if so, how many threads to use.
+                // tio.cpu_nthreads() is the maximum number we have
+                // available.
+                int max_nthreads = tio.cpu_nthreads();
+                if (max_nthreads == 1 || level < 19) {
+                    // No threading
+                    for(size_t i=0;i<curlevel_size;++i) {
+                        bool flag = get_lsb(curlevel[i]);
+                        DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
+                        DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
+                        if (save_expansion) {
+                            nextlevel[2*i] = leftchild;
+                            nextlevel[2*i+1] = rightchild;
+                        }
+                        value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
+                        value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
+                        value_t lefthigh =
+                            value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
+                        value_t righthigh =
+                            value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
+                        low_sum += (leftlow + rightlow);
+                        high_sum += (lefthigh + righthigh);
+                        high_xor ^= (lefthigh ^ righthigh);
+                    }
+                } else {
+                    int nthreads =
+                        int(ceil(sqrt(double(curlevel_size/6000))));
+                    if (nthreads > max_nthreads) {
+                        nthreads = max_nthreads;
+                    }
+                    value_t tlow_sum[nthreads];
+                    value_t thigh_sum[nthreads];
+                    value_t thigh_xor[nthreads];
+                    size_t threadstart = 0;
+                    size_t threadchunk = curlevel_size / nthreads;
+                    size_t threadextra = curlevel_size % nthreads;
+                    boost::asio::thread_pool pool(nthreads);
+                    for (int t=0;t<nthreads;++t) {
+                        size_t threadsize = threadchunk + (size_t(t) < threadextra);
+                        size_t threadend = threadstart + threadsize;
+                        boost::asio::post(pool,
+                            [t, &tlow_sum, &thigh_sum, &thigh_xor, threadstart, threadend,
+                            &curlevel, &nextlevel, CW, CWR, save_expansion] {
+                                value_t low_sum = 0;
+                                value_t high_sum = 0;
+                                value_t high_xor = 0;
+                                for(size_t i=threadstart;i<threadend;++i) {
+                                    bool flag = get_lsb(curlevel[i]);
+                                    DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
+                                    DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
+                                    if (save_expansion) {
+                                        nextlevel[2*i] = leftchild;
+                                        nextlevel[2*i+1] = rightchild;
+                                    }
+                                    value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
+                                    value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
+                                    value_t lefthigh =
+                                        value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
+                                    value_t righthigh =
+                                        value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
+                                    low_sum += (leftlow + rightlow);
+                                    high_sum += (lefthigh + righthigh);
+                                    high_xor ^= (lefthigh ^ righthigh);
+                                }
+                                tlow_sum[t] = low_sum;
+                                thigh_sum[t] = high_sum;
+                                thigh_xor[t] = high_xor;
+                            });
+                        threadstart = threadend;
+                    }
+                    pool.join();
+                    for (int t=0;t<nthreads;++t) {
+                        low_sum += tlow_sum[t];
+                        high_sum += thigh_sum[t];
+                        high_xor ^= thigh_xor[t];
                     }
-                    value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
-                    value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
-                    value_t lefthigh =
-                        value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
-                    value_t righthigh =
-                        value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
-                    low_sum += (leftlow + rightlow);
-                    high_sum += (lefthigh + righthigh);
-                    high_xor ^= (lefthigh ^ righthigh);
                 }
                 if (player == 1) {
                     low_sum = -low_sum;