dirauth.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. #!/usr/bin/env python3
  2. import random # For simulation, not cryptography!
  3. import bisect
  4. import math
  5. import nacl.encoding
  6. import nacl.signing
  7. import network
  8. # A relay descriptor is a dict containing:
  9. # epoch: epoch id
  10. # idkey: a public identity key
  11. # onionkey: a public onion key
  12. # addr: a network address
  13. # bw: bandwidth
  14. # flags: relay flags
  15. # vrfkey: a VRF public key (Single-Pass Walking Onions only)
  16. # sig: a signature over the above by the idkey
  17. class RelayDescriptor:
  18. def __init__(self, descdict):
  19. self.descdict = descdict
  20. def __str__(self, withsig = True):
  21. res = "RelayDesc [\n"
  22. for k in ["epoch", "idkey", "onionkey", "addr", "bw", "flags",
  23. "vrfkey", "sig"]:
  24. if k in self.descdict:
  25. if k == "idkey" or k == "onionkey":
  26. res += " " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
  27. elif k == "sig":
  28. if withsig:
  29. res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
  30. else:
  31. res += " " + k + ": " + str(self.descdict[k]) + "\n"
  32. res += "]\n"
  33. return res
  34. def sign(self, signingkey, perfstats):
  35. serialized = self.__str__(False)
  36. signed = signingkey.sign(serialized.encode("ascii"))
  37. perfstats.sigs += 1
  38. self.descdict["sig"] = signed.signature
  39. @staticmethod
  40. def verify(desc, perfstats):
  41. assert(type(desc) is RelayDescriptor)
  42. serialized = desc.__str__(False)
  43. perfstats.verifs += 1
  44. desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
  45. # A consensus is a dict containing:
  46. # epoch: epoch id
  47. # numrelays: total number of relays
  48. # totbw: total bandwidth of relays
  49. # relays: list of relay descriptors (Vanilla Onion Routing only)
  50. # sigs: list of signatures from the dirauths
  51. class Consensus:
  52. def __init__(self, epoch, relays):
  53. relays = [ d for d in relays if d.descdict['epoch'] == epoch ]
  54. self.consdict = dict()
  55. self.consdict['epoch'] = epoch
  56. self.consdict['numrelays'] = len(relays)
  57. self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
  58. self.consdict['relays'] = relays
  59. def __str__(self, withsigs = True):
  60. res = "Consensus [\n"
  61. for k in ["epoch", "numrelays", "totbw"]:
  62. if k in self.consdict:
  63. res += " " + k + ": " + str(self.consdict[k]) + "\n"
  64. for r in self.consdict['relays']:
  65. res += str(r)
  66. if withsigs and ('sigs' in self.consdict):
  67. for s in self.consdict['sigs']:
  68. res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
  69. res += "]\n"
  70. return res
  71. def sign(self, signingkey, index, perfstats):
  72. """Use the given signing key to sign the consensus, storing the
  73. result in the sigs list at the given index."""
  74. serialized = self.__str__(False)
  75. signed = signingkey.sign(serialized.encode("ascii"))
  76. perfstats.sigs += 1
  77. if 'sigs' not in self.consdict:
  78. self.consdict['sigs'] = []
  79. if index >= len(self.consdict['sigs']):
  80. self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
  81. self.consdict['sigs'][index] = signed.signature
  82. def bw_cdf(self):
  83. """Create the array of cumulative bandwidth values from a consensus.
  84. The array (cdf) will have the same length as the number of relays
  85. in the consensus. cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
  86. cdf = [0]
  87. for r in self.consdict['relays']:
  88. cdf.append(cdf[-1]+r.descdict['bw'])
  89. # Remove the last item, which should be the sum of all the bws
  90. cdf.pop()
  91. print('cdf=', cdf)
  92. return cdf
  93. def select_weighted_relay(self, cdf):
  94. """Use the cdf generated by bw_cdf to select a relay with
  95. probability proportional to its bw weight."""
  96. totbw = self.consdict['totbw']
  97. if totbw < 1:
  98. raise ValueError("No relays to choose from")
  99. val = random.randint(0, totbw-1)
  100. # Find the rightmost entry less than or equal to val
  101. idx = bisect.bisect_right(cdf, val)
  102. return self.consdict['relays'][idx-1]
  103. @staticmethod
  104. def verify(consensus, verifkeylist, perfstats):
  105. """Use the given list of verification keys to check the
  106. signatures on the consensus."""
  107. assert(type(consensus) is Consensus)
  108. serialized = consensus.__str__(False)
  109. for i, vk in enumerate(verifkeylist):
  110. perfstats.verifs += 1
  111. vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
  112. class DirAuthNetMsg(network.NetMsg):
  113. """The subclass of NetMsg for messages to and from directory
  114. authorities."""
  115. class DirAuthUploadDescMsg(DirAuthNetMsg):
  116. """The subclass of DirAuthNetMsg for uploading a relay
  117. descriptor."""
  118. def __init__(self, desc):
  119. self.desc = desc
  120. class DirAuthDelDescMsg(DirAuthNetMsg):
  121. """The subclass of DirAuthNetMsg for deleting a relay
  122. descriptor."""
  123. def __init__(self, desc):
  124. self.desc = desc
  125. class DirAuthGetConsensusMsg(DirAuthNetMsg):
  126. """The subclass of DirAuthNetMsg for fetching the consensus."""
  127. class DirAuthConsensusMsg(DirAuthNetMsg):
  128. """The subclass of DirAuthNetMsg for returning the consensus."""
  129. def __init__(self, consensus):
  130. self.consensus = consensus
  131. class DirAuthGetConsensusDiffMsg(DirAuthNetMsg):
  132. """The subclass of DirAuthNetMsg for fetching the consensus, if the
  133. requestor already has the previous consensus."""
  134. class DirAuthConsensusDiffMsg(DirAuthNetMsg):
  135. """The subclass of DirAuthNetMsg for returning the consensus, if the
  136. requestor already has the previous consensus. We don't _actually_
  137. produce the diff at this time; we just charge fewer bytes for this
  138. message."""
  139. def __init__(self, consensus):
  140. self.consensus = consensus
  141. def size(self):
  142. if network.symbolic_byte_counters:
  143. return super().size()
  144. return math.ceil(DirAuthConsensusMsg(self.consensus).size() \
  145. * network.P_Delta)
  146. class DirAuthGetENDIVEMsg(DirAuthNetMsg):
  147. """The subclass of DirAuthNetMsg for fetching the ENDIVE."""
  148. class DirAuthENDIVEMsg(DirAuthNetMsg):
  149. """The subclass of DirAuthNetMsg for returning the ENDIVE."""
  150. def __init__(self, endive):
  151. self.endive = endive
  152. class DirAuthConnection(network.ClientConnection):
  153. """The subclass of Connection for connections to directory
  154. authorities."""
  155. def __init__(self, peer):
  156. super().__init__(peer)
  157. def uploaddesc(self, desc):
  158. """Upload our RelayDescriptor to the DirAuth."""
  159. self.sendmsg(DirAuthUploadDescMeg(desc))
  160. def getconsensus(self):
  161. self.consensus = None
  162. self.sendmsg(DirAuthGetConsensusMsg())
  163. return self.consensus
  164. def getconsensusdiff(self):
  165. self.consensus = None
  166. self.sendmsg(DirAuthGetConsensusDiffMsg())
  167. return self.consensus
  168. def getENDIVE(self):
  169. self.endive = None
  170. self.sendmsg(DirAuthGetENDIVEMsg())
  171. return self.endive
  172. def receivedfromserver(self, msg):
  173. if isinstance(msg, DirAuthConsensusMsg):
  174. self.consensus = msg.consensus
  175. elif isinstance(msg, DirAuthConsensusDiffMsg):
  176. self.consensus = msg.consensus
  177. elif isinstance(msg, DirAuthENDIVEMsg):
  178. self.endive = msg.endive
  179. else:
  180. raise TypeError('Not a server-originating DirAuthNetMsg', msg)
  181. class DirAuth(network.Server):
  182. """The class representing directory authorities."""
  183. # We simulate the act of computing the consensus by keeping a
  184. # class-static dict that's accessible to all of the dirauths
  185. # This dict is indexed by epoch, and the value is itself a dict
  186. # indexed by the stringified descriptor, with value a pair of (the
  187. # number of dirauths that saw that descriptor, the descriptor
  188. # itself).
  189. uploadeddescs = dict()
  190. consensus = None
  191. endive = None
  192. def __init__(self, me, tot):
  193. """Create a new directory authority. me is the index of which
  194. dirauth this one is (starting from 0), and tot is the total
  195. number of dirauths."""
  196. self.me = me
  197. self.tot = tot
  198. self.name = "Dirauth %d of %d" % (me+1, tot)
  199. self.perfstats = network.PerfStats(network.EntType.DIRAUTH)
  200. self.perfstats.is_bootstrapping = True
  201. # Create the dirauth signature keypair
  202. self.sigkey = nacl.signing.SigningKey.generate()
  203. self.perfstats.keygens += 1
  204. self.netaddr = network.thenetwork.bind(self)
  205. self.perfstats.name = "DirAuth at %s" % self.netaddr
  206. network.thenetwork.setdirauthkey(me, self.sigkey.verify_key)
  207. network.thenetwork.wantepochticks(self, True, True)
  208. def connected(self, client):
  209. """Callback invoked when a client connects to us. This callback
  210. creates the DirAuthConnection that will be passed to the
  211. client."""
  212. # We don't actually need to keep per-connection state at
  213. # dirauths, even in long-lived connections, so this is
  214. # particularly simple.
  215. return DirAuthConnection(self)
  216. def generate_consensus(self, epoch):
  217. """Generate the consensus (and ENDIVE, if using Walking Onions)
  218. for the given epoch, which should be the one after the one
  219. that's currently about to end."""
  220. threshold = int(self.tot/2)+1
  221. consensusdescs = []
  222. for numseen, desc in DirAuth.uploadeddescs[epoch].values():
  223. if numseen >= threshold:
  224. consensusdescs.append(desc)
  225. DirAuth.consensus = Consensus(epoch, consensusdescs)
  226. def epoch_ending(self, epoch):
  227. # Only dirauth 0 actually needs to generate the consensus
  228. # because of the shared class-static state, but everyone has to
  229. # sign it. Note that this code relies on dirauth 0's
  230. # epoch_ending callback being called before any of the other
  231. # dirauths'.
  232. if (epoch+1) not in DirAuth.uploadeddescs:
  233. DirAuth.uploadeddescs[epoch+1] = dict()
  234. if self.me == 0:
  235. self.generate_consensus(epoch+1)
  236. del DirAuth.uploadeddescs[epoch+1]
  237. DirAuth.consensus.sign(self.sigkey, self.me, self.perfstats)
  238. def received(self, client, msg):
  239. self.perfstats.bytes_received += msg.size()
  240. if isinstance(msg, DirAuthUploadDescMsg):
  241. # Check the uploaded descriptor for sanity
  242. epoch = msg.desc.descdict['epoch']
  243. if epoch != network.thenetwork.getepoch() + 1:
  244. return
  245. # Store it in the class-static dict
  246. if epoch not in DirAuth.uploadeddescs:
  247. DirAuth.uploadeddescs[epoch] = dict()
  248. descstr = str(msg.desc)
  249. if descstr not in DirAuth.uploadeddescs[epoch]:
  250. DirAuth.uploadeddescs[epoch][descstr] = (1, msg.desc)
  251. else:
  252. DirAuth.uploadeddescs[epoch][descstr] = \
  253. (DirAuth.uploadeddescs[epoch][descstr][0]+1,
  254. DirAuth.uploadeddescs[epoch][descstr][1])
  255. elif isinstance(msg, DirAuthDelDescMsg):
  256. # Check the uploaded descriptor for sanity
  257. epoch = msg.desc.descdict['epoch']
  258. if epoch != network.thenetwork.getepoch() + 1:
  259. return
  260. # Remove it from the class-static dict
  261. if epoch not in DirAuth.uploadeddescs:
  262. return
  263. descstr = str(msg.desc)
  264. if descstr not in DirAuth.uploadeddescs[epoch]:
  265. return
  266. elif DirAuth.uploadeddescs[epoch][descstr][0] == 1:
  267. del DirAuth.uploadeddescs[epoch][descstr]
  268. else:
  269. DirAuth.uploadeddescs[epoch][descstr] = \
  270. (DirAuth.uploadeddescs[epoch][descstr][0]-1,
  271. DirAuth.uploadeddescs[epoch][descstr][1])
  272. elif isinstance(msg, DirAuthGetConsensusMsg):
  273. replymsg = DirAuthConsensusMsg(DirAuth.consensus)
  274. msgsize = replymsg.size()
  275. self.perfstats.bytes_sent += msgsize
  276. client.reply(replymsg)
  277. elif isinstance(msg, DirAuthGetConsensusDiffMsg):
  278. replymsg = DirAuthConsensusDiffMsg(DirAuth.consensus)
  279. msgsize = replymsg.size()
  280. self.perfstats.bytes_sent += msgsize
  281. client.reply(replymsg)
  282. elif isinstance(msg, DirAuthGetENDIVEMsg):
  283. replymsg = DirAuthENDIVEMsg(DirAuth.endive)
  284. msgsize = replymsg.size()
  285. self.perfstats.bytes_sent += msgsize
  286. client.reply(replymsg)
  287. else:
  288. raise TypeError('Not a client-originating DirAuthNetMsg', msg)
  289. def closed(self):
  290. pass
  291. if __name__ == '__main__':
  292. # Start some dirauths
  293. numdirauths = 9
  294. dirauthaddrs = []
  295. for i in range(numdirauths):
  296. dirauth = DirAuth(i, numdirauths)
  297. dirauthaddrs.append(dirauth.netaddr)
  298. for a in dirauthaddrs:
  299. print(a,end=' ')
  300. print()
  301. network.thenetwork.nextepoch()