|
- #!/usr/bin/env python3
- import random
- import pickle
- import logging
- import math
- import bisect
- from enum import Enum
- # Set this to True if you want the bytes sent and received to be added
- # symbolically, in terms of the numbers of each type of network message.
- # You will need sympy installed for this to work.
- symbolic_byte_counters = False
- try:
- import sympy
- except:
- pass
- # Network parameters
- # On average, how large is a consensus diff as compared to a full
- # consensus?
- P_Delta = 0.019
- class WOMode(Enum):
- """The different Walking Onion modes"""
- VANILLA = 0 # No Walking Onions
- TELESCOPING = 1 # Telescoping Walking Onions
- SINGLEPASS = 2 # Single-Pass Walking Onions
- def string_to_type(type_input):
- reprs = {'vanilla': WOMode.VANILLA, 'telescoping': WOMode.TELESCOPING,
- 'single-pass': WOMode.SINGLEPASS }
- if type_input in reprs.keys():
- return reprs[type_input]
- return -1
- class SNIPAuthMode(Enum):
- """The different styles of SNIP authentication"""
- NONE = 0 # No SNIPs; only used for WOMode = VANILLA
- MERKLE = 1 # Merkle trees
- THRESHSIG = 2 # Threshold signatures
- # We only need to differentiate between merkle and telescoping on the
- # command line input, Vanilla always takes a NONE type but nothing else
- # does.
- def string_to_type(type_input):
- reprs = {'merkle': SNIPAuthMode.MERKLE,
- 'telesocping': SNIPAuthMode.THRESHSIG }
- if type_input in reprs.keys():
- return reprs[type_input]
- return -1
- class EntType(Enum):
- """The different types of entities in the system."""
- NONE = 0
- DIRAUTH = 1
- RELAY = 2
- CLIENT = 3
- class PerfStats:
- """A class to store performance statistics for a relay or client.
- We keep track of bytes sent, bytes received, and counts of
- public-key operations of various types. We will reset these every
- epoch."""
- def __init__(self, ent_type, bw=None):
- # Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
- self.ent_type = ent_type
- # A printable name for the entity
- self.name = None
- # The relay bandwidth, if appropriate
- self.bw = bw
- # True if bootstrapping this epoch
- self.is_bootstrapping = False
- self.reset()
- def __str__(self):
- return "%s: type=%s boot=%s sent=%s recv=%s keygen=%d sig=%d verif=%d dh=%d" % \
- (self.name, self.ent_type.name, self.is_bootstrapping, \
- self.bytes_sent, self.bytes_received, self.keygens, \
- self.sigs, self.verifs, self.dhs)
- def reset(self):
- """Reset the counters, typically at the beginning of each
- epoch."""
- # Bytes sent and received
- self.bytes_sent = 0
- self.bytes_received = 0
- # Public-key operations: key generation, signing, verification,
- # Diffie-Hellman
- self.keygens = 0
- self.sigs = 0
- self.verifs = 0
- self.dhs = 0
- class PerfStatsStats:
- """Accumulate a number of PerfStats objects to compute the means and
- stddevs of their fields."""
- class SingleStat:
- """Accumulate single numbers to compute their mean and
- stddev."""
- def __init__(self):
- self.tot = 0
- self.totsq = 0
- self.N = 0
- def accum(self, x):
- self.tot += x
- self.totsq += x*x
- self.N += 1
- def __str__(self):
- mean = self.tot/self.N
- if self.N > 1:
- stddev = math.sqrt((self.totsq - self.tot*self.tot/self.N) \
- / (self.N - 1))
- return "%f \pm %f" % (mean, stddev)
- else:
- return "%f" % mean
- def __init__(self, usebw=False):
- self.usebw = usebw
- self.bytes_sent = PerfStatsStats.SingleStat()
- self.bytes_received = PerfStatsStats.SingleStat()
- self.bytes_tot = PerfStatsStats.SingleStat()
- if self.usebw:
- self.bytesperbw_sent = PerfStatsStats.SingleStat()
- self.bytesperbw_received = PerfStatsStats.SingleStat()
- self.bytesperbw_tot = PerfStatsStats.SingleStat()
- self.keygens = PerfStatsStats.SingleStat()
- self.sigs = PerfStatsStats.SingleStat()
- self.verifs = PerfStatsStats.SingleStat()
- self.dhs = PerfStatsStats.SingleStat()
- self.N = 0
- def accum(self, stat):
- self.bytes_sent.accum(stat.bytes_sent)
- self.bytes_received.accum(stat.bytes_received)
- self.bytes_tot.accum(stat.bytes_sent + stat.bytes_received)
- if self.usebw:
- self.bytesperbw_sent.accum(stat.bytes_sent/stat.bw)
- self.bytesperbw_received.accum(stat.bytes_received/stat.bw)
- self.bytesperbw_tot.accum((stat.bytes_sent + stat.bytes_received)/stat.bw)
- self.keygens.accum(stat.keygens)
- self.sigs.accum(stat.sigs)
- self.verifs.accum(stat.verifs)
- self.dhs.accum(stat.dhs)
- self.N += 1
- def __str__(self):
- if self.N > 0:
- if self.usebw:
- return "sent=%s recv=%s bytes=%s sentperbw=%s recvperbw=%s bytesperbw=%s keygen=%s sig=%s verif=%s dh=%s N=%s" % \
- (self.bytes_sent, self.bytes_received, self.bytes_tot,
- self.bytesperbw_sent, self.bytesperbw_received,
- self.bytesperbw_tot,
- self.keygens, self.sigs, self.verifs, self.dhs, self.N)
- else:
- return "sent=%s recv=%s bytes=%s keygen=%s sig=%s verif=%s dh=%s N=%s" % \
- (self.bytes_sent, self.bytes_received, self.bytes_tot,
- self.keygens, self.sigs, self.verifs, self.dhs, self.N)
- else:
- return "N=0"
- class NetAddr:
- """A class representing a network address"""
- nextaddr = 1
- def __init__(self):
- """Generate a fresh network address"""
- self.addr = NetAddr.nextaddr
- NetAddr.nextaddr += 1
- def __eq__(self, other):
- return (isinstance(other, self.__class__)
- and self.__dict__ == other.__dict__)
- def __hash__(self):
- return hash(self.addr)
- def __str__(self):
- return self.addr.__str__()
- class NetNoServer(Exception):
- """No server is listening on the address someone tried to connect
- to."""
- class Network:
- """A class representing a simulated network. Servers can bind()
- to the network, yielding a NetAddr (network address), and clients
- can connect() to a NetAddr yielding a Connection."""
- def __init__(self):
- self.servers = dict()
- self.epoch = 1
- self.epochprioritycallbacks = []
- self.epochpriorityendingcallbacks = []
- self.epochcallbacks = []
- self.epochendingcallbacks = []
- self.dirauthkeylist = []
- self.fallbackrelays = []
- self.womode = WOMode.VANILLA
- self.snipauthmode = SNIPAuthMode.NONE
- def printservers(self):
- """Print the list of NetAddrs bound to something."""
- print("Servers:")
- for a in self.servers.keys():
- print(a)
- def setdirauthkey(self, index, vk):
- """Set the public verification key for dirauth number index to
- vk."""
- if index >= len(self.dirauthkeylist):
- self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
- self.dirauthkeylist[index] = vk
- def dirauthkeys(self):
- """Return the list of dirauth public verification keys."""
- return self.dirauthkeylist
- def getepoch(self):
- """Return the current epoch."""
- return self.epoch
- def nextepoch(self):
- """Increment the current epoch, and return it."""
- logging.info("Ending epoch %s", self.epoch)
- totendingcallbacks = len(self.epochpriorityendingcallbacks) + \
- len(self.epochendingcallbacks)
- numendingcalled = 0
- lastroundpercent = -1
- for l in [ self.epochpriorityendingcallbacks,
- self.epochendingcallbacks ]:
- for c in l:
- c.epoch_ending(self.epoch)
- numendingcalled += 1
- roundpercent = int(100*numendingcalled/totendingcallbacks)
- if roundpercent != lastroundpercent:
- logging.info("Ending epoch %s %d%% complete",
- self.epoch, roundpercent)
- lastroundpercent = roundpercent
- self.epoch += 1
- logging.info("Starting epoch %s", self.epoch)
- totcallbacks = len(self.epochprioritycallbacks) + \
- len(self.epochcallbacks)
- numcalled = 0
- lastroundpercent = -1
- for l in [ self.epochprioritycallbacks, self.epochcallbacks ]:
- for c in l:
- c.newepoch(self.epoch)
- numcalled += 1
- roundpercent = int(100*numcalled/totcallbacks)
- if roundpercent != lastroundpercent:
- logging.info("Starting epoch %s %d%% complete",
- self.epoch, roundpercent)
- lastroundpercent = roundpercent
- logging.info("Epoch %s started", self.epoch)
- return self.epoch
- def wantepochticks(self, callback, want, priority=False, end=False):
- """Register or deregister an object from receiving epoch change
- callbacks. If want is True, the callback object's newepoch()
- method will be called at each epoch change, with an argument of
- the new epoch. If want if False, the callback object will be
- deregistered. If priority is True, call back this object before
- any object with priority=False. If end is True, the callback
- object's epoch_ending() method will be called instead at the end
- of the epoch, just _before_ the epoch number change."""
- if end:
- if priority:
- l = self.epochpriorityendingcallbacks
- else:
- l = self.epochendingcallbacks
- else:
- if priority:
- l = self.epochprioritycallbacks
- else:
- l = self.epochcallbacks
- if want:
- l.append(callback)
- else:
- l.remove(callback)
- def bind(self, server):
- """Bind a server to a newly generated NetAddr, returning the
- NetAddr. The server's bound() callback will also be invoked."""
- addr = NetAddr()
- self.servers[addr] = server
- server.bound(addr, lambda: self.servers.pop(addr))
- return addr
- def connect(self, client, srvaddr, perfstats):
- """Connect the given client to the server bound to addr. Throw
- an exception if there is no server bound to that address."""
- try:
- server = self.servers[srvaddr]
- except KeyError:
- raise NetNoServer()
- conn = server.connected(client)
- conn.perfstats = perfstats
- return conn
- def setfallbackrelays(self, fallbackrelays):
- """Set the list of globally known fallback relays. Clients use
- these to bootstrap when they know no other relays."""
- self.fallbackrelays = fallbackrelays
- # Construct the CDF of fallback relay bws, so that clients can
- # choose a fallback relay weighted by bw
- self.fallbackbwcdf = [0]
- for r in fallbackrelays:
- self.fallbackbwcdf.append(self.fallbackbwcdf[-1]+r.bw)
- # Remove the last item, which should be the sum of all the
- # relays
- self.fallbacktotbw = self.fallbackbwcdf.pop()
- def getfallbackrelay(self):
- """Get a random one of the globally known fallback relays,
- weighted by bw. Clients use these to bootstrap when they know
- no other relays."""
- idx = random.randint(0, self.fallbacktotbw-1)
- i = bisect.bisect_right(self.fallbackbwcdf, idx)
- r = self.fallbackrelays[i-1]
- return r
- def set_wo_style(self, womode, snipauthmode):
- """Set the Walking Onions mode and the SNIP authenticate mode
- for the network."""
- if ((womode == WOMode.VANILLA) \
- and (snipauthmode != SNIPAuthMode.NONE)) or \
- ((womode != WOMode.VANILLA) and \
- (snipauthmode == SNIPAuthMode.NONE)):
- # Incompatible settings
- raise ValueError("Bad argument combination")
- self.womode = womode
- self.snipauthmode = snipauthmode
- # The singleton instance of Network
- thenetwork = Network()
- class NetMsg:
- """The parent class of network messages. Subclass this class to
- implement specific kinds of network messages."""
- def size(self):
- """Return the size of this network message. For now, just
- pickle it and return the length of that. There's some
- unnecessary overhead in this method; if you want specific
- messages to have more accurate sizes, override this method in
- the subclass. Alternately, if symbolic_byte_counters is set,
- return a symbolic representation of the message size instead, so
- that the total byte counts will clearly show how many of each
- message type were sent and received."""
- if symbolic_byte_counters:
- sz = sympy.symbols(type(self).__name__)
- else:
- sz = len(pickle.dumps(self))
- # logging.info("%s size %d", type(self).__name__, sz)
- return sz
- class StringNetMsg(NetMsg):
- """Send an arbitratry string as a NetMsg."""
- def __init__(self, data):
- self.data = data
- def __str__(self):
- return self.data.__str__()
- class Connection:
- def __init__(self, peer = None):
- """Create a Connection object with the given peer."""
- self.peer = peer
- def closed(self):
- logging.debug("connection closed with %s", self.peer)
- self.peer = None
- def close(self):
- logging.debug("closing connection with %s", self.peer)
- self.peer.closed()
- self.peer = None
- class ClientConnection(Connection):
- """The parent class of client-side network connections. Subclass
- this class to do anything more elaborate than just passing arbitrary
- NetMsgs, which then get ignored. Use subclasses of this class when
- the server required no per-connection state, such as just fetching
- consensus documents."""
- def __init__(self, peer):
- """Create a ClientConnection object with the given peer. The
- peer must have a received(client, msg) method."""
- self.peer = peer
- self.perfstats = None
- def sendmsg(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- msgsize = netmsg.size()
- self.perfstats.bytes_sent += msgsize
- self.peer.received(self, netmsg)
- def reply(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- msgsize = netmsg.size()
- self.perfstats.bytes_received += msgsize
- self.receivedfromserver(netmsg)
- class ServerConnection(Connection):
- """The parent class of server-side network connections."""
- def __init__(self):
- self.peer = None
- def sendmsg(self, netmsg):
- assert(isinstance(netmsg, NetMsg))
- self.peer.received(netmsg)
- def received(self, client, netmsg):
- logging.debug("received %s from client %s", netmsg, client)
- class Server:
- """The parent class of network servers. Subclass this class to
- implement servers of different kinds. You will probably only need
- to override the implementation of connected()."""
- def __init__(self, name):
- self.name = name
- def __str__(self):
- return self.name.__str__()
- def bound(self, netaddr, closer):
- """Callback invoked when the server is successfully bound to a
- NetAddr. The parameters are the netaddr to which the server is
- bound and closer (as in a thing that causes something to close)
- is a callback function that should be used when the server
- wishes to stop listening for new connections."""
- self.closer = closer
- def close(self):
- """Stop listening for new connections."""
- self.closer()
- def connected(self, client):
- """Callback invoked when a client connects to this server.
- This callback must create the Connection object that will be
- returned to the client."""
- logging.debug("server %s connected to client %s", self, client)
- serverconnection = ServerConnection()
- clientconnection = ClientConnection(serverconnection)
- serverconnection.peer = clientconnection
- return clientconnection
- if __name__ == '__main__':
- n1 = NetAddr()
- n2 = NetAddr()
- assert(n1 == n1)
- assert(not (n1 == n2))
- assert(n1 != n2)
- print(n1, n2)
- # Initialize the (non-cryptographic) random seed
- random.seed(1)
- srv = Server("hello world server")
- thenetwork.printservers()
- a = thenetwork.bind(srv)
- thenetwork.printservers()
- print("in main", a)
- perfstats = PerfStats(EntType.NONE)
- conn = thenetwork.connect("hello world client", a, perfstats)
- conn.sendmsg(StringNetMsg("hi"))
- conn.close()
- srv.close()
- thenetwork.printservers()
|