#!/usr/bin/env python3 import random # For simulation, not cryptography! import math import nacl.utils import nacl.signing import nacl.public import nacl.hash import network import dirauth class RelayNetMsg(network.NetMsg): """The subclass of NetMsg for messages between relays and either relays or clients.""" class RelayGetConsensusMsg(RelayNetMsg): """The subclass of RelayNetMsg for fetching the consensus.""" class RelayConsensusMsg(RelayNetMsg): """The subclass of RelayNetMsg for returning the consensus.""" def __init__(self, consensus): self.consensus = consensus class RelayRandomHopMsg(RelayNetMsg): """A message used for testing, that hops from relay to relay randomly until its TTL expires.""" def __init__(self, ttl): self.ttl = ttl def __str__(self): return "RandomHop TTL=%d" % self.ttl class VanillaCreateCircuitMsg(RelayNetMsg): """The message for requesting circuit creation in Vanilla Onion Routing.""" def __init__(self, circid, ntor_request): self.circid = circid self.ntor_request = ntor_request class VanillaCreatedCircuitMsg(RelayNetMsg): """The message for responding to circuit creation in Vanilla Onion Routing.""" def __init__(self, ntor_reply): self.ntor_reply = ntor_reply class CircuitCellMsg(RelayNetMsg): """Send a message tagged with a circuit id.""" def __init__(self, circuitid, cell): self.circid = circuitid self.cell = cell def __str__(self): return "C%d:%s" % (self.circid, self.cell) def size(self): # circuitids are 4 bytes return 4 + self.cell.size() class RelayFallbackTerminationError(Exception): """An exception raised when someone tries to terminate a fallback relay.""" class NTor: """A class implementing the ntor one-way authenticated key agreement scheme. The details are not exactly the same as either the ntor paper or Tor's implementation, but it will agree on keys and have the same number of public key operations.""" def __init__(self, perfstats): self.perfstats = perfstats def request(self): """Create the ntor request message: X = g^x.""" self.client_ephem_key = nacl.public.PrivateKey.generate() self.perfstats.keygens += 1 return self.client_ephem_key.public_key @staticmethod def reply(onion_privkey, idpubkey, client_pubkey, perfstats): """The server calls this static method to produce the ntor reply message: (Y = g^y, B = g^b, A = H(M, "verify")) and the shared secret S = H(M, "secret") for M = (X^y,X^b,ID,B,X,Y).""" server_ephem_key = nacl.public.PrivateKey.generate() perfstats.keygens += 1 xykey = nacl.public.Box(server_ephem_key, client_pubkey).shared_key() xbkey = nacl.public.Box(onion_privkey, client_pubkey).shared_key() perfstats.dhs += 2 M = xykey + xbkey + \ idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \ onion_privkey.public_key.encode(encoder=nacl.encoding.RawEncoder) + \ server_ephem_key.public_key.encode(encoder=nacl.encoding.RawEncoder) A = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder) S = nacl.hash.sha256(M + b'secret', encoder=nacl.encoding.RawEncoder) return ((server_ephem_key.public_key, onion_privkey.public_key, A), \ S) def verify(self, reply, onion_pubkey, idpubkey): """The client calls this method to verify the ntor reply message, passing the onion and id public keys for the server it's expecting to be talking to . Returns the shared secret on success, or raises ValueError on failure.""" server_ephem_pubkey, server_onion_pubkey, authtag = reply if onion_pubkey != server_onion_pubkey: raise ValueError("NTor onion pubkey mismatch") xykey = nacl.public.Box(self.client_ephem_key, server_ephem_pubkey).shared_key() xbkey = nacl.public.Box(self.client_ephem_key, onion_pubkey).shared_key() self.perfstats.dhs += 2 M = xykey + xbkey + \ idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \ onion_pubkey.encode(encoder=nacl.encoding.RawEncoder) + \ server_ephem_pubkey.encode(encoder=nacl.encoding.RawEncoder) Acheck = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder) S = nacl.hash.sha256(M + b'secret', encoder=nacl.encoding.RawEncoder) if Acheck != authtag: raise ValueError("NTor auth mismatch") return S class CircuitHandler: """A class for managing sending and receiving encrypted cells on a particular circuit.""" def __init__(self, channel, circid): self.channel = channel self.circid = circid self.send_cell = self.channel_send_cell # The list of relay descriptors that form the circuit so far # (client side only) self.circuit_descs = [] # The dispatch table is indexed by type, and the values are # objects with received_cell(circhandler, cell) methods. self.cell_dispatch_table = dict() def channel_send_cell(self, cell): """Send a cell on this circuit.""" self.channel.send_msg(CircuitCellMsg(self.circid, cell)) def received_cell(self, cell): """A cell has been received on this circuit. Dispatch it according to its type.""" celltype = type(cell) if celltype in self.cell_dispatch_table: self.cell_dispatch_table[celltype].received_cell(self, cell) class Channel(network.Connection): """A class representing a channel between a relay and either a client or a relay, transporting cells from various circuits.""" def __init__(self): super().__init__() # The CellRelay managing this Channel self.cellhandler = None # The Channel at the other end self.peer = None # The function to call when the connection closes self.closer = lambda: 0 # The next circuit id to use on this channel. The party that # opened the channel uses even numbers; the receiving party uses # odd numbers. self.next_circid = None # A map for CircuitHandlers to use for each open circuit on the # channel self.circuithandlers = dict() def closed(self): self.closer() self.peer = None def close(self): if self.peer is not None and self.peer is not self: self.peer.closed() self.closed() def new_circuit(self): """Allocate a new circuit on this channel, returning the new circuit's id and the new CircuitHandler.""" circid = self.next_circid self.next_circid += 2 circuithandler = CircuitHandler(self, circid) self.circuithandlers[circid] = circuithandler return circid, circuithandler def new_circuit_with_circid(self, circid): """Allocate a new circuit on this channel, with the circuit id received from our peer. Return the new CircuitHandler""" circuithandler = CircuitHandler(self, circid) self.circuithandlers[circid] = circuithandler return circuithandler def send_cell(self, circid, cell): """Send the given message on the given circuit, encrypting or decrypting as needed.""" self.circuithandlers[circid].send_cell(cell) def send_raw_cell(self, circid, cell): """Send the given message, tagged for the given circuit id. No encryption or decryption is done.""" self.send_msg(CircuitCellMsg(self.circid, self.cell)) def send_msg(self, msg): """Send the given NetMsg on the channel.""" self.cellhandler.perfstats.bytes_sent += msg.size() self.peer.received(self.cellhandler.myaddr, msg) def received(self, peeraddr, msg): """Callback when a message is received from the network.""" self.cellhandler.perfstats.bytes_received += msg.size() if isinstance(msg, CircuitCellMsg): circid, cell = msg.circid, msg.cell self.circuithandlers[circid].received_cell(cell) else: self.cellhandler.received_msg(msg, peeraddr, self) class CellHandler: """The class that manages the channels to other relays and clients. Relays and clients both use subclasses of this class to both create on-demand channels to relays, to gracefully handle the closing of channels, and to handle commands received over the channels.""" def __init__(self, myaddr, dirauthaddrs, perfstats): # A dictionary of Channels to other hosts, indexed by NetAddr self.channels = dict() self.myaddr = myaddr self.dirauthaddrs = dirauthaddrs self.consensus = None self.perfstats = perfstats def terminate(self): """Close all connections we're managing.""" while self.channels: channelitems = iter(self.channels.items()) addr, channel = next(channelitems) print('closing channel', addr, channel) channel.close() def add_channel(self, channel, peeraddr): """Add the given channel to the list of channels we are managing. If we are already managing a channel to the same peer, close it first.""" if peeraddr in self.channels: self.channels[peeraddr].close() channel.cellhandler = self self.channels[peeraddr] = channel channel.closer = lambda: self.channels.pop(peeraddr) def get_channel_to(self, addr): """Get the Channel connected to the given NetAddr, creating one if none exists right now.""" if addr in self.channels: return self.channels[addr] # Create the new channel newchannel = network.thenetwork.connect(self.myaddr, addr, \ self.perfstats) self.channels[addr] = newchannel newchannel.closer = lambda: self.channels.pop(addr) newchannel.cellhandler = self return newchannel def received_msg(self, msg, peeraddr, channel): """Callback when a NetMsg not specific to a circuit is received.""" print("CellHandler: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr)) def received_cell(self, circid, cell, peeraddr, channel): """Callback with a circuit-specific cell is received.""" print("CellHandler: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr)) def send_msg(self, msg, peeraddr): """Send a message to the peer with the given address.""" channel = self.get_channel_to(peeraddr) channel.send_msg(msg) def send_cell(self, circid, cell, peeraddr): """Send a cell on the given circuit to the peer with the given address.""" channel = self.get_channel_to(peeraddr) channel.send_cell(circid, cell) class CellRelay(CellHandler): """The subclass of CellHandler for relays.""" def __init__(self, myaddr, dirauthaddrs, onionprivkey, idpubkey, perfstats): super().__init__(myaddr, dirauthaddrs, perfstats) self.onionkey = onionprivkey self.idpubkey = idpubkey def get_consensus(self): """Download a fresh consensus from a random dirauth.""" a = random.choice(self.dirauthaddrs) c = network.thenetwork.connect(self, a, self.perfstats) self.consensus = c.getconsensus() dirauth.Consensus.verify(self.consensus, \ network.thenetwork.dirauthkeys(), self.perfstats) c.close() def received_msg(self, msg, peeraddr, channel): """Callback when a NetMsg not specific to a circuit is received.""" print("CellRelay: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr)) if isinstance(msg, RelayRandomHopMsg): if msg.ttl > 0: # Pick a random next hop from the consensus nexthop = random.choice(self.consensus.consdict['relays']) nextaddr = nexthop.descdict['addr'] self.send_msg(RelayRandomHopMsg(msg.ttl-1), nextaddr) elif isinstance(msg, RelayGetConsensusMsg): self.send_msg(RelayConsensusMsg(self.consensus), peeraddr) elif isinstance(msg, VanillaCreateCircuitMsg): # A new circuit has arrived circhandler = channel.new_circuit_with_circid(msg.circid) # Create the ntor reply reply, secret = NTor.reply(self.onionkey, self.idpubkey, \ msg.ntor_request, self.perfstats) # Set up the circuit to use the shared secret # TODO print('relay secret=', secret) # Send the ntor reply self.send_msg(CircuitCellMsg(msg.circid, VanillaCreatedCircuitMsg(reply)), peeraddr) else: return super().received_msg(msg, peeraddr, channel) def received_cell(self, circid, cell, peeraddr, channel): """Callback with a circuit-specific cell is received.""" print("CellRelay: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr)) return super().received_cell(circid, cell, peeraddr, channel) class Relay(network.Server): """The class representing an onion relay.""" def __init__(self, dirauthaddrs, bw, flags): # Gather performance statistics self.perfstats = network.PerfStats(network.EntType.RELAY) self.perfstats.is_bootstrapping = True # Create the identity and onion keys self.idkey = nacl.signing.SigningKey.generate() self.onionkey = nacl.public.PrivateKey.generate() self.perfstats.keygens += 2 self.name = self.idkey.verify_key.encode(encoder=nacl.encoding.HexEncoder).decode("ascii") # Bind to the network to get a network address self.netaddr = network.thenetwork.bind(self) self.perfstats.name = "Relay at %s" % self.netaddr # Our bandwidth and flags self.bw = bw self.flags = flags # Register for epoch change notification network.thenetwork.wantepochticks(self, True, end=True) network.thenetwork.wantepochticks(self, True) # Create the CellRelay connection manager self.cellhandler = CellRelay(self.netaddr, dirauthaddrs, \ self.onionkey, self.idkey.verify_key, self.perfstats) # Initially, we're not a fallback relay self.is_fallbackrelay = False self.uploaddesc() def terminate(self): """Stop this relay.""" if self.is_fallbackrelay: # Fallback relays must not (for now) terminate raise RelayFallbackTerminationError(self) # Stop listening for epoch ticks network.thenetwork.wantepochticks(self, False, end=True) network.thenetwork.wantepochticks(self, False) # Tell the dirauths we're going away self.uploaddesc(False) # Close connections to other relays self.cellhandler.terminate() # Stop listening to our own bound port self.close() def set_is_fallbackrelay(self, isfallback = True): """Set this relay to be a fallback relay (or unset if passed False).""" self.is_fallbackrelay = isfallback def epoch_ending(self, epoch): # Download the new consensus, which will have been created # already since the dirauths' epoch_ending callbacks happened # before the relays'. self.cellhandler.get_consensus() def newepoch(self, epoch): self.uploaddesc() def uploaddesc(self, upload=True): # Upload the descriptor for the epoch to come, or delete a # previous upload if upload=False descdict = dict(); descdict["epoch"] = network.thenetwork.getepoch() + 1 descdict["idkey"] = self.idkey.verify_key descdict["onionkey"] = self.onionkey.public_key descdict["addr"] = self.netaddr descdict["bw"] = self.bw descdict["flags"] = self.flags desc = dirauth.RelayDescriptor(descdict) desc.sign(self.idkey, self.perfstats) dirauth.RelayDescriptor.verify(desc, self.perfstats) if upload: descmsg = dirauth.DirAuthUploadDescMsg(desc) else: # Note that this relies on signatures being deterministic; # otherwise we'd need to save the descriptor we uploaded # before so we could tell the airauths to delete the exact # one descmsg = dirauth.DirAuthDelDescMsg(desc) # Upload them for a in self.cellhandler.dirauthaddrs: c = network.thenetwork.connect(self, a, self.perfstats) c.sendmsg(descmsg) c.close() def connected(self, peer): """Callback invoked when someone (client or relay) connects to us. Create a pair of linked Channels and return the peer half to the peer.""" # Create the linked pair if peer is self.netaddr: # A self-loop? We'll allow it. peerchannel = Channel() peerchannel.peer = peerchannel peerchannel.next_circid = 2 return peerchannel peerchannel = Channel() ourchannel = Channel() peerchannel.peer = ourchannel peerchannel.next_circid = 2 ourchannel.peer = peerchannel ourchannel.next_circid = 1 # Add our channel to the CellRelay self.cellhandler.add_channel(ourchannel, peer) return peerchannel if __name__ == '__main__': perfstats = network.PerfStats(network.EntType.NONE) # Start some dirauths numdirauths = 9 dirauthaddrs = [] for i in range(numdirauths): dira = dirauth.DirAuth(i, numdirauths) dirauthaddrs.append(dira.netaddr) # Start some relays numrelays = 10 relays = [] for i in range(numrelays): # Relay bandwidths (at least the ones fast enough to get used) # in the live Tor network (as of Dec 2019) are well approximated # by (200000-(200000-25000)/3*log10(x)) where x is a # uniform integer in [1,2500] x = random.randint(1,2500) bw = int(200000-(200000-25000)/3*math.log10(x)) relays.append(Relay(dirauthaddrs, bw, 0)) # The fallback relays are a hardcoded list of about 5% of the # relays, used by clients for bootstrapping numfallbackrelays = int(numrelays * 0.05) + 1 fallbackrelays = random.sample(relays, numfallbackrelays) for r in fallbackrelays: r.set_is_fallbackrelay() network.thenetwork.setfallbackrelays(fallbackrelays) # Tick the epoch network.thenetwork.nextepoch() dirauth.Consensus.verify(dirauth.DirAuth.consensus, \ network.thenetwork.dirauthkeys(), perfstats) print('ticked; epoch=', network.thenetwork.getepoch()) relays[3].cellhandler.send_msg(RelayRandomHopMsg(30), relays[5].netaddr) # See what channels exist and do a consistency check for r in relays: print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellhandler.channels.keys()])) raddr = r.netaddr for ad, ch in r.cellhandler.channels.items(): if ch.peer.cellhandler.myaddr != ad: print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr) if ch.peer.cellhandler.channels[raddr].peer is not ch: print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer) # Stop some relays relays[3].terminate() del relays[3] relays[5].terminate() del relays[5] relays[7].terminate() del relays[7] # Tick the epoch network.thenetwork.nextepoch() print(dirauth.DirAuth.consensus) # See what channels exist and do a consistency check for r in relays: print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellhandler.channels.keys()])) raddr = r.netaddr for ad, ch in r.cellhandler.channels.items(): if ch.peer.cellhandler.myaddr != ad: print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr) if ch.peer.cellhandler.channels[raddr].peer is not ch: print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer) channel = relays[3].cellhandler.get_channel_to(relays[5].netaddr) circid, circhandler = channel.new_circuit() peerchannel = relays[5].cellhandler.get_channel_to(relays[3].netaddr) peerchannel.new_circuit_with_circid(circid) relays[3].cellhandler.send_cell(circid, network.StringNetMsg("test"), relays[5].netaddr) idpubkey = dirauth.DirAuth.consensus.consdict["relays"][1].descdict["idkey"] onionpubkey = dirauth.DirAuth.consensus.consdict["relays"][1].descdict["onionkey"] nt = NTor(perfstats) req = nt.request() R, S = NTor.reply(relays[1].onionkey, idpubkey, req, perfstats) S2 = nt.verify(R, onionpubkey, idpubkey) print(S == S2)