Browse Source

Implement SNIPs and ENDIVEs (threshold signature mode only so far)

The threshold signatures are simulated by just a single signature from
the first dirauth.  All dirauths' perfstats are charged for the
signatures, however.
Ian Goldberg 4 years ago
parent
commit
162e75adce
2 changed files with 201 additions and 5 deletions
  1. 188 4
      dirauth.py
  2. 13 1
      relay.py

+ 188 - 4
dirauth.py

@@ -50,10 +50,72 @@ class RelayDescriptor:
         desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
 
 
+# A SNIP is a dict containing:
+#  epoch: epoch id
+#  idkey: a public identity key
+#  onionkey: a public onion key
+#  addr: a network address
+#  flags: relay flags
+#  vrfkey: a VRF public key (Single-Pass Walking Onions only)
+#  range: the (lo,hi) values for the index range (lo is inclusive, hi is
+#         exclusive; that is, x is in the range if lo <= x < hi).
+#         lo=hi denotes an empty range.
+#  auth: either a signature from the authorities over the above
+#         (Threshold signature case) or a Merkle path to the root
+#         contained in the consensus (Merkle tree case)
+#
+# Note that the fields of the SNIP are the same as those of the
+# RelayDescriptor, except bw and sig are removed, and range and auth are
+# added.
+class SNIP:
+    def __init__(self, snipdict):
+        self.snipdict = snipdict
+
+    def __str__(self, withauth = True):
+        res = "SNIP [\n"
+        for k in ["epoch", "idkey", "onionkey", "addr", "flags",
+                    "vrfkey", "range", "auth"]:
+            if k in self.snipdict:
+                if k == "idkey" or k == "onionkey":
+                    res += "  " + k + ": " + self.snipdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
+                elif k == "auth":
+                    if withauth:
+                        if network.thenetwork.snipauthmode == \
+                                network.SNIPAuthMode.THRESHSIG:
+                            res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
+                        else:
+                            raise NotImplementedError("Merkle auth not yet implemented")
+                else:
+                    res += "  " + k + ": " + str(self.snipdict[k]) + "\n"
+        res += "]\n"
+        return res
+
+    def auth(self, signingkey, perfstats):
+        if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG:
+            serialized = self.__str__(False)
+            signed = signingkey.sign(serialized.encode("ascii"))
+            perfstats.sigs += 1
+            self.snipdict["auth"] = signed.signature
+        else:
+            raise NotImplementedError("Merkle auth not yet implemented")
+
+    @staticmethod
+    def verify(snip, consensus, verifykey, perfstats):
+        if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG:
+            assert(type(snip) is SNIP and type(consensus) is Consensus)
+            serialized = snip.__str__(False)
+            perfstats.verifs += 1
+            verifykey.verify(serialized.encode("ascii"),
+                    snip.snipdict["auth"])
+        else:
+            raise NotImplementedError("Merkle auth not yet implemented")
+
+
 # A consensus is a dict containing:
 #  epoch: epoch id
 #  numrelays: total number of relays
 #  totbw: total bandwidth of relays
+#  merkleroot: the root of the SNIP Merkle tree (Merkle tree auth only)
 #  relays: list of relay descriptors (Vanilla Onion Routing only)
 #  sigs: list of signatures from the dirauths
 class Consensus:
@@ -62,16 +124,20 @@ class Consensus:
         self.consdict = dict()
         self.consdict['epoch'] = epoch
         self.consdict['numrelays'] = len(relays)
-        self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
-        self.consdict['relays'] = relays
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
+            self.consdict['relays'] = relays
+        else:
+            self.consdict['totbw'] = 1<<32
 
     def __str__(self, withsigs = True):
         res = "Consensus [\n"
         for k in ["epoch", "numrelays", "totbw"]:
             if k in self.consdict:
                 res += "  " + k + ": " + str(self.consdict[k]) + "\n"
-        for r in self.consdict['relays']:
-            res += str(r)
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            for r in self.consdict['relays']:
+                res += str(r)
         if withsigs and ('sigs' in self.consdict):
             for s in self.consdict['sigs']:
                 res += "  sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
@@ -124,6 +190,70 @@ class Consensus:
             vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
 
 
+# An ENDIVE is a dict containing:
+#  epoch: epoch id
+#  snips: list of SNIPS (in THRESHSIG mode, these include the auth
+#         signatures; in MERKLE mode, these do _not_ include auth)
+#  sigs: list of signatures from the dirauths
+class ENDIVE:
+    def __init__(self, epoch, snips):
+        snips = [ s for s in snips if s.snipdict['epoch'] == epoch ]
+        self.enddict = dict()
+        self.enddict['epoch'] = epoch
+        self.enddict['snips'] = snips
+
+    def __str__(self, withsigs = True):
+        res = "ENDIVE [\n"
+        for k in ["epoch"]:
+            if k in self.enddict:
+                res += "  " + k + ": " + str(self.enddict[k]) + "\n"
+        for s in self.enddict['snips']:
+            res += str(s)
+        if withsigs and ('sigs' in self.enddict):
+            for s in self.enddict['sigs']:
+                res += "  sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
+        res += "]\n"
+        return res
+
+    def sign(self, signingkey, index, perfstats):
+        """Use the given signing key to sign the ENDIVE, storing the
+        result in the sigs list at the given index."""
+        serialized = self.__str__(False)
+        signed = signingkey.sign(serialized.encode("ascii"))
+        perfstats.sigs += 1
+        if 'sigs' not in self.enddict:
+            self.enddict['sigs'] = []
+        if index >= len(self.enddict['sigs']):
+            self.enddict['sigs'].extend([None] * (index+1-len(self.enddict['sigs'])))
+        self.enddict['sigs'][index] = signed.signature
+
+    def bw_cdf(self):
+        """Create the array of cumulative bandwidth values from an ENDIVE.
+        The array (cdf) will have the same length as the number of relays
+        in the ENDIVE.  cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
+        cdf = [ s.snipdict['range'][0] for s in self.enddict['snips'] ]
+        print('ENDIVE cdf=', cdf)
+        return cdf
+
+    def select_snip_by_index(self, i, cdf):
+        """Use the cdf generated by bw_cdf to select the SNIP for which
+        i is in the index range.  Choose i with
+        random.randint(0, (1<<32)-1)."""
+        # Find the rightmost entry less than or equal to i
+        idx = bisect.bisect_right(cdf, i)
+        return self.enddict['snips'][idx-1]
+
+    @staticmethod
+    def verify(endive, verifkeylist, perfstats):
+        """Use the given list of verification keys to check the
+        signatures on the ENDIVE."""
+        assert(type(endive) is ENDIVE)
+        serialized = endive.__str__(False)
+        for i, vk in enumerate(verifkeylist):
+            perfstats.verifs += 1
+            vk.verify(serialized.encode("ascii"), endive.enddict['sigs'][i])
+
+
 class DirAuthNetMsg(network.NetMsg):
     """The subclass of NetMsg for messages to and from directory
     authorities."""
@@ -179,6 +309,30 @@ class DirAuthENDIVEMsg(DirAuthNetMsg):
     def __init__(self, endive):
         self.endive = endive
 
+class DirAuthGetENDIVEDiffMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for fetching the ENDIVE, if the
+    requestor already has the previous ENDIVE."""
+
+class DirAuthENDIVEDiffMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for returning the ENDIVE, if the
+    requestor already has the previous consensus.  We don't _actually_
+    produce the diff at this time; we just charge fewer bytes for this
+    message in Merkle mode.  In threshold signature mode, we would still
+    need to download at least the new signatures for every SNIP in the
+    ENDIVE, so for now, just assume there's no gain from ENDIVE diffs in
+    threshold signature mode."""
+
+    def __init__(self, endive):
+        self.endive = endive
+
+    def size(self):
+        if network.symbolic_byte_counters:
+            return super().size()
+        if network.thenetwork.snipauthmode == \
+                network.SNIPAuthMode.THRESHSIG:
+            return DirAuthENDIVEMsg(self.endive).size()
+        return math.ceil(DirAuthENDIVEMsg(self.endive).size() \
+                            * network.P_Delta)
 
 class DirAuthConnection(network.ClientConnection):
     """The subclass of Connection for connections to directory
@@ -269,6 +423,20 @@ class DirAuth(network.Server):
             if numseen >= threshold:
                 consensusdescs.append(desc)
         DirAuth.consensus = Consensus(epoch, consensusdescs)
+        if network.thenetwork.womode != network.WOMode.VANILLA:
+            totbw = sum([ d.descdict["bw"] for d in consensusdescs ])
+            hi = 0
+            cumbw = 0
+            snips = []
+            for d in consensusdescs:
+                cumbw += d.descdict["bw"]
+                lo = hi
+                hi = int((cumbw<<32)/totbw)
+                snipdict = dict(d.descdict)
+                del snipdict["bw"]
+                snipdict["range"] = (lo,hi)
+                snips.append(SNIP(snipdict))
+            DirAuth.endive = ENDIVE(epoch, snips)
 
     def epoch_ending(self, epoch):
         # Only dirauth 0 actually needs to generate the consensus
@@ -281,7 +449,23 @@ class DirAuth(network.Server):
         if self.me == 0:
             self.generate_consensus(epoch+1)
             del DirAuth.uploadeddescs[epoch+1]
+            if network.thenetwork.snipauthmode == \
+                    network.SNIPAuthMode.THRESHSIG:
+                for s in DirAuth.endive.enddict['snips']:
+                    s.auth(self.sigkey, self.perfstats)
+        else:
+            if network.thenetwork.snipauthmode == \
+                    network.SNIPAuthMode.THRESHSIG:
+                for s in DirAuth.endive.enddict['snips']:
+                    # We're just simulating threshold sigs by having
+                    # only the first dirauth sign, but in reality each
+                    # dirauth would contribute to the signature (at the
+                    # same cost as each one signing), so we'll charge
+                    # their perfstats as well
+                    self.perfstats.sigs += 1
         DirAuth.consensus.sign(self.sigkey, self.me, self.perfstats)
+        if network.thenetwork.womode != network.WOMode.VANILLA:
+            DirAuth.endive.sign(self.sigkey, self.me, self.perfstats)
 
     def received(self, client, msg):
         self.perfstats.bytes_received += msg.size()

+ 13 - 1
relay.py

@@ -720,6 +720,9 @@ class Relay(network.Server):
 if __name__ == '__main__':
     perfstats = network.PerfStats(network.EntType.NONE)
 
+    network.thenetwork.set_wo_style(network.WOMode.TELESCOPING,
+            network.SNIPAuthMode.THRESHSIG)
+
     # Start some dirauths
     numdirauths = 9
     dirauthaddrs = []
@@ -750,12 +753,20 @@ if __name__ == '__main__':
     # Tick the epoch
     network.thenetwork.nextepoch()
 
+    print(dirauth.DirAuth.consensus)
+    print(dirauth.DirAuth.endive)
+
     dirauth.Consensus.verify(dirauth.DirAuth.consensus,
             network.thenetwork.dirauthkeys(), perfstats)
+    dirauth.ENDIVE.verify(dirauth.DirAuth.endive,
+            network.thenetwork.dirauthkeys(), perfstats)
+    for s in dirauth.DirAuth.endive.enddict['snips']:
+        dirauth.SNIP.verify(s, dirauth.DirAuth.consensus,
+                network.thenetwork.dirauthkeys()[0], perfstats)
 
     print('ticked; epoch=', network.thenetwork.getepoch())
 
-    relays[3].channelmgr.send_msg(RelayRandomHopMsg(30), relays[5].netaddr)
+    # relays[3].channelmgr.send_msg(RelayRandomHopMsg(30), relays[5].netaddr)
 
     # See what channels exist and do a consistency check
     for r in relays:
@@ -780,6 +791,7 @@ if __name__ == '__main__':
     network.thenetwork.nextepoch()
 
     print(dirauth.DirAuth.consensus)
+    print(dirauth.DirAuth.endive)
 
     # See what channels exist and do a consistency check
     for r in relays: