dirauth.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. #!/usr/bin/env python3
  2. import random # For simulation, not cryptography!
  3. import bisect
  4. import nacl.encoding
  5. import nacl.signing
  6. import network
  7. from enum import Enum
  8. class EntType(Enum):
  9. """The different types of entities in the system."""
  10. NONE = 0
  11. DIRAUTH = 1
  12. RELAY = 2
  13. CLIENT = 3
  14. class PerfStats:
  15. """A class to store performance statistics for a relay or client.
  16. We keep track of bytes sent, bytes received, and counts of
  17. public-key operations of various types. We will reset these every
  18. epoch."""
  19. def __init__(self, ent_type):
  20. # Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
  21. self.ent_type = ent_type
  22. # A printable name for the entity
  23. self.name = None
  24. # True if bootstrapping this epoch
  25. self.is_bootstrapping = False
  26. # Bytes sent and received
  27. self.bytes_sent = 0
  28. self.bytes_received = 0
  29. # Public-key operations: key generation, signing, verification,
  30. # Diffie-Hellman
  31. self.keygens = 0
  32. self.sigs = 0
  33. self.verifs = 0
  34. self.dhs = 0
  35. def __str__(self):
  36. return "%s: type=%s boot=%s sent=%d recv=%d keygen=%d sig=%d verif=%d dh=%d" % \
  37. (self.name, self.ent_type.name, self.is_bootstrapping, \
  38. self.bytes_sent, self.bytes_received, self.keygens, \
  39. self.sigs, self.verifs, self.dhs)
  40. # A relay descriptor is a dict containing:
  41. # epoch: epoch id
  42. # idkey: a public identity key
  43. # onionkey: a public onion key
  44. # addr: a network address
  45. # bw: bandwidth
  46. # flags: relay flags
  47. # vrfkey: a VRF public key (Single-Pass Walking Onions only)
  48. # sig: a signature over the above by the idkey
  49. class RelayDescriptor:
  50. def __init__(self, descdict):
  51. self.descdict = descdict
  52. def __str__(self, withsig = True):
  53. res = "RelayDesc [\n"
  54. for k in ["epoch", "idkey", "onionkey", "addr", "bw", "flags",
  55. "vrfkey", "sig"]:
  56. if k in self.descdict:
  57. if k == "idkey" or k == "onionkey":
  58. res += " " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
  59. elif k == "sig":
  60. if withsig:
  61. res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
  62. else:
  63. res += " " + k + ": " + str(self.descdict[k]) + "\n"
  64. res += "]\n"
  65. return res
  66. def sign(self, signingkey, perfstats):
  67. serialized = self.__str__(False)
  68. signed = signingkey.sign(serialized.encode("ascii"))
  69. perfstats.sigs += 1
  70. self.descdict["sig"] = signed.signature
  71. @staticmethod
  72. def verify(desc, perfstats):
  73. assert(type(desc) is RelayDescriptor)
  74. serialized = desc.__str__(False)
  75. perfstats.verifs += 1
  76. desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
  77. # A consensus is a dict containing:
  78. # epoch: epoch id
  79. # numrelays: total number of relays
  80. # totbw: total bandwidth of relays
  81. # relays: list of relay descriptors (Vanilla Onion Routing only)
  82. # sigs: list of signatures from the dirauths
  83. class Consensus:
  84. def __init__(self, epoch, relays):
  85. relays = [ d for d in relays if d.descdict['epoch'] == epoch ]
  86. self.consdict = dict()
  87. self.consdict['epoch'] = epoch
  88. self.consdict['numrelays'] = len(relays)
  89. self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
  90. self.consdict['relays'] = relays
  91. def __str__(self, withsigs = True):
  92. res = "Consensus [\n"
  93. for k in ["epoch", "numrelays", "totbw"]:
  94. if k in self.consdict:
  95. res += " " + k + ": " + str(self.consdict[k]) + "\n"
  96. for r in self.consdict['relays']:
  97. res += str(r)
  98. if withsigs and ('sigs' in self.consdict):
  99. for s in self.consdict['sigs']:
  100. res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
  101. res += "]\n"
  102. return res
  103. def sign(self, signingkey, index, perfstats):
  104. """Use the given signing key to sign the consensus, storing the
  105. result in the sigs list at the given index."""
  106. serialized = self.__str__(False)
  107. signed = signingkey.sign(serialized.encode("ascii"))
  108. perfstats.sigs += 1
  109. if 'sigs' not in self.consdict:
  110. self.consdict['sigs'] = []
  111. if index >= len(self.consdict['sigs']):
  112. self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
  113. self.consdict['sigs'][index] = signed.signature
  114. def bw_cdf(self):
  115. """Create the array of cumulative bandwidth values from a consensus.
  116. The array (cdf) will have the same length as the number of relays
  117. in the consensus. cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
  118. cdf = [0]
  119. for r in self.consdict['relays']:
  120. cdf.append(cdf[-1]+r.descdict['bw'])
  121. # Remove the last item, which should be the sum of all the bws
  122. cdf.pop()
  123. print('cdf=', cdf)
  124. return cdf
  125. def select_weighted_relay(self, cdf):
  126. """Use the cdf generated by bw_cdf to select a relay with
  127. probability proportional to its bw weight."""
  128. totbw = self.consdict['totbw']
  129. if totbw < 1:
  130. raise ValueError("No relays to choose from")
  131. val = random.randint(0, totbw-1)
  132. # Find the rightmost entry less than or equal to val
  133. idx = bisect.bisect_right(cdf, val)
  134. return self.consdict['relays'][idx-1]
  135. @staticmethod
  136. def verify(consensus, verifkeylist, perfstats):
  137. """Use the given list of verification keys to check the
  138. signatures on the consensus."""
  139. assert(type(consensus) is Consensus)
  140. serialized = consensus.__str__(False)
  141. for i, vk in enumerate(verifkeylist):
  142. perfstats.verifs += 1
  143. vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
  144. class DirAuthNetMsg(network.NetMsg):
  145. """The subclass of NetMsg for messages to and from directory
  146. authorities."""
  147. class DirAuthUploadDescMsg(DirAuthNetMsg):
  148. """The subclass of DirAuthNetMsg for uploading a relay
  149. descriptor."""
  150. def __init__(self, desc):
  151. self.desc = desc
  152. class DirAuthDelDescMsg(DirAuthNetMsg):
  153. """The subclass of DirAuthNetMsg for deleting a relay
  154. descriptor."""
  155. def __init__(self, desc):
  156. self.desc = desc
  157. class DirAuthGetConsensusMsg(DirAuthNetMsg):
  158. """The subclass of DirAuthNetMsg for fetching the consensus."""
  159. class DirAuthConsensusMsg(DirAuthNetMsg):
  160. """The subclass of DirAuthNetMsg for returning the consensus."""
  161. def __init__(self, consensus):
  162. self.consensus = consensus
  163. class DirAuthGetENDIVEMsg(DirAuthNetMsg):
  164. """The subclass of DirAuthNetMsg for fetching the ENDIVE."""
  165. class DirAuthENDIVEMsg(DirAuthNetMsg):
  166. """The subclass of DirAuthNetMsg for returning the ENDIVE."""
  167. def __init__(self, endive):
  168. self.endive = endive
  169. class DirAuthConnection(network.ClientConnection):
  170. """The subclass of Connection for connections to directory
  171. authorities."""
  172. def __init__(self, peer = None):
  173. super().__init__(peer)
  174. def uploaddesc(self, desc):
  175. """Upload our RelayDescriptor to the DirAuth."""
  176. self.sendmsg(DirAuthUploadDescMeg(desc))
  177. def getconsensus(self):
  178. self.consensus = None
  179. self.sendmsg(DirAuthGetConsensusMsg())
  180. return self.consensus
  181. def getENDIVE(self):
  182. self.endive = None
  183. self.sendmsg(DirAuthGetENDIVEMsg())
  184. return self.endive
  185. def receivedfromserver(self, msg):
  186. if isinstance(msg, DirAuthConsensusMsg):
  187. self.consensus = msg.consensus
  188. elif isinstance(msg, DirAuthENDIVEMsg):
  189. self.endive = msg.endive
  190. else:
  191. raise TypeError('Not a server-originating DirAuthNetMsg', msg)
  192. class DirAuth(network.Server):
  193. """The class representing directory authorities."""
  194. # We simulate the act of computing the consensus by keeping a
  195. # class-static dict that's accessible to all of the dirauths
  196. # This dict is indexed by epoch, and the value is itself a dict
  197. # indexed by the stringified descriptor, with value a pair of (the
  198. # number of dirauths that saw that descriptor, the descriptor
  199. # itself).
  200. uploadeddescs = dict()
  201. consensus = None
  202. endive = None
  203. def __init__(self, me, tot):
  204. """Create a new directory authority. me is the index of which
  205. dirauth this one is (starting from 0), and tot is the total
  206. number of dirauths."""
  207. self.me = me
  208. self.tot = tot
  209. self.name = "Dirauth %d of %d" % (me+1, tot)
  210. self.perfstats = PerfStats(EntType.DIRAUTH)
  211. self.perfstats.is_bootstrapping = True
  212. # Create the dirauth signature keypair
  213. self.sigkey = nacl.signing.SigningKey.generate()
  214. self.perfstats.keygens += 1
  215. self.netaddr = network.thenetwork.bind(self)
  216. self.perfstats.name = "DirAuth at %s" % self.netaddr
  217. network.thenetwork.setdirauthkey(me, self.sigkey.verify_key)
  218. network.thenetwork.wantepochticks(self, True, True)
  219. def connected(self, client):
  220. """Callback invoked when a client connects to us. This callback
  221. creates the DirAuthConnection that will be passed to the
  222. client."""
  223. # We don't actually need to keep per-connection state at
  224. # dirauths, even in long-lived connections, so this is
  225. # particularly simple.
  226. return DirAuthConnection(self)
  227. def generate_consensus(self, epoch):
  228. """Generate the consensus (and ENDIVE, if using Walking Onions)
  229. for the given epoch, which should be the one after the one
  230. that's currently about to end."""
  231. threshold = int(self.tot/2)+1
  232. consensusdescs = []
  233. for numseen, desc in DirAuth.uploadeddescs[epoch].values():
  234. if numseen >= threshold:
  235. consensusdescs.append(desc)
  236. DirAuth.consensus = Consensus(epoch, consensusdescs)
  237. def epoch_ending(self, epoch):
  238. # Only dirauth 0 actually needs to generate the consensus
  239. # because of the shared class-static state, but everyone has to
  240. # sign it. Note that this code relies on dirauth 0's
  241. # epoch_ending callback being called before any of the other
  242. # dirauths'.
  243. if (epoch+1) not in DirAuth.uploadeddescs:
  244. DirAuth.uploadeddescs[epoch+1] = dict()
  245. if self.me == 0:
  246. self.generate_consensus(epoch+1)
  247. del DirAuth.uploadeddescs[epoch+1]
  248. DirAuth.consensus.sign(self.sigkey, self.me, self.perfstats)
  249. def received(self, client, msg):
  250. if isinstance(msg, DirAuthUploadDescMsg):
  251. # Check the uploaded descriptor for sanity
  252. epoch = msg.desc.descdict['epoch']
  253. if epoch != network.thenetwork.getepoch() + 1:
  254. return
  255. # Store it in the class-static dict
  256. if epoch not in DirAuth.uploadeddescs:
  257. DirAuth.uploadeddescs[epoch] = dict()
  258. descstr = str(msg.desc)
  259. if descstr not in DirAuth.uploadeddescs[epoch]:
  260. DirAuth.uploadeddescs[epoch][descstr] = (1, msg.desc)
  261. else:
  262. DirAuth.uploadeddescs[epoch][descstr] = \
  263. (DirAuth.uploadeddescs[epoch][descstr][0]+1,
  264. DirAuth.uploadeddescs[epoch][descstr][1])
  265. elif isinstance(msg, DirAuthDelDescMsg):
  266. # Check the uploaded descriptor for sanity
  267. epoch = msg.desc.descdict['epoch']
  268. if epoch != network.thenetwork.getepoch() + 1:
  269. return
  270. # Remove it from the class-static dict
  271. if epoch not in DirAuth.uploadeddescs:
  272. return
  273. descstr = str(msg.desc)
  274. if descstr not in DirAuth.uploadeddescs[epoch]:
  275. return
  276. elif DirAuth.uploadeddescs[epoch][descstr][0] == 1:
  277. del DirAuth.uploadeddescs[epoch][descstr]
  278. else:
  279. DirAuth.uploadeddescs[epoch][descstr] = \
  280. (DirAuth.uploadeddescs[epoch][descstr][0]-1,
  281. DirAuth.uploadeddescs[epoch][descstr][1])
  282. elif isinstance(msg, DirAuthGetConsensusMsg):
  283. client.reply(DirAuthConsensusMsg(DirAuth.consensus))
  284. elif isinstance(msg, DirAuthGetENDIVEMsg):
  285. client.reply(DirAuthENDIVEMsg(DirAuth.endive))
  286. else:
  287. raise TypeError('Not a client-originating DirAuthNetMsg', msg)
  288. def closed(self):
  289. pass
  290. if __name__ == '__main__':
  291. # Start some dirauths
  292. numdirauths = 9
  293. dirauthaddrs = []
  294. for i in range(numdirauths):
  295. dirauth = DirAuth(i, numdirauths)
  296. dirauthaddrs.append(dirauth.netaddr)
  297. for a in dirauthaddrs:
  298. print(a,end=' ')
  299. print()
  300. network.thenetwork.nextepoch()