Browse Source

Generate and sign (Vanilla Onion Routing) consensuses every epoch

Ian Goldberg 4 years ago
parent
commit
ec464f3f0e
3 changed files with 118 additions and 17 deletions
  1. 87 12
      dirauth.py
  2. 28 5
      network.py
  3. 3 0
      relay.py

+ 87 - 12
dirauth.py

@@ -18,7 +18,7 @@ class RelayDescriptor:
         self.descdict = descdict
 
     def __str__(self, withsig = True):
-        res = "RelayDesc[\n"
+        res = "RelayDesc [\n"
         for k in ["epoch", "idkey", "onionkey", "addr", "bw", "flags",
                     "vrfkey", "sig"]:
             if k in self.descdict:
@@ -41,6 +41,52 @@ class RelayDescriptor:
         serialized = self.__str__(False)
         self.descdict["idkey"].verify(serialized.encode("ascii"), self.descdict["sig"])
 
+# A consensus is a dict containing:
+#  epoch: epoch id
+#  numrelays: total number of relays
+#  totbw: total bandwidth of relays
+#  relays: list of relay descriptors (Vanilla Onion Routing only)
+#  sigs: list of signatures from the dirauths
+class Consensus:
+    def __init__(self, epoch, relays):
+        relays = [ d for d in relays if d.descdict['epoch'] == epoch ]
+        self.consdict = dict()
+        self.consdict['epoch'] = epoch
+        self.consdict['numrelays'] = len(relays)
+        self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
+        self.consdict['relays'] = relays
+
+    def __str__(self, withsigs = True):
+        res = "Consensus [\n"
+        for k in ["epoch", "numrelays", "totbw"]:
+            if k in self.consdict:
+                res += "  " + k + ": " + str(self.consdict[k]) + "\n"
+        for r in self.consdict['relays']:
+            res += str(r)
+        if withsigs and ('sigs' in self.consdict):
+            for s in self.consdict['sigs']:
+                res += "  sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
+        res += "]\n"
+        return res
+
+    def sign(self, signingkey, index):
+        """Use the given signing key to sign the consensus, storing the
+        result in the sigs list at the given index."""
+        serialized = self.__str__(False)
+        signed = signingkey.sign(serialized.encode("ascii"))
+        if 'sigs' not in self.consdict:
+            self.consdict['sigs'] = []
+        if index >= len(self.consdict['sigs']):
+            self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
+        self.consdict['sigs'][index] = signed.signature
+
+    def verify(self, verifkeylist):
+        """Use the given list of verification keys to check the
+        signatures on the consensus."""
+        serialized = self.__str__(False)
+        for i, vk in enumerate(verifkeylist):
+            vk.verify(serialized.encode("ascii"), self.consdict['sigs'][i])
+
 class DirAuthNetMsg(network.NetMsg):
     """The subclass of NetMsg for messages to and from directory
     authorities."""
@@ -105,9 +151,12 @@ class DirAuth(network.Server):
     # We simulate the act of computing the consensus by keeping a
     # class-static dict that's accessible to all of the dirauths
     # This dict is indexed by epoch, and the value is itself a dict
-    # indexed by the stringified descriptor, with value of the number of
-    # dirauths that saw that descriptor.
+    # indexed by the stringified descriptor, with value a pair of (the
+    # number of dirauths that saw that descriptor, the descriptor
+    # itself).
     uploadeddescs = dict()
+    consensus = None
+    endive = None
 
     def __init__(self, me, tot):
         """Create a new directory authority. me is the index of which
@@ -116,9 +165,12 @@ class DirAuth(network.Server):
         self.me = me
         self.tot = tot
         self.name = "Dirauth %d of %d" % (me+1, tot)
-        self.consensus = None
-        self.endive = None
-        network.thenetwork.wantepochticks(self, True)
+
+        # Create the dirauth signature keypair
+        self.sigkey = nacl.signing.SigningKey.generate()
+
+        network.thenetwork.setdirauthkey(me, self.sigkey.verify_key)
+        network.thenetwork.wantepochticks(self, True, True)
 
     def connected(self, client):
         """Callback invoked when a client connects to us. This callback
@@ -130,8 +182,29 @@ class DirAuth(network.Server):
         # particularly simple.
         return DirAuthConnection(self)
 
-    def newepoch(self, epoch):
-        print('New epoch', epoch, 'for', self)
+    def generate_consensus(self, epoch):
+        """Generate the consensus (and ENDIVE, if using Walking Onions)
+        for the given epoch, which should be the one after the one
+        that's currently about to end."""
+        threshold = int(self.tot/2)+1
+        consensusdescs = []
+        for numseen, desc in DirAuth.uploadeddescs[epoch].values():
+            if numseen >= threshold:
+                consensusdescs.append(desc)
+        DirAuth.consensus = Consensus(epoch, consensusdescs)
+
+
+    def epoch_ending(self, epoch):
+        # Only dirauth 0 actually needs to generate the consensus
+        # because of the shared class-static state, but everyone has to
+        # sign it.  Note that this code relies on dirauth 0's
+        # epoch_ending callback being called before any of the other
+        # dirauths'.
+        if self.me == 0:
+            self.generate_consensus(epoch+1)
+            del DirAuth.uploadeddescs[epoch+1]
+        DirAuth.consensus.sign(self.sigkey, self.me)
+        print(DirAuth.consensus)
 
     def received(self, client, msg):
         if isinstance(msg, DirAuthUploadDescMsg):
@@ -144,13 +217,15 @@ class DirAuth(network.Server):
                 DirAuth.uploadeddescs[epoch] = dict()
             descstr = str(msg.desc)
             if descstr not in DirAuth.uploadeddescs[epoch]:
-                DirAuth.uploadeddescs[epoch][descstr] = 1
+                DirAuth.uploadeddescs[epoch][descstr] = (1, msg.desc)
             else:
-                DirAuth.uploadeddescs[epoch][descstr] += 1
+                DirAuth.uploadeddescs[epoch][descstr] = \
+                    (DirAuth.uploadeddescs[epoch][descstr][0]+1,
+                     DirAuth.uploadeddescs[epoch][descstr][1])
         elif isinstance(msg, DirAuthGetConsensusMsg):
-            client.sendmsg(DirAuthConsensusMsg(self.consensus))
+            client.sendmsg(DirAuthConsensusMsg(DirAuth.consensus))
         elif isinstance(msg, DirAuthGetENDIVEMsg):
-            client.sendmsg(DirAuthENDIVEMsg(self.endive))
+            client.sendmsg(DirAuthENDIVEMsg(DirAuth.endive))
         else:
             raise TypeError('Not a client-originating DirAuthNetMsg', msg)
 

+ 28 - 5
network.py

@@ -28,6 +28,8 @@ class Network:
         self.servers = dict()
         self.epoch = 1
         self.epochcallbacks = []
+        self.epochendingcallbacks = []
+        self.dirauthkeylist = []
 
     def printservers(self):
         """Print the list of NetAddrs bound to something."""
@@ -35,27 +37,48 @@ class Network:
         for a in self.servers.keys():
             print(a)
 
+    def setdirauthkey(self, index, vk):
+        """Set the public verification key for dirauth number index to
+        vk."""
+        if index >= len(self.dirauthkeylist):
+            self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
+        self.dirauthkeylist[index] = vk
+
+    def dirauthkeys(self):
+        """Return the list of dirauth public verification keys."""
+        return self.dirauthkeylist
+
     def getepoch(self):
         """Return the current epoch."""
         return self.epoch
 
     def nextepoch(self):
         """Increment the current epoch, and return it."""
+        for c in self.epochendingcallbacks:
+            c.epoch_ending(self.epoch)
         self.epoch += 1
         for c in self.epochcallbacks:
             c.newepoch(self.epoch)
         return self.epoch
 
-    def wantepochticks(self, callback, want):
+    def wantepochticks(self, callback, want, end=False):
         """Register or deregister an object from receiving epoch change
         callbacks.  If want is True, the callback object's newepoch()
         method will be called at each epoch change, with an argument of
         the new epoch.  If want if False, the callback object will be
-        deregistered."""
-        if want:
-            self.epochcallbacks.append(callback)
+        deregistered.  If end is True, the callback object's
+        epoch_ending() method will be called instead at the end of the
+        epoch, just _before_ the epoch number change."""
+        if end:
+            if want:
+                self.epochendingcallbacks.append(callback)
+            else:
+                self.epochendingcallbacks.remove(callback)
         else:
-            self.epochcallbacks.remove(callback)
+            if want:
+                self.epochcallbacks.append(callback)
+            else:
+                self.epochcallbacks.remove(callback)
 
     def bind(self, server):
         """Bind a server to a newly generated NetAddr, returning the

+ 3 - 0
relay.py

@@ -70,4 +70,7 @@ if __name__ == '__main__':
 
     # Tick the epoch
     network.thenetwork.nextepoch()
+
+    dirauth.DirAuth.consensus.verify(network.thenetwork.dirauthkeys())
+
     print('ticked; epoch=', network.thenetwork.getepoch())