relay.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. #!/usr/bin/env python3
  2. import random # For simulation, not cryptography!
  3. import math
  4. import nacl.utils
  5. import nacl.signing
  6. import nacl.public
  7. import nacl.hash
  8. import network
  9. import dirauth
  10. class RelayNetMsg(network.NetMsg):
  11. """The subclass of NetMsg for messages between relays and either
  12. relays or clients."""
  13. class RelayGetConsensusMsg(RelayNetMsg):
  14. """The subclass of RelayNetMsg for fetching the consensus."""
  15. class RelayConsensusMsg(RelayNetMsg):
  16. """The subclass of RelayNetMsg for returning the consensus."""
  17. def __init__(self, consensus):
  18. self.consensus = consensus
  19. class RelayRandomHopMsg(RelayNetMsg):
  20. """A message used for testing, that hops from relay to relay
  21. randomly until its TTL expires."""
  22. def __init__(self, ttl):
  23. self.ttl = ttl
  24. def __str__(self):
  25. return "RandomHop TTL=%d" % self.ttl
  26. class VanillaCreateCircuitMsg(RelayNetMsg):
  27. """The message for requesting circuit creation in Vanilla Onion
  28. Routing."""
  29. def __init__(self, circid, ntor_request):
  30. self.circid = circid
  31. self.ntor_request = ntor_request
  32. class VanillaCreatedCircuitMsg(RelayNetMsg):
  33. """The message for responding to circuit creation in Vanilla Onion
  34. Routing."""
  35. def __init__(self, ntor_reply):
  36. self.ntor_reply = ntor_reply
  37. class CircuitCellMsg(RelayNetMsg):
  38. """Send a message tagged with a circuit id."""
  39. def __init__(self, circuitid, cell):
  40. self.circid = circuitid
  41. self.cell = cell
  42. def __str__(self):
  43. return "C%d:%s" % (self.circid, self.cell)
  44. def size(self):
  45. # circuitids are 4 bytes
  46. return 4 + self.cell.size()
  47. class RelayFallbackTerminationError(Exception):
  48. """An exception raised when someone tries to terminate a fallback
  49. relay."""
  50. class NTor:
  51. """A class implementing the ntor one-way authenticated key agreement
  52. scheme. The details are not exactly the same as either the ntor
  53. paper or Tor's implementation, but it will agree on keys and have
  54. the same number of public key operations."""
  55. def __init__(self, perfstats):
  56. self.perfstats = perfstats
  57. def request(self):
  58. """Create the ntor request message: X = g^x."""
  59. self.client_ephem_key = nacl.public.PrivateKey.generate()
  60. self.perfstats.keygens += 1
  61. return self.client_ephem_key.public_key
  62. @staticmethod
  63. def reply(onion_privkey, idpubkey, client_pubkey, perfstats):
  64. """The server calls this static method to produce the ntor reply
  65. message: (Y = g^y, B = g^b, A = H(M, "verify")) and the shared
  66. secret S = H(M, "secret") for M = (X^y,X^b,ID,B,X,Y)."""
  67. server_ephem_key = nacl.public.PrivateKey.generate()
  68. perfstats.keygens += 1
  69. xykey = nacl.public.Box(server_ephem_key, client_pubkey).shared_key()
  70. xbkey = nacl.public.Box(onion_privkey, client_pubkey).shared_key()
  71. perfstats.dhs += 2
  72. M = xykey + xbkey + \
  73. idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \
  74. onion_privkey.public_key.encode(encoder=nacl.encoding.RawEncoder) + \
  75. server_ephem_key.public_key.encode(encoder=nacl.encoding.RawEncoder)
  76. A = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
  77. S = nacl.hash.sha256(M + b'secret', encoder=nacl.encoding.RawEncoder)
  78. return ((server_ephem_key.public_key, onion_privkey.public_key, A), \
  79. S)
  80. def verify(self, reply, onion_pubkey, idpubkey):
  81. """The client calls this method to verify the ntor reply
  82. message, passing the onion and id public keys for the server
  83. it's expecting to be talking to . Returns the shared secret on
  84. success, or raises ValueError on failure."""
  85. server_ephem_pubkey, server_onion_pubkey, authtag = reply
  86. if onion_pubkey != server_onion_pubkey:
  87. raise ValueError("NTor onion pubkey mismatch")
  88. xykey = nacl.public.Box(self.client_ephem_key, server_ephem_pubkey).shared_key()
  89. xbkey = nacl.public.Box(self.client_ephem_key, onion_pubkey).shared_key()
  90. self.perfstats.dhs += 2
  91. M = xykey + xbkey + \
  92. idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \
  93. onion_pubkey.encode(encoder=nacl.encoding.RawEncoder) + \
  94. server_ephem_pubkey.encode(encoder=nacl.encoding.RawEncoder)
  95. Acheck = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
  96. S = nacl.hash.sha256(M + b'secret', encoder=nacl.encoding.RawEncoder)
  97. if Acheck != authtag:
  98. raise ValueError("NTor auth mismatch")
  99. return S
  100. class CircuitHandler:
  101. """A class for managing sending and receiving encrypted cells on a
  102. particular circuit."""
  103. def __init__(self, channel, circid):
  104. self.channel = channel
  105. self.circid = circid
  106. self.send_cell = self.channel_send_cell
  107. # The list of relay descriptors that form the circuit so far
  108. # (client side only)
  109. self.circuit_descs = []
  110. # The dispatch table is indexed by type, and the values are
  111. # objects with received_cell(circhandler, cell) methods.
  112. self.cell_dispatch_table = dict()
  113. def channel_send_cell(self, cell):
  114. """Send a cell on this circuit."""
  115. self.channel.send_msg(CircuitCellMsg(self.circid, cell))
  116. def received_cell(self, cell):
  117. """A cell has been received on this circuit. Dispatch it
  118. according to its type."""
  119. celltype = type(cell)
  120. if celltype in self.cell_dispatch_table:
  121. self.cell_dispatch_table[celltype].received_cell(self, cell)
  122. class Channel(network.Connection):
  123. """A class representing a channel between a relay and either a
  124. client or a relay, transporting cells from various circuits."""
  125. def __init__(self):
  126. super().__init__()
  127. # The CellRelay managing this Channel
  128. self.cellhandler = None
  129. # The Channel at the other end
  130. self.peer = None
  131. # The function to call when the connection closes
  132. self.closer = lambda: 0
  133. # The next circuit id to use on this channel. The party that
  134. # opened the channel uses even numbers; the receiving party uses
  135. # odd numbers.
  136. self.next_circid = None
  137. # A map for CircuitHandlers to use for each open circuit on the
  138. # channel
  139. self.circuithandlers = dict()
  140. def closed(self):
  141. self.closer()
  142. self.peer = None
  143. def close(self):
  144. if self.peer is not None and self.peer is not self:
  145. self.peer.closed()
  146. self.closed()
  147. def new_circuit(self):
  148. """Allocate a new circuit on this channel, returning the new
  149. circuit's id and the new CircuitHandler."""
  150. circid = self.next_circid
  151. self.next_circid += 2
  152. circuithandler = CircuitHandler(self, circid)
  153. self.circuithandlers[circid] = circuithandler
  154. return circid, circuithandler
  155. def new_circuit_with_circid(self, circid):
  156. """Allocate a new circuit on this channel, with the circuit id
  157. received from our peer. Return the new CircuitHandler"""
  158. circuithandler = CircuitHandler(self, circid)
  159. self.circuithandlers[circid] = circuithandler
  160. return circuithandler
  161. def send_cell(self, circid, cell):
  162. """Send the given message on the given circuit, encrypting or
  163. decrypting as needed."""
  164. self.circuithandlers[circid].send_cell(cell)
  165. def send_raw_cell(self, circid, cell):
  166. """Send the given message, tagged for the given circuit id. No
  167. encryption or decryption is done."""
  168. self.send_msg(CircuitCellMsg(self.circid, self.cell))
  169. def send_msg(self, msg):
  170. """Send the given NetMsg on the channel."""
  171. self.cellhandler.perfstats.bytes_sent += msg.size()
  172. self.peer.received(self.cellhandler.myaddr, msg)
  173. def received(self, peeraddr, msg):
  174. """Callback when a message is received from the network."""
  175. self.cellhandler.perfstats.bytes_received += msg.size()
  176. if isinstance(msg, CircuitCellMsg):
  177. circid, cell = msg.circid, msg.cell
  178. self.circuithandlers[circid].received_cell(cell)
  179. else:
  180. self.cellhandler.received_msg(msg, peeraddr, self)
  181. class CellHandler:
  182. """The class that manages the channels to other relays and clients.
  183. Relays and clients both use subclasses of this class to both create
  184. on-demand channels to relays, to gracefully handle the closing of
  185. channels, and to handle commands received over the channels."""
  186. def __init__(self, myaddr, dirauthaddrs, perfstats):
  187. # A dictionary of Channels to other hosts, indexed by NetAddr
  188. self.channels = dict()
  189. self.myaddr = myaddr
  190. self.dirauthaddrs = dirauthaddrs
  191. self.consensus = None
  192. self.perfstats = perfstats
  193. def terminate(self):
  194. """Close all connections we're managing."""
  195. while self.channels:
  196. channelitems = iter(self.channels.items())
  197. addr, channel = next(channelitems)
  198. print('closing channel', addr, channel)
  199. channel.close()
  200. def add_channel(self, channel, peeraddr):
  201. """Add the given channel to the list of channels we are
  202. managing. If we are already managing a channel to the same
  203. peer, close it first."""
  204. if peeraddr in self.channels:
  205. self.channels[peeraddr].close()
  206. channel.cellhandler = self
  207. self.channels[peeraddr] = channel
  208. channel.closer = lambda: self.channels.pop(peeraddr)
  209. def get_channel_to(self, addr):
  210. """Get the Channel connected to the given NetAddr, creating one
  211. if none exists right now."""
  212. if addr in self.channels:
  213. return self.channels[addr]
  214. # Create the new channel
  215. newchannel = network.thenetwork.connect(self.myaddr, addr, \
  216. self.perfstats)
  217. self.channels[addr] = newchannel
  218. newchannel.closer = lambda: self.channels.pop(addr)
  219. newchannel.cellhandler = self
  220. return newchannel
  221. def received_msg(self, msg, peeraddr, channel):
  222. """Callback when a NetMsg not specific to a circuit is
  223. received."""
  224. print("CellHandler: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
  225. def received_cell(self, circid, cell, peeraddr, channel):
  226. """Callback with a circuit-specific cell is received."""
  227. print("CellHandler: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
  228. def send_msg(self, msg, peeraddr):
  229. """Send a message to the peer with the given address."""
  230. channel = self.get_channel_to(peeraddr)
  231. channel.send_msg(msg)
  232. def send_cell(self, circid, cell, peeraddr):
  233. """Send a cell on the given circuit to the peer with the given
  234. address."""
  235. channel = self.get_channel_to(peeraddr)
  236. channel.send_cell(circid, cell)
  237. class CellRelay(CellHandler):
  238. """The subclass of CellHandler for relays."""
  239. def __init__(self, myaddr, dirauthaddrs, onionprivkey, idpubkey, perfstats):
  240. super().__init__(myaddr, dirauthaddrs, perfstats)
  241. self.onionkey = onionprivkey
  242. self.idpubkey = idpubkey
  243. def get_consensus(self):
  244. """Download a fresh consensus from a random dirauth."""
  245. a = random.choice(self.dirauthaddrs)
  246. c = network.thenetwork.connect(self, a, self.perfstats)
  247. self.consensus = c.getconsensus()
  248. dirauth.Consensus.verify(self.consensus, \
  249. network.thenetwork.dirauthkeys(), self.perfstats)
  250. c.close()
  251. def received_msg(self, msg, peeraddr, channel):
  252. """Callback when a NetMsg not specific to a circuit is
  253. received."""
  254. print("CellRelay: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
  255. if isinstance(msg, RelayRandomHopMsg):
  256. if msg.ttl > 0:
  257. # Pick a random next hop from the consensus
  258. nexthop = random.choice(self.consensus.consdict['relays'])
  259. nextaddr = nexthop.descdict['addr']
  260. self.send_msg(RelayRandomHopMsg(msg.ttl-1), nextaddr)
  261. elif isinstance(msg, RelayGetConsensusMsg):
  262. self.send_msg(RelayConsensusMsg(self.consensus), peeraddr)
  263. elif isinstance(msg, VanillaCreateCircuitMsg):
  264. # A new circuit has arrived
  265. circhandler = channel.new_circuit_with_circid(msg.circid)
  266. # Create the ntor reply
  267. reply, secret = NTor.reply(self.onionkey, self.idpubkey, \
  268. msg.ntor_request, self.perfstats)
  269. # Set up the circuit to use the shared secret
  270. # TODO
  271. print('relay secret=', secret)
  272. # Send the ntor reply
  273. self.send_msg(CircuitCellMsg(msg.circid, VanillaCreatedCircuitMsg(reply)), peeraddr)
  274. else:
  275. return super().received_msg(msg, peeraddr, channel)
  276. def received_cell(self, circid, cell, peeraddr, channel):
  277. """Callback with a circuit-specific cell is received."""
  278. print("CellRelay: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
  279. return super().received_cell(circid, cell, peeraddr, channel)
  280. class Relay(network.Server):
  281. """The class representing an onion relay."""
  282. def __init__(self, dirauthaddrs, bw, flags):
  283. # Gather performance statistics
  284. self.perfstats = network.PerfStats(network.EntType.RELAY)
  285. self.perfstats.is_bootstrapping = True
  286. # Create the identity and onion keys
  287. self.idkey = nacl.signing.SigningKey.generate()
  288. self.onionkey = nacl.public.PrivateKey.generate()
  289. self.perfstats.keygens += 2
  290. self.name = self.idkey.verify_key.encode(encoder=nacl.encoding.HexEncoder).decode("ascii")
  291. # Bind to the network to get a network address
  292. self.netaddr = network.thenetwork.bind(self)
  293. self.perfstats.name = "Relay at %s" % self.netaddr
  294. # Our bandwidth and flags
  295. self.bw = bw
  296. self.flags = flags
  297. # Register for epoch change notification
  298. network.thenetwork.wantepochticks(self, True, end=True)
  299. network.thenetwork.wantepochticks(self, True)
  300. # Create the CellRelay connection manager
  301. self.cellhandler = CellRelay(self.netaddr, dirauthaddrs, \
  302. self.onionkey, self.idkey.verify_key, self.perfstats)
  303. # Initially, we're not a fallback relay
  304. self.is_fallbackrelay = False
  305. self.uploaddesc()
  306. def terminate(self):
  307. """Stop this relay."""
  308. if self.is_fallbackrelay:
  309. # Fallback relays must not (for now) terminate
  310. raise RelayFallbackTerminationError(self)
  311. # Stop listening for epoch ticks
  312. network.thenetwork.wantepochticks(self, False, end=True)
  313. network.thenetwork.wantepochticks(self, False)
  314. # Tell the dirauths we're going away
  315. self.uploaddesc(False)
  316. # Close connections to other relays
  317. self.cellhandler.terminate()
  318. # Stop listening to our own bound port
  319. self.close()
  320. def set_is_fallbackrelay(self, isfallback = True):
  321. """Set this relay to be a fallback relay (or unset if passed
  322. False)."""
  323. self.is_fallbackrelay = isfallback
  324. def epoch_ending(self, epoch):
  325. # Download the new consensus, which will have been created
  326. # already since the dirauths' epoch_ending callbacks happened
  327. # before the relays'.
  328. self.cellhandler.get_consensus()
  329. def newepoch(self, epoch):
  330. self.uploaddesc()
  331. def uploaddesc(self, upload=True):
  332. # Upload the descriptor for the epoch to come, or delete a
  333. # previous upload if upload=False
  334. descdict = dict();
  335. descdict["epoch"] = network.thenetwork.getepoch() + 1
  336. descdict["idkey"] = self.idkey.verify_key
  337. descdict["onionkey"] = self.onionkey.public_key
  338. descdict["addr"] = self.netaddr
  339. descdict["bw"] = self.bw
  340. descdict["flags"] = self.flags
  341. desc = dirauth.RelayDescriptor(descdict)
  342. desc.sign(self.idkey, self.perfstats)
  343. dirauth.RelayDescriptor.verify(desc, self.perfstats)
  344. if upload:
  345. descmsg = dirauth.DirAuthUploadDescMsg(desc)
  346. else:
  347. # Note that this relies on signatures being deterministic;
  348. # otherwise we'd need to save the descriptor we uploaded
  349. # before so we could tell the airauths to delete the exact
  350. # one
  351. descmsg = dirauth.DirAuthDelDescMsg(desc)
  352. # Upload them
  353. for a in self.cellhandler.dirauthaddrs:
  354. c = network.thenetwork.connect(self, a, self.perfstats)
  355. c.sendmsg(descmsg)
  356. c.close()
  357. def connected(self, peer):
  358. """Callback invoked when someone (client or relay) connects to
  359. us. Create a pair of linked Channels and return the peer half
  360. to the peer."""
  361. # Create the linked pair
  362. if peer is self.netaddr:
  363. # A self-loop? We'll allow it.
  364. peerchannel = Channel()
  365. peerchannel.peer = peerchannel
  366. peerchannel.next_circid = 2
  367. return peerchannel
  368. peerchannel = Channel()
  369. ourchannel = Channel()
  370. peerchannel.peer = ourchannel
  371. peerchannel.next_circid = 2
  372. ourchannel.peer = peerchannel
  373. ourchannel.next_circid = 1
  374. # Add our channel to the CellRelay
  375. self.cellhandler.add_channel(ourchannel, peer)
  376. return peerchannel
  377. if __name__ == '__main__':
  378. perfstats = network.PerfStats(network.EntType.NONE)
  379. # Start some dirauths
  380. numdirauths = 9
  381. dirauthaddrs = []
  382. for i in range(numdirauths):
  383. dira = dirauth.DirAuth(i, numdirauths)
  384. dirauthaddrs.append(dira.netaddr)
  385. # Start some relays
  386. numrelays = 10
  387. relays = []
  388. for i in range(numrelays):
  389. # Relay bandwidths (at least the ones fast enough to get used)
  390. # in the live Tor network (as of Dec 2019) are well approximated
  391. # by (200000-(200000-25000)/3*log10(x)) where x is a
  392. # uniform integer in [1,2500]
  393. x = random.randint(1,2500)
  394. bw = int(200000-(200000-25000)/3*math.log10(x))
  395. relays.append(Relay(dirauthaddrs, bw, 0))
  396. # The fallback relays are a hardcoded list of about 5% of the
  397. # relays, used by clients for bootstrapping
  398. numfallbackrelays = int(numrelays * 0.05) + 1
  399. fallbackrelays = random.sample(relays, numfallbackrelays)
  400. for r in fallbackrelays:
  401. r.set_is_fallbackrelay()
  402. network.thenetwork.setfallbackrelays(fallbackrelays)
  403. # Tick the epoch
  404. network.thenetwork.nextepoch()
  405. dirauth.Consensus.verify(dirauth.DirAuth.consensus, \
  406. network.thenetwork.dirauthkeys(), perfstats)
  407. print('ticked; epoch=', network.thenetwork.getepoch())
  408. relays[3].cellhandler.send_msg(RelayRandomHopMsg(30), relays[5].netaddr)
  409. # See what channels exist and do a consistency check
  410. for r in relays:
  411. print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellhandler.channels.keys()]))
  412. raddr = r.netaddr
  413. for ad, ch in r.cellhandler.channels.items():
  414. if ch.peer.cellhandler.myaddr != ad:
  415. print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr)
  416. if ch.peer.cellhandler.channels[raddr].peer is not ch:
  417. print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer)
  418. # Stop some relays
  419. relays[3].terminate()
  420. del relays[3]
  421. relays[5].terminate()
  422. del relays[5]
  423. relays[7].terminate()
  424. del relays[7]
  425. # Tick the epoch
  426. network.thenetwork.nextepoch()
  427. print(dirauth.DirAuth.consensus)
  428. # See what channels exist and do a consistency check
  429. for r in relays:
  430. print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellhandler.channels.keys()]))
  431. raddr = r.netaddr
  432. for ad, ch in r.cellhandler.channels.items():
  433. if ch.peer.cellhandler.myaddr != ad:
  434. print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr)
  435. if ch.peer.cellhandler.channels[raddr].peer is not ch:
  436. print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer)
  437. channel = relays[3].cellhandler.get_channel_to(relays[5].netaddr)
  438. circid, circhandler = channel.new_circuit()
  439. peerchannel = relays[5].cellhandler.get_channel_to(relays[3].netaddr)
  440. peerchannel.new_circuit_with_circid(circid)
  441. relays[3].cellhandler.send_cell(circid, network.StringNetMsg("test"), relays[5].netaddr)
  442. idpubkey = dirauth.DirAuth.consensus.consdict["relays"][1].descdict["idkey"]
  443. onionpubkey = dirauth.DirAuth.consensus.consdict["relays"][1].descdict["onionkey"]
  444. nt = NTor(perfstats)
  445. req = nt.request()
  446. R, S = NTor.reply(relays[1].onionkey, idpubkey, req, perfstats)
  447. S2 = nt.verify(R, onionpubkey, idpubkey)
  448. print(S == S2)