Browse Source

Start on the simulator

Current status: relays can upload descriptors to dirauths
Ian Goldberg 4 years ago
commit
ab375b0775
3 changed files with 409 additions and 0 deletions
  1. 154 0
      dirauth.py
  2. 187 0
      network.py
  3. 68 0
      relay.py

+ 154 - 0
dirauth.py

@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+
+import nacl.encoding
+import nacl.signing
+import network
+
+# A relay descriptor is a dict containing:
+#  epoch: epoch id
+#  idkey: a public identity key
+#  onionkey: a public onion key
+#  addr: a network address
+#  bw: bandwidth
+#  flags: relay flags
+#  vrfkey: a VRF public key (Single-Pass Walking Onions only)
+#  sig: a signature over the above by the idkey
+class RelayDescriptor:
+    def __init__(self, descdict):
+        self.descdict = descdict
+
+    def __str__(self, withsig = True):
+        res = "RelayDesc[\n"
+        for k in ["epoch", "idkey", "onionkey", "addr", "bw", "flags",
+                    "vrfkey", "sig"]:
+            if k in self.descdict:
+                if k == "idkey" or k == "onionkey":
+                    res += "  " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
+                elif k == "sig":
+                    if withsig:
+                        res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
+                else:
+                    res += "  " + k + ": " + str(self.descdict[k]) + "\n"
+        res += "]\n"
+        return res
+
+    def sign(self, signingkey):
+        serialized = self.__str__(False)
+        signed = signingkey.sign(serialized.encode("ascii"))
+        self.descdict["sig"] = signed.signature
+
+    def verify(self):
+        serialized = self.__str__(False)
+        self.descdict["idkey"].verify(serialized.encode("ascii"), self.descdict["sig"])
+
+class DirAuthNetMsg(network.NetMsg):
+    """The subclass of NetMsg for messages to and from directory
+    authorities."""
+
+class DirAuthUploadDescMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for uploading a relay
+    descriptor."""
+
+    def __init__(self, desc):
+        self.desc = desc
+
+class DirAuthGetConsensusMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for fetching the consensus."""
+
+class DirAuthConsensusMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for returning the consensus."""
+
+    def __init__(self, consensus):
+        self.consensus = consensus
+
+class DirAuthGetENDIVEMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for fetching the ENDIVE."""
+
+class DirAuthENDIVEMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for returning the ENDIVE."""
+
+    def __init__(self, endive):
+        self.endive = endive
+
+class DirAuthConnection(network.ClientConnection):
+    """The subclass of Connection for connections to directory
+    authorities."""
+
+    def __init__(self, peer = None):
+        super().__init__(peer)
+
+    def uploaddesc(self, desc):
+        """Upload our RelayDescriptor to the DirAuth."""
+        self.sendmsg(DirAuthUploadDescMeg(desc))
+
+    def getconsensus(self):
+        self.consensus = None
+        self.sendmsg(DirAuthGetConsensusMsg())
+        return self.consensus
+
+    def getENDIVE(self):
+        self.endive = None
+        self.sendmsg(DirAuthGetENDIVEMsg())
+        return self.endive
+
+    def receivedfromserver(self, msg):
+        if isinstance(msg, DirAuthConsensusMsg):
+            self.consensus = msg.consensus
+        elif isinstance(msg, DirAuthENDIVEMsg):
+            self.endive = msg.endive
+        else:
+            raise TypeError('Not a server-originating DirAuthNetMsg', msg)
+    
+class DirAuth(network.Server):
+    """The class representing directory authorities."""
+
+    def __init__(self, me, tot):
+        """Create a new directory authority. me is the index of which
+        dirauth this one is (starting from 0), and tot is the total
+        number of dirauths."""
+        self.me = me
+        self.tot = tot
+        self.name = "Dirauth %d of %d" % (me+1, tot)
+        self.consensus = None
+        self.endive = None
+        network.thenetwork.wantepochticks(self, True)
+
+    def connected(self, client):
+        """Callback invoked when a client connects to us. This callback
+        creates the DirAuthConnection that will be passed to the
+        client."""
+
+        # We don't actually need to keep per-connection state at
+        # dirauths, even in long-lived connections, so this is
+        # particularly simple.
+        return DirAuthConnection(self)
+
+    def newepoch(self, epoch):
+        print('New epoch', epoch, 'for', self)
+
+    def received(self, client, msg):
+        if isinstance(msg, DirAuthUploadDescMsg):
+            print(self.name, 'received descriptor from', client, ":", msg.desc)
+        elif isinstance(msg, DirAuthGetConsensusMsg):
+            client.sendmsg(DirAuthConsensusMsg(self.consensus))
+        elif isinstance(msg, DirAuthGetENDIVEMsg):
+            client.sendmsg(DirAuthENDIVEMsg(self.endive))
+        else:
+            raise TypeError('Not a client-originating DirAuthNetMsg', msg)
+
+    def closed(self):
+        pass
+
+if __name__ == '__main__':
+    # Start some dirauths
+    numdirauths = 9
+    dirauthaddrs = []
+    for i in range(numdirauths):
+        dirauth = DirAuth(i, numdirauths)
+        dirauthaddrs.append(network.thenetwork.bind(dirauth))
+
+    for a in dirauthaddrs:
+        print(a,end=' ')
+    print()
+
+    network.thenetwork.nextepoch()

+ 187 - 0
network.py

@@ -0,0 +1,187 @@
+#!/usr/bin/env python3
+
+class NetAddr:
+    """A class representing a network address"""
+    nextaddr = 1
+
+    def __init__(self):
+        """Generate a fresh network address"""
+        self.addr = NetAddr.nextaddr
+        NetAddr.nextaddr += 1
+
+    def __eq__(self, other):
+        return (isinstance(other, self.__class__)
+            and self.__dict__ == other.__dict__)
+
+    def __hash__(self):
+        return hash(self.addr)
+
+    def __str__(self):
+        return self.addr.__str__()
+
+class Network:
+    """A class representing a simulated network.  Servers can bind()
+    to the network, yielding a NetAddr (network address), and clients
+    can connect() to a NetAddr yielding a Connection."""
+
+    def __init__(self):
+        self.servers = dict()
+        self.epoch = 1
+        self.epochcallbacks = []
+
+    def printservers(self):
+        """Print the list of NetAddrs bound to something."""
+        print("Servers:")
+        for a in self.servers.keys():
+            print(a)
+
+    def getepoch(self):
+        """Return the current epoch."""
+        return self.epoch
+
+    def nextepoch(self):
+        """Increment the current epoch, and return it."""
+        self.epoch += 1
+        for c in self.epochcallbacks:
+            c.newepoch(self.epoch)
+        return self.epoch
+
+    def wantepochticks(self, callback, want):
+        """Register or deregister an object from receiving epoch change
+        callbacks.  If want is True, the callback object's newepoch()
+        method will be called at each epoch change, with an argument of
+        the new epoch.  If want if False, the callback object will be
+        deregistered."""
+        if want:
+            self.epochcallbacks.append(callback)
+        else:
+            self.epochcallbacks.remove(callback)
+
+    def bind(self, server):
+        """Bind a server to a newly generated NetAddr, returning the
+        NetAddr.  The server's bound() callback will also be invoked."""
+        addr = NetAddr()
+        self.servers[addr] = server
+        server.bound(addr, lambda: self.servers.pop(addr))
+        return addr
+
+    def connect(self, client, srvaddr):
+        """Connect the given client to the server bound to addr.  Throw
+        an exception if there is no server bound to that address."""
+        server = self.servers[srvaddr]
+        return server.connected(client)
+
+# The singleton instance of Network
+thenetwork = Network()
+
+class NetMsg:
+    """The parent class of network messages.  Subclass this class to
+    implement specific kinds of network messages."""
+
+class StringNetMsg(NetMsg):
+    """Send an arbitratry string as a NetMsg."""
+    def __init__(self, str):
+        self.data = str
+
+    def __str__(self):
+        return self.data.__str__()
+
+class Connection:
+    def __init__(self, peer = None):
+        """Create a Connection object with the given peer."""
+        self.peer = peer
+
+    def closed(self):
+        print("connection closed with", self.peer)
+        self.peer = None
+
+    def close(self):
+        print("closing connection with", self.peer)
+        self.peer.closed()
+        self.peer = None
+
+class ClientConnection(Connection):
+    """The parent class of client-side network connections.  Subclass
+    this class to do anything more elaborate than just passing arbitrary
+    NetMsgs, which then get ignored."""
+    
+    def __init__(self, peer):
+        """Create a ClientConnection object with the given peer.  The
+        peer must have a received(client, msg) method."""
+        self.peer = peer
+
+    def sendmsg(self, netmsg):
+        assert(isinstance(netmsg, NetMsg))
+        self.peer.received(self, netmsg)
+
+    def received(self, netmsg):
+        print("received", netmsg, "from server")
+
+
+class ServerConnection(Connection):
+    """The parent class of server-side network connections."""
+
+    def __init__(self):
+        self.peer = None
+
+    def sendmsg(self, netmsg):
+        assert(isinstance(netmsg, NetMsg))
+        self.peer.received(netmsg)
+
+    def received(self, client, netmsg):
+        print("received", netmsg, "from client", client)
+
+class Server:
+    """The parent class of network servers.  Subclass this class to
+    implement servers of different kinds.  You will probably only need
+    to override the implementation of connected()."""
+
+    def __init__(self, name):
+        self.name = name
+
+    def __str__(self):
+        return self.name.__str__()
+
+    def bound(self, netaddr, closer):
+        """Callback invoked when the server is successfully bound to a
+        NetAddr.  The parameters are the netaddr to which the server is
+        bound and closer (as in a thing that causes something to close)
+        is a callback function that should be used when the server
+        wishes to stop listening for new connections."""
+        self.closer = closer
+
+    def close(self):
+        """Stop listening for new connections."""
+        self.closer()
+
+    def connected(self, client):
+        """Callback invoked when a client connects to this server.
+        This callback must create the Connection object that will be
+        returned to the client."""
+        print("server", self, "conected to client", client)
+        serverconnection = ServerConnection()
+        clientconnection = ClientConnection(serverconnection)
+        serverconnection.peer = clientconnection
+        return clientconnection
+
+if __name__ == '__main__':
+    n1 = NetAddr()
+    n2 = NetAddr()
+    assert(n1 == n1)
+    assert(not (n1 == n2))
+    assert(n1 != n2)
+    print(n1, n2)
+
+    srv = Server("hello world server")
+
+    thenetwork.printservers()
+    a = thenetwork.bind(srv)
+    thenetwork.printservers()
+    print("in main", a)
+
+    conn = thenetwork.connect("hello world client", a)
+    conn.sendmsg(StringNetMsg("hi"))
+    conn.close()
+
+    srv.close()
+    thenetwork.printservers()

+ 68 - 0
relay.py

@@ -0,0 +1,68 @@
+#!/usr/bin/env python3
+
+import nacl.utils
+from nacl.signing import SigningKey
+from nacl.public import PrivateKey, Box
+
+import network
+import dirauth
+
+class Relay(network.Server):
+    """The class representing an onion relay."""
+
+    def __init__(self, dirauthaddrs, bw, flags):
+        self.consensus = None
+        self.dirauthaddrs = dirauthaddrs
+
+        # Create the identity and onion keys
+        self.idkey = SigningKey.generate()
+        self.onionkey = PrivateKey.generate()
+
+        # Bind to the network to get a network address
+        self.netaddr = network.thenetwork.bind(self)
+
+        # Our bandwidth and flags
+        self.bw = bw
+        self.flags = flags
+
+        # Register for epoch change notification
+        network.thenetwork.wantepochticks(self, True)
+
+        self.uploaddesc()
+
+    def newepoch(self, epoch):
+        self.uploaddesc()
+
+    def uploaddesc(self):
+        # Upload the descriptor for the epoch to come
+        descdict = dict();
+        descdict["epoch"] = network.thenetwork.getepoch() + 1
+        descdict["idkey"] = self.idkey.verify_key
+        descdict["onionkey"] = self.onionkey.public_key
+        descdict["addr"] = self.netaddr
+        descdict["bw"] = self.bw
+        descdict["flags"] = self.flags
+        desc = dirauth.RelayDescriptor(descdict)
+        desc.sign(self.idkey)
+        desc.verify()
+
+        descmsg = dirauth.DirAuthUploadDescMsg(desc)
+
+        # Upload them
+        for a in self.dirauthaddrs:
+            c = network.thenetwork.connect(self, a)
+            c.sendmsg(descmsg)
+            c.close()
+
+if __name__ == '__main__':
+    # Start some dirauths
+    numdirauths = 9
+    dirauthaddrs = []
+    for i in range(numdirauths):
+        dira = dirauth.DirAuth(i, numdirauths)
+        dirauthaddrs.append(network.thenetwork.bind(dira))
+
+    # Start some relays
+    numrelays = 10
+    for i in range(numrelays):
+        Relay(dirauthaddrs, 500, 0)