123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429 |
- #!/usr/bin/env python3
- import random
- import pickle
- import logging
- import math
- from enum import Enum
- # Set this to True if you want the bytes sent and received to be added
- # symbolically, in terms of the numbers of each type of network message.
- # You will need sympy installed for this to work.
- symbolic_byte_counters = False
- if symbolic_byte_counters:
- import sympy
- # Network parameters
- # On average, how large is a consensus diff as compared to a full
- # consensus?
- P_Delta = 0.019
- class WOMode(Enum):
- """The different Walking Onion modes"""
- VANILLA = 0 # No Walking Onions
- TELESCOPING = 1 # Telescoping Walking Onions
- SINGLEPASS = 2 # Single-Pass Walking Onions
- def string_to_type(type_input):
- reprs = {'vanilla': WOMode.VANILLA, 'telescoping': WOMode.TELESCOPING,
- 'single-pass': WOMode.SINGLEPASS }
- if type_input in reprs.keys():
- return reprs[type_input]
- return -1
- class SNIPAuthMode(Enum):
- """The different styles of SNIP authentication"""
- NONE = 0 # No SNIPs; only used for WOMode = VANILLA
- MERKLE = 1 # Merkle trees
- THRESHSIG = 2 # Threshold signatures
- class EntType(Enum):
- """The different types of entities in the system."""
- NONE = 0
- DIRAUTH = 1
- RELAY = 2
- CLIENT = 3
- class PerfStats:
- """A class to store performance statistics for a relay or client.
- We keep track of bytes sent, bytes received, and counts of
- public-key operations of various types. We will reset these every
- epoch."""
- def __init__(self, ent_type):
- # Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
- self.ent_type = ent_type
- # A printable name for the entity
- self.name = None
- self.reset()
- def __str__(self):
- return "%s: type=%s boot=%s sent=%s recv=%s keygen=%d sig=%d verif=%d dh=%d" % \
- (self.name, self.ent_type.name, self.is_bootstrapping, \
- self.bytes_sent, self.bytes_received, self.keygens, \
- self.sigs, self.verifs, self.dhs)
- def reset(self):
- """Reset the counters, typically at the beginning of each
- epoch."""
- # True if bootstrapping this epoch
- self.is_bootstrapping = False
- # Bytes sent and received
- self.bytes_sent = 0
- self.bytes_received = 0
- # Public-key operations: key generation, signing, verification,
- # Diffie-Hellman
- self.keygens = 0
- self.sigs = 0
- self.verifs = 0
- self.dhs = 0
- class PerfStatsStats:
- """Accumulate a number of PerfStats objects to compute the means and
- stddevs of their fields."""
- class SingleStat:
- """Accumulate single numbers to compute their mean and
- stddev."""
- def __init__(self):
- self.tot = 0
- self.totsq = 0
- self.N = 0
- def accum(self, x):
- self.tot += x
- self.totsq += x*x
- self.N += 1
- def __str__(self):
- mean = self.tot/self.N
- stddev = math.sqrt((self.totsq - self.tot*self.tot/self.N) \
- / (self.N - 1))
- return "%f \pm %f" % (mean, stddev)
- def __init__(self):
- self.bytes_sent = PerfStatsStats.SingleStat()
- self.bytes_received = PerfStatsStats.SingleStat()
- self.bytes_tot = PerfStatsStats.SingleStat()
- self.keygens = PerfStatsStats.SingleStat()
- self.sigs = PerfStatsStats.SingleStat()
- self.verifs = PerfStatsStats.SingleStat()
- self.dhs = PerfStatsStats.SingleStat()
- self.N = 0
- def accum(self, stat):
- self.bytes_sent.accum(stat.bytes_sent)
- self.bytes_received.accum(stat.bytes_received)
- self.bytes_tot.accum(stat.bytes_sent + stat.bytes_received)
- self.keygens.accum(stat.keygens)
- self.sigs.accum(stat.sigs)
- self.verifs.accum(stat.verifs)
- self.dhs.accum(stat.dhs)
- self.N += 1
- def __str__(self):
- if self.N > 0:
- return "sent=%s recv=%s bytes=%s keygen=%s sig=%s verif=%s dh=%s N=%s" % \
- (self.bytes_sent, self.bytes_received, self.bytes_tot,
- self.keygens, self.sigs, self.verifs, self.dhs, self.N)
- else:
- return "None"
- class NetAddr:
- """A class representing a network address"""
- nextaddr = 1
- def __init__(self):
- """Generate a fresh network address"""
- self.addr = NetAddr.nextaddr
- NetAddr.nextaddr += 1
- def __eq__(self, other):
- return (isinstance(other, self.__class__)
- and self.__dict__ == other.__dict__)
- def __hash__(self):
- return hash(self.addr)
- def __str__(self):
- return self.addr.__str__()
- class NetNoServer(Exception):
- """No server is listening on the address someone tried to connect
- to."""
- class Network:
- """A class representing a simulated network. Servers can bind()
- to the network, yielding a NetAddr (network address), and clients
- can connect() to a NetAddr yielding a Connection."""
- def __init__(self):
- self.servers = dict()
- self.epoch = 1
- self.epochcallbacks = []
- self.epochendingcallbacks = []
- self.dirauthkeylist = []
- self.fallbackrelays = []
- self.womode = WOMode.VANILLA
- self.snipauthmode = SNIPAuthMode.NONE
- def printservers(self):
- """Print the list of NetAddrs bound to something."""
- print("Servers:")
- for a in self.servers.keys():
- print(a)
- def setdirauthkey(self, index, vk):
- """Set the public verification key for dirauth number index to
- vk."""
- if index >= len(self.dirauthkeylist):
- self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
- self.dirauthkeylist[index] = vk
- def dirauthkeys(self):
- """Return the list of dirauth public verification keys."""
- return self.dirauthkeylist
- def getepoch(self):
- """Return the current epoch."""
- return self.epoch
- def nextepoch(self):
- """Increment the current epoch, and return it."""
- logging.info("Ending epoch %s", self.epoch)
- for c in self.epochendingcallbacks:
- c.epoch_ending(self.epoch)
- self.epoch += 1
- logging.info("Starting epoch %s", self.epoch)
- totcallbacks = len(self.epochcallbacks)
- lastroundpercent = -1
- for i, c in enumerate(self.epochcallbacks):
- c.newepoch(self.epoch)
- roundpercent = int(100*(i+1)/totcallbacks)
- if roundpercent != lastroundpercent:
- logging.info("%d%% complete", roundpercent)
- lastroundpercent = roundpercent
- logging.info("Epoch %s started", self.epoch)
- return self.epoch
- def wantepochticks(self, callback, want, end=False):
- """Register or deregister an object from receiving epoch change
- callbacks. If want is True, the callback object's newepoch()
- method will be called at each epoch change, with an argument of
- the new epoch. If want if False, the callback object will be
- deregistered. If end is True, the callback object's
- epoch_ending() method will be called instead at the end of the
- epoch, just _before_ the epoch number change."""
- if end:
- if want:
- self.epochendingcallbacks.append(callback)
- else:
- self.epochendingcallbacks.remove(callback)
- else:
- if want:
- self.epochcallbacks.append(callback)
- else:
- self.epochcallbacks.remove(callback)
- def bind(self, server):
- """Bind a server to a newly generated NetAddr, returning the
- NetAddr. The server's bound() callback will also be invoked."""
- addr = NetAddr()
- self.servers[addr] = server
- server.bound(addr, lambda: self.servers.pop(addr))
- return addr
- def connect(self, client, srvaddr, perfstats):
- """Connect the given client to the server bound to addr. Throw
- an exception if there is no server bound to that address."""
- try:
- server = self.servers[srvaddr]
- except KeyError:
- raise NetNoServer()
- conn = server.connected(client)
- conn.perfstats = perfstats
- return conn
- def setfallbackrelays(self, fallbackrelays):
- """Set the list of globally known fallback relays. Clients use
- these to bootstrap when they know no other relays."""
- self.fallbackrelays = fallbackrelays
- def getfallbackrelays(self):
- """Get the list of globally known fallback relays. Clients use
- these to bootstrap when they know no other relays."""
- return self.fallbackrelays
- def set_wo_style(self, womode, snipauthmode):
- """Set the Walking Onions mode and the SNIP authenticate mode
- for the network."""
- if ((womode == WOMode.VANILLA) \
- and (snipauthmode != SNIPAuthMode.NONE)) or \
- ((womode != WOMode.VANILLA) and \
- (snipauthmode == SNIPAuthMode.NONE)):
- # Incompatible settings
- raise ValueError("Bad argument combination")
- self.womode = womode
- self.snipauthmode = snipauthmode
- # The singleton instance of Network
- thenetwork = Network()
- class NetMsg:
- """The parent class of network messages. Subclass this class to
- implement specific kinds of network messages."""
- def size(self):
- """Return the size of this network message. For now, just
- pickle it and return the length of that. There's some
- unnecessary overhead in this method; if you want specific
- messages to have more accurate sizes, override this method in
- the subclass. Alternately, if symbolic_byte_counters is set,
- return a symbolic representation of the message size instead, so
- that the total byte counts will clearly show how many of each
- message type were sent and received."""
- if symbolic_byte_counters:
- sz = sympy.symbols(type(self).__name__)
- else:
- sz = len(pickle.dumps(self))
- return sz
- class StringNetMsg(NetMsg):
- """Send an arbitratry string as a NetMsg."""
- def __init__(self, data):
- self.data = data
- def __str__(self):
- return self.data.__str__()
- class Connection:
- def __init__(self, peer = None):
- """Create a Connection object with the given peer."""
- self.peer = peer
- def closed(self):
- logging.debug("connection closed with %s", self.peer)
- self.peer = None
- def close(self):
- logging.debug("closing connection with %s", self.peer)
- self.peer.closed()
- self.peer = None
- class ClientConnection(Connection):
- """The parent class of client-side network connections. Subclass
- this class to do anything more elaborate than just passing arbitrary
- NetMsgs, which then get ignored. Use subclasses of this class when
- the server required no per-connection state, such as just fetching
- consensus documents."""
- def __init__(self, peer):
- """Create a ClientConnection object with the given peer. The
- peer must have a received(client, msg) method."""
- self.peer = peer
- self.perfstats = None
- def sendmsg(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- msgsize = netmsg.size()
- self.perfstats.bytes_sent += msgsize
- self.peer.received(self, netmsg)
- def reply(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- msgsize = netmsg.size()
- self.perfstats.bytes_received += msgsize
- self.receivedfromserver(netmsg)
- class ServerConnection(Connection):
- """The parent class of server-side network connections."""
- def __init__(self):
- self.peer = None
- def sendmsg(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- self.peer.received(netmsg)
- def received(self, client, netmsg):
- logging.debug("received %s from client %s", netmsg, client)
- class Server:
- """The parent class of network servers. Subclass this class to
- implement servers of different kinds. You will probably only need
- to override the implementation of connected()."""
- def __init__(self, name):
- self.name = name
- def __str__(self):
- return self.name.__str__()
- def bound(self, netaddr, closer):
- """Callback invoked when the server is successfully bound to a
- NetAddr. The parameters are the netaddr to which the server is
- bound and closer (as in a thing that causes something to close)
- is a callback function that should be used when the server
- wishes to stop listening for new connections."""
- self.closer = closer
- def close(self):
- """Stop listening for new connections."""
- self.closer()
- def connected(self, client):
- """Callback invoked when a client connects to this server.
- This callback must create the Connection object that will be
- returned to the client."""
- logging.debug("server %s connected to client %s", self, client)
- serverconnection = ServerConnection()
- clientconnection = ClientConnection(serverconnection)
- serverconnection.peer = clientconnection
- return clientconnection
- if __name__ == '__main__':
- n1 = NetAddr()
- n2 = NetAddr()
- assert(n1 == n1)
- assert(not (n1 == n2))
- assert(n1 != n2)
- print(n1, n2)
- # Initialize the (non-cryptographic) random seed
- random.seed(1)
- srv = Server("hello world server")
- thenetwork.printservers()
- a = thenetwork.bind(srv)
- thenetwork.printservers()
- print("in main", a)
- perfstats = PerfStats(EntType.NONE)
- conn = thenetwork.connect("hello world client", a, perfstats)
- conn.sendmsg(StringNetMsg("hi"))
- conn.close()
- srv.close()
- thenetwork.printservers()
|