Browse Source

Clients choose a guard from the bw-weighted relay list

Ian Goldberg 4 years ago
parent
commit
aaee3b812c
4 changed files with 86 additions and 10 deletions
  1. 51 6
      client.py
  2. 23 0
      dirauth.py
  3. 9 1
      network.py
  4. 3 3
      relay.py

+ 51 - 6
client.py

@@ -13,12 +13,39 @@ class CellClient(relay.CellHandler):
     def __init__(self, myaddr, dirauthaddrs):
         super().__init__(myaddr, dirauthaddrs)
         self.guardaddr = None
+        self.consensus_cdf = [0]
 
     def get_consensus_from_fallbackrelay(self):
         """Download a fresh consensus from a random fallbackrelay."""
         fb = random.choice(network.thenetwork.getfallbackrelays())
         self.send_msg(relay.RelayGetConsensusMsg(), fb.netaddr)
 
+    def ensure_guard_vanilla(self):
+        """Ensure that we have a channel to a guard (Vanilla Onion
+        Routing version)."""
+        while True:
+            if self.guardaddr is None:
+                # Pick a guard from the consensus
+                guard = self.consensus.select_weighted_relay(self.consensus_cdf)
+                self.guardaddr = guard.descdict['addr']
+
+            # Connect to the guard
+            try:
+                self.get_channel_to(self.guardaddr)
+            except network.NetNoServer:
+                # Our guard is gone
+                self.guardaddr = None
+
+            if self.guardaddr is not None:
+                break
+
+        print('chose guard=', self.guardaddr)
+
+    def ensure_guard(self):
+        """Ensure that we have a channel to a guard."""
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            self.ensure_guard_vanilla()
+
     def received_msg(self, msg, peeraddr, peer):
         """Callback when a NetMsg not specific to a circuit is
         received."""
@@ -26,6 +53,7 @@ class CellClient(relay.CellHandler):
         if isinstance(msg, relay.RelayConsensusMsg):
             self.consensus = msg.consensus
             dirauth.verify_consensus(self.consensus, network.thenetwork.dirauthkeys())
+            self.consensus_cdf = self.consensus.bw_cdf()
         else:
             return super().received_msg(msg, peeraddr, peer)
 
@@ -35,7 +63,6 @@ class CellClient(relay.CellHandler):
         return super().received_cell(circid, cell, peeraddr, peer)
 
 
-
 class Client:
     """A class representing a Tor client."""
 
@@ -72,11 +99,12 @@ class Client:
     def newepoch(self, epoch):
         """Callback that fires at the start of each epoch"""
 
-        if self.cellhandler.consensus is None or \
-                self.cellhandler.consensus.consdict['epoch'] != \
-                network.thenetwork.getepoch():
-            # We'll need a new consensus
-            self.get_consensus()
+        # We'll need a new consensus
+        self.get_consensus()
+
+        # If we don't have a guard, pick one and make a channel to it
+        self.cellhandler.ensure_guard()
+
 
 
 if __name__ == '__main__':
@@ -146,3 +174,20 @@ if __name__ == '__main__':
 
             if ch.peer.cellhandler.channels[caddr].peer is not ch:
                 print('asymmetry:', caddr, ad, ch, ch.peer.cellhandler.channels[caddr].peer)
+
+    relays[3].terminate()
+    del relays[3]
+
+    # Tick the epoch
+    network.thenetwork.nextepoch()
+
+    # See what channels exist and do a consistency check
+    for c in clients:
+        print("%s: %s" % (c.netaddr, [ str(k) for k in c.cellhandler.channels.keys()]))
+        caddr = c.netaddr
+        for ad, ch in c.cellhandler.channels.items():
+            if ch.peer.cellhandler.myaddr != ad:
+                print('address mismatch:', caddr, ad, ch.peer.cellhandler.myaddr)
+
+            if ch.peer.cellhandler.channels[caddr].peer is not ch:
+                print('asymmetry:', caddr, ad, ch, ch.peer.cellhandler.channels[caddr].peer)

+ 23 - 0
dirauth.py

@@ -1,5 +1,8 @@
 #!/usr/bin/env python3
 
+import random # For simulation, not cryptography!
+import bisect
+
 import nacl.encoding
 import nacl.signing
 import network
@@ -82,6 +85,26 @@ class Consensus:
             self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
         self.consdict['sigs'][index] = signed.signature
 
+    def bw_cdf(self):
+        """Create the array of cumulative bandwidth values from a consensus.
+        The array (cdf) will have the same length as the number of relays
+        in the consensus.  cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
+        cdf = [0]
+        for r in self.consdict['relays']:
+            cdf.append(cdf[-1]+r.descdict['bw'])
+        # Remove the last item, which should be the sum of all the bws
+        cdf.pop()
+        print('cdf=', cdf)
+        return cdf
+
+    def select_weighted_relay(self, cdf):
+        """Use the cdf generated by bw_cdf to select a relay with
+        probability proportional to its bw weight."""
+        val = random.randint(1, self.consdict['totbw'])
+        idx = bisect.bisect_left(cdf, val)
+        return self.consdict['relays'][idx-1]
+
+
 
 def verify_consensus(consensus, verifkeylist):
     """Use the given list of verification keys to check the

+ 9 - 1
network.py

@@ -37,6 +37,11 @@ class NetAddr:
         return self.addr.__str__()
 
 
+class NetNoServer(Exception):
+    """No server is listening on the address someone tried to connect
+    to."""
+
+
 class Network:
     """A class representing a simulated network.  Servers can bind()
     to the network, yielding a NetAddr (network address), and clients
@@ -112,7 +117,10 @@ class Network:
     def connect(self, client, srvaddr):
         """Connect the given client to the server bound to addr.  Throw
         an exception if there is no server bound to that address."""
-        server = self.servers[srvaddr]
+        try:
+            server = self.servers[srvaddr]
+        except KeyError:
+            raise NetNoServer()
         return server.connected(client)
 
     def setfallbackrelays(self, fallbackrelays):

+ 3 - 3
relay.py

@@ -356,11 +356,11 @@ if __name__ == '__main__':
 
     # Stop some relays
     relays[3].terminate()
-    relays.remove(relays[3])
+    del relays[3]
     relays[5].terminate()
-    relays.remove(relays[5])
+    del relays[5]
     relays[7].terminate()
-    relays.remove(relays[7])
+    del relays[7]
 
     # Tick the epoch
     network.thenetwork.nextepoch()