network.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. #!/usr/bin/env python3
  2. import random
  3. from enum import Enum
  4. class WOMode(Enum):
  5. """The different Walking Onion modes"""
  6. VANILLA = 0 # No Walking Onions
  7. TELESCOPING = 1 # Telescoping Walking Onions
  8. SINGLEPASS = 2 # Single-Pass Walking Onions
  9. class SNIPAuthMode(Enum):
  10. """The different styles of SNIP authentication"""
  11. NONE = 0 # No SNIPs; only used for WOMode = VANILLA
  12. MERKLE = 1 # Merkle trees
  13. THRESHSIG = 2 # Threshold signatures
  14. class NetAddr:
  15. """A class representing a network address"""
  16. nextaddr = 1
  17. def __init__(self):
  18. """Generate a fresh network address"""
  19. self.addr = NetAddr.nextaddr
  20. NetAddr.nextaddr += 1
  21. def __eq__(self, other):
  22. return (isinstance(other, self.__class__)
  23. and self.__dict__ == other.__dict__)
  24. def __hash__(self):
  25. return hash(self.addr)
  26. def __str__(self):
  27. return self.addr.__str__()
  28. class Network:
  29. """A class representing a simulated network. Servers can bind()
  30. to the network, yielding a NetAddr (network address), and clients
  31. can connect() to a NetAddr yielding a Connection."""
  32. def __init__(self):
  33. self.servers = dict()
  34. self.epoch = 1
  35. self.epochcallbacks = []
  36. self.epochendingcallbacks = []
  37. self.dirauthkeylist = []
  38. self.fallbackrelays = []
  39. self.womode = WOMode.VANILLA
  40. self.snipauthmode = SNIPAuthMode.NONE
  41. def printservers(self):
  42. """Print the list of NetAddrs bound to something."""
  43. print("Servers:")
  44. for a in self.servers.keys():
  45. print(a)
  46. def setdirauthkey(self, index, vk):
  47. """Set the public verification key for dirauth number index to
  48. vk."""
  49. if index >= len(self.dirauthkeylist):
  50. self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
  51. self.dirauthkeylist[index] = vk
  52. def dirauthkeys(self):
  53. """Return the list of dirauth public verification keys."""
  54. return self.dirauthkeylist
  55. def getepoch(self):
  56. """Return the current epoch."""
  57. return self.epoch
  58. def nextepoch(self):
  59. """Increment the current epoch, and return it."""
  60. for c in self.epochendingcallbacks:
  61. c.epoch_ending(self.epoch)
  62. self.epoch += 1
  63. for c in self.epochcallbacks:
  64. c.newepoch(self.epoch)
  65. return self.epoch
  66. def wantepochticks(self, callback, want, end=False):
  67. """Register or deregister an object from receiving epoch change
  68. callbacks. If want is True, the callback object's newepoch()
  69. method will be called at each epoch change, with an argument of
  70. the new epoch. If want if False, the callback object will be
  71. deregistered. If end is True, the callback object's
  72. epoch_ending() method will be called instead at the end of the
  73. epoch, just _before_ the epoch number change."""
  74. if end:
  75. if want:
  76. self.epochendingcallbacks.append(callback)
  77. else:
  78. self.epochendingcallbacks.remove(callback)
  79. else:
  80. if want:
  81. self.epochcallbacks.append(callback)
  82. else:
  83. self.epochcallbacks.remove(callback)
  84. def bind(self, server):
  85. """Bind a server to a newly generated NetAddr, returning the
  86. NetAddr. The server's bound() callback will also be invoked."""
  87. addr = NetAddr()
  88. self.servers[addr] = server
  89. server.bound(addr, lambda: self.servers.pop(addr))
  90. return addr
  91. def connect(self, client, srvaddr):
  92. """Connect the given client to the server bound to addr. Throw
  93. an exception if there is no server bound to that address."""
  94. server = self.servers[srvaddr]
  95. return server.connected(client)
  96. def setfallbackrelays(self, fallbackrelays):
  97. """Set the list of globally known fallback relays. Clients use
  98. these to bootstrap when they know no other relays."""
  99. self.fallbackrelays = fallbackrelays
  100. def getfallbackrelays(self):
  101. """Get the list of globally known fallback relays. Clients use
  102. these to bootstrap when they know no other relays."""
  103. return self.fallbackrelays
  104. def set_wo_style(self, womode, snipauthmode):
  105. """Set the Walking Onions mode and the SNIP authenticate mode
  106. for the network."""
  107. if ((womode == WOMode.VANILLA) \
  108. and (snipauthmode != SNIPAuthMode.NONE)) or \
  109. ((womode != WOMode.VANILLA) and \
  110. (snipauthmode == SNIPAuthMode.NONE)):
  111. # Incompatible settings
  112. raise ValueError("Bad argument combination")
  113. self.womode = womode
  114. self.snipauthmode = snipauthmode
  115. # The singleton instance of Network
  116. thenetwork = Network()
  117. # Initialize the (non-cryptographic) random seed
  118. random.seed(1)
  119. class NetMsg:
  120. """The parent class of network messages. Subclass this class to
  121. implement specific kinds of network messages."""
  122. class StringNetMsg(NetMsg):
  123. """Send an arbitratry string as a NetMsg."""
  124. def __init__(self, str):
  125. self.data = str
  126. def __str__(self):
  127. return self.data.__str__()
  128. class Connection:
  129. def __init__(self, peer = None):
  130. """Create a Connection object with the given peer."""
  131. self.peer = peer
  132. def closed(self):
  133. print("connection closed with", self.peer)
  134. self.peer = None
  135. def close(self):
  136. print("closing connection with", self.peer)
  137. self.peer.closed()
  138. self.peer = None
  139. class ClientConnection(Connection):
  140. """The parent class of client-side network connections. Subclass
  141. this class to do anything more elaborate than just passing arbitrary
  142. NetMsgs, which then get ignored. Use subclasses of this class when
  143. the server required no per-connection state, such as just fetching
  144. consensus documents."""
  145. def __init__(self, peer):
  146. """Create a ClientConnection object with the given peer. The
  147. peer must have a received(client, msg) method."""
  148. self.peer = peer
  149. def sendmsg(self, netmsg):
  150. assert(isinstance(netmsg, NetMsg))
  151. self.peer.received(self, netmsg)
  152. def reply(self, netmsg):
  153. assert(isinstance(netmsg, NetMsg))
  154. self.receivedfromserver(netmsg)
  155. def received(self, netmsg):
  156. print("received", netmsg, "from server")
  157. class ServerConnection(Connection):
  158. """The parent class of server-side network connections."""
  159. def __init__(self):
  160. self.peer = None
  161. def sendmsg(self, netmsg):
  162. assert(isinstance(netmsg, NetMsg))
  163. self.peer.received(netmsg)
  164. def received(self, client, netmsg):
  165. print("received", netmsg, "from client", client)
  166. class Server:
  167. """The parent class of network servers. Subclass this class to
  168. implement servers of different kinds. You will probably only need
  169. to override the implementation of connected()."""
  170. def __init__(self, name):
  171. self.name = name
  172. def __str__(self):
  173. return self.name.__str__()
  174. def bound(self, netaddr, closer):
  175. """Callback invoked when the server is successfully bound to a
  176. NetAddr. The parameters are the netaddr to which the server is
  177. bound and closer (as in a thing that causes something to close)
  178. is a callback function that should be used when the server
  179. wishes to stop listening for new connections."""
  180. self.closer = closer
  181. def close(self):
  182. """Stop listening for new connections."""
  183. self.closer()
  184. def connected(self, client):
  185. """Callback invoked when a client connects to this server.
  186. This callback must create the Connection object that will be
  187. returned to the client."""
  188. print("server", self, "conected to client", client)
  189. serverconnection = ServerConnection()
  190. clientconnection = ClientConnection(serverconnection)
  191. serverconnection.peer = clientconnection
  192. return clientconnection
  193. if __name__ == '__main__':
  194. n1 = NetAddr()
  195. n2 = NetAddr()
  196. assert(n1 == n1)
  197. assert(not (n1 == n2))
  198. assert(n1 != n2)
  199. print(n1, n2)
  200. srv = Server("hello world server")
  201. thenetwork.printservers()
  202. a = thenetwork.bind(srv)
  203. thenetwork.printservers()
  204. print("in main", a)
  205. conn = thenetwork.connect("hello world client", a)
  206. conn.sendmsg(StringNetMsg("hi"))
  207. conn.close()
  208. srv.close()
  209. thenetwork.printservers()