network.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. #!/usr/bin/env python3
  2. import random
  3. import pickle
  4. import logging
  5. import math
  6. import bisect
  7. from enum import Enum
  8. # Set this to True if you want the bytes sent and received to be added
  9. # symbolically, in terms of the numbers of each type of network message.
  10. # You will need sympy installed for this to work.
  11. symbolic_byte_counters = False
  12. try:
  13. import sympy
  14. except:
  15. pass
  16. # Network parameters
  17. # On average, how large is a consensus diff as compared to a full
  18. # consensus?
  19. P_Delta = 0.019
  20. class WOMode(Enum):
  21. """The different Walking Onion modes"""
  22. VANILLA = 0 # No Walking Onions
  23. TELESCOPING = 1 # Telescoping Walking Onions
  24. SINGLEPASS = 2 # Single-Pass Walking Onions
  25. def string_to_type(type_input):
  26. reprs = {'vanilla': WOMode.VANILLA, 'telescoping': WOMode.TELESCOPING,
  27. 'single-pass': WOMode.SINGLEPASS }
  28. if type_input in reprs.keys():
  29. return reprs[type_input]
  30. return -1
  31. class SNIPAuthMode(Enum):
  32. """The different styles of SNIP authentication"""
  33. NONE = 0 # No SNIPs; only used for WOMode = VANILLA
  34. MERKLE = 1 # Merkle trees
  35. THRESHSIG = 2 # Threshold signatures
  36. # We only need to differentiate between merkle and telescoping on the
  37. # command line input, Vanilla always takes a NONE type but nothing else
  38. # does.
  39. def string_to_type(type_input):
  40. reprs = {'merkle': SNIPAuthMode.MERKLE,
  41. 'telesocping': SNIPAuthMode.THRESHSIG }
  42. if type_input in reprs.keys():
  43. return reprs[type_input]
  44. return -1
  45. class EntType(Enum):
  46. """The different types of entities in the system."""
  47. NONE = 0
  48. DIRAUTH = 1
  49. RELAY = 2
  50. CLIENT = 3
  51. class PerfStats:
  52. """A class to store performance statistics for a relay or client.
  53. We keep track of bytes sent, bytes received, and counts of
  54. public-key operations of various types. We will reset these every
  55. epoch."""
  56. def __init__(self, ent_type, bw=None):
  57. # Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
  58. self.ent_type = ent_type
  59. # A printable name for the entity
  60. self.name = None
  61. # The relay bandwidth, if appropriate
  62. self.bw = bw
  63. # True if bootstrapping this epoch
  64. self.is_bootstrapping = False
  65. self.reset()
  66. def __str__(self):
  67. return "%s: type=%s boot=%s sent=%s recv=%s keygen=%d sig=%d verif=%d dh=%d" % \
  68. (self.name, self.ent_type.name, self.is_bootstrapping, \
  69. self.bytes_sent, self.bytes_received, self.keygens, \
  70. self.sigs, self.verifs, self.dhs)
  71. def reset(self):
  72. """Reset the counters, typically at the beginning of each
  73. epoch."""
  74. # Bytes sent and received
  75. self.bytes_sent = 0
  76. self.bytes_received = 0
  77. # Public-key operations: key generation, signing, verification,
  78. # Diffie-Hellman
  79. self.keygens = 0
  80. self.sigs = 0
  81. self.verifs = 0
  82. self.dhs = 0
  83. class PerfStatsStats:
  84. """Accumulate a number of PerfStats objects to compute the means and
  85. stddevs of their fields."""
  86. class SingleStat:
  87. """Accumulate single numbers to compute their mean and
  88. stddev."""
  89. def __init__(self):
  90. self.tot = 0
  91. self.totsq = 0
  92. self.N = 0
  93. def accum(self, x):
  94. self.tot += x
  95. self.totsq += x*x
  96. self.N += 1
  97. def __str__(self):
  98. mean = self.tot/self.N
  99. if self.N > 1:
  100. stddev = math.sqrt((self.totsq - self.tot*self.tot/self.N) \
  101. / (self.N - 1))
  102. return "%f \pm %f" % (mean, stddev)
  103. else:
  104. return "%f" % mean
  105. def __init__(self, usebw=False):
  106. self.usebw = usebw
  107. self.bytes_sent = PerfStatsStats.SingleStat()
  108. self.bytes_received = PerfStatsStats.SingleStat()
  109. self.bytes_tot = PerfStatsStats.SingleStat()
  110. if self.usebw:
  111. self.bytesperbw_sent = PerfStatsStats.SingleStat()
  112. self.bytesperbw_received = PerfStatsStats.SingleStat()
  113. self.bytesperbw_tot = PerfStatsStats.SingleStat()
  114. self.keygens = PerfStatsStats.SingleStat()
  115. self.sigs = PerfStatsStats.SingleStat()
  116. self.verifs = PerfStatsStats.SingleStat()
  117. self.dhs = PerfStatsStats.SingleStat()
  118. self.N = 0
  119. def accum(self, stat):
  120. self.bytes_sent.accum(stat.bytes_sent)
  121. self.bytes_received.accum(stat.bytes_received)
  122. self.bytes_tot.accum(stat.bytes_sent + stat.bytes_received)
  123. if self.usebw:
  124. self.bytesperbw_sent.accum(stat.bytes_sent/stat.bw)
  125. self.bytesperbw_received.accum(stat.bytes_received/stat.bw)
  126. self.bytesperbw_tot.accum((stat.bytes_sent + stat.bytes_received)/stat.bw)
  127. self.keygens.accum(stat.keygens)
  128. self.sigs.accum(stat.sigs)
  129. self.verifs.accum(stat.verifs)
  130. self.dhs.accum(stat.dhs)
  131. self.N += 1
  132. def __str__(self):
  133. if self.N > 0:
  134. if self.usebw:
  135. return "sent=%s recv=%s bytes=%s sentperbw=%s recvperbw=%s bytesperbw=%s keygen=%s sig=%s verif=%s dh=%s N=%s" % \
  136. (self.bytes_sent, self.bytes_received, self.bytes_tot,
  137. self.bytesperbw_sent, self.bytesperbw_received,
  138. self.bytesperbw_tot,
  139. self.keygens, self.sigs, self.verifs, self.dhs, self.N)
  140. else:
  141. return "sent=%s recv=%s bytes=%s keygen=%s sig=%s verif=%s dh=%s N=%s" % \
  142. (self.bytes_sent, self.bytes_received, self.bytes_tot,
  143. self.keygens, self.sigs, self.verifs, self.dhs, self.N)
  144. else:
  145. return "N=0"
  146. class NetAddr:
  147. """A class representing a network address"""
  148. nextaddr = 1
  149. def __init__(self):
  150. """Generate a fresh network address"""
  151. self.addr = NetAddr.nextaddr
  152. NetAddr.nextaddr += 1
  153. def __eq__(self, other):
  154. return (isinstance(other, self.__class__)
  155. and self.__dict__ == other.__dict__)
  156. def __hash__(self):
  157. return hash(self.addr)
  158. def __str__(self):
  159. return self.addr.__str__()
  160. class NetNoServer(Exception):
  161. """No server is listening on the address someone tried to connect
  162. to."""
  163. class Network:
  164. """A class representing a simulated network. Servers can bind()
  165. to the network, yielding a NetAddr (network address), and clients
  166. can connect() to a NetAddr yielding a Connection."""
  167. def __init__(self):
  168. self.servers = dict()
  169. self.epoch = 1
  170. self.epochprioritycallbacks = []
  171. self.epochpriorityendingcallbacks = []
  172. self.epochcallbacks = []
  173. self.epochendingcallbacks = []
  174. self.dirauthkeylist = []
  175. self.fallbackrelays = []
  176. self.womode = WOMode.VANILLA
  177. self.snipauthmode = SNIPAuthMode.NONE
  178. def printservers(self):
  179. """Print the list of NetAddrs bound to something."""
  180. print("Servers:")
  181. for a in self.servers.keys():
  182. print(a)
  183. def setdirauthkey(self, index, vk):
  184. """Set the public verification key for dirauth number index to
  185. vk."""
  186. if index >= len(self.dirauthkeylist):
  187. self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
  188. self.dirauthkeylist[index] = vk
  189. def dirauthkeys(self):
  190. """Return the list of dirauth public verification keys."""
  191. return self.dirauthkeylist
  192. def getepoch(self):
  193. """Return the current epoch."""
  194. return self.epoch
  195. def nextepoch(self):
  196. """Increment the current epoch, and return it."""
  197. logging.info("Ending epoch %s", self.epoch)
  198. totendingcallbacks = len(self.epochpriorityendingcallbacks) + \
  199. len(self.epochendingcallbacks)
  200. numendingcalled = 0
  201. lastroundpercent = -1
  202. for l in [ self.epochpriorityendingcallbacks,
  203. self.epochendingcallbacks ]:
  204. for c in l:
  205. c.epoch_ending(self.epoch)
  206. numendingcalled += 1
  207. roundpercent = int(100*numendingcalled/totendingcallbacks)
  208. if roundpercent != lastroundpercent:
  209. logging.info("Ending epoch %s %d%% complete",
  210. self.epoch, roundpercent)
  211. lastroundpercent = roundpercent
  212. self.epoch += 1
  213. logging.info("Starting epoch %s", self.epoch)
  214. totcallbacks = len(self.epochprioritycallbacks) + \
  215. len(self.epochcallbacks)
  216. numcalled = 0
  217. lastroundpercent = -1
  218. for l in [ self.epochprioritycallbacks, self.epochcallbacks ]:
  219. for c in l:
  220. c.newepoch(self.epoch)
  221. numcalled += 1
  222. roundpercent = int(100*numcalled/totcallbacks)
  223. if roundpercent != lastroundpercent:
  224. logging.info("Starting epoch %s %d%% complete",
  225. self.epoch, roundpercent)
  226. lastroundpercent = roundpercent
  227. logging.info("Epoch %s started", self.epoch)
  228. return self.epoch
  229. def wantepochticks(self, callback, want, priority=False, end=False):
  230. """Register or deregister an object from receiving epoch change
  231. callbacks. If want is True, the callback object's newepoch()
  232. method will be called at each epoch change, with an argument of
  233. the new epoch. If want if False, the callback object will be
  234. deregistered. If priority is True, call back this object before
  235. any object with priority=False. If end is True, the callback
  236. object's epoch_ending() method will be called instead at the end
  237. of the epoch, just _before_ the epoch number change."""
  238. if end:
  239. if priority:
  240. l = self.epochpriorityendingcallbacks
  241. else:
  242. l = self.epochendingcallbacks
  243. else:
  244. if priority:
  245. l = self.epochprioritycallbacks
  246. else:
  247. l = self.epochcallbacks
  248. if want:
  249. l.append(callback)
  250. else:
  251. l.remove(callback)
  252. def bind(self, server):
  253. """Bind a server to a newly generated NetAddr, returning the
  254. NetAddr. The server's bound() callback will also be invoked."""
  255. addr = NetAddr()
  256. self.servers[addr] = server
  257. server.bound(addr, lambda: self.servers.pop(addr))
  258. return addr
  259. def connect(self, client, srvaddr, perfstats):
  260. """Connect the given client to the server bound to addr. Throw
  261. an exception if there is no server bound to that address."""
  262. try:
  263. server = self.servers[srvaddr]
  264. except KeyError:
  265. raise NetNoServer()
  266. conn = server.connected(client)
  267. conn.perfstats = perfstats
  268. return conn
  269. def setfallbackrelays(self, fallbackrelays):
  270. """Set the list of globally known fallback relays. Clients use
  271. these to bootstrap when they know no other relays."""
  272. self.fallbackrelays = fallbackrelays
  273. # Construct the CDF of fallback relay bws, so that clients can
  274. # choose a fallback relay weighted by bw
  275. self.fallbackbwcdf = [0]
  276. for r in fallbackrelays:
  277. self.fallbackbwcdf.append(self.fallbackbwcdf[-1]+r.bw)
  278. # Remove the last item, which should be the sum of all the
  279. # relays
  280. self.fallbacktotbw = self.fallbackbwcdf.pop()
  281. def getfallbackrelay(self):
  282. """Get a random one of the globally known fallback relays,
  283. weighted by bw. Clients use these to bootstrap when they know
  284. no other relays."""
  285. idx = random.randint(0, self.fallbacktotbw-1)
  286. i = bisect.bisect_right(self.fallbackbwcdf, idx)
  287. r = self.fallbackrelays[i-1]
  288. return r
  289. def set_wo_style(self, womode, snipauthmode):
  290. """Set the Walking Onions mode and the SNIP authenticate mode
  291. for the network."""
  292. if ((womode == WOMode.VANILLA) \
  293. and (snipauthmode != SNIPAuthMode.NONE)) or \
  294. ((womode != WOMode.VANILLA) and \
  295. (snipauthmode == SNIPAuthMode.NONE)):
  296. # Incompatible settings
  297. raise ValueError("Bad argument combination")
  298. self.womode = womode
  299. self.snipauthmode = snipauthmode
  300. # The singleton instance of Network
  301. thenetwork = Network()
  302. class NetMsg:
  303. """The parent class of network messages. Subclass this class to
  304. implement specific kinds of network messages."""
  305. def size(self):
  306. """Return the size of this network message. For now, just
  307. pickle it and return the length of that. There's some
  308. unnecessary overhead in this method; if you want specific
  309. messages to have more accurate sizes, override this method in
  310. the subclass. Alternately, if symbolic_byte_counters is set,
  311. return a symbolic representation of the message size instead, so
  312. that the total byte counts will clearly show how many of each
  313. message type were sent and received."""
  314. if symbolic_byte_counters:
  315. sz = sympy.symbols(type(self).__name__)
  316. else:
  317. sz = len(pickle.dumps(self))
  318. # logging.info("%s size %d", type(self).__name__, sz)
  319. return sz
  320. class StringNetMsg(NetMsg):
  321. """Send an arbitratry string as a NetMsg."""
  322. def __init__(self, data):
  323. self.data = data
  324. def __str__(self):
  325. return self.data.__str__()
  326. class Connection:
  327. def __init__(self, peer = None):
  328. """Create a Connection object with the given peer."""
  329. self.peer = peer
  330. def closed(self):
  331. logging.debug("connection closed with %s", self.peer)
  332. self.peer = None
  333. def close(self):
  334. logging.debug("closing connection with %s", self.peer)
  335. self.peer.closed()
  336. self.peer = None
  337. class ClientConnection(Connection):
  338. """The parent class of client-side network connections. Subclass
  339. this class to do anything more elaborate than just passing arbitrary
  340. NetMsgs, which then get ignored. Use subclasses of this class when
  341. the server required no per-connection state, such as just fetching
  342. consensus documents."""
  343. def __init__(self, peer):
  344. """Create a ClientConnection object with the given peer. The
  345. peer must have a received(client, msg) method."""
  346. self.peer = peer
  347. self.perfstats = None
  348. def sendmsg(self, netmsg):
  349. assert(isinstance(netmsg, NetMsg))
  350. msgsize = netmsg.size()
  351. self.perfstats.bytes_sent += msgsize
  352. self.peer.received(self, netmsg)
  353. def reply(self, netmsg):
  354. assert(isinstance(netmsg, NetMsg))
  355. msgsize = netmsg.size()
  356. self.perfstats.bytes_received += msgsize
  357. self.receivedfromserver(netmsg)
  358. class ServerConnection(Connection):
  359. """The parent class of server-side network connections."""
  360. def __init__(self):
  361. self.peer = None
  362. def sendmsg(self, netmsg):
  363. assert(isinstance(netmsg, NetMsg))
  364. self.peer.received(netmsg)
  365. def received(self, client, netmsg):
  366. logging.debug("received %s from client %s", netmsg, client)
  367. class Server:
  368. """The parent class of network servers. Subclass this class to
  369. implement servers of different kinds. You will probably only need
  370. to override the implementation of connected()."""
  371. def __init__(self, name):
  372. self.name = name
  373. def __str__(self):
  374. return self.name.__str__()
  375. def bound(self, netaddr, closer):
  376. """Callback invoked when the server is successfully bound to a
  377. NetAddr. The parameters are the netaddr to which the server is
  378. bound and closer (as in a thing that causes something to close)
  379. is a callback function that should be used when the server
  380. wishes to stop listening for new connections."""
  381. self.closer = closer
  382. def close(self):
  383. """Stop listening for new connections."""
  384. self.closer()
  385. def connected(self, client):
  386. """Callback invoked when a client connects to this server.
  387. This callback must create the Connection object that will be
  388. returned to the client."""
  389. logging.debug("server %s connected to client %s", self, client)
  390. serverconnection = ServerConnection()
  391. clientconnection = ClientConnection(serverconnection)
  392. serverconnection.peer = clientconnection
  393. return clientconnection
  394. if __name__ == '__main__':
  395. n1 = NetAddr()
  396. n2 = NetAddr()
  397. assert(n1 == n1)
  398. assert(not (n1 == n2))
  399. assert(n1 != n2)
  400. print(n1, n2)
  401. # Initialize the (non-cryptographic) random seed
  402. random.seed(1)
  403. srv = Server("hello world server")
  404. thenetwork.printservers()
  405. a = thenetwork.bind(srv)
  406. thenetwork.printservers()
  407. print("in main", a)
  408. perfstats = PerfStats(EntType.NONE)
  409. conn = thenetwork.connect("hello world client", a, perfstats)
  410. conn.sendmsg(StringNetMsg("hi"))
  411. conn.close()
  412. srv.close()
  413. thenetwork.printservers()