#!/usr/bin/env python3 import random # For simulation, not cryptography! import bisect import math import logging import nacl.encoding import nacl.signing import merklelib import hashlib import network # A relay descriptor is a dict containing: # epoch: epoch id # idkey: a public identity key # onionkey: a public onion key # addr: a network address # bw: bandwidth # flags: relay flags # pathselkey: a path selection public key (Single-Pass Walking Onions only) # vrfkey: a VRF public key (Single-Pass Walking Onions only) # sig: a signature over the above by the idkey class RelayDescriptor: def __init__(self, descdict): self.descdict = descdict def __str__(self, withsig = True): res = "RelayDesc [\n" for k in ["epoch", "idkey", "onionkey", "pathselkey", "addr", "bw", "flags", "vrfkey", "sig"]: if k in self.descdict: if k == "idkey" or k == "onionkey" or k == "pathselkey": res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n" elif k == "sig": if withsig: res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n" else: res += " " + k + ": " + str(self.descdict[k]) + "\n" res += "]\n" return res def sign(self, signingkey, perfstats): serialized = self.__str__(False) signed = signingkey.sign(serialized.encode("ascii")) perfstats.sigs += 1 self.descdict["sig"] = signed.signature @staticmethod def verify(desc, perfstats): assert(type(desc) is RelayDescriptor) serialized = desc.__str__(False) perfstats.verifs += 1 idkey = nacl.signing.VerifyKey(desc.descdict["idkey"]) idkey.verify(serialized.encode("ascii"), desc.descdict["sig"]) # A SNIP is a dict containing: # epoch: epoch id # idkey: a public identity key # onionkey: a public onion key # addr: a network address # flags: relay flags # pathselkey: a path selection public key (Single-Pass Walking Onions only) # range: the (lo,hi) values for the index range (lo is inclusive, hi is # exclusive; that is, x is in the range if lo <= x < hi). # lo=hi denotes an empty range. # auth: either a signature from the authorities over the above # (Threshold signature case) or a Merkle path to the root # contained in the consensus (Merkle tree case) # # Note that the fields of the SNIP are the same as those of the # RelayDescriptor, except bw and sig are removed, and range and auth are # added. class SNIP: def __init__(self, snipdict): self.snipdict = snipdict def __str__(self, withauth = True): res = "SNIP [\n" for k in ["epoch", "idkey", "onionkey", "pathselkey", "addr", "flags", "range", "auth"]: if k in self.snipdict: if k == "idkey" or k == "onionkey" or k == "pathselkey": res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n" elif k == "auth": if withauth: if network.thenetwork.snipauthmode == \ network.SNIPAuthMode.THRESHSIG: res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n" else: res += " " + k + ": " + str(self.snipdict[k]) else: res += " " + k + ": " + str(self.snipdict[k]) + "\n" res += "]\n" return res def auth(self, signingkey, perfstats): if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG: serialized = self.__str__(False) signed = signingkey.sign(serialized.encode("ascii")) perfstats.sigs += 1 self.snipdict["auth"] = signed.signature else: raise ValueError("Merkle auth not valid for SNIP.auth") @staticmethod def verify(snip, consensus, verifykey, perfstats): if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG: assert(type(snip) is SNIP and type(consensus) is Consensus) assert(consensus.consdict["epoch"] == snip.snipdict["epoch"]) serialized = snip.__str__(False) perfstats.verifs += 1 verifykey.verify(serialized.encode("ascii"), snip.snipdict["auth"]) else: assert(merklelib.verify_leaf_inclusion( snip.__str__(False), [merklelib.AuditNode(p[0], p[1]) for p in snip.snipdict["auth"]], merklelib.Hasher(), consensus.consdict["merkleroot"])) # A consensus is a dict containing: # epoch: epoch id # numrelays: total number of relays # totbw: total bandwidth of relays # merkleroot: the root of the SNIP Merkle tree (Merkle tree auth only) # relays: list of relay descriptors (Vanilla Onion Routing only) # sigs: list of signatures from the dirauths class Consensus: def __init__(self, epoch, relays): relays = [ d for d in relays if d.descdict['epoch'] == epoch ] self.consdict = dict() self.consdict['epoch'] = epoch self.consdict['numrelays'] = len(relays) if network.thenetwork.womode == network.WOMode.VANILLA: self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ]) self.consdict['relays'] = relays else: self.consdict['totbw'] = 1<<32 def __str__(self, withsigs = True): res = "Consensus [\n" for k in ["epoch", "numrelays", "totbw"]: if k in self.consdict: res += " " + k + ": " + str(self.consdict[k]) + "\n" if network.thenetwork.womode == network.WOMode.VANILLA: for r in self.consdict['relays']: res += str(r) if network.thenetwork.snipauthmode == network.SNIPAuthMode.MERKLE: for k in ["merkleroot"]: if k in self.consdict: res += " " + k + ": " + str(self.consdict[k]) + "\n" if withsigs and ('sigs' in self.consdict): for s in self.consdict['sigs']: res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n" res += "]\n" return res def sign(self, signingkey, index, perfstats): """Use the given signing key to sign the consensus, storing the result in the sigs list at the given index.""" serialized = self.__str__(False) signed = signingkey.sign(serialized.encode("ascii")) perfstats.sigs += 1 if 'sigs' not in self.consdict: self.consdict['sigs'] = [] if index >= len(self.consdict['sigs']): self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs']))) self.consdict['sigs'][index] = signed.signature @staticmethod def verify(consensus, verifkeylist, perfstats): """Use the given list of verification keys to check the signatures on the consensus. Return the RelayPicker if successful, or raise an exception otherwise.""" assert(type(consensus) is Consensus) serialized = consensus.__str__(False) for i, vk in enumerate(verifkeylist): perfstats.verifs += 1 vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i]) # If we got this far, all is well. Return the RelayPicker. return RelayPicker.get(consensus) # An ENDIVE is a dict containing: # epoch: epoch id # snips: list of SNIPS (in THRESHSIG mode, these include the auth # signatures; in MERKLE mode, these do _not_ include auth) # sigs: list of signatures from the dirauths class ENDIVE: def __init__(self, epoch, snips): snips = [ s for s in snips if s.snipdict['epoch'] == epoch ] self.enddict = dict() self.enddict['epoch'] = epoch self.enddict['snips'] = snips def __str__(self, withsigs = True): res = "ENDIVE [\n" for k in ["epoch"]: if k in self.enddict: res += " " + k + ": " + str(self.enddict[k]) + "\n" for s in self.enddict['snips']: res += str(s) if withsigs and ('sigs' in self.enddict): for s in self.enddict['sigs']: res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n" res += "]\n" return res def sign(self, signingkey, index, perfstats): """Use the given signing key to sign the ENDIVE, storing the result in the sigs list at the given index.""" serialized = self.__str__(False) signed = signingkey.sign(serialized.encode("ascii")) perfstats.sigs += 1 if 'sigs' not in self.enddict: self.enddict['sigs'] = [] if index >= len(self.enddict['sigs']): self.enddict['sigs'].extend([None] * (index+1-len(self.enddict['sigs']))) self.enddict['sigs'][index] = signed.signature @staticmethod def verify(endive, consensus, verifkeylist, perfstats): """Use the given list of verification keys to check the signatures on the ENDIVE and consensus. Return the RelayPicker if successful, or raise an exception otherwise.""" assert(type(endive) is ENDIVE and type(consensus) is Consensus) serializedcons = consensus.__str__(False) for i, vk in enumerate(verifkeylist): perfstats.verifs += 1 vk.verify(serializedcons.encode("ascii"), consensus.consdict['sigs'][i]) serializedend = endive.__str__(False) for i, vk in enumerate(verifkeylist): perfstats.verifs += 1 vk.verify(serializedend.encode("ascii"), endive.enddict['sigs'][i]) # If we got this far, all is well. Return the RelayPicker. return RelayPicker.get(consensus, endive) class RelayPicker: """An instance of this class (which may be a singleton in the simulation) is returned by the Consensus.verify() and ENDIVE.verify() methods. It does any necessary precomputation and/or caching, and exposes a method to select a random bw-weighted relay, either explicitly specifying a uniform random value, or letting the choice be done internally.""" # The singleton instance relaypicker = None def __init__(self, consensus, endive = None): self.epoch = consensus.consdict["epoch"] self.totbw = consensus.consdict["totbw"] self.consensus = consensus self.endive = endive assert(endive is None or endive.enddict["epoch"] == self.epoch) if network.thenetwork.womode == network.WOMode.VANILLA: # Create the array of cumulative bandwidth values from a # consensus. The array (cdf) will have the same length as # the number of relays in the consensus. cdf[0] = 0, and # cdf[i] = cdf[i-1] + relay[i-1].bw. self.cdf = [0] for r in consensus.consdict['relays']: self.cdf.append(self.cdf[-1]+r.descdict['bw']) # Remove the last item, which should be the sum of all the bws self.cdf.pop() logging.debug('cdf=%s', self.cdf) else: # Note that clients will call this with endive = None if self.endive is not None: self.cdf = [ s.snipdict['range'][0] \ for s in self.endive.enddict['snips'] ] if network.thenetwork.snipauthmode == \ network.SNIPAuthMode.MERKLE: # Construct the Merkle tree of SNIPs and check the # root matches the one in the consensus self.merkletree = merklelib.MerkleTree( [snip.__str__(False) \ for snip in DirAuth.endive.enddict['snips']], merklelib.Hasher()) assert(self.consensus.consdict["merkleroot"] == \ self.merkletree.merkle_root) else: self.cdf = None logging.debug('cdf=%s', self.cdf) @staticmethod def get(consensus, endive = None): # Return the singleton instance, if it exists for this epoch # However, don't use the cached instance if that one has # endive=None, but we were passed a real ENDIVE if RelayPicker.relaypicker is not None and \ (RelayPicker.relaypicker.endive is not None or \ endive is None) and \ RelayPicker.relaypicker.epoch == consensus.consdict["epoch"]: return RelayPicker.relaypicker # Create it otherwise, storing the result as the singleton RelayPicker.relaypicker = RelayPicker(consensus, endive) return RelayPicker.relaypicker def pick_relay_by_uniform_index(self, idx): """Pass in a uniform random index random(0,totbw-1) to get a relay's descriptor or snip (depending on the network mode) selected weighted by bw.""" if network.thenetwork.womode == network.WOMode.VANILLA: relays = self.consensus.consdict['relays'] else: relays = self.endive.enddict['snips'] # Find the rightmost entry less than or equal to idx i = bisect.bisect_right(self.cdf, idx) r = relays[i-1] if network.thenetwork.snipauthmode == \ network.SNIPAuthMode.MERKLE: # If we haven't yet computed the Merkle path for this SNIP, # do it now, and store it in the SNIP so that the client # will get it. if "auth" not in r.snipdict: r.snipdict["auth"] = [ (p.hash, p.type) for p in \ self.merkletree.get_proof(r.__str__(False))._nodes] return r def pick_weighted_relay(self): """Select a random relay with probability proportional to its bw weight.""" idx = self.pick_weighted_relay_index() return self.pick_relay_by_uniform_index(idx) def pick_weighted_relay_index(self): """Select a random relay index (for use in Walking Onions) uniformly, which will results in picking a relay with probability proportional to its bw weight.""" totbw = self.totbw if totbw < 1: raise ValueError("No relays to choose from") return random.randint(0, totbw-1) class DirAuthNetMsg(network.NetMsg): """The subclass of NetMsg for messages to and from directory authorities.""" class DirAuthUploadDescMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for uploading a relay descriptor.""" def __init__(self, desc): self.desc = desc class DirAuthDelDescMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for deleting a relay descriptor.""" def __init__(self, desc): self.desc = desc class DirAuthGetConsensusMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for fetching the consensus.""" class DirAuthConsensusMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for returning the consensus.""" def __init__(self, consensus): self.consensus = consensus class DirAuthGetConsensusDiffMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for fetching the consensus, if the requestor already has the previous consensus.""" class DirAuthConsensusDiffMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for returning the consensus, if the requestor already has the previous consensus. We don't _actually_ produce the diff at this time; we just charge fewer bytes for this message.""" def __init__(self, consensus): self.consensus = consensus def size(self): if network.symbolic_byte_counters: return super().size() return math.ceil(DirAuthConsensusMsg(self.consensus).size() \ * network.P_Delta) class DirAuthGetENDIVEMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for fetching the ENDIVE.""" class DirAuthENDIVEMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for returning the ENDIVE.""" def __init__(self, endive): self.endive = endive class DirAuthGetENDIVEDiffMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for fetching the ENDIVE, if the requestor already has the previous ENDIVE.""" class DirAuthENDIVEDiffMsg(DirAuthNetMsg): """The subclass of DirAuthNetMsg for returning the ENDIVE, if the requestor already has the previous consensus. We don't _actually_ produce the diff at this time; we just charge fewer bytes for this message in Merkle mode. In threshold signature mode, we would still need to download at least the new signatures for every SNIP in the ENDIVE, so for now, just assume there's no gain from ENDIVE diffs in threshold signature mode.""" def __init__(self, endive): self.endive = endive def size(self): if network.symbolic_byte_counters: return super().size() if network.thenetwork.snipauthmode == \ network.SNIPAuthMode.THRESHSIG: return DirAuthENDIVEMsg(self.endive).size() return math.ceil(DirAuthENDIVEMsg(self.endive).size() \ * network.P_Delta) class DirAuthConnection(network.ClientConnection): """The subclass of Connection for connections to directory authorities.""" def __init__(self, peer): super().__init__(peer) def uploaddesc(self, desc): """Upload our RelayDescriptor to the DirAuth.""" self.sendmsg(DirAuthUploadDescMeg(desc)) def getconsensus(self): self.consensus = None self.sendmsg(DirAuthGetConsensusMsg()) return self.consensus def getconsensusdiff(self): self.consensus = None self.sendmsg(DirAuthGetConsensusDiffMsg()) return self.consensus def getendive(self): self.endive = None self.sendmsg(DirAuthGetENDIVEMsg()) return self.endive def getendivediff(self): self.endive = None self.sendmsg(DirAuthGetENDIVEDiffMsg()) return self.endive def receivedfromserver(self, msg): if isinstance(msg, DirAuthConsensusMsg): self.consensus = msg.consensus elif isinstance(msg, DirAuthConsensusDiffMsg): self.consensus = msg.consensus elif isinstance(msg, DirAuthENDIVEMsg): self.endive = msg.endive elif isinstance(msg, DirAuthENDIVEDiffMsg): self.endive = msg.endive else: raise TypeError('Not a server-originating DirAuthNetMsg', msg) class DirAuth(network.Server): """The class representing directory authorities.""" # We simulate the act of computing the consensus by keeping a # class-static dict that's accessible to all of the dirauths # This dict is indexed by epoch, and the value is itself a dict # indexed by the stringified descriptor, with value a pair of (the # number of dirauths that saw that descriptor, the descriptor # itself). uploadeddescs = dict() consensus = None endive = None def __init__(self, me, tot): """Create a new directory authority. me is the index of which dirauth this one is (starting from 0), and tot is the total number of dirauths.""" self.me = me self.tot = tot self.name = "Dirauth %d of %d" % (me+1, tot) self.perfstats = network.PerfStats(network.EntType.DIRAUTH) self.perfstats.is_bootstrapping = True # Create the dirauth signature keypair self.sigkey = nacl.signing.SigningKey.generate() self.perfstats.keygens += 1 self.netaddr = network.thenetwork.bind(self) self.perfstats.name = "DirAuth at %s" % self.netaddr network.thenetwork.setdirauthkey(me, self.sigkey.verify_key) network.thenetwork.wantepochticks(self, True, True, True) def connected(self, client): """Callback invoked when a client connects to us. This callback creates the DirAuthConnection that will be passed to the client.""" # We don't actually need to keep per-connection state at # dirauths, even in long-lived connections, so this is # particularly simple. return DirAuthConnection(self) def generate_consensus(self, epoch): """Generate the consensus (and ENDIVE, if using Walking Onions) for the given epoch, which should be the one after the one that's currently about to end.""" threshold = int(self.tot/2)+1 consensusdescs = [] for numseen, desc in DirAuth.uploadeddescs[epoch].values(): if numseen >= threshold: consensusdescs.append(desc) DirAuth.consensus = Consensus(epoch, consensusdescs) if network.thenetwork.womode != network.WOMode.VANILLA: totbw = sum([ d.descdict["bw"] for d in consensusdescs ]) hi = 0 cumbw = 0 snips = [] for d in consensusdescs: cumbw += d.descdict["bw"] lo = hi hi = int((cumbw<<32)/totbw) snipdict = dict(d.descdict) del snipdict["bw"] snipdict["range"] = (lo,hi) snips.append(SNIP(snipdict)) DirAuth.endive = ENDIVE(epoch, snips) def epoch_ending(self, epoch): # Only dirauth 0 actually needs to generate the consensus # because of the shared class-static state, but everyone has to # sign it. Note that this code relies on dirauth 0's # epoch_ending callback being called before any of the other # dirauths'. if (epoch+1) not in DirAuth.uploadeddescs: DirAuth.uploadeddescs[epoch+1] = dict() if self.me == 0: self.generate_consensus(epoch+1) del DirAuth.uploadeddescs[epoch+1] if network.thenetwork.snipauthmode == \ network.SNIPAuthMode.THRESHSIG: for s in DirAuth.endive.enddict['snips']: s.auth(self.sigkey, self.perfstats) elif network.thenetwork.snipauthmode == \ network.SNIPAuthMode.MERKLE: # Construct the Merkle tree of the SNIPs in the ENDIVE # and put the root in the consensus tree = merklelib.MerkleTree( [snip.__str__(False) \ for snip in DirAuth.endive.enddict['snips']], merklelib.Hasher()) DirAuth.consensus.consdict["merkleroot"] = tree.merkle_root else: if network.thenetwork.snipauthmode == \ network.SNIPAuthMode.THRESHSIG: for s in DirAuth.endive.enddict['snips']: # We're just simulating threshold sigs by having # only the first dirauth sign, but in reality each # dirauth would contribute to the signature (at the # same cost as each one signing), so we'll charge # their perfstats as well self.perfstats.sigs += 1 DirAuth.consensus.sign(self.sigkey, self.me, self.perfstats) if network.thenetwork.womode != network.WOMode.VANILLA: DirAuth.endive.sign(self.sigkey, self.me, self.perfstats) def received(self, client, msg): self.perfstats.bytes_received += msg.size() if isinstance(msg, DirAuthUploadDescMsg): # Check the uploaded descriptor for sanity epoch = msg.desc.descdict['epoch'] if epoch != network.thenetwork.getepoch() + 1: return # Store it in the class-static dict if epoch not in DirAuth.uploadeddescs: DirAuth.uploadeddescs[epoch] = dict() descstr = str(msg.desc) if descstr not in DirAuth.uploadeddescs[epoch]: DirAuth.uploadeddescs[epoch][descstr] = (1, msg.desc) else: DirAuth.uploadeddescs[epoch][descstr] = \ (DirAuth.uploadeddescs[epoch][descstr][0]+1, DirAuth.uploadeddescs[epoch][descstr][1]) elif isinstance(msg, DirAuthDelDescMsg): # Check the uploaded descriptor for sanity epoch = msg.desc.descdict['epoch'] if epoch != network.thenetwork.getepoch() + 1: return # Remove it from the class-static dict if epoch not in DirAuth.uploadeddescs: return descstr = str(msg.desc) if descstr not in DirAuth.uploadeddescs[epoch]: return elif DirAuth.uploadeddescs[epoch][descstr][0] == 1: del DirAuth.uploadeddescs[epoch][descstr] else: DirAuth.uploadeddescs[epoch][descstr] = \ (DirAuth.uploadeddescs[epoch][descstr][0]-1, DirAuth.uploadeddescs[epoch][descstr][1]) elif isinstance(msg, DirAuthGetConsensusMsg): replymsg = DirAuthConsensusMsg(DirAuth.consensus) msgsize = replymsg.size() self.perfstats.bytes_sent += msgsize client.reply(replymsg) elif isinstance(msg, DirAuthGetConsensusDiffMsg): replymsg = DirAuthConsensusDiffMsg(DirAuth.consensus) msgsize = replymsg.size() self.perfstats.bytes_sent += msgsize client.reply(replymsg) elif isinstance(msg, DirAuthGetENDIVEMsg): replymsg = DirAuthENDIVEMsg(DirAuth.endive) msgsize = replymsg.size() self.perfstats.bytes_sent += msgsize client.reply(replymsg) elif isinstance(msg, DirAuthGetENDIVEDiffMsg): replymsg = DirAuthENDIVEDiffMsg(DirAuth.endive) msgsize = replymsg.size() self.perfstats.bytes_sent += msgsize client.reply(replymsg) else: raise TypeError('Not a client-originating DirAuthNetMsg', msg) def closed(self): pass if __name__ == '__main__': # Start some dirauths numdirauths = 9 dirauthaddrs = [] for i in range(numdirauths): dirauth = DirAuth(i, numdirauths) dirauthaddrs.append(dirauth.netaddr) for a in dirauthaddrs: print(a,end=' ') print() network.thenetwork.nextepoch()