Browse Source

Make the functions to verify relay descriptors and consensuses into static methods

These static methods check the exact type of their arguments.  We still
want to avoid someone passing us a "consensus" object that has the wrong
type, and a "verify" method that just returns true.  This way, we call
the definitely correct Consensus.verify method, passing it the claimed
consensus.  The method will check the type and verify the signature.
Ian Goldberg 4 years ago
parent
commit
e8b2e741c6
3 changed files with 18 additions and 17 deletions
  1. 2 2
      client.py
  2. 13 12
      dirauth.py
  3. 3 3
      relay.py

+ 2 - 2
client.py

@@ -52,7 +52,7 @@ class CellClient(relay.CellHandler):
         print("Client %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
         if isinstance(msg, relay.RelayConsensusMsg):
             self.consensus = msg.consensus
-            dirauth.verify_consensus(self.consensus, network.thenetwork.dirauthkeys())
+            dirauth.Consensus.verify(self.consensus, network.thenetwork.dirauthkeys())
             self.consensus_cdf = self.consensus.bw_cdf()
         else:
             return super().received_msg(msg, peeraddr, peer)
@@ -138,7 +138,7 @@ if __name__ == '__main__':
     # Tick the epoch
     network.thenetwork.nextepoch()
 
-    dirauth.verify_consensus(dirauth.DirAuth.consensus, network.thenetwork.dirauthkeys())
+    dirauth.Consensus.verify(dirauth.DirAuth.consensus, network.thenetwork.dirauthkeys())
 
     print('ticked; epoch=', network.thenetwork.getepoch())
 

+ 13 - 12
dirauth.py

@@ -40,10 +40,11 @@ class RelayDescriptor:
         signed = signingkey.sign(serialized.encode("ascii"))
         self.descdict["sig"] = signed.signature
 
-
-def verify_relaydesc(desc):
-    serialized = desc.__str__(False)
-    desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
+    @staticmethod
+    def verify(desc):
+        assert(type(desc) is RelayDescriptor)
+        serialized = desc.__str__(False)
+        desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
 
 
 # A consensus is a dict containing:
@@ -108,14 +109,14 @@ class Consensus:
         idx = bisect.bisect_right(cdf, val)
         return self.consdict['relays'][idx-1]
 
-
-
-def verify_consensus(consensus, verifkeylist):
-    """Use the given list of verification keys to check the
-    signatures on the consensus."""
-    serialized = consensus.__str__(False)
-    for i, vk in enumerate(verifkeylist):
-        vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
+    @staticmethod
+    def verify(consensus, verifkeylist):
+        """Use the given list of verification keys to check the
+        signatures on the consensus."""
+        assert(type(consensus) is Consensus)
+        serialized = consensus.__str__(False)
+        for i, vk in enumerate(verifkeylist):
+            vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
 
 
 class DirAuthNetMsg(network.NetMsg):

+ 3 - 3
relay.py

@@ -169,7 +169,7 @@ class CellRelay(CellHandler):
         a = random.choice(self.dirauthaddrs)
         c = network.thenetwork.connect(self, a)
         self.consensus = c.getconsensus()
-        dirauth.verify_consensus(self.consensus, network.thenetwork.dirauthkeys())
+        dirauth.Consensus.verify(self.consensus, network.thenetwork.dirauthkeys())
         c.close()
 
     def received_msg(self, msg, peeraddr, peer):
@@ -267,7 +267,7 @@ class Relay(network.Server):
         descdict["flags"] = self.flags
         desc = dirauth.RelayDescriptor(descdict)
         desc.sign(self.idkey)
-        dirauth.verify_relaydesc(desc)
+        dirauth.RelayDescriptor.verify(desc)
 
         if upload:
             descmsg = dirauth.DirAuthUploadDescMsg(desc)
@@ -337,7 +337,7 @@ if __name__ == '__main__':
     # Tick the epoch
     network.thenetwork.nextepoch()
 
-    dirauth.verify_consensus(dirauth.DirAuth.consensus, network.thenetwork.dirauthkeys())
+    dirauth.Consensus.verify(dirauth.DirAuth.consensus, network.thenetwork.dirauthkeys())
 
     print('ticked; epoch=', network.thenetwork.getepoch())