Browse Source

Clients choose a guard from the bw-weighted relay list

Ian Goldberg 4 years ago
parent
commit
aaee3b812c
4 changed files with 86 additions and 10 deletions
  1. 51 6
      client.py
  2. 23 0
      dirauth.py
  3. 9 1
      network.py
  4. 3 3
      relay.py

+ 51 - 6
client.py

@@ -13,12 +13,39 @@ class CellClient(relay.CellHandler):
     def __init__(self, myaddr, dirauthaddrs):
     def __init__(self, myaddr, dirauthaddrs):
         super().__init__(myaddr, dirauthaddrs)
         super().__init__(myaddr, dirauthaddrs)
         self.guardaddr = None
         self.guardaddr = None
+        self.consensus_cdf = [0]
 
 
     def get_consensus_from_fallbackrelay(self):
     def get_consensus_from_fallbackrelay(self):
         """Download a fresh consensus from a random fallbackrelay."""
         """Download a fresh consensus from a random fallbackrelay."""
         fb = random.choice(network.thenetwork.getfallbackrelays())
         fb = random.choice(network.thenetwork.getfallbackrelays())
         self.send_msg(relay.RelayGetConsensusMsg(), fb.netaddr)
         self.send_msg(relay.RelayGetConsensusMsg(), fb.netaddr)
 
 
+    def ensure_guard_vanilla(self):
+        """Ensure that we have a channel to a guard (Vanilla Onion
+        Routing version)."""
+        while True:
+            if self.guardaddr is None:
+                # Pick a guard from the consensus
+                guard = self.consensus.select_weighted_relay(self.consensus_cdf)
+                self.guardaddr = guard.descdict['addr']
+
+            # Connect to the guard
+            try:
+                self.get_channel_to(self.guardaddr)
+            except network.NetNoServer:
+                # Our guard is gone
+                self.guardaddr = None
+
+            if self.guardaddr is not None:
+                break
+
+        print('chose guard=', self.guardaddr)
+
+    def ensure_guard(self):
+        """Ensure that we have a channel to a guard."""
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            self.ensure_guard_vanilla()
+
     def received_msg(self, msg, peeraddr, peer):
     def received_msg(self, msg, peeraddr, peer):
         """Callback when a NetMsg not specific to a circuit is
         """Callback when a NetMsg not specific to a circuit is
         received."""
         received."""
@@ -26,6 +53,7 @@ class CellClient(relay.CellHandler):
         if isinstance(msg, relay.RelayConsensusMsg):
         if isinstance(msg, relay.RelayConsensusMsg):
             self.consensus = msg.consensus
             self.consensus = msg.consensus
             dirauth.verify_consensus(self.consensus, network.thenetwork.dirauthkeys())
             dirauth.verify_consensus(self.consensus, network.thenetwork.dirauthkeys())
+            self.consensus_cdf = self.consensus.bw_cdf()
         else:
         else:
             return super().received_msg(msg, peeraddr, peer)
             return super().received_msg(msg, peeraddr, peer)
 
 
@@ -35,7 +63,6 @@ class CellClient(relay.CellHandler):
         return super().received_cell(circid, cell, peeraddr, peer)
         return super().received_cell(circid, cell, peeraddr, peer)
 
 
 
 
-
 class Client:
 class Client:
     """A class representing a Tor client."""
     """A class representing a Tor client."""
 
 
@@ -72,11 +99,12 @@ class Client:
     def newepoch(self, epoch):
     def newepoch(self, epoch):
         """Callback that fires at the start of each epoch"""
         """Callback that fires at the start of each epoch"""
 
 
-        if self.cellhandler.consensus is None or \
-                self.cellhandler.consensus.consdict['epoch'] != \
-                network.thenetwork.getepoch():
-            # We'll need a new consensus
-            self.get_consensus()
+        # We'll need a new consensus
+        self.get_consensus()
+
+        # If we don't have a guard, pick one and make a channel to it
+        self.cellhandler.ensure_guard()
+
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
@@ -146,3 +174,20 @@ if __name__ == '__main__':
 
 
             if ch.peer.cellhandler.channels[caddr].peer is not ch:
             if ch.peer.cellhandler.channels[caddr].peer is not ch:
                 print('asymmetry:', caddr, ad, ch, ch.peer.cellhandler.channels[caddr].peer)
                 print('asymmetry:', caddr, ad, ch, ch.peer.cellhandler.channels[caddr].peer)
+
+    relays[3].terminate()
+    del relays[3]
+
+    # Tick the epoch
+    network.thenetwork.nextepoch()
+
+    # See what channels exist and do a consistency check
+    for c in clients:
+        print("%s: %s" % (c.netaddr, [ str(k) for k in c.cellhandler.channels.keys()]))
+        caddr = c.netaddr
+        for ad, ch in c.cellhandler.channels.items():
+            if ch.peer.cellhandler.myaddr != ad:
+                print('address mismatch:', caddr, ad, ch.peer.cellhandler.myaddr)
+
+            if ch.peer.cellhandler.channels[caddr].peer is not ch:
+                print('asymmetry:', caddr, ad, ch, ch.peer.cellhandler.channels[caddr].peer)

+ 23 - 0
dirauth.py

@@ -1,5 +1,8 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 
 
+import random # For simulation, not cryptography!
+import bisect
+
 import nacl.encoding
 import nacl.encoding
 import nacl.signing
 import nacl.signing
 import network
 import network
@@ -82,6 +85,26 @@ class Consensus:
             self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
             self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
         self.consdict['sigs'][index] = signed.signature
         self.consdict['sigs'][index] = signed.signature
 
 
+    def bw_cdf(self):
+        """Create the array of cumulative bandwidth values from a consensus.
+        The array (cdf) will have the same length as the number of relays
+        in the consensus.  cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
+        cdf = [0]
+        for r in self.consdict['relays']:
+            cdf.append(cdf[-1]+r.descdict['bw'])
+        # Remove the last item, which should be the sum of all the bws
+        cdf.pop()
+        print('cdf=', cdf)
+        return cdf
+
+    def select_weighted_relay(self, cdf):
+        """Use the cdf generated by bw_cdf to select a relay with
+        probability proportional to its bw weight."""
+        val = random.randint(1, self.consdict['totbw'])
+        idx = bisect.bisect_left(cdf, val)
+        return self.consdict['relays'][idx-1]
+
+
 
 
 def verify_consensus(consensus, verifkeylist):
 def verify_consensus(consensus, verifkeylist):
     """Use the given list of verification keys to check the
     """Use the given list of verification keys to check the

+ 9 - 1
network.py

@@ -37,6 +37,11 @@ class NetAddr:
         return self.addr.__str__()
         return self.addr.__str__()
 
 
 
 
+class NetNoServer(Exception):
+    """No server is listening on the address someone tried to connect
+    to."""
+
+
 class Network:
 class Network:
     """A class representing a simulated network.  Servers can bind()
     """A class representing a simulated network.  Servers can bind()
     to the network, yielding a NetAddr (network address), and clients
     to the network, yielding a NetAddr (network address), and clients
@@ -112,7 +117,10 @@ class Network:
     def connect(self, client, srvaddr):
     def connect(self, client, srvaddr):
         """Connect the given client to the server bound to addr.  Throw
         """Connect the given client to the server bound to addr.  Throw
         an exception if there is no server bound to that address."""
         an exception if there is no server bound to that address."""
-        server = self.servers[srvaddr]
+        try:
+            server = self.servers[srvaddr]
+        except KeyError:
+            raise NetNoServer()
         return server.connected(client)
         return server.connected(client)
 
 
     def setfallbackrelays(self, fallbackrelays):
     def setfallbackrelays(self, fallbackrelays):

+ 3 - 3
relay.py

@@ -356,11 +356,11 @@ if __name__ == '__main__':
 
 
     # Stop some relays
     # Stop some relays
     relays[3].terminate()
     relays[3].terminate()
-    relays.remove(relays[3])
+    del relays[3]
     relays[5].terminate()
     relays[5].terminate()
-    relays.remove(relays[5])
+    del relays[5]
     relays[7].terminate()
     relays[7].terminate()
-    relays.remove(relays[7])
+    del relays[7]
 
 
     # Tick the epoch
     # Tick the epoch
     network.thenetwork.nextepoch()
     network.thenetwork.nextepoch()