#!/usr/bin/env python3

import random # For simulation, not cryptography!
import math
import sys
import os
import logging
import resource

import network
import dirauth
import relay
import client

class Simulator:
    def __init__(self, relaytarget, clienttarget, statslogger):
        self.relaytarget = relaytarget
        self.clienttarget = clienttarget
        self.statslogger = statslogger

        # Some (for now) hard-coded parameters

        # The number of directory authorities
        numdirauths = 9

        # The fraction of relays that are fallback relays
        # Taken from the live network in Jan 2020
        fracfallbackrelays = 0.023

        # Mean number of circuits created per client per epoch
        self.gamma = 8.9

        # Churn is controlled by three parameters:
        # newmean: the mean number of new arrivals per epoch
        # newstddev: the stddev number of new arrivals per epoch
        # oldprob: the probability any given existing one leaves per epoch

        # If target is the desired steady state number, then it should
        # be the case that target * oldprob = newmean.  That way, if the
        # current number is below target, on average you add more than
        # you remove, and if the current number is above target, on
        # average you add fewer than you remove.

        # For relays, looking at all the consensuses for Nov and Dec
        # 2019, newmean is about 1.0% of the network size, and newstddev
        # is about 0.3% of the network size.
        self.relay_newmean = 0.010 * self.relaytarget
        self.relay_newstddev = 0.003 * self.relaytarget
        self.relay_oldprob = 0.010

        # For clients, looking at how many clients request a consensus
        # with an if-modified-since date more than 3 hours old (and so
        # we treat them as "new") over several days in late Dec 2019,
        # newmean is about 16% of all clients, and newstddev is about 4%
        # of all clients.

        # if the environment variable WOSIM_CLIENT_CHURN is set to 0,
        # don't churn clients at all.  This allows us to see the effect
        # of client churn on relay bandwidth.
        if os.getenv('WOSIM_CLIENT_CHURN', '1') == '0':
            self.client_newmean = 0
            self.client_newstddev = 0
            self.client_oldprob = 0
        else:
            self.client_newmean = 0.16 * self.clienttarget
            self.client_newstddev = 0.04 * self.clienttarget
            self.client_oldprob = 0.16

        # Start some dirauths
        self.dirauthaddrs = []
        self.dirauths = []
        for i in range(numdirauths):
            dira = dirauth.DirAuth(i, numdirauths)
            self.dirauths.append(dira)
            self.dirauthaddrs.append(dira.netaddr)

        # Start some relays
        self.relays = []
        for i in range(self.relaytarget):
            # Relay bandwidths (at least the ones fast enough to get used)
            # in the live Tor network (as of Dec 2019) are well approximated
            # by (200000-(200000-25000)/3*log10(x)) where x is a
            # uniform integer in [1,2500]
            x = random.randint(1,2500)
            bw = int(200000-(200000-25000)/3*math.log10(x))
            self.relays.append(relay.Relay(self.dirauthaddrs, bw, 0))

        # The fallback relays are a hardcoded list of a small fraction
        # of the relays, used by clients for bootstrapping
        numfallbackrelays = int(self.relaytarget * fracfallbackrelays) + 1
        fallbackrelays = random.sample(self.relays, numfallbackrelays)
        for r in fallbackrelays:
            r.set_is_fallbackrelay()
        network.thenetwork.setfallbackrelays(fallbackrelays)

        # Tick the epoch to build the first consensus
        network.thenetwork.nextepoch()

        # Start some clients
        self.clients = []
        for i in range(clienttarget):
            self.clients.append(client.Client(self.dirauthaddrs))

        # Throw away all the performance statistics to this point
        for d in self.dirauths: d.perfstats.reset()
        for r in self.relays: r.perfstats.reset()
        # The clients' stats are already at 0, but they have the
        # "bootstrapping" flag set, which we want to keep, so we
        # won't reset them.

        # Tick the epoch to bootstrap the clients
        network.thenetwork.nextepoch()

    def one_epoch(self):
        """Simulate one epoch."""

        epoch = network.thenetwork.getepoch()

        # Each client will start a random number of circuits in a
        # Poisson distribution with mean gamma.  To randomize the order
        # of the clients creating each circuit, we actually use a
        # Poisson distribution with mean (gamma*num_clients), and assign
        # each event to a uniformly random client.  (This does in fact
        # give the required distribution.)

        numclients = len(self.clients)

        # simtime is the simulated time, measured in epochs (i.e.,
        # 0=start of this epoch; 1=end of this epoch)
        simtime = 0
        numcircs = 0

        allcircs = []

        lastpercent = -1
        while simtime < 1.0:
            try:
                allcircs.append(
                    random.choice(self.clients).channelmgr.new_circuit())
            except ValueError as e:
                self.statslogger.error(str(e))
                raise e

            simtime += random.expovariate(self.gamma * numclients)
            numcircs += 1
            percent = int(100*simtime)
            #if percent != lastpercent:
            if numcircs % 100 == 0:
                logging.info("Creating circuits in epoch %s: %d%% (%d circuits)",
                        epoch, percent, numcircs)
                lastpercent = percent

        # gather stats
        totsent = 0
        totrecv = 0
        dirasent = 0
        dirarecv = 0
        relaysent = 0
        relayrecv = 0
        clisent = 0
        clirecv = 0
        dirastats = network.PerfStatsStats()
        for d in self.dirauths:
            logging.debug("%s", d.perfstats)
            dirasent += d.perfstats.bytes_sent
            dirarecv += d.perfstats.bytes_received
            dirastats.accum(d.perfstats)
        totsent += dirasent
        totrecv += dirarecv
        relaystats = network.PerfStatsStats(True)
        relaybstats = network.PerfStatsStats(True)
        relaynbstats = network.PerfStatsStats(True)
        relayfbstats = network.PerfStatsStats(True)
        for r in self.relays:
            logging.debug("%s", r.perfstats)
            relaysent += r.perfstats.bytes_sent
            relayrecv += r.perfstats.bytes_received
            relaystats.accum(r.perfstats)
            if r.is_fallbackrelay:
                relayfbstats.accum(r.perfstats)
            else:
                if r.perfstats.is_bootstrapping:
                    relaybstats.accum(r.perfstats)
                else:
                    relaynbstats.accum(r.perfstats)
        totsent += relaysent
        totrecv += relayrecv
        clistats = network.PerfStatsStats()
        clibstats = network.PerfStatsStats()
        clinbstats = network.PerfStatsStats()
        for c in self.clients:
            logging.debug("%s", c.perfstats)
            clisent += c.perfstats.bytes_sent
            clirecv += c.perfstats.bytes_received
            clistats.accum(c.perfstats)
            if c.perfstats.is_bootstrapping:
                clibstats.accum(c.perfstats)
            else:
                clinbstats.accum(c.perfstats)
        totsent += clisent
        totrecv += clirecv
        self.statslogger.info("DirAuths sent=%s recv=%s bytes=%s" % \
                (dirasent, dirarecv, dirasent+dirarecv))
        self.statslogger.info("Relays sent=%s recv=%s bytes=%s" % \
                (relaysent, relayrecv, relaysent+relayrecv))
        self.statslogger.info("Client sent=%s recv=%s bytes=%s" % \
                (clisent, clirecv, clisent+clirecv))
        self.statslogger.info("Total sent=%s recv=%s bytes=%s" % \
                (totsent, totrecv, totsent+totrecv))
        numdirauths = len(self.dirauths)
        numrelays = len(self.relays)
        numclients = len(self.clients)
        self.statslogger.info("Dirauths %s", dirastats)
        self.statslogger.info("Relays %s", relaystats)
        self.statslogger.info("Relays(FB) %s", relayfbstats)
        self.statslogger.info("Relays(B) %s", relaybstats)
        self.statslogger.info("Relays(NB) %s", relaynbstats)
        self.statslogger.info("Clients %s", clistats)
        self.statslogger.info("Clients(B) %s", clibstats)
        self.statslogger.info("Clients(NB) %s", clinbstats)

        # Close circuits
        for c in allcircs:
            c.close()

        # Clear bootstrapping flag
        for d in self.dirauths: d.perfstats.is_bootstrapping = False
        for r in self.relays: r.perfstats.is_bootstrapping = False
        for c in self.clients: c.perfstats.is_bootstrapping = False

        # Churn relays

        # Stop some of the (non-fallback) relays
        relays_remaining = []
        numrelays = len(self.relays)
        numrelaysterminated = 0
        lastpercent = 0
        logging.info("Terminating some relays")
        for i, r in enumerate(self.relays):
            percent = int(100*(i+1)/numrelays)
            if not r.is_fallbackrelay and \
                    random.random() < self.relay_oldprob:
                r.terminate()
                numrelaysterminated += 1
            else:
                # Keep this relay
                relays_remaining.append(r)
            if percent != lastpercent:
                lastpercent = percent
                logging.info("%d%% relays considered, %d terminated",
                        percent, numrelaysterminated)
        self.relays = relays_remaining

        # Start some new relays
        relays_new = int(random.normalvariate(self.relay_newmean,
                self.relay_newstddev))
        logging.info("Starting %d new relays", relays_new)
        if relays_new > 0:
            for i in range(relays_new):
                x = random.randint(1,2500)
                bw = int(200000-(200000-25000)/3*math.log10(x))
                self.relays.append(relay.Relay(self.dirauthaddrs, bw, 0))

        # churn clients

        if self.client_oldprob > 0:
            # Stop some of the clients
            clients_remaining = []
            numclients = len(self.clients)
            numclientsterminated = 0
            lastpercent = 0
            logging.info("Terminating some clients")
            for i, c in enumerate(self.clients):
                percent = int(100*(i+1)/numclients)
                if random.random() < self.client_oldprob:
                    c.terminate()
                    numclientsterminated += 1
                else:
                    # Keep this client
                    clients_remaining.append(c)
                if percent != lastpercent:
                    lastpercent = percent
                    logging.info("%d%% clients considered, %d terminated",
                            percent, numclientsterminated)
            self.clients = clients_remaining

            # Start some new clients
            clients_new = int(random.normalvariate(self.client_newmean,
                    self.client_newstddev))
            logging.info("Starting %d new clients", clients_new)
            if clients_new > 0:
                for i in range(clients_new):
                    self.clients.append(client.Client(self.dirauthaddrs))

        # Reset stats
        for d in self.dirauths: d.perfstats.reset()
        for r in self.relays: r.perfstats.reset()
        for c in self.clients: c.perfstats.reset()

        # Tick the epoch
        network.thenetwork.nextepoch()


if __name__ == '__main__':
    # Args: womode snipauthmode networkscale numepochs randseed logdir
    if len(sys.argv) != 7:
        sys.stderr.write("Usage: womode snipauthmode networkscale numepochs randseed logdir\n")
        sys.exit(1)

    womode = network.WOMode[sys.argv[1].upper()]
    snipauthmode = network.SNIPAuthMode[sys.argv[2].upper()]
    networkscale = float(sys.argv[3])
    numepochs = int(sys.argv[4])
    randseed = int(sys.argv[5])
    logfile = "%s/%s_%s_%f_%s_%s.log" % (sys.argv[6], womode.name,
        snipauthmode.name, networkscale, numepochs, randseed)

    # Seed the PRNG.  On Ubuntu 18.04, this in fact makes future calls
    # to (non-cryptographic) random numbers deterministic.  On Ubuntu
    # 16.04, it does not.
    random.seed(randseed)

    loglevel = logging.INFO
    # Uncomment to see all the debug messages
    # loglevel = logging.DEBUG

    logging.basicConfig(level=loglevel,
            format="%(asctime)s:%(levelname)s:%(message)s")

    # The gathered statistics get logged separately
    statslogger = logging.getLogger("simulator")
    handler = logging.FileHandler(logfile)
    handler.setFormatter(logging.Formatter("%(asctime)s:%(message)s"))
    statslogger.addHandler(handler)
    statslogger.setLevel(logging.INFO)

    statslogger.info("Starting simulation %s", logfile)

    # Set the Walking Onions style to use
    network.thenetwork.set_wo_style(womode, snipauthmode)

    # The steady-state numbers of relays and clients
    relaytarget = math.ceil(6500 * networkscale)
    clienttarget = math.ceil(2500000 * networkscale)

    # Create the simulation
    simulator = Simulator(relaytarget, clienttarget, statslogger)

    for e in range(numepochs):
        statslogger.info("Starting epoch %s simulation", e+3)
        simulator.one_epoch()

        maxmemmib = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024
        statslogger.info("%d MiB used", maxmemmib)