123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- #!/usr/bin/env python3
- class NetAddr:
- """A class representing a network address"""
- nextaddr = 1
- def __init__(self):
- """Generate a fresh network address"""
- self.addr = NetAddr.nextaddr
- NetAddr.nextaddr += 1
- def __eq__(self, other):
- return (isinstance(other, self.__class__)
- and self.__dict__ == other.__dict__)
- def __hash__(self):
- return hash(self.addr)
- def __str__(self):
- return self.addr.__str__()
- class Network:
- """A class representing a simulated network. Servers can bind()
- to the network, yielding a NetAddr (network address), and clients
- can connect() to a NetAddr yielding a Connection."""
- def __init__(self):
- self.servers = dict()
- self.epoch = 1
- self.epochcallbacks = []
- self.epochendingcallbacks = []
- self.dirauthkeylist = []
- def printservers(self):
- """Print the list of NetAddrs bound to something."""
- print("Servers:")
- for a in self.servers.keys():
- print(a)
- def setdirauthkey(self, index, vk):
- """Set the public verification key for dirauth number index to
- vk."""
- if index >= len(self.dirauthkeylist):
- self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
- self.dirauthkeylist[index] = vk
- def dirauthkeys(self):
- """Return the list of dirauth public verification keys."""
- return self.dirauthkeylist
- def getepoch(self):
- """Return the current epoch."""
- return self.epoch
- def nextepoch(self):
- """Increment the current epoch, and return it."""
- for c in self.epochendingcallbacks:
- c.epoch_ending(self.epoch)
- self.epoch += 1
- for c in self.epochcallbacks:
- c.newepoch(self.epoch)
- return self.epoch
- def wantepochticks(self, callback, want, end=False):
- """Register or deregister an object from receiving epoch change
- callbacks. If want is True, the callback object's newepoch()
- method will be called at each epoch change, with an argument of
- the new epoch. If want if False, the callback object will be
- deregistered. If end is True, the callback object's
- epoch_ending() method will be called instead at the end of the
- epoch, just _before_ the epoch number change."""
- if end:
- if want:
- self.epochendingcallbacks.append(callback)
- else:
- self.epochendingcallbacks.remove(callback)
- else:
- if want:
- self.epochcallbacks.append(callback)
- else:
- self.epochcallbacks.remove(callback)
- def bind(self, server):
- """Bind a server to a newly generated NetAddr, returning the
- NetAddr. The server's bound() callback will also be invoked."""
- addr = NetAddr()
- self.servers[addr] = server
- server.bound(addr, lambda: self.servers.pop(addr))
- return addr
- def connect(self, client, srvaddr):
- """Connect the given client to the server bound to addr. Throw
- an exception if there is no server bound to that address."""
- server = self.servers[srvaddr]
- return server.connected(client)
- # The singleton instance of Network
- thenetwork = Network()
- class NetMsg:
- """The parent class of network messages. Subclass this class to
- implement specific kinds of network messages."""
- class StringNetMsg(NetMsg):
- """Send an arbitratry string as a NetMsg."""
- def __init__(self, str):
- self.data = str
- def __str__(self):
- return self.data.__str__()
- class Connection:
- def __init__(self, peer = None):
- """Create a Connection object with the given peer."""
- self.peer = peer
- def closed(self):
- print("connection closed with", self.peer)
- self.peer = None
- def close(self):
- print("closing connection with", self.peer)
- self.peer.closed()
- self.peer = None
- class ClientConnection(Connection):
- """The parent class of client-side network connections. Subclass
- this class to do anything more elaborate than just passing arbitrary
- NetMsgs, which then get ignored."""
-
- def __init__(self, peer):
- """Create a ClientConnection object with the given peer. The
- peer must have a received(client, msg) method."""
- self.peer = peer
- def sendmsg(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- self.peer.received(self, netmsg)
- def received(self, netmsg):
- print("received", netmsg, "from server")
- class ServerConnection(Connection):
- """The parent class of server-side network connections."""
- def __init__(self):
- self.peer = None
- def sendmsg(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- self.peer.received(netmsg)
- def received(self, client, netmsg):
- print("received", netmsg, "from client", client)
- class Server:
- """The parent class of network servers. Subclass this class to
- implement servers of different kinds. You will probably only need
- to override the implementation of connected()."""
- def __init__(self, name):
- self.name = name
- def __str__(self):
- return self.name.__str__()
- def bound(self, netaddr, closer):
- """Callback invoked when the server is successfully bound to a
- NetAddr. The parameters are the netaddr to which the server is
- bound and closer (as in a thing that causes something to close)
- is a callback function that should be used when the server
- wishes to stop listening for new connections."""
- self.closer = closer
- def close(self):
- """Stop listening for new connections."""
- self.closer()
- def connected(self, client):
- """Callback invoked when a client connects to this server.
- This callback must create the Connection object that will be
- returned to the client."""
- print("server", self, "conected to client", client)
- serverconnection = ServerConnection()
- clientconnection = ClientConnection(serverconnection)
- serverconnection.peer = clientconnection
- return clientconnection
- if __name__ == '__main__':
- n1 = NetAddr()
- n2 = NetAddr()
- assert(n1 == n1)
- assert(not (n1 == n2))
- assert(n1 != n2)
- print(n1, n2)
- srv = Server("hello world server")
- thenetwork.printservers()
- a = thenetwork.bind(srv)
- thenetwork.printservers()
- print("in main", a)
- conn = thenetwork.connect("hello world client", a)
- conn.sendmsg(StringNetMsg("hi"))
- conn.close()
- srv.close()
- thenetwork.printservers()
|