dirauth.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. #!/usr/bin/env python3
  2. import random # For simulation, not cryptography!
  3. import bisect
  4. import nacl.encoding
  5. import nacl.signing
  6. import network
  7. # A relay descriptor is a dict containing:
  8. # epoch: epoch id
  9. # idkey: a public identity key
  10. # onionkey: a public onion key
  11. # addr: a network address
  12. # bw: bandwidth
  13. # flags: relay flags
  14. # vrfkey: a VRF public key (Single-Pass Walking Onions only)
  15. # sig: a signature over the above by the idkey
  16. class RelayDescriptor:
  17. def __init__(self, descdict):
  18. self.descdict = descdict
  19. def __str__(self, withsig = True):
  20. res = "RelayDesc [\n"
  21. for k in ["epoch", "idkey", "onionkey", "addr", "bw", "flags",
  22. "vrfkey", "sig"]:
  23. if k in self.descdict:
  24. if k == "idkey" or k == "onionkey":
  25. res += " " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
  26. elif k == "sig":
  27. if withsig:
  28. res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
  29. else:
  30. res += " " + k + ": " + str(self.descdict[k]) + "\n"
  31. res += "]\n"
  32. return res
  33. def sign(self, signingkey):
  34. serialized = self.__str__(False)
  35. signed = signingkey.sign(serialized.encode("ascii"))
  36. self.descdict["sig"] = signed.signature
  37. @staticmethod
  38. def verify(desc):
  39. assert(type(desc) is RelayDescriptor)
  40. serialized = desc.__str__(False)
  41. desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
  42. # A consensus is a dict containing:
  43. # epoch: epoch id
  44. # numrelays: total number of relays
  45. # totbw: total bandwidth of relays
  46. # relays: list of relay descriptors (Vanilla Onion Routing only)
  47. # sigs: list of signatures from the dirauths
  48. class Consensus:
  49. def __init__(self, epoch, relays):
  50. relays = [ d for d in relays if d.descdict['epoch'] == epoch ]
  51. self.consdict = dict()
  52. self.consdict['epoch'] = epoch
  53. self.consdict['numrelays'] = len(relays)
  54. self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
  55. self.consdict['relays'] = relays
  56. def __str__(self, withsigs = True):
  57. res = "Consensus [\n"
  58. for k in ["epoch", "numrelays", "totbw"]:
  59. if k in self.consdict:
  60. res += " " + k + ": " + str(self.consdict[k]) + "\n"
  61. for r in self.consdict['relays']:
  62. res += str(r)
  63. if withsigs and ('sigs' in self.consdict):
  64. for s in self.consdict['sigs']:
  65. res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
  66. res += "]\n"
  67. return res
  68. def sign(self, signingkey, index):
  69. """Use the given signing key to sign the consensus, storing the
  70. result in the sigs list at the given index."""
  71. serialized = self.__str__(False)
  72. signed = signingkey.sign(serialized.encode("ascii"))
  73. if 'sigs' not in self.consdict:
  74. self.consdict['sigs'] = []
  75. if index >= len(self.consdict['sigs']):
  76. self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
  77. self.consdict['sigs'][index] = signed.signature
  78. def bw_cdf(self):
  79. """Create the array of cumulative bandwidth values from a consensus.
  80. The array (cdf) will have the same length as the number of relays
  81. in the consensus. cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
  82. cdf = [0]
  83. for r in self.consdict['relays']:
  84. cdf.append(cdf[-1]+r.descdict['bw'])
  85. # Remove the last item, which should be the sum of all the bws
  86. cdf.pop()
  87. print('cdf=', cdf)
  88. return cdf
  89. def select_weighted_relay(self, cdf):
  90. """Use the cdf generated by bw_cdf to select a relay with
  91. probability proportional to its bw weight."""
  92. totbw = self.consdict['totbw']
  93. if totbw < 1:
  94. raise ValueError("No relays to choose from")
  95. val = random.randint(0, totbw-1)
  96. # Find the rightmost entry less than or equal to val
  97. idx = bisect.bisect_right(cdf, val)
  98. return self.consdict['relays'][idx-1]
  99. @staticmethod
  100. def verify(consensus, verifkeylist):
  101. """Use the given list of verification keys to check the
  102. signatures on the consensus."""
  103. assert(type(consensus) is Consensus)
  104. serialized = consensus.__str__(False)
  105. for i, vk in enumerate(verifkeylist):
  106. vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
  107. class DirAuthNetMsg(network.NetMsg):
  108. """The subclass of NetMsg for messages to and from directory
  109. authorities."""
  110. class DirAuthUploadDescMsg(DirAuthNetMsg):
  111. """The subclass of DirAuthNetMsg for uploading a relay
  112. descriptor."""
  113. def __init__(self, desc):
  114. self.desc = desc
  115. class DirAuthDelDescMsg(DirAuthNetMsg):
  116. """The subclass of DirAuthNetMsg for deleting a relay
  117. descriptor."""
  118. def __init__(self, desc):
  119. self.desc = desc
  120. class DirAuthGetConsensusMsg(DirAuthNetMsg):
  121. """The subclass of DirAuthNetMsg for fetching the consensus."""
  122. class DirAuthConsensusMsg(DirAuthNetMsg):
  123. """The subclass of DirAuthNetMsg for returning the consensus."""
  124. def __init__(self, consensus):
  125. self.consensus = consensus
  126. class DirAuthGetENDIVEMsg(DirAuthNetMsg):
  127. """The subclass of DirAuthNetMsg for fetching the ENDIVE."""
  128. class DirAuthENDIVEMsg(DirAuthNetMsg):
  129. """The subclass of DirAuthNetMsg for returning the ENDIVE."""
  130. def __init__(self, endive):
  131. self.endive = endive
  132. class DirAuthConnection(network.ClientConnection):
  133. """The subclass of Connection for connections to directory
  134. authorities."""
  135. def __init__(self, peer = None):
  136. super().__init__(peer)
  137. def uploaddesc(self, desc):
  138. """Upload our RelayDescriptor to the DirAuth."""
  139. self.sendmsg(DirAuthUploadDescMeg(desc))
  140. def getconsensus(self):
  141. self.consensus = None
  142. self.sendmsg(DirAuthGetConsensusMsg())
  143. return self.consensus
  144. def getENDIVE(self):
  145. self.endive = None
  146. self.sendmsg(DirAuthGetENDIVEMsg())
  147. return self.endive
  148. def receivedfromserver(self, msg):
  149. if isinstance(msg, DirAuthConsensusMsg):
  150. self.consensus = msg.consensus
  151. elif isinstance(msg, DirAuthENDIVEMsg):
  152. self.endive = msg.endive
  153. else:
  154. raise TypeError('Not a server-originating DirAuthNetMsg', msg)
  155. class DirAuth(network.Server):
  156. """The class representing directory authorities."""
  157. # We simulate the act of computing the consensus by keeping a
  158. # class-static dict that's accessible to all of the dirauths
  159. # This dict is indexed by epoch, and the value is itself a dict
  160. # indexed by the stringified descriptor, with value a pair of (the
  161. # number of dirauths that saw that descriptor, the descriptor
  162. # itself).
  163. uploadeddescs = dict()
  164. consensus = None
  165. endive = None
  166. def __init__(self, me, tot):
  167. """Create a new directory authority. me is the index of which
  168. dirauth this one is (starting from 0), and tot is the total
  169. number of dirauths."""
  170. self.me = me
  171. self.tot = tot
  172. self.name = "Dirauth %d of %d" % (me+1, tot)
  173. # Create the dirauth signature keypair
  174. self.sigkey = nacl.signing.SigningKey.generate()
  175. network.thenetwork.setdirauthkey(me, self.sigkey.verify_key)
  176. network.thenetwork.wantepochticks(self, True, True)
  177. def connected(self, client):
  178. """Callback invoked when a client connects to us. This callback
  179. creates the DirAuthConnection that will be passed to the
  180. client."""
  181. # We don't actually need to keep per-connection state at
  182. # dirauths, even in long-lived connections, so this is
  183. # particularly simple.
  184. return DirAuthConnection(self)
  185. def generate_consensus(self, epoch):
  186. """Generate the consensus (and ENDIVE, if using Walking Onions)
  187. for the given epoch, which should be the one after the one
  188. that's currently about to end."""
  189. threshold = int(self.tot/2)+1
  190. consensusdescs = []
  191. for numseen, desc in DirAuth.uploadeddescs[epoch].values():
  192. if numseen >= threshold:
  193. consensusdescs.append(desc)
  194. DirAuth.consensus = Consensus(epoch, consensusdescs)
  195. def epoch_ending(self, epoch):
  196. # Only dirauth 0 actually needs to generate the consensus
  197. # because of the shared class-static state, but everyone has to
  198. # sign it. Note that this code relies on dirauth 0's
  199. # epoch_ending callback being called before any of the other
  200. # dirauths'.
  201. if self.me == 0:
  202. self.generate_consensus(epoch+1)
  203. del DirAuth.uploadeddescs[epoch+1]
  204. DirAuth.consensus.sign(self.sigkey, self.me)
  205. def received(self, client, msg):
  206. if isinstance(msg, DirAuthUploadDescMsg):
  207. # Check the uploaded descriptor for sanity
  208. epoch = msg.desc.descdict['epoch']
  209. if epoch != network.thenetwork.getepoch() + 1:
  210. return
  211. # Store it in the class-static dict
  212. if epoch not in DirAuth.uploadeddescs:
  213. DirAuth.uploadeddescs[epoch] = dict()
  214. descstr = str(msg.desc)
  215. if descstr not in DirAuth.uploadeddescs[epoch]:
  216. DirAuth.uploadeddescs[epoch][descstr] = (1, msg.desc)
  217. else:
  218. DirAuth.uploadeddescs[epoch][descstr] = \
  219. (DirAuth.uploadeddescs[epoch][descstr][0]+1,
  220. DirAuth.uploadeddescs[epoch][descstr][1])
  221. elif isinstance(msg, DirAuthDelDescMsg):
  222. # Check the uploaded descriptor for sanity
  223. epoch = msg.desc.descdict['epoch']
  224. if epoch != network.thenetwork.getepoch() + 1:
  225. return
  226. # Remove it from the class-static dict
  227. if epoch not in DirAuth.uploadeddescs:
  228. return
  229. descstr = str(msg.desc)
  230. if descstr not in DirAuth.uploadeddescs[epoch]:
  231. return
  232. elif DirAuth.uploadeddescs[epoch][descstr][0] == 1:
  233. del DirAuth.uploadeddescs[epoch][descstr]
  234. else:
  235. DirAuth.uploadeddescs[epoch][descstr] = \
  236. (DirAuth.uploadeddescs[epoch][descstr][0]-1,
  237. DirAuth.uploadeddescs[epoch][descstr][1])
  238. elif isinstance(msg, DirAuthGetConsensusMsg):
  239. client.reply(DirAuthConsensusMsg(DirAuth.consensus))
  240. elif isinstance(msg, DirAuthGetENDIVEMsg):
  241. client.reply(DirAuthENDIVEMsg(DirAuth.endive))
  242. else:
  243. raise TypeError('Not a client-originating DirAuthNetMsg', msg)
  244. def closed(self):
  245. pass
  246. if __name__ == '__main__':
  247. # Start some dirauths
  248. numdirauths = 9
  249. dirauthaddrs = []
  250. for i in range(numdirauths):
  251. dirauth = DirAuth(i, numdirauths)
  252. dirauthaddrs.append(network.thenetwork.bind(dirauth))
  253. for a in dirauthaddrs:
  254. print(a,end=' ')
  255. print()
  256. network.thenetwork.nextepoch()