dirauth.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. #!/usr/bin/env python3
  2. import random # For simulation, not cryptography!
  3. import bisect
  4. import math
  5. import nacl.encoding
  6. import nacl.signing
  7. import network
  8. # A relay descriptor is a dict containing:
  9. # epoch: epoch id
  10. # idkey: a public identity key
  11. # onionkey: a public onion key
  12. # addr: a network address
  13. # bw: bandwidth
  14. # flags: relay flags
  15. # vrfkey: a VRF public key (Single-Pass Walking Onions only)
  16. # sig: a signature over the above by the idkey
  17. class RelayDescriptor:
  18. def __init__(self, descdict):
  19. self.descdict = descdict
  20. def __str__(self, withsig = True):
  21. res = "RelayDesc [\n"
  22. for k in ["epoch", "idkey", "onionkey", "addr", "bw", "flags",
  23. "vrfkey", "sig"]:
  24. if k in self.descdict:
  25. if k == "idkey" or k == "onionkey":
  26. res += " " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
  27. elif k == "sig":
  28. if withsig:
  29. res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
  30. else:
  31. res += " " + k + ": " + str(self.descdict[k]) + "\n"
  32. res += "]\n"
  33. return res
  34. def sign(self, signingkey, perfstats):
  35. serialized = self.__str__(False)
  36. signed = signingkey.sign(serialized.encode("ascii"))
  37. perfstats.sigs += 1
  38. self.descdict["sig"] = signed.signature
  39. @staticmethod
  40. def verify(desc, perfstats):
  41. assert(type(desc) is RelayDescriptor)
  42. serialized = desc.__str__(False)
  43. perfstats.verifs += 1
  44. desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
  45. # A SNIP is a dict containing:
  46. # epoch: epoch id
  47. # idkey: a public identity key
  48. # onionkey: a public onion key
  49. # addr: a network address
  50. # flags: relay flags
  51. # vrfkey: a VRF public key (Single-Pass Walking Onions only)
  52. # range: the (lo,hi) values for the index range (lo is inclusive, hi is
  53. # exclusive; that is, x is in the range if lo <= x < hi).
  54. # lo=hi denotes an empty range.
  55. # auth: either a signature from the authorities over the above
  56. # (Threshold signature case) or a Merkle path to the root
  57. # contained in the consensus (Merkle tree case)
  58. #
  59. # Note that the fields of the SNIP are the same as those of the
  60. # RelayDescriptor, except bw and sig are removed, and range and auth are
  61. # added.
  62. class SNIP:
  63. def __init__(self, snipdict):
  64. self.snipdict = snipdict
  65. def __str__(self, withauth = True):
  66. res = "SNIP [\n"
  67. for k in ["epoch", "idkey", "onionkey", "addr", "flags",
  68. "vrfkey", "range", "auth"]:
  69. if k in self.snipdict:
  70. if k == "idkey" or k == "onionkey":
  71. res += " " + k + ": " + self.snipdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
  72. elif k == "auth":
  73. if withauth:
  74. if network.thenetwork.snipauthmode == \
  75. network.SNIPAuthMode.THRESHSIG:
  76. res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
  77. else:
  78. raise NotImplementedError("Merkle auth not yet implemented")
  79. else:
  80. res += " " + k + ": " + str(self.snipdict[k]) + "\n"
  81. res += "]\n"
  82. return res
  83. def auth(self, signingkey, perfstats):
  84. if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG:
  85. serialized = self.__str__(False)
  86. signed = signingkey.sign(serialized.encode("ascii"))
  87. perfstats.sigs += 1
  88. self.snipdict["auth"] = signed.signature
  89. else:
  90. raise NotImplementedError("Merkle auth not yet implemented")
  91. @staticmethod
  92. def verify(snip, consensus, verifykey, perfstats):
  93. if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG:
  94. assert(type(snip) is SNIP and type(consensus) is Consensus)
  95. serialized = snip.__str__(False)
  96. perfstats.verifs += 1
  97. verifykey.verify(serialized.encode("ascii"),
  98. snip.snipdict["auth"])
  99. else:
  100. raise NotImplementedError("Merkle auth not yet implemented")
  101. # A consensus is a dict containing:
  102. # epoch: epoch id
  103. # numrelays: total number of relays
  104. # totbw: total bandwidth of relays
  105. # merkleroot: the root of the SNIP Merkle tree (Merkle tree auth only)
  106. # relays: list of relay descriptors (Vanilla Onion Routing only)
  107. # sigs: list of signatures from the dirauths
  108. class Consensus:
  109. def __init__(self, epoch, relays):
  110. relays = [ d for d in relays if d.descdict['epoch'] == epoch ]
  111. self.consdict = dict()
  112. self.consdict['epoch'] = epoch
  113. self.consdict['numrelays'] = len(relays)
  114. if network.thenetwork.womode == network.WOMode.VANILLA:
  115. self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
  116. self.consdict['relays'] = relays
  117. else:
  118. self.consdict['totbw'] = 1<<32
  119. def __str__(self, withsigs = True):
  120. res = "Consensus [\n"
  121. for k in ["epoch", "numrelays", "totbw"]:
  122. if k in self.consdict:
  123. res += " " + k + ": " + str(self.consdict[k]) + "\n"
  124. if network.thenetwork.womode == network.WOMode.VANILLA:
  125. for r in self.consdict['relays']:
  126. res += str(r)
  127. if withsigs and ('sigs' in self.consdict):
  128. for s in self.consdict['sigs']:
  129. res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
  130. res += "]\n"
  131. return res
  132. def sign(self, signingkey, index, perfstats):
  133. """Use the given signing key to sign the consensus, storing the
  134. result in the sigs list at the given index."""
  135. serialized = self.__str__(False)
  136. signed = signingkey.sign(serialized.encode("ascii"))
  137. perfstats.sigs += 1
  138. if 'sigs' not in self.consdict:
  139. self.consdict['sigs'] = []
  140. if index >= len(self.consdict['sigs']):
  141. self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
  142. self.consdict['sigs'][index] = signed.signature
  143. def bw_cdf(self):
  144. """Create the array of cumulative bandwidth values from a consensus.
  145. The array (cdf) will have the same length as the number of relays
  146. in the consensus. cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
  147. cdf = [0]
  148. for r in self.consdict['relays']:
  149. cdf.append(cdf[-1]+r.descdict['bw'])
  150. # Remove the last item, which should be the sum of all the bws
  151. cdf.pop()
  152. print('cdf=', cdf)
  153. return cdf
  154. def select_weighted_relay(self, cdf):
  155. """Use the cdf generated by bw_cdf to select a relay with
  156. probability proportional to its bw weight."""
  157. totbw = self.consdict['totbw']
  158. if totbw < 1:
  159. raise ValueError("No relays to choose from")
  160. val = random.randint(0, totbw-1)
  161. # Find the rightmost entry less than or equal to val
  162. idx = bisect.bisect_right(cdf, val)
  163. return self.consdict['relays'][idx-1]
  164. @staticmethod
  165. def verify(consensus, verifkeylist, perfstats):
  166. """Use the given list of verification keys to check the
  167. signatures on the consensus."""
  168. assert(type(consensus) is Consensus)
  169. serialized = consensus.__str__(False)
  170. for i, vk in enumerate(verifkeylist):
  171. perfstats.verifs += 1
  172. vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
  173. # An ENDIVE is a dict containing:
  174. # epoch: epoch id
  175. # snips: list of SNIPS (in THRESHSIG mode, these include the auth
  176. # signatures; in MERKLE mode, these do _not_ include auth)
  177. # sigs: list of signatures from the dirauths
  178. class ENDIVE:
  179. def __init__(self, epoch, snips):
  180. snips = [ s for s in snips if s.snipdict['epoch'] == epoch ]
  181. self.enddict = dict()
  182. self.enddict['epoch'] = epoch
  183. self.enddict['snips'] = snips
  184. def __str__(self, withsigs = True):
  185. res = "ENDIVE [\n"
  186. for k in ["epoch"]:
  187. if k in self.enddict:
  188. res += " " + k + ": " + str(self.enddict[k]) + "\n"
  189. for s in self.enddict['snips']:
  190. res += str(s)
  191. if withsigs and ('sigs' in self.enddict):
  192. for s in self.enddict['sigs']:
  193. res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
  194. res += "]\n"
  195. return res
  196. def sign(self, signingkey, index, perfstats):
  197. """Use the given signing key to sign the ENDIVE, storing the
  198. result in the sigs list at the given index."""
  199. serialized = self.__str__(False)
  200. signed = signingkey.sign(serialized.encode("ascii"))
  201. perfstats.sigs += 1
  202. if 'sigs' not in self.enddict:
  203. self.enddict['sigs'] = []
  204. if index >= len(self.enddict['sigs']):
  205. self.enddict['sigs'].extend([None] * (index+1-len(self.enddict['sigs'])))
  206. self.enddict['sigs'][index] = signed.signature
  207. def bw_cdf(self):
  208. """Create the array of cumulative bandwidth values from an ENDIVE.
  209. The array (cdf) will have the same length as the number of relays
  210. in the ENDIVE. cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
  211. cdf = [ s.snipdict['range'][0] for s in self.enddict['snips'] ]
  212. print('ENDIVE cdf=', cdf)
  213. return cdf
  214. def select_snip_by_index(self, i, cdf):
  215. """Use the cdf generated by bw_cdf to select the SNIP for which
  216. i is in the index range. Choose i with
  217. random.randint(0, (1<<32)-1)."""
  218. # Find the rightmost entry less than or equal to i
  219. idx = bisect.bisect_right(cdf, i)
  220. return self.enddict['snips'][idx-1]
  221. @staticmethod
  222. def verify(endive, verifkeylist, perfstats):
  223. """Use the given list of verification keys to check the
  224. signatures on the ENDIVE."""
  225. assert(type(endive) is ENDIVE)
  226. serialized = endive.__str__(False)
  227. for i, vk in enumerate(verifkeylist):
  228. perfstats.verifs += 1
  229. vk.verify(serialized.encode("ascii"), endive.enddict['sigs'][i])
  230. class DirAuthNetMsg(network.NetMsg):
  231. """The subclass of NetMsg for messages to and from directory
  232. authorities."""
  233. class DirAuthUploadDescMsg(DirAuthNetMsg):
  234. """The subclass of DirAuthNetMsg for uploading a relay
  235. descriptor."""
  236. def __init__(self, desc):
  237. self.desc = desc
  238. class DirAuthDelDescMsg(DirAuthNetMsg):
  239. """The subclass of DirAuthNetMsg for deleting a relay
  240. descriptor."""
  241. def __init__(self, desc):
  242. self.desc = desc
  243. class DirAuthGetConsensusMsg(DirAuthNetMsg):
  244. """The subclass of DirAuthNetMsg for fetching the consensus."""
  245. class DirAuthConsensusMsg(DirAuthNetMsg):
  246. """The subclass of DirAuthNetMsg for returning the consensus."""
  247. def __init__(self, consensus):
  248. self.consensus = consensus
  249. class DirAuthGetConsensusDiffMsg(DirAuthNetMsg):
  250. """The subclass of DirAuthNetMsg for fetching the consensus, if the
  251. requestor already has the previous consensus."""
  252. class DirAuthConsensusDiffMsg(DirAuthNetMsg):
  253. """The subclass of DirAuthNetMsg for returning the consensus, if the
  254. requestor already has the previous consensus. We don't _actually_
  255. produce the diff at this time; we just charge fewer bytes for this
  256. message."""
  257. def __init__(self, consensus):
  258. self.consensus = consensus
  259. def size(self):
  260. if network.symbolic_byte_counters:
  261. return super().size()
  262. return math.ceil(DirAuthConsensusMsg(self.consensus).size() \
  263. * network.P_Delta)
  264. class DirAuthGetENDIVEMsg(DirAuthNetMsg):
  265. """The subclass of DirAuthNetMsg for fetching the ENDIVE."""
  266. class DirAuthENDIVEMsg(DirAuthNetMsg):
  267. """The subclass of DirAuthNetMsg for returning the ENDIVE."""
  268. def __init__(self, endive):
  269. self.endive = endive
  270. class DirAuthGetENDIVEDiffMsg(DirAuthNetMsg):
  271. """The subclass of DirAuthNetMsg for fetching the ENDIVE, if the
  272. requestor already has the previous ENDIVE."""
  273. class DirAuthENDIVEDiffMsg(DirAuthNetMsg):
  274. """The subclass of DirAuthNetMsg for returning the ENDIVE, if the
  275. requestor already has the previous consensus. We don't _actually_
  276. produce the diff at this time; we just charge fewer bytes for this
  277. message in Merkle mode. In threshold signature mode, we would still
  278. need to download at least the new signatures for every SNIP in the
  279. ENDIVE, so for now, just assume there's no gain from ENDIVE diffs in
  280. threshold signature mode."""
  281. def __init__(self, endive):
  282. self.endive = endive
  283. def size(self):
  284. if network.symbolic_byte_counters:
  285. return super().size()
  286. if network.thenetwork.snipauthmode == \
  287. network.SNIPAuthMode.THRESHSIG:
  288. return DirAuthENDIVEMsg(self.endive).size()
  289. return math.ceil(DirAuthENDIVEMsg(self.endive).size() \
  290. * network.P_Delta)
  291. class DirAuthConnection(network.ClientConnection):
  292. """The subclass of Connection for connections to directory
  293. authorities."""
  294. def __init__(self, peer):
  295. super().__init__(peer)
  296. def uploaddesc(self, desc):
  297. """Upload our RelayDescriptor to the DirAuth."""
  298. self.sendmsg(DirAuthUploadDescMeg(desc))
  299. def getconsensus(self):
  300. self.consensus = None
  301. self.sendmsg(DirAuthGetConsensusMsg())
  302. return self.consensus
  303. def getconsensusdiff(self):
  304. self.consensus = None
  305. self.sendmsg(DirAuthGetConsensusDiffMsg())
  306. return self.consensus
  307. def getENDIVE(self):
  308. self.endive = None
  309. self.sendmsg(DirAuthGetENDIVEMsg())
  310. return self.endive
  311. def receivedfromserver(self, msg):
  312. if isinstance(msg, DirAuthConsensusMsg):
  313. self.consensus = msg.consensus
  314. elif isinstance(msg, DirAuthConsensusDiffMsg):
  315. self.consensus = msg.consensus
  316. elif isinstance(msg, DirAuthENDIVEMsg):
  317. self.endive = msg.endive
  318. else:
  319. raise TypeError('Not a server-originating DirAuthNetMsg', msg)
  320. class DirAuth(network.Server):
  321. """The class representing directory authorities."""
  322. # We simulate the act of computing the consensus by keeping a
  323. # class-static dict that's accessible to all of the dirauths
  324. # This dict is indexed by epoch, and the value is itself a dict
  325. # indexed by the stringified descriptor, with value a pair of (the
  326. # number of dirauths that saw that descriptor, the descriptor
  327. # itself).
  328. uploadeddescs = dict()
  329. consensus = None
  330. endive = None
  331. def __init__(self, me, tot):
  332. """Create a new directory authority. me is the index of which
  333. dirauth this one is (starting from 0), and tot is the total
  334. number of dirauths."""
  335. self.me = me
  336. self.tot = tot
  337. self.name = "Dirauth %d of %d" % (me+1, tot)
  338. self.perfstats = network.PerfStats(network.EntType.DIRAUTH)
  339. self.perfstats.is_bootstrapping = True
  340. # Create the dirauth signature keypair
  341. self.sigkey = nacl.signing.SigningKey.generate()
  342. self.perfstats.keygens += 1
  343. self.netaddr = network.thenetwork.bind(self)
  344. self.perfstats.name = "DirAuth at %s" % self.netaddr
  345. network.thenetwork.setdirauthkey(me, self.sigkey.verify_key)
  346. network.thenetwork.wantepochticks(self, True, True)
  347. def connected(self, client):
  348. """Callback invoked when a client connects to us. This callback
  349. creates the DirAuthConnection that will be passed to the
  350. client."""
  351. # We don't actually need to keep per-connection state at
  352. # dirauths, even in long-lived connections, so this is
  353. # particularly simple.
  354. return DirAuthConnection(self)
  355. def generate_consensus(self, epoch):
  356. """Generate the consensus (and ENDIVE, if using Walking Onions)
  357. for the given epoch, which should be the one after the one
  358. that's currently about to end."""
  359. threshold = int(self.tot/2)+1
  360. consensusdescs = []
  361. for numseen, desc in DirAuth.uploadeddescs[epoch].values():
  362. if numseen >= threshold:
  363. consensusdescs.append(desc)
  364. DirAuth.consensus = Consensus(epoch, consensusdescs)
  365. if network.thenetwork.womode != network.WOMode.VANILLA:
  366. totbw = sum([ d.descdict["bw"] for d in consensusdescs ])
  367. hi = 0
  368. cumbw = 0
  369. snips = []
  370. for d in consensusdescs:
  371. cumbw += d.descdict["bw"]
  372. lo = hi
  373. hi = int((cumbw<<32)/totbw)
  374. snipdict = dict(d.descdict)
  375. del snipdict["bw"]
  376. snipdict["range"] = (lo,hi)
  377. snips.append(SNIP(snipdict))
  378. DirAuth.endive = ENDIVE(epoch, snips)
  379. def epoch_ending(self, epoch):
  380. # Only dirauth 0 actually needs to generate the consensus
  381. # because of the shared class-static state, but everyone has to
  382. # sign it. Note that this code relies on dirauth 0's
  383. # epoch_ending callback being called before any of the other
  384. # dirauths'.
  385. if (epoch+1) not in DirAuth.uploadeddescs:
  386. DirAuth.uploadeddescs[epoch+1] = dict()
  387. if self.me == 0:
  388. self.generate_consensus(epoch+1)
  389. del DirAuth.uploadeddescs[epoch+1]
  390. if network.thenetwork.snipauthmode == \
  391. network.SNIPAuthMode.THRESHSIG:
  392. for s in DirAuth.endive.enddict['snips']:
  393. s.auth(self.sigkey, self.perfstats)
  394. else:
  395. if network.thenetwork.snipauthmode == \
  396. network.SNIPAuthMode.THRESHSIG:
  397. for s in DirAuth.endive.enddict['snips']:
  398. # We're just simulating threshold sigs by having
  399. # only the first dirauth sign, but in reality each
  400. # dirauth would contribute to the signature (at the
  401. # same cost as each one signing), so we'll charge
  402. # their perfstats as well
  403. self.perfstats.sigs += 1
  404. DirAuth.consensus.sign(self.sigkey, self.me, self.perfstats)
  405. if network.thenetwork.womode != network.WOMode.VANILLA:
  406. DirAuth.endive.sign(self.sigkey, self.me, self.perfstats)
  407. def received(self, client, msg):
  408. self.perfstats.bytes_received += msg.size()
  409. if isinstance(msg, DirAuthUploadDescMsg):
  410. # Check the uploaded descriptor for sanity
  411. epoch = msg.desc.descdict['epoch']
  412. if epoch != network.thenetwork.getepoch() + 1:
  413. return
  414. # Store it in the class-static dict
  415. if epoch not in DirAuth.uploadeddescs:
  416. DirAuth.uploadeddescs[epoch] = dict()
  417. descstr = str(msg.desc)
  418. if descstr not in DirAuth.uploadeddescs[epoch]:
  419. DirAuth.uploadeddescs[epoch][descstr] = (1, msg.desc)
  420. else:
  421. DirAuth.uploadeddescs[epoch][descstr] = \
  422. (DirAuth.uploadeddescs[epoch][descstr][0]+1,
  423. DirAuth.uploadeddescs[epoch][descstr][1])
  424. elif isinstance(msg, DirAuthDelDescMsg):
  425. # Check the uploaded descriptor for sanity
  426. epoch = msg.desc.descdict['epoch']
  427. if epoch != network.thenetwork.getepoch() + 1:
  428. return
  429. # Remove it from the class-static dict
  430. if epoch not in DirAuth.uploadeddescs:
  431. return
  432. descstr = str(msg.desc)
  433. if descstr not in DirAuth.uploadeddescs[epoch]:
  434. return
  435. elif DirAuth.uploadeddescs[epoch][descstr][0] == 1:
  436. del DirAuth.uploadeddescs[epoch][descstr]
  437. else:
  438. DirAuth.uploadeddescs[epoch][descstr] = \
  439. (DirAuth.uploadeddescs[epoch][descstr][0]-1,
  440. DirAuth.uploadeddescs[epoch][descstr][1])
  441. elif isinstance(msg, DirAuthGetConsensusMsg):
  442. replymsg = DirAuthConsensusMsg(DirAuth.consensus)
  443. msgsize = replymsg.size()
  444. self.perfstats.bytes_sent += msgsize
  445. client.reply(replymsg)
  446. elif isinstance(msg, DirAuthGetConsensusDiffMsg):
  447. replymsg = DirAuthConsensusDiffMsg(DirAuth.consensus)
  448. msgsize = replymsg.size()
  449. self.perfstats.bytes_sent += msgsize
  450. client.reply(replymsg)
  451. elif isinstance(msg, DirAuthGetENDIVEMsg):
  452. replymsg = DirAuthENDIVEMsg(DirAuth.endive)
  453. msgsize = replymsg.size()
  454. self.perfstats.bytes_sent += msgsize
  455. client.reply(replymsg)
  456. else:
  457. raise TypeError('Not a client-originating DirAuthNetMsg', msg)
  458. def closed(self):
  459. pass
  460. if __name__ == '__main__':
  461. # Start some dirauths
  462. numdirauths = 9
  463. dirauthaddrs = []
  464. for i in range(numdirauths):
  465. dirauth = DirAuth(i, numdirauths)
  466. dirauthaddrs.append(dirauth.netaddr)
  467. for a in dirauthaddrs:
  468. print(a,end=' ')
  469. print()
  470. network.thenetwork.nextepoch()