relay.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. #!/usr/bin/env python3
  2. import random # For simulation, not cryptography!
  3. import math
  4. import nacl.utils
  5. import nacl.signing
  6. import nacl.public
  7. import network
  8. import dirauth
  9. class RelayNetMsg(network.NetMsg):
  10. """The subclass of NetMsg for messages between relays and either
  11. relays or clients."""
  12. class RelayGetConsensusMsg(RelayNetMsg):
  13. """The subclass of RelayNetMsg for fetching the consensus."""
  14. class RelayConsensusMsg(RelayNetMsg):
  15. """The subclass of RelayNetMsg for returning the consensus."""
  16. def __init__(self, consensus):
  17. self.consensus = consensus
  18. class RelayRandomHopMsg(RelayNetMsg):
  19. """A message used for testing, that hops from relay to relay
  20. randomly until its TTL expires."""
  21. def __init__(self, ttl):
  22. self.ttl = ttl
  23. def __str__(self):
  24. return "RandomHop TTL=%d" % self.ttl
  25. class VanillaCreateCircuitMsg(RelayNetMsg):
  26. """The message for requesting circuit creation in Vanilla Onion
  27. Routing."""
  28. def __init__(self, circid, ntor_request):
  29. self.circid = circid
  30. self.ntor_request = ntor_request
  31. class VanillaCreatedCircuitMsg(RelayNetMsg):
  32. """The message for responding to circuit creation in Vanilla Onion
  33. Routing."""
  34. def __init__(self, circid, ntor_response):
  35. self.circid = circid
  36. self.ntor_response = ntor_response
  37. class CircuitCellMsg(RelayNetMsg):
  38. """Send a message tagged with a circuit id."""
  39. def __init__(self, circuitid, cell):
  40. self.circid = circuitid
  41. self.cell = cell
  42. def __str__(self):
  43. return "C%d:%s" % (self.circid, self.cell)
  44. class RelayFallbackTerminationError(Exception):
  45. """An exception raised when someone tries to terminate a fallback
  46. relay."""
  47. class CircuitHandler:
  48. """A class for managing sending and receiving encrypted cells on a
  49. particular circuit."""
  50. def __init__(self, channel, circid):
  51. self.channel = channel
  52. self.circid = circid
  53. self.send_cell = self.channel_send_cell
  54. self.received_cell = self.channel_received_cell
  55. def channel_send_cell(self, cell):
  56. """Send a cell on this circuit."""
  57. self.channel.send_msg(CircuitCellMsg(self.circid, cell))
  58. def channel_received_cell(self, cell, peeraddr, peer):
  59. """A cell has been received on this circuit. Forward it to the
  60. channel's received_cell callback."""
  61. self.channel.cellhandler.received_cell(self.circid, cell, peeraddr, peer)
  62. class Channel(network.Connection):
  63. """A class representing a channel between a relay and either a
  64. client or a relay, transporting cells from various circuits."""
  65. def __init__(self):
  66. super().__init__()
  67. # The CellRelay managing this Channel
  68. self.cellhandler = None
  69. # The Channel at the other end
  70. self.peer = None
  71. # The function to call when the connection closes
  72. self.closer = lambda: 0
  73. # The next circuit id to use on this channel. The party that
  74. # opened the channel uses even numbers; the receiving party uses
  75. # odd numbers.
  76. self.next_circid = None
  77. # A map for CircuitHandlers to use for each open circuit on the
  78. # channel
  79. self.circuithandlers = dict()
  80. def closed(self):
  81. self.closer()
  82. self.peer = None
  83. def close(self):
  84. if self.peer is not None and self.peer is not self:
  85. self.peer.closed()
  86. self.closed()
  87. def new_circuit(self):
  88. """Allocate a new circuit on this channel, returning the new
  89. circuit's id."""
  90. circid = self.next_circid
  91. self.next_circid += 2
  92. self.circuithandlers[circid] = CircuitHandler(self, circid)
  93. return circid
  94. def new_circuit_with_circid(self, circid):
  95. """Allocate a new circuit on this channel, with the circuit id
  96. received from our peer."""
  97. self.circuithandlers[circid] = CircuitHandler(self, circid)
  98. def send_cell(self, circid, cell):
  99. """Send the given message on the given circuit, encrypting or
  100. decrypting as needed."""
  101. self.circuithandlers[circid].send_cell(cell)
  102. def send_raw_cell(self, circid, cell):
  103. """Send the given message, tagged for the given circuit id. No
  104. encryption or decryption is done."""
  105. self.send_msg(CircuitCellMsg(self.circid, self.cell))
  106. def send_msg(self, msg):
  107. """Send the given NetMsg on the channel."""
  108. self.cellhandler.perfstats.bytes_sent += msg.size()
  109. self.peer.received(self.cellhandler.myaddr, msg)
  110. def received(self, peeraddr, msg):
  111. """Callback when a message is received from the network."""
  112. self.cellhandler.perfstats.bytes_received += msg.size()
  113. if isinstance(msg, CircuitCellMsg):
  114. circid, cell = msg.circid, msg.cell
  115. self.circuithandlers[circid].received_cell(cell, peeraddr, self.peer)
  116. else:
  117. self.cellhandler.received_msg(msg, peeraddr, self.peer)
  118. class CellHandler:
  119. """The class that manages the channels to other relays and clients.
  120. Relays and clients both use subclasses of this class to both create
  121. on-demand channels to relays, to gracefully handle the closing of
  122. channels, and to handle commands received over the channels."""
  123. def __init__(self, myaddr, dirauthaddrs, perfstats):
  124. # A dictionary of Channels to other hosts, indexed by NetAddr
  125. self.channels = dict()
  126. self.myaddr = myaddr
  127. self.dirauthaddrs = dirauthaddrs
  128. self.consensus = None
  129. self.perfstats = perfstats
  130. def terminate(self):
  131. """Close all connections we're managing."""
  132. while self.channels:
  133. channelitems = iter(self.channels.items())
  134. addr, channel = next(channelitems)
  135. print('closing channel', addr, channel)
  136. channel.close()
  137. def add_channel(self, channel, peeraddr):
  138. """Add the given channel to the list of channels we are
  139. managing. If we are already managing a channel to the same
  140. peer, close it first."""
  141. if peeraddr in self.channels:
  142. self.channels[peeraddr].close()
  143. channel.cellhandler = self
  144. self.channels[peeraddr] = channel
  145. channel.closer = lambda: self.channels.pop(peeraddr)
  146. def get_channel_to(self, addr):
  147. """Get the Channel connected to the given NetAddr, creating one
  148. if none exists right now."""
  149. if addr in self.channels:
  150. return self.channels[addr]
  151. # Create the new channel
  152. newchannel = network.thenetwork.connect(self.myaddr, addr, \
  153. self.perfstats)
  154. self.channels[addr] = newchannel
  155. newchannel.closer = lambda: self.channels.pop(addr)
  156. newchannel.cellhandler = self
  157. return newchannel
  158. def received_msg(self, msg, peeraddr, peer):
  159. """Callback when a NetMsg not specific to a circuit is
  160. received."""
  161. print("CellHandler: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
  162. def received_cell(self, circid, cell, peeraddr, peer):
  163. """Callback with a circuit-specific cell is received."""
  164. print("CellHandler: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
  165. def send_msg(self, msg, peeraddr):
  166. """Send a message to the peer with the given address."""
  167. channel = self.get_channel_to(peeraddr)
  168. channel.send_msg(msg)
  169. def send_cell(self, circid, cell, peeraddr):
  170. """Send a cell on the given circuit to the peer with the given
  171. address."""
  172. channel = self.get_channel_to(peeraddr)
  173. channel.send_cell(circid, cell)
  174. class CellRelay(CellHandler):
  175. """The subclass of CellHandler for relays."""
  176. def __init__(self, myaddr, dirauthaddrs, perfstats):
  177. super().__init__(myaddr, dirauthaddrs, perfstats)
  178. def get_consensus(self):
  179. """Download a fresh consensus from a random dirauth."""
  180. a = random.choice(self.dirauthaddrs)
  181. c = network.thenetwork.connect(self, a, self.perfstats)
  182. self.consensus = c.getconsensus()
  183. dirauth.Consensus.verify(self.consensus, \
  184. network.thenetwork.dirauthkeys(), self.perfstats)
  185. c.close()
  186. def received_msg(self, msg, peeraddr, peer):
  187. """Callback when a NetMsg not specific to a circuit is
  188. received."""
  189. print("CellRelay: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
  190. if isinstance(msg, RelayRandomHopMsg):
  191. if msg.ttl > 0:
  192. # Pick a random next hop from the consensus
  193. nexthop = random.choice(self.consensus.consdict['relays'])
  194. nextaddr = nexthop.descdict['addr']
  195. self.send_msg(RelayRandomHopMsg(msg.ttl-1), nextaddr)
  196. elif isinstance(msg, RelayGetConsensusMsg):
  197. self.send_msg(RelayConsensusMsg(self.consensus), peeraddr)
  198. else:
  199. return super().received_msg(msg, peeraddr, peer)
  200. def received_cell(self, circid, cell, peeraddr, peer):
  201. """Callback with a circuit-specific cell is received."""
  202. print("CellRelay: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
  203. return super().received_cell(circid, cell, peeraddr, peer)
  204. class Relay(network.Server):
  205. """The class representing an onion relay."""
  206. def __init__(self, dirauthaddrs, bw, flags):
  207. # Gather performance statistics
  208. self.perfstats = network.PerfStats(network.EntType.RELAY)
  209. self.perfstats.is_bootstrapping = True
  210. # Create the identity and onion keys
  211. self.idkey = nacl.signing.SigningKey.generate()
  212. self.onionkey = nacl.public.PrivateKey.generate()
  213. self.perfstats.keygens += 2
  214. self.name = self.idkey.verify_key.encode(encoder=nacl.encoding.HexEncoder).decode("ascii")
  215. # Bind to the network to get a network address
  216. self.netaddr = network.thenetwork.bind(self)
  217. self.perfstats.name = "Relay at %s" % self.netaddr
  218. # Our bandwidth and flags
  219. self.bw = bw
  220. self.flags = flags
  221. # Register for epoch change notification
  222. network.thenetwork.wantepochticks(self, True, end=True)
  223. network.thenetwork.wantepochticks(self, True)
  224. # Create the CellRelay connection manager
  225. self.cellhandler = CellRelay(self.netaddr, dirauthaddrs, self.perfstats)
  226. # Initially, we're not a fallback relay
  227. self.is_fallbackrelay = False
  228. self.uploaddesc()
  229. def terminate(self):
  230. """Stop this relay."""
  231. if self.is_fallbackrelay:
  232. # Fallback relays must not (for now) terminate
  233. raise RelayFallbackTerminationError(self)
  234. # Stop listening for epoch ticks
  235. network.thenetwork.wantepochticks(self, False, end=True)
  236. network.thenetwork.wantepochticks(self, False)
  237. # Tell the dirauths we're going away
  238. self.uploaddesc(False)
  239. # Close connections to other relays
  240. self.cellhandler.terminate()
  241. # Stop listening to our own bound port
  242. self.close()
  243. def set_is_fallbackrelay(self, isfallback = True):
  244. """Set this relay to be a fallback relay (or unset if passed
  245. False)."""
  246. self.is_fallbackrelay = isfallback
  247. def epoch_ending(self, epoch):
  248. # Download the new consensus, which will have been created
  249. # already since the dirauths' epoch_ending callbacks happened
  250. # before the relays'.
  251. self.cellhandler.get_consensus()
  252. def newepoch(self, epoch):
  253. self.uploaddesc()
  254. def uploaddesc(self, upload=True):
  255. # Upload the descriptor for the epoch to come, or delete a
  256. # previous upload if upload=False
  257. descdict = dict();
  258. descdict["epoch"] = network.thenetwork.getepoch() + 1
  259. descdict["idkey"] = self.idkey.verify_key
  260. descdict["onionkey"] = self.onionkey.public_key
  261. descdict["addr"] = self.netaddr
  262. descdict["bw"] = self.bw
  263. descdict["flags"] = self.flags
  264. desc = dirauth.RelayDescriptor(descdict)
  265. desc.sign(self.idkey, self.perfstats)
  266. dirauth.RelayDescriptor.verify(desc, self.perfstats)
  267. if upload:
  268. descmsg = dirauth.DirAuthUploadDescMsg(desc)
  269. else:
  270. # Note that this relies on signatures being deterministic;
  271. # otherwise we'd need to save the descriptor we uploaded
  272. # before so we could tell the airauths to delete the exact
  273. # one
  274. descmsg = dirauth.DirAuthDelDescMsg(desc)
  275. # Upload them
  276. for a in self.cellhandler.dirauthaddrs:
  277. c = network.thenetwork.connect(self, a, self.perfstats)
  278. c.sendmsg(descmsg)
  279. c.close()
  280. def connected(self, peer):
  281. """Callback invoked when someone (client or relay) connects to
  282. us. Create a pair of linked Channels and return the peer half
  283. to the peer."""
  284. # Create the linked pair
  285. if peer is self.netaddr:
  286. # A self-loop? We'll allow it.
  287. peerchannel = Channel()
  288. peerchannel.peer = peerchannel
  289. peerchannel.next_circid = 2
  290. return peerchannel
  291. peerchannel = Channel()
  292. ourchannel = Channel()
  293. peerchannel.peer = ourchannel
  294. peerchannel.next_circid = 2
  295. ourchannel.peer = peerchannel
  296. ourchannel.next_circid = 1
  297. # Add our channel to the CellRelay
  298. self.cellhandler.add_channel(ourchannel, peer)
  299. return peerchannel
  300. if __name__ == '__main__':
  301. perfstats = network.PerfStats(network.EntType.NONE)
  302. # Start some dirauths
  303. numdirauths = 9
  304. dirauthaddrs = []
  305. for i in range(numdirauths):
  306. dira = dirauth.DirAuth(i, numdirauths)
  307. dirauthaddrs.append(dira.netaddr)
  308. # Start some relays
  309. numrelays = 10
  310. relays = []
  311. for i in range(numrelays):
  312. # Relay bandwidths (at least the ones fast enough to get used)
  313. # in the live Tor network (as of Dec 2019) are well approximated
  314. # by (200000-(200000-25000)/3*log10(x)) where x is a
  315. # uniform integer in [1,2500]
  316. x = random.randint(1,2500)
  317. bw = int(200000-(200000-25000)/3*math.log10(x))
  318. relays.append(Relay(dirauthaddrs, bw, 0))
  319. # The fallback relays are a hardcoded list of about 5% of the
  320. # relays, used by clients for bootstrapping
  321. numfallbackrelays = int(numrelays * 0.05) + 1
  322. fallbackrelays = random.sample(relays, numfallbackrelays)
  323. for r in fallbackrelays:
  324. r.set_is_fallbackrelay()
  325. network.thenetwork.setfallbackrelays(fallbackrelays)
  326. # Tick the epoch
  327. network.thenetwork.nextepoch()
  328. dirauth.Consensus.verify(dirauth.DirAuth.consensus, \
  329. network.thenetwork.dirauthkeys(), perfstats)
  330. print('ticked; epoch=', network.thenetwork.getepoch())
  331. relays[3].cellhandler.send_msg(RelayRandomHopMsg(30), relays[5].netaddr)
  332. # See what channels exist and do a consistency check
  333. for r in relays:
  334. print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellhandler.channels.keys()]))
  335. raddr = r.netaddr
  336. for ad, ch in r.cellhandler.channels.items():
  337. if ch.peer.cellhandler.myaddr != ad:
  338. print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr)
  339. if ch.peer.cellhandler.channels[raddr].peer is not ch:
  340. print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer)
  341. # Stop some relays
  342. relays[3].terminate()
  343. del relays[3]
  344. relays[5].terminate()
  345. del relays[5]
  346. relays[7].terminate()
  347. del relays[7]
  348. # Tick the epoch
  349. network.thenetwork.nextepoch()
  350. print(dirauth.DirAuth.consensus)
  351. # See what channels exist and do a consistency check
  352. for r in relays:
  353. print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellhandler.channels.keys()]))
  354. raddr = r.netaddr
  355. for ad, ch in r.cellhandler.channels.items():
  356. if ch.peer.cellhandler.myaddr != ad:
  357. print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr)
  358. if ch.peer.cellhandler.channels[raddr].peer is not ch:
  359. print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer)
  360. channel = relays[3].cellhandler.get_channel_to(relays[5].netaddr)
  361. circid = channel.new_circuit()
  362. peerchannel = relays[5].cellhandler.get_channel_to(relays[3].netaddr)
  363. peerchannel.new_circuit_with_circid(circid)
  364. relays[3].cellhandler.send_cell(circid, network.StringNetMsg("test"), relays[5].netaddr)