Browse Source

Implement Merkle tree SNIP authentication

Ian Goldberg 4 years ago
parent
commit
585f33f72d
3 changed files with 48 additions and 13 deletions
  1. 1 1
      client.py
  2. 41 8
      dirauth.py
  3. 6 4
      relay.py

+ 1 - 1
client.py

@@ -426,7 +426,7 @@ if __name__ == '__main__':
             network.SNIPAuthMode.NONE)
     elif network_mode == network.WOMode.TELESCOPING:
         network.thenetwork.set_wo_style(network.WOMode.TELESCOPING,
-            network.SNIPAuthMode.THRESHSIG)
+            network.SNIPAuthMode.MERKLE)
     # TODO set single-pass
 
     # Start some dirauths

+ 41 - 8
dirauth.py

@@ -6,6 +6,8 @@ import math
 
 import nacl.encoding
 import nacl.signing
+import merklelib
+import hashlib
 import network
 
 # A relay descriptor is a dict containing:
@@ -84,7 +86,7 @@ class SNIP:
                                 network.SNIPAuthMode.THRESHSIG:
                             res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
                         else:
-                            raise NotImplementedError("Merkle auth not yet implemented")
+                            res += "  " + k + ": " + str(self.snipdict[k])
                 else:
                     res += "  " + k + ": " + str(self.snipdict[k]) + "\n"
         res += "]\n"
@@ -97,7 +99,7 @@ class SNIP:
             perfstats.sigs += 1
             self.snipdict["auth"] = signed.signature
         else:
-            raise NotImplementedError("Merkle auth not yet implemented")
+            raise ValueError("Merkle auth not valid for SNIP.auth")
 
     @staticmethod
     def verify(snip, consensus, verifykey, perfstats):
@@ -109,7 +111,9 @@ class SNIP:
             verifykey.verify(serialized.encode("ascii"),
                     snip.snipdict["auth"])
         else:
-            raise NotImplementedError("Merkle auth not yet implemented")
+            assert(merklelib.verify_leaf_inclusion(
+                    snip.__str__(False), snip.snipdict["auth"],
+                    merklelib.Hasher(), consensus.consdict["merkleroot"]))
 
 
 # A consensus is a dict containing:
@@ -139,6 +143,10 @@ class Consensus:
         if network.thenetwork.womode == network.WOMode.VANILLA:
             for r in self.consdict['relays']:
                 res += str(r)
+        if network.thenetwork.snipauthmode == network.SNIPAuthMode.MERKLE:
+            for k in ["merkleroot"]:
+                if k in self.consdict:
+                    res += "  " + k + ": " + str(self.consdict[k]) + "\n"
         if withsigs and ('sigs' in self.consdict):
             for s in self.consdict['sigs']:
                 res += "  sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
@@ -255,17 +263,24 @@ class RelayPicker:
             # Remove the last item, which should be the sum of all the bws
             self.cdf.pop()
             print('cdf=', self.cdf)
-        elif network.thenetwork.snipauthmode == \
-                network.SNIPAuthMode.THRESHSIG:
+        else:
             # Note that clients will call this with endive = None
             if self.endive is not None:
                 self.cdf = [ s.snipdict['range'][0] \
                         for s in self.endive.enddict['snips'] ]
+                if network.thenetwork.snipauthmode == \
+                        network.SNIPAuthMode.MERKLE:
+                    # Construct the Merkle tree of SNIPs and check the
+                    # root matches the one in the consensus
+                    self.merkletree = merklelib.MerkleTree(
+                            [snip.__str__(False) \
+                                for snip in DirAuth.endive.enddict['snips']],
+                            merklelib.Hasher())
+                    assert(self.consensus.consdict["merkleroot"] == \
+                            self.merkletree.merkle_root)
             else:
                 self.cdf = None
             print('cdf=', self.cdf)
-        else:
-            raise NotImplementedError("Merkle auth not yet implemented")
 
     @staticmethod
     def get(consensus, endive = None):
@@ -292,7 +307,16 @@ class RelayPicker:
 
         # Find the rightmost entry less than or equal to idx
         i = bisect.bisect_right(self.cdf, idx)
-        return relays[i-1]
+        r = relays[i-1]
+        if network.thenetwork.snipauthmode == \
+                network.SNIPAuthMode.MERKLE:
+            # If we haven't yet computed the Merkle path for this SNIP,
+            # do it now, and store it in the SNIP so that the client
+            # will get it.
+            if "auth" not in r.snipdict:
+                r.snipdict["auth"] = \
+                        self.merkletree.get_proof(r.__str__(False))
+        return r
 
     def pick_weighted_relay(self):
         """Select a random relay with probability proportional to its bw
@@ -516,6 +540,15 @@ class DirAuth(network.Server):
                     network.SNIPAuthMode.THRESHSIG:
                 for s in DirAuth.endive.enddict['snips']:
                     s.auth(self.sigkey, self.perfstats)
+            elif network.thenetwork.snipauthmode == \
+                    network.SNIPAuthMode.MERKLE:
+                # Construct the Merkle tree of the SNIPs in the ENDIVE
+                # and put the root in the consensus
+                tree = merklelib.MerkleTree(
+                        [snip.__str__(False) \
+                                for snip in DirAuth.endive.enddict['snips']],
+                        merklelib.Hasher())
+                DirAuth.consensus.consdict["merkleroot"] = tree.merkle_root
         else:
             if network.thenetwork.snipauthmode == \
                     network.SNIPAuthMode.THRESHSIG:

+ 6 - 4
relay.py

@@ -869,7 +869,7 @@ if __name__ == '__main__':
     perfstats = network.PerfStats(network.EntType.NONE)
 
     network.thenetwork.set_wo_style(network.WOMode.TELESCOPING,
-            network.SNIPAuthMode.THRESHSIG)
+            network.SNIPAuthMode.MERKLE)
 
     # Start some dirauths
     numdirauths = 9
@@ -911,9 +911,11 @@ if __name__ == '__main__':
         relaypicker = dirauth.ENDIVE.verify(dirauth.DirAuth.endive,
                 dirauth.DirAuth.consensus,
                 network.thenetwork.dirauthkeys(), perfstats)
-        for s in dirauth.DirAuth.endive.enddict['snips']:
-            dirauth.SNIP.verify(s, dirauth.DirAuth.consensus,
-                    network.thenetwork.dirauthkeys()[0], perfstats)
+        if network.thenetwork.snipauthmode == \
+                network.SNIPAuthMode.THRESHSIG:
+            for s in dirauth.DirAuth.endive.enddict['snips']:
+                dirauth.SNIP.verify(s, dirauth.DirAuth.consensus,
+                        network.thenetwork.dirauthkeys()[0], perfstats)
 
     print('ticked; epoch=', network.thenetwork.getepoch())