#!/usr/bin/env python3 import random import pickle import logging import math from enum import Enum # Set this to True if you want the bytes sent and received to be added # symbolically, in terms of the numbers of each type of network message. # You will need sympy installed for this to work. symbolic_byte_counters = False if symbolic_byte_counters: import sympy # Network parameters # On average, how large is a consensus diff as compared to a full # consensus? P_Delta = 0.019 class WOMode(Enum): """The different Walking Onion modes""" VANILLA = 0 # No Walking Onions TELESCOPING = 1 # Telescoping Walking Onions SINGLEPASS = 2 # Single-Pass Walking Onions def string_to_type(type_input): reprs = {'vanilla': WOMode.VANILLA, 'telescoping': WOMode.TELESCOPING, 'single-pass': WOMode.SINGLEPASS } if type_input in reprs.keys(): return reprs[type_input] return -1 class SNIPAuthMode(Enum): """The different styles of SNIP authentication""" NONE = 0 # No SNIPs; only used for WOMode = VANILLA MERKLE = 1 # Merkle trees THRESHSIG = 2 # Threshold signatures # We only need to differentiate between merkle and telescoping on the # command line input, Vanilla always takes a NONE type but nothing else # does. def string_to_type(type_input): reprs = {'merkle': SNIPAuthMode.MERKLE, 'telesocping': SNIPAuthMode.THRESHSIG } if type_input in reprs.keys(): return reprs[type_input] return -1 class EntType(Enum): """The different types of entities in the system.""" NONE = 0 DIRAUTH = 1 RELAY = 2 CLIENT = 3 class PerfStats: """A class to store performance statistics for a relay or client. We keep track of bytes sent, bytes received, and counts of public-key operations of various types. We will reset these every epoch.""" def __init__(self, ent_type): # Which type of entity is this for (DIRAUTH, RELAY, CLIENT) self.ent_type = ent_type # A printable name for the entity self.name = None self.reset() def __str__(self): return "%s: type=%s boot=%s sent=%s recv=%s keygen=%d sig=%d verif=%d dh=%d" % \ (self.name, self.ent_type.name, self.is_bootstrapping, \ self.bytes_sent, self.bytes_received, self.keygens, \ self.sigs, self.verifs, self.dhs) def reset(self): """Reset the counters, typically at the beginning of each epoch.""" # True if bootstrapping this epoch self.is_bootstrapping = False # Bytes sent and received self.bytes_sent = 0 self.bytes_received = 0 # Public-key operations: key generation, signing, verification, # Diffie-Hellman self.keygens = 0 self.sigs = 0 self.verifs = 0 self.dhs = 0 class PerfStatsStats: """Accumulate a number of PerfStats objects to compute the means and stddevs of their fields.""" class SingleStat: """Accumulate single numbers to compute their mean and stddev.""" def __init__(self): self.tot = 0 self.totsq = 0 self.N = 0 def accum(self, x): self.tot += x self.totsq += x*x self.N += 1 def __str__(self): mean = self.tot/self.N stddev = math.sqrt((self.totsq - self.tot*self.tot/self.N) \ / (self.N - 1)) return "%f \pm %f" % (mean, stddev) def __init__(self): self.bytes_sent = PerfStatsStats.SingleStat() self.bytes_received = PerfStatsStats.SingleStat() self.bytes_tot = PerfStatsStats.SingleStat() self.keygens = PerfStatsStats.SingleStat() self.sigs = PerfStatsStats.SingleStat() self.verifs = PerfStatsStats.SingleStat() self.dhs = PerfStatsStats.SingleStat() self.N = 0 def accum(self, stat): self.bytes_sent.accum(stat.bytes_sent) self.bytes_received.accum(stat.bytes_received) self.bytes_tot.accum(stat.bytes_sent + stat.bytes_received) self.keygens.accum(stat.keygens) self.sigs.accum(stat.sigs) self.verifs.accum(stat.verifs) self.dhs.accum(stat.dhs) self.N += 1 def __str__(self): if self.N > 0: return "sent=%s recv=%s bytes=%s keygen=%s sig=%s verif=%s dh=%s N=%s" % \ (self.bytes_sent, self.bytes_received, self.bytes_tot, self.keygens, self.sigs, self.verifs, self.dhs, self.N) else: return "N=0" class NetAddr: """A class representing a network address""" nextaddr = 1 def __init__(self): """Generate a fresh network address""" self.addr = NetAddr.nextaddr NetAddr.nextaddr += 1 def __eq__(self, other): return (isinstance(other, self.__class__) and self.__dict__ == other.__dict__) def __hash__(self): return hash(self.addr) def __str__(self): return self.addr.__str__() class NetNoServer(Exception): """No server is listening on the address someone tried to connect to.""" class Network: """A class representing a simulated network. Servers can bind() to the network, yielding a NetAddr (network address), and clients can connect() to a NetAddr yielding a Connection.""" def __init__(self): self.servers = dict() self.epoch = 1 self.epochcallbacks = [] self.epochendingcallbacks = [] self.dirauthkeylist = [] self.fallbackrelays = [] self.womode = WOMode.VANILLA self.snipauthmode = SNIPAuthMode.NONE def printservers(self): """Print the list of NetAddrs bound to something.""" print("Servers:") for a in self.servers.keys(): print(a) def setdirauthkey(self, index, vk): """Set the public verification key for dirauth number index to vk.""" if index >= len(self.dirauthkeylist): self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist))) self.dirauthkeylist[index] = vk def dirauthkeys(self): """Return the list of dirauth public verification keys.""" return self.dirauthkeylist def getepoch(self): """Return the current epoch.""" return self.epoch def nextepoch(self): """Increment the current epoch, and return it.""" logging.info("Ending epoch %s", self.epoch) totendingcallbacks = len(self.epochendingcallbacks) lastroundpercent = -1 for i, c in enumerate(self.epochendingcallbacks): c.epoch_ending(self.epoch) roundpercent = int(100*(i+1)/totendingcallbacks) if roundpercent != lastroundpercent: logging.info("Ending epoch %s %d%% complete", self.epoch, roundpercent) lastroundpercent = roundpercent self.epoch += 1 logging.info("Starting epoch %s", self.epoch) totcallbacks = len(self.epochcallbacks) lastroundpercent = -1 for i, c in enumerate(self.epochcallbacks): c.newepoch(self.epoch) roundpercent = int(100*(i+1)/totcallbacks) if roundpercent != lastroundpercent: logging.info("Starting epoch %s %d%% complete", self.epoch, roundpercent) lastroundpercent = roundpercent logging.info("Epoch %s started", self.epoch) return self.epoch def wantepochticks(self, callback, want, end=False): """Register or deregister an object from receiving epoch change callbacks. If want is True, the callback object's newepoch() method will be called at each epoch change, with an argument of the new epoch. If want if False, the callback object will be deregistered. If end is True, the callback object's epoch_ending() method will be called instead at the end of the epoch, just _before_ the epoch number change.""" if end: if want: self.epochendingcallbacks.append(callback) else: self.epochendingcallbacks.remove(callback) else: if want: self.epochcallbacks.append(callback) else: self.epochcallbacks.remove(callback) def bind(self, server): """Bind a server to a newly generated NetAddr, returning the NetAddr. The server's bound() callback will also be invoked.""" addr = NetAddr() self.servers[addr] = server server.bound(addr, lambda: self.servers.pop(addr)) return addr def connect(self, client, srvaddr, perfstats): """Connect the given client to the server bound to addr. Throw an exception if there is no server bound to that address.""" try: server = self.servers[srvaddr] except KeyError: raise NetNoServer() conn = server.connected(client) conn.perfstats = perfstats return conn def setfallbackrelays(self, fallbackrelays): """Set the list of globally known fallback relays. Clients use these to bootstrap when they know no other relays.""" self.fallbackrelays = fallbackrelays def getfallbackrelays(self): """Get the list of globally known fallback relays. Clients use these to bootstrap when they know no other relays.""" return self.fallbackrelays def set_wo_style(self, womode, snipauthmode): """Set the Walking Onions mode and the SNIP authenticate mode for the network.""" if ((womode == WOMode.VANILLA) \ and (snipauthmode != SNIPAuthMode.NONE)) or \ ((womode != WOMode.VANILLA) and \ (snipauthmode == SNIPAuthMode.NONE)): # Incompatible settings raise ValueError("Bad argument combination") self.womode = womode self.snipauthmode = snipauthmode # The singleton instance of Network thenetwork = Network() class NetMsg: """The parent class of network messages. Subclass this class to implement specific kinds of network messages.""" def size(self): """Return the size of this network message. For now, just pickle it and return the length of that. There's some unnecessary overhead in this method; if you want specific messages to have more accurate sizes, override this method in the subclass. Alternately, if symbolic_byte_counters is set, return a symbolic representation of the message size instead, so that the total byte counts will clearly show how many of each message type were sent and received.""" if symbolic_byte_counters: sz = sympy.symbols(type(self).__name__) else: sz = len(pickle.dumps(self)) # logging.info("%s size %d", type(self).__name__, sz) return sz class StringNetMsg(NetMsg): """Send an arbitratry string as a NetMsg.""" def __init__(self, data): self.data = data def __str__(self): return self.data.__str__() class Connection: def __init__(self, peer = None): """Create a Connection object with the given peer.""" self.peer = peer def closed(self): logging.debug("connection closed with %s", self.peer) self.peer = None def close(self): logging.debug("closing connection with %s", self.peer) self.peer.closed() self.peer = None class ClientConnection(Connection): """The parent class of client-side network connections. Subclass this class to do anything more elaborate than just passing arbitrary NetMsgs, which then get ignored. Use subclasses of this class when the server required no per-connection state, such as just fetching consensus documents.""" def __init__(self, peer): """Create a ClientConnection object with the given peer. The peer must have a received(client, msg) method.""" self.peer = peer self.perfstats = None def sendmsg(self, netmsg): assert(isinstance(netmsg, NetMsg)) msgsize = netmsg.size() self.perfstats.bytes_sent += msgsize self.peer.received(self, netmsg) def reply(self, netmsg): assert(isinstance(netmsg, NetMsg)) msgsize = netmsg.size() self.perfstats.bytes_received += msgsize self.receivedfromserver(netmsg) class ServerConnection(Connection): """The parent class of server-side network connections.""" def __init__(self): self.peer = None def sendmsg(self, netmsg): assert(isinstance(netmsg, NetMsg)) self.peer.received(netmsg) def received(self, client, netmsg): logging.debug("received %s from client %s", netmsg, client) class Server: """The parent class of network servers. Subclass this class to implement servers of different kinds. You will probably only need to override the implementation of connected().""" def __init__(self, name): self.name = name def __str__(self): return self.name.__str__() def bound(self, netaddr, closer): """Callback invoked when the server is successfully bound to a NetAddr. The parameters are the netaddr to which the server is bound and closer (as in a thing that causes something to close) is a callback function that should be used when the server wishes to stop listening for new connections.""" self.closer = closer def close(self): """Stop listening for new connections.""" self.closer() def connected(self, client): """Callback invoked when a client connects to this server. This callback must create the Connection object that will be returned to the client.""" logging.debug("server %s connected to client %s", self, client) serverconnection = ServerConnection() clientconnection = ClientConnection(serverconnection) serverconnection.peer = clientconnection return clientconnection if __name__ == '__main__': n1 = NetAddr() n2 = NetAddr() assert(n1 == n1) assert(not (n1 == n2)) assert(n1 != n2) print(n1, n2) # Initialize the (non-cryptographic) random seed random.seed(1) srv = Server("hello world server") thenetwork.printservers() a = thenetwork.bind(srv) thenetwork.printservers() print("in main", a) perfstats = PerfStats(EntType.NONE) conn = thenetwork.connect("hello world client", a, perfstats) conn.sendmsg(StringNetMsg("hi")) conn.close() srv.close() thenetwork.printservers()