network.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. #!/usr/bin/env python3
  2. import random
  3. import pickle
  4. from enum import Enum
  5. class WOMode(Enum):
  6. """The different Walking Onion modes"""
  7. VANILLA = 0 # No Walking Onions
  8. TELESCOPING = 1 # Telescoping Walking Onions
  9. SINGLEPASS = 2 # Single-Pass Walking Onions
  10. class SNIPAuthMode(Enum):
  11. """The different styles of SNIP authentication"""
  12. NONE = 0 # No SNIPs; only used for WOMode = VANILLA
  13. MERKLE = 1 # Merkle trees
  14. THRESHSIG = 2 # Threshold signatures
  15. class EntType(Enum):
  16. """The different types of entities in the system."""
  17. NONE = 0
  18. DIRAUTH = 1
  19. RELAY = 2
  20. CLIENT = 3
  21. class PerfStats:
  22. """A class to store performance statistics for a relay or client.
  23. We keep track of bytes sent, bytes received, and counts of
  24. public-key operations of various types. We will reset these every
  25. epoch."""
  26. def __init__(self, ent_type):
  27. # Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
  28. self.ent_type = ent_type
  29. # A printable name for the entity
  30. self.name = None
  31. # True if bootstrapping this epoch
  32. self.is_bootstrapping = False
  33. # Bytes sent and received
  34. self.bytes_sent = 0
  35. self.bytes_received = 0
  36. # Public-key operations: key generation, signing, verification,
  37. # Diffie-Hellman
  38. self.keygens = 0
  39. self.sigs = 0
  40. self.verifs = 0
  41. self.dhs = 0
  42. def __str__(self):
  43. return "%s: type=%s boot=%s sent=%d recv=%d keygen=%d sig=%d verif=%d dh=%d" % \
  44. (self.name, self.ent_type.name, self.is_bootstrapping, \
  45. self.bytes_sent, self.bytes_received, self.keygens, \
  46. self.sigs, self.verifs, self.dhs)
  47. class NetAddr:
  48. """A class representing a network address"""
  49. nextaddr = 1
  50. def __init__(self):
  51. """Generate a fresh network address"""
  52. self.addr = NetAddr.nextaddr
  53. NetAddr.nextaddr += 1
  54. def __eq__(self, other):
  55. return (isinstance(other, self.__class__)
  56. and self.__dict__ == other.__dict__)
  57. def __hash__(self):
  58. return hash(self.addr)
  59. def __str__(self):
  60. return self.addr.__str__()
  61. class NetNoServer(Exception):
  62. """No server is listening on the address someone tried to connect
  63. to."""
  64. class Network:
  65. """A class representing a simulated network. Servers can bind()
  66. to the network, yielding a NetAddr (network address), and clients
  67. can connect() to a NetAddr yielding a Connection."""
  68. def __init__(self):
  69. self.servers = dict()
  70. self.epoch = 1
  71. self.epochcallbacks = []
  72. self.epochendingcallbacks = []
  73. self.dirauthkeylist = []
  74. self.fallbackrelays = []
  75. self.womode = WOMode.VANILLA
  76. self.snipauthmode = SNIPAuthMode.NONE
  77. def printservers(self):
  78. """Print the list of NetAddrs bound to something."""
  79. print("Servers:")
  80. for a in self.servers.keys():
  81. print(a)
  82. def setdirauthkey(self, index, vk):
  83. """Set the public verification key for dirauth number index to
  84. vk."""
  85. if index >= len(self.dirauthkeylist):
  86. self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
  87. self.dirauthkeylist[index] = vk
  88. def dirauthkeys(self):
  89. """Return the list of dirauth public verification keys."""
  90. return self.dirauthkeylist
  91. def getepoch(self):
  92. """Return the current epoch."""
  93. return self.epoch
  94. def nextepoch(self):
  95. """Increment the current epoch, and return it."""
  96. for c in self.epochendingcallbacks:
  97. c.epoch_ending(self.epoch)
  98. self.epoch += 1
  99. for c in self.epochcallbacks:
  100. c.newepoch(self.epoch)
  101. return self.epoch
  102. def wantepochticks(self, callback, want, end=False):
  103. """Register or deregister an object from receiving epoch change
  104. callbacks. If want is True, the callback object's newepoch()
  105. method will be called at each epoch change, with an argument of
  106. the new epoch. If want if False, the callback object will be
  107. deregistered. If end is True, the callback object's
  108. epoch_ending() method will be called instead at the end of the
  109. epoch, just _before_ the epoch number change."""
  110. if end:
  111. if want:
  112. self.epochendingcallbacks.append(callback)
  113. else:
  114. self.epochendingcallbacks.remove(callback)
  115. else:
  116. if want:
  117. self.epochcallbacks.append(callback)
  118. else:
  119. self.epochcallbacks.remove(callback)
  120. def bind(self, server):
  121. """Bind a server to a newly generated NetAddr, returning the
  122. NetAddr. The server's bound() callback will also be invoked."""
  123. addr = NetAddr()
  124. self.servers[addr] = server
  125. server.bound(addr, lambda: self.servers.pop(addr))
  126. return addr
  127. def connect(self, client, srvaddr, perfstats):
  128. """Connect the given client to the server bound to addr. Throw
  129. an exception if there is no server bound to that address."""
  130. try:
  131. server = self.servers[srvaddr]
  132. except KeyError:
  133. raise NetNoServer()
  134. conn = server.connected(client)
  135. conn.perfstats = perfstats
  136. return conn
  137. def setfallbackrelays(self, fallbackrelays):
  138. """Set the list of globally known fallback relays. Clients use
  139. these to bootstrap when they know no other relays."""
  140. self.fallbackrelays = fallbackrelays
  141. def getfallbackrelays(self):
  142. """Get the list of globally known fallback relays. Clients use
  143. these to bootstrap when they know no other relays."""
  144. return self.fallbackrelays
  145. def set_wo_style(self, womode, snipauthmode):
  146. """Set the Walking Onions mode and the SNIP authenticate mode
  147. for the network."""
  148. if ((womode == WOMode.VANILLA) \
  149. and (snipauthmode != SNIPAuthMode.NONE)) or \
  150. ((womode != WOMode.VANILLA) and \
  151. (snipauthmode == SNIPAuthMode.NONE)):
  152. # Incompatible settings
  153. raise ValueError("Bad argument combination")
  154. self.womode = womode
  155. self.snipauthmode = snipauthmode
  156. # The singleton instance of Network
  157. thenetwork = Network()
  158. # Initialize the (non-cryptographic) random seed
  159. random.seed(1)
  160. class NetMsg:
  161. """The parent class of network messages. Subclass this class to
  162. implement specific kinds of network messages."""
  163. def size(self):
  164. """Return the size of this network message. For now, just
  165. pickle it and return the length of that. There's some
  166. unnecessary overhead in this method; if you want specific
  167. messages to have more accurate sizes, override this method in
  168. the subclass."""
  169. sz = len(pickle.dumps(self))
  170. print('size',sz,type(self))
  171. return sz
  172. class StringNetMsg(NetMsg):
  173. """Send an arbitratry string as a NetMsg."""
  174. def __init__(self, str):
  175. self.data = str
  176. def __str__(self):
  177. return self.data.__str__()
  178. class Connection:
  179. def __init__(self, peer = None):
  180. """Create a Connection object with the given peer."""
  181. self.peer = peer
  182. def closed(self):
  183. print("connection closed with", self.peer)
  184. self.peer = None
  185. def close(self):
  186. print("closing connection with", self.peer)
  187. self.peer.closed()
  188. self.peer = None
  189. class ClientConnection(Connection):
  190. """The parent class of client-side network connections. Subclass
  191. this class to do anything more elaborate than just passing arbitrary
  192. NetMsgs, which then get ignored. Use subclasses of this class when
  193. the server required no per-connection state, such as just fetching
  194. consensus documents."""
  195. def __init__(self, peer):
  196. """Create a ClientConnection object with the given peer. The
  197. peer must have a received(client, msg) method."""
  198. self.peer = peer
  199. self.perfstats = None
  200. def sendmsg(self, netmsg):
  201. assert(isinstance(netmsg, NetMsg))
  202. msgsize = netmsg.size()
  203. self.perfstats.bytes_sent += msgsize
  204. self.peer.received(self, netmsg)
  205. def reply(self, netmsg):
  206. assert(isinstance(netmsg, NetMsg))
  207. msgsize = netmsg.size()
  208. self.perfstats.bytes_received += msgsize
  209. self.receivedfromserver(netmsg)
  210. class ServerConnection(Connection):
  211. """The parent class of server-side network connections."""
  212. def __init__(self):
  213. self.peer = None
  214. def sendmsg(self, netmsg):
  215. assert(isinstance(netmsg, NetMsg))
  216. self.peer.received(netmsg)
  217. def received(self, client, netmsg):
  218. print("received", netmsg, "from client", client)
  219. class Server:
  220. """The parent class of network servers. Subclass this class to
  221. implement servers of different kinds. You will probably only need
  222. to override the implementation of connected()."""
  223. def __init__(self, name):
  224. self.name = name
  225. def __str__(self):
  226. return self.name.__str__()
  227. def bound(self, netaddr, closer):
  228. """Callback invoked when the server is successfully bound to a
  229. NetAddr. The parameters are the netaddr to which the server is
  230. bound and closer (as in a thing that causes something to close)
  231. is a callback function that should be used when the server
  232. wishes to stop listening for new connections."""
  233. self.closer = closer
  234. def close(self):
  235. """Stop listening for new connections."""
  236. self.closer()
  237. def connected(self, client):
  238. """Callback invoked when a client connects to this server.
  239. This callback must create the Connection object that will be
  240. returned to the client."""
  241. print("server", self, "conected to client", client)
  242. serverconnection = ServerConnection()
  243. clientconnection = ClientConnection(serverconnection)
  244. serverconnection.peer = clientconnection
  245. return clientconnection
  246. if __name__ == '__main__':
  247. n1 = NetAddr()
  248. n2 = NetAddr()
  249. assert(n1 == n1)
  250. assert(not (n1 == n2))
  251. assert(n1 != n2)
  252. print(n1, n2)
  253. srv = Server("hello world server")
  254. thenetwork.printservers()
  255. a = thenetwork.bind(srv)
  256. thenetwork.printservers()
  257. print("in main", a)
  258. perfstats = PerfStats(EntType.NONE)
  259. conn = thenetwork.connect("hello world client", a, perfstats)
  260. conn.sendmsg(StringNetMsg("hi"))
  261. conn.close()
  262. srv.close()
  263. thenetwork.printservers()