Browse Source

Complete the update protocol

Ian Goldberg 1 year ago
parent
commit
0de416dfa9
7 changed files with 80 additions and 22 deletions
  1. 1 1
      duoram.hpp
  2. 35 9
      duoram.tcc
  3. 8 3
      online.cpp
  4. 0 2
      preproc.cpp
  5. 1 1
      rdpf.cpp
  6. 1 0
      rdpf.tcc
  7. 34 6
      types.hpp

+ 1 - 1
duoram.hpp

@@ -145,7 +145,7 @@ protected:
     // duoram.p0_blind[indexmap(idx)], etc.)
     inline std::tuple<T&,T&> get_server(size_t idx) const {
         size_t physaddr = indexmap(idx);
-        return std::make_tuple(
+        return std::tie(
             duoram.p0_blind[physaddr],
             duoram.p1_blind[physaddr]);
     }

+ 35 - 9
duoram.tcc

@@ -139,24 +139,50 @@ typename Duoram<T>::Shape::MemRefAS
         auto Mshift = combine(Moffset, peerMoffset);
 
         // Evaluate the DPFs and add them to the database
-        StreamEval ev(dt, indshift, shape.tio.aes_ops());
+        StreamEval ev(dt, -indshift, shape.tio.aes_ops());
         for (size_t i=0; i<shape.shape_size; ++i) {
             auto L = ev.next();
-            shape.get_comp(i) += dt.scaled_as(L) + dt.unit_as(L) * Mshift;
+            // The values from the three DPFs
+            auto [V0, V1, V2] = dt.scaled_as(L) + dt.unit_as(L) * Mshift;
+            // References to the appropriate cells in our database, our
+            // blind, and our copy of the peer's blinded database
+            auto [DB, BL, PBD] = shape.get_comp(i);
+            DB += V0;
+            if (player == 0) {
+                BL -= V1;
+                PBD += V2-V0;
+            } else {
+                BL -= V2;
+                PBD += V1-V0;
+            }
         }
     } else {
         // The server does this
+
         RDPFPair dp = shape.tio.rdpfpair(shape.addr_size);
         RegAS p0indoffset, p1indoffset;
-        RegAS p0Moffset[2];
-        RegAS p1Moffset[2];
+        std::tuple<RegAS,RegAS> p0Moffset, p1Moffset;
+
+        // Receive the index and message offsets from the computational
+        // players and combine them
         shape.tio.recv_p0(&p0indoffset, BITBYTES(shape.addr_size));
-        shape.tio.iostream_p0() >> p0Moffset[0] >> p0Moffset[1];
+        shape.tio.iostream_p0() >> p0Moffset;
         shape.tio.recv_p1(&p1indoffset, BITBYTES(shape.addr_size));
-        shape.tio.iostream_p1() >> p1Moffset[0] >> p1Moffset[1];
-        p0indoffset += p1indoffset;
-        p0Moffset[0] += p1Moffset[0];
-        p0Moffset[1] += p1Moffset[1];
+        shape.tio.iostream_p1() >> p1Moffset;
+        auto indshift = combine(p0indoffset, p1indoffset);
+        auto Mshift = combine(p0Moffset, p1Moffset);
+
+        // Evaluate the DPFs and subtract them from the blinds
+        StreamEval ev(dp, -indshift, shape.tio.aes_ops());
+        for (size_t i=0; i<shape.shape_size; ++i) {
+            auto L = ev.next();
+            // The values from the two DPFs
+            auto V = dp.scaled_as(L) + dp.unit_as(L) * Mshift;
+            // shape.get_server(i) returns a pair of references to the
+            // appropriate cells in the two blinded databases, so we can
+            // subtract the pair directly.
+            shape.get_server(i) -= V;
+        }
     }
     return *this;
 }

+ 8 - 3
online.cpp

@@ -418,24 +418,29 @@ static void duoram_test(MPCIO &mpcio, yield_t &yield,
     const PRACOptions &opts, char **args)
 {
     nbits_t depth=6;
+    address_t share=arc4random();
 
     if (*args) {
         depth = atoi(*args);
         ++args;
     }
+    if (*args) {
+        share = atoi(*args);
+        ++args;
+    }
+    share &= ((address_t(1)<<depth)-1);
 
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth] {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth, share] {
             size_t size = size_t(1)<<depth;
             MPCTIO tio(mpcio, thread_num);
             // size_t &op_counter = tio.aes_ops();
             Duoram<RegAS> oram(mpcio.player, size);
-            printf("%ld\n", oram.size());
             auto A = oram.flat(tio, yield);
             RegAS aidx;
-            aidx.randomize(depth);
+            aidx.ashare = share;
             RegAS M;
             if (tio.player() == 0) {
                 M.ashare = 0xbabb0000;

+ 0 - 2
preproc.cpp

@@ -224,10 +224,8 @@ void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char *
                                 printf("sxr1 = %016lx\n", rdpfpair.dpf[1].scaled_xor.xshare);
                                 printf("dep1 = %d\n", rdpfpair.dpf[1].depth());
                                     if (opts.expand_rdpfs) {
-                                        printf("Expanding\n");
                                         rdpfpair.dpf[0].expand(stio.aes_ops());
                                         rdpfpair.dpf[1].expand(stio.aes_ops());
-                                        printf("Expanded\n");
                                     }
                                     pairfile.os() << rdpfpair;
                                 });

+ 1 - 1
rdpf.cpp

@@ -378,7 +378,7 @@ RDPFTriple::RDPFTriple(MPCTIO &tio, yield_t &yield,
         coroutines.emplace_back(
             [&, i](yield_t &yield) {
                 dpf[i] = RDPF(tio, yield, xs_target, depth,
-                save_expansion);
+                    save_expansion);
             });
     }
     coroutines.emplace_back(

+ 1 - 0
rdpf.tcc

@@ -17,6 +17,7 @@ StreamEval<T>::StreamEval(const T &rdpf, address_t start,
     } else {
         indexmask = ~0;
     }
+    start &= indexmask;
     // Record that we haven't actually output the leaf for index start
     // itself yet
     nextindex = start;

+ 34 - 6
types.hpp

@@ -467,16 +467,44 @@ struct RDPFTripleName { static constexpr const char *name = "r"; };
 
 // Default I/O for various types
 
-// Otherwise the comma is treated as an argument separator
-#define COMMA ,
 DEFAULT_IO(RegBS)
 DEFAULT_IO(RegAS)
-DEFAULT_IO(std::tuple<RegAS COMMA RegAS>)
-DEFAULT_IO(std::tuple<RegAS COMMA RegAS COMMA RegAS>)
 DEFAULT_IO(RegXS)
-DEFAULT_IO(std::tuple<RegXS COMMA RegXS>)
-DEFAULT_IO(std::tuple<RegXS COMMA RegXS COMMA RegXS>)
 DEFAULT_IO(MultTriple)
 DEFAULT_IO(HalfTriple)
 
+// And for pairs and triples
+
+#define DEFAULT_TUPLE_IO(CLASSNAME)                                  \
+    template <typename T>                                            \
+    T& operator>>(T& is, std::tuple<CLASSNAME, CLASSNAME> &x)        \
+    {                                                                \
+        is >> std::get<0>(x) >> std::get<1>(x);                      \
+        return is;                                                   \
+    }                                                                \
+                                                                     \
+    template <typename T>                                            \
+    T& operator<<(T& os, const std::tuple<CLASSNAME, CLASSNAME> &x)  \
+    {                                                                \
+        os << std::get<0>(x) << std::get<1>(x);                      \
+        return os;                                                   \
+    }                                                                \
+                                                                     \
+    template <typename T>                                            \
+    T& operator>>(T& is, std::tuple<CLASSNAME, CLASSNAME, CLASSNAME> &x) \
+    {                                                                \
+        is >> std::get<0>(x) >> std::get<1>(x) >> std::get<2>(x);    \
+        return is;                                                   \
+    }                                                                \
+                                                                     \
+    template <typename T>                                            \
+    T& operator<<(T& os, const std::tuple<CLASSNAME, CLASSNAME, CLASSNAME> &x) \
+    {                                                                \
+        os << std::get<0>(x) << std::get<1>(x) << std::get<2>(x);    \
+        return os;                                                   \
+    }
+
+DEFAULT_TUPLE_IO(RegAS)
+DEFAULT_TUPLE_IO(RegXS)
+
 #endif