Browse Source

put a lock around the memoizer stuff, removed segfault

tristangurtler 3 years ago
parent
commit
da7f1d6f7f
4 changed files with 124 additions and 11 deletions
  1. 1 1
      prsona/Makefile
  2. 16 1
      prsona/inc/server.hpp
  3. 105 6
      prsona/src/server.cpp
  4. 2 3
      prsona/src/serverEntity.cpp

+ 1 - 1
prsona/Makefile

@@ -28,7 +28,7 @@ MG_INC_PATH = $(MG_PATH)/include
 MG_OBJ_PATH = $(MG_PATH)/out/src
 
 CPP = g++
-CPPFLAGS = -std=c++14 -Wall -I$(PRSONA_INC_PATH) -I$(BGN_INC_PATH) -I$(666_INC_PATH) -I$(MG_INC_PATH) -O1
+CPPFLAGS = -std=c++14 -Wall -I$(PRSONA_INC_PATH) -I$(BGN_INC_PATH) -I$(666_INC_PATH) -I$(MG_INC_PATH) -O2
 CPPTESTFLAGS = -std=c++14 -Wall -I$(PRSONA_INC_PATH) -I$(BGN_INC_PATH) -I$(666_INC_PATH) -I$(MG_INC_PATH) -g
 LDFLAGS = -lgmp -lgmpxx -lssl -lcrypto -lpthread
 NETWORK_LDFLAGS = -ldl

+ 16 - 1
prsona/inc/server.hpp

@@ -2,6 +2,7 @@
 #define __PRSONA_SERVER_HPP
 
 #include <vector>
+#include <mutex>
 
 #include "BGN.hpp"
 #include "Curvepoint.hpp"
@@ -20,6 +21,16 @@ class PrsonaServer : public PrsonaBase {
             size_t numServers,
             const BGN& other_bgn);
 
+        PrsonaServer(
+            const PrsonaServer& other);
+        PrsonaServer(
+            PrsonaServer&& other);
+        PrsonaServer &operator=(
+            const PrsonaServer& other);
+        PrsonaServer &operator=(
+            PrsonaServer&& other);
+        ~PrsonaServer();
+
         // BASIC PUBLIC SYSTEM INFO GETTERS
         BGNPublicKey get_bgn_public_key() const;
         size_t get_num_clients() const;
@@ -356,9 +367,10 @@ class PrsonaServer : public PrsonaBase {
 
     private:
         // constants for servers
-        const size_t numServers;
+        size_t numServers;
 
         // Identical between all servers (but collaboratively constructed)
+        std::mutex *decryptMtx;
         BGN bgnSystem;
 
         // Private; different for each server
@@ -373,6 +385,9 @@ class PrsonaServer : public PrsonaBase {
         std::vector<EGCiphertext> currentUserEncryptedTallies;
         std::vector<std::vector<TwistBipoint>> voteMatrix;
 
+        void remove();
+        void copy(const PrsonaServer& other);
+
         /**
          * NOTE: voteMatrix structure:
          * Each element represents a vote by <rowID> applied to <colID>.

+ 105 - 6
prsona/src/server.cpp

@@ -17,6 +17,8 @@ PrsonaServer::PrsonaServer(
 : numServers(numServers)
 {
     currentSeed.set_random();
+
+    decryptMtx = new std::mutex();
 }
 
 // Used for all other servers, so they have the same BGN parameters
@@ -26,6 +28,98 @@ PrsonaServer::PrsonaServer(
 : numServers(numServers), bgnSystem(otherBgn)
 {
     currentSeed.set_random();
+
+    decryptMtx = new std::mutex();
+}
+
+PrsonaServer::PrsonaServer(
+    const PrsonaServer& other)
+{
+    copy(other);
+}
+
+PrsonaServer::PrsonaServer(
+    PrsonaServer&& other)
+{
+    numServers = std::move(other.numServers);
+    
+    decryptMtx = other.decryptMtx;
+    other.decryptMtx = NULL;
+
+    bgnSystem = std::move(other.bgnSystem);
+    currentSeed = std::move(other.currentSeed);
+    nextSeed = std::move(other.nextSeed);
+    currentGeneratorProof = std::move(other.currentGeneratorProof);
+    currentFreshGenerator = std::move(other.currentFreshGenerator);
+    previousVoteTallies = std::move(other.previousVoteTallies);
+    currentPseudonyms = std::move(other.currentPseudonyms);
+    currentUserEncryptedTallies = std::move(other.currentUserEncryptedTallies);
+    voteMatrix = std::move(other.voteMatrix);
+}
+
+PrsonaServer &PrsonaServer::operator=(
+    const PrsonaServer& other)
+{
+    if (&other != this)
+    {
+        remove();
+        copy(other);
+    }
+
+    return *this;
+}
+
+PrsonaServer &PrsonaServer::operator=(
+    PrsonaServer&& other)
+{
+    if (&other != this)
+    {
+        remove();
+
+        numServers = std::move(other.numServers);
+    
+        decryptMtx = other.decryptMtx;
+        other.decryptMtx = new std::mutex();
+
+        bgnSystem = std::move(other.bgnSystem);
+        currentSeed = std::move(other.currentSeed);
+        nextSeed = std::move(other.nextSeed);
+        currentGeneratorProof = std::move(other.currentGeneratorProof);
+        currentFreshGenerator = std::move(other.currentFreshGenerator);
+        previousVoteTallies = std::move(other.previousVoteTallies);
+        currentPseudonyms = std::move(other.currentPseudonyms);
+        currentUserEncryptedTallies = std::move(other.currentUserEncryptedTallies);
+        voteMatrix = std::move(other.voteMatrix);
+    }
+
+    return *this;
+}
+
+PrsonaServer::~PrsonaServer()
+{
+    remove();
+}
+
+void PrsonaServer::copy(
+    const PrsonaServer& other)
+{
+    numServers = other.numServers;
+    decryptMtx = new std::mutex();
+    bgnSystem = other.bgnSystem;
+    currentSeed = other.currentSeed;
+    nextSeed = other.nextSeed;
+    currentGeneratorProof = other.currentGeneratorProof;
+    currentFreshGenerator = other.currentFreshGenerator;
+    previousVoteTallies = other.previousVoteTallies;
+    currentPseudonyms = other.currentPseudonyms;
+    currentUserEncryptedTallies = other.currentUserEncryptedTallies;
+    voteMatrix = other.voteMatrix;
+}
+
+void PrsonaServer::remove()
+{
+    delete decryptMtx;
+    decryptMtx = NULL;
 }
 
 /*
@@ -407,14 +501,16 @@ void homomorphic_addition_r(
 void tally_r(
     void *a,
     void *b,
-    const void *c,
+    void *c,
     const void *d,
+    const void *e,
     size_t i)
 {
-    BGN *bgnSystem = (BGN *) a;
-    Scalar *dst = (Scalar *) b;
-    const std::vector<CurveBipoint> *previousVoteTallies = (const std::vector<CurveBipoint> *) c;
-    const std::vector<std::vector<TwistBipoint>> *voteMatrix = (const std::vector<std::vector<TwistBipoint>> *) d;
+    std::mutex *decryptMtx = (std::mutex *) a;
+    BGN *bgnSystem = (BGN *) b;
+    Scalar *dst = (Scalar *) c;
+    const std::vector<CurveBipoint> *previousVoteTallies = (const std::vector<CurveBipoint> *) d;
+    const std::vector<std::vector<TwistBipoint>> *voteMatrix = (const std::vector<std::vector<TwistBipoint>> *) e;
 
     Quadripoint *weightedVotes = new Quadripoint[previousVoteTallies->size()];
     std::vector<std::thread> parallelizedMults;
@@ -429,6 +525,7 @@ void tally_r(
     homomorphic_addition_r(bgnSystem, weightedVotes, previousVoteTallies->size());
 
     // DECRYPT
+    std::unique_lock<std::mutex> lck(*decryptMtx);
     *dst = bgnSystem->decrypt(weightedVotes[0]);
 
     delete [] weightedVotes;
@@ -446,7 +543,7 @@ std::vector<Scalar> PrsonaServer::tally_scores()
     std::vector<std::thread> parallelizedTallies;
 
     for (size_t i = 0; i < voteMatrix.size(); i++)
-        parallelizedTallies.push_back(std::thread(tally_r, &bgnSystem, decryptedTallies + i, &previousVoteTallies, &voteMatrix, i));
+        parallelizedTallies.push_back(std::thread(tally_r, decryptMtx, &bgnSystem, decryptedTallies + i, &previousVoteTallies, &voteMatrix, i));
     for (size_t i = 0; i < parallelizedTallies.size(); i++)
         parallelizedTallies[i].join();
 
@@ -470,6 +567,7 @@ Scalar PrsonaServer::get_max_possible_score()
         currEncryptedVal = bgnSystem.homomorphic_addition_no_rerandomize(currEncryptedVal, previousVoteTallies[i]);
 
     // DECRYPT
+    std::unique_lock<std::mutex> lck(*decryptMtx);
     Scalar retval = bgnSystem.decrypt(currEncryptedVal);
 
     return retval;
@@ -1826,6 +1924,7 @@ bool PrsonaServer::verify_vote_proof(
 void PrsonaServer::print_scores(
     const std::vector<CurveBipoint>& scores)
 {
+    std::unique_lock<std::mutex> lck(*decryptMtx);
     std::cout << "[";
     for (size_t i = 0; i < scores.size(); i++)
         std::cout << bgnSystem.decrypt(scores[i]) << (i == scores.size() - 1 ? "]" : " ");

+ 2 - 3
prsona/src/serverEntity.cpp

@@ -17,11 +17,10 @@ PrsonaServerEntity::PrsonaServerEntity(
         std::cerr << "You have to have at least 1 server. I'm making it anyways." << std::endl;
 
     // Make the first server, which makes the BGN parameters
-    PrsonaServer firstServer(numServers);
-    servers.push_back(firstServer);
+    servers.push_back(PrsonaServer(numServers));
 
     // Make the rest of the servers, which take the BGN parameters
-    const BGN& sharedBGN = firstServer.get_bgn_details();
+    const BGN& sharedBGN = servers[0].get_bgn_details();
     for (size_t i = 1; i < numServers; i++)
         servers.push_back(PrsonaServer(numServers, sharedBGN));