Browse Source

Refactor the random picking of relays into a RelayPicker class

Ian Goldberg 4 years ago
parent
commit
575073ac49
3 changed files with 116 additions and 65 deletions
  1. 4 9
      client.py
  2. 94 45
      dirauth.py
  3. 18 11
      relay.py

+ 4 - 9
client.py

@@ -42,8 +42,7 @@ class VanillaCreatedExtendedHandler:
 
         nexthop = None
         while nexthop is None:
-            nexthop = self.channelmgr.consensus.select_weighted_relay(
-                    self.channelmgr.consensus_cdf)
+            nexthop = self.channelmgr.relaypicker.pick_weighted_relay()
             if nexthop.descdict['addr'] in \
                     [ desc.descdict['addr'] \
                         for desc in circhandler.circuit_descs ]:
@@ -71,8 +70,6 @@ class ClientChannelManager(relay.ChannelManager):
         super().__init__(myaddr, dirauthaddrs, perfstats)
         self.guardaddr = None
         self.guard = None
-        if network.thenetwork.womode == network.WOMode.VANILLA:
-            self.consensus_cdf = []
 
     def get_consensus_from_fallbackrelay(self):
         """Download a fresh consensus from a random fallbackrelay."""
@@ -89,7 +86,7 @@ class ClientChannelManager(relay.ChannelManager):
         while True:
             if self.guardaddr is None:
                 # Pick a guard from the consensus
-                self.guard = self.consensus.select_weighted_relay(self.consensus_cdf)
+                self.guard = self.relaypicker.pick_weighted_relay()
                 self.guardaddr = self.guard.descdict['addr']
 
             # Connect to the guard
@@ -146,11 +143,9 @@ class ClientChannelManager(relay.ChannelManager):
         print("Client %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
         if isinstance(msg, relay.RelayConsensusMsg) or \
                 isinstance(msg, relay.RelayConsensusDiffMsg):
-            dirauth.Consensus.verify(msg.consensus,
+            self.relaypicker = dirauth.Consensus.verify(msg.consensus,
                     network.thenetwork.dirauthkeys(), self.perfstats)
             self.consensus = msg.consensus
-            if network.thenetwork.womode == network.WOMode.VANILLA:
-                self.consensus_cdf = self.consensus.bw_cdf()
         else:
             return super().received_msg(msg, peeraddr, channel)
 
@@ -306,7 +301,7 @@ if __name__ == '__main__':
     # Pick a bunch of bw-weighted random relays and look at the
     # distribution
     for i in range(100):
-        r = clients[0].channelmgr.consensus.select_weighted_relay(clients[0].channelmgr.consensus_cdf)
+        r = clients[0].channelmgr.relaypicker.pick_weighted_relay()
         print("relay",r.descdict["addr"])
 
     relays[3].terminate()

+ 94 - 45
dirauth.py

@@ -156,38 +156,18 @@ class Consensus:
             self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
         self.consdict['sigs'][index] = signed.signature
 
-    def bw_cdf(self):
-        """Create the array of cumulative bandwidth values from a consensus.
-        The array (cdf) will have the same length as the number of relays
-        in the consensus.  cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
-        cdf = [0]
-        for r in self.consdict['relays']:
-            cdf.append(cdf[-1]+r.descdict['bw'])
-        # Remove the last item, which should be the sum of all the bws
-        cdf.pop()
-        print('cdf=', cdf)
-        return cdf
-
-    def select_weighted_relay(self, cdf):
-        """Use the cdf generated by bw_cdf to select a relay with
-        probability proportional to its bw weight."""
-        totbw = self.consdict['totbw']
-        if totbw < 1:
-            raise ValueError("No relays to choose from")
-        val = random.randint(0, totbw-1)
-        # Find the rightmost entry less than or equal to val
-        idx = bisect.bisect_right(cdf, val)
-        return self.consdict['relays'][idx-1]
-
     @staticmethod
     def verify(consensus, verifkeylist, perfstats):
         """Use the given list of verification keys to check the
-        signatures on the consensus."""
+        signatures on the consensus.  Return the RelayPicker if
+        successful, or raise an exception otherwise."""
         assert(type(consensus) is Consensus)
         serialized = consensus.__str__(False)
         for i, vk in enumerate(verifkeylist):
             perfstats.verifs += 1
             vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
+        # If we got this far, all is well.  Return the RelayPicker.
+        return RelayPicker.get(consensus)
 
 
 # An ENDIVE is a dict containing:
@@ -227,31 +207,100 @@ class ENDIVE:
             self.enddict['sigs'].extend([None] * (index+1-len(self.enddict['sigs'])))
         self.enddict['sigs'][index] = signed.signature
 
-    def bw_cdf(self):
-        """Create the array of cumulative bandwidth values from an ENDIVE.
-        The array (cdf) will have the same length as the number of relays
-        in the ENDIVE.  cdf[0] = 0, and cdf[i] = cdf[i-1] + relay[i-1].bw."""
-        cdf = [ s.snipdict['range'][0] for s in self.enddict['snips'] ]
-        print('ENDIVE cdf=', cdf)
-        return cdf
-
-    def select_snip_by_index(self, i, cdf):
-        """Use the cdf generated by bw_cdf to select the SNIP for which
-        i is in the index range.  Choose i with
-        random.randint(0, (1<<32)-1)."""
-        # Find the rightmost entry less than or equal to i
-        idx = bisect.bisect_right(cdf, i)
-        return self.enddict['snips'][idx-1]
-
     @staticmethod
-    def verify(endive, verifkeylist, perfstats):
+    def verify(endive, consensus, verifkeylist, perfstats):
         """Use the given list of verification keys to check the
-        signatures on the ENDIVE."""
-        assert(type(endive) is ENDIVE)
-        serialized = endive.__str__(False)
+        signatures on the ENDIVE and consensus.  Return the RelayPicker
+        if successful, or raise an exception otherwise."""
+        assert(type(endive) is ENDIVE and type(consensus) is Consensus)
+        serializedcons = consensus.__str__(False)
         for i, vk in enumerate(verifkeylist):
             perfstats.verifs += 1
-            vk.verify(serialized.encode("ascii"), endive.enddict['sigs'][i])
+            vk.verify(serializedcons.encode("ascii"), consensus.consdict['sigs'][i])
+        serializedend = endive.__str__(False)
+        for i, vk in enumerate(verifkeylist):
+            perfstats.verifs += 1
+            vk.verify(serializedend.encode("ascii"), endive.enddict['sigs'][i])
+        # If we got this far, all is well.  Return the RelayPicker.
+        return RelayPicker.get(consensus, endive)
+
+
+class RelayPicker:
+    """An instance of this class (which may be a singleton in the
+    simulation) is returned by the Consensus.verify() and
+    ENDIVE.verify() methods.  It does any necessary precomputation
+    and/or caching, and exposes a method to select a random bw-weighted
+    relay, either explicitly specifying a uniform random value, or
+    letting the choice be done internally."""
+
+    # The singleton instance
+    relaypicker = None
+
+    def __init__(self, consensus, endive = None):
+        self.epoch = consensus.consdict["epoch"]
+        self.totbw = consensus.consdict["totbw"]
+        self.consensus = consensus
+        self.endive = endive
+        assert(endive is None or endive.enddict["epoch"] == self.epoch)
+
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            # Create the array of cumulative bandwidth values from a
+            # consensus.  The array (cdf) will have the same length as
+            # the number of relays in the consensus.  cdf[0] = 0, and
+            # cdf[i] = cdf[i-1] + relay[i-1].bw.
+            self.cdf = [0]
+            for r in consensus.consdict['relays']:
+                self.cdf.append(self.cdf[-1]+r.descdict['bw'])
+            # Remove the last item, which should be the sum of all the bws
+            self.cdf.pop()
+            print('cdf=', self.cdf)
+        elif network.thenetwork.snipauthmode == \
+                network.SNIPAuthMode.THRESHSIG:
+            # Note that clients will call this with endive = None
+            if self.endive is not None:
+                self.cdf = [ s.snipdict['range'][0] \
+                        for s in self.endive.enddict['snips'] ]
+            else:
+                self.cdf = None
+            print('cdf=', self.cdf)
+        else:
+            raise NotImplementedError("Merkle auth not yet implemented")
+
+    @staticmethod
+    def get(consensus, endive = None):
+        # Return the singleton instance, if it exists for this epoch
+        # However, don't use the cached instance if that one has
+        # endive=None, but we were passed a real ENDIVE
+        if RelayPicker.relaypicker is not None and \
+                (RelayPicker.relaypicker.endive is not None or \
+                        endive is None) and \
+                RelayPicker.relaypicker.epoch == consensus.consdict["epoch"]:
+            return RelayPicker.relaypicker
+
+        # Create it otherwise, storing the result as the singleton
+        RelayPicker.relaypicker = RelayPicker(consensus, endive)
+        return RelayPicker.relaypicker
+
+    def pick_relay_by_uniform_index(self, idx):
+        """Pass in a uniform random index random(0,totbw-1) to get a
+        relay selected weighted by bw."""
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            relays = self.consensus.consdict['relays']
+        else:
+            relays = self.endive.enddict['snips']
+
+        # Find the rightmost entry less than or equal to idx
+        i = bisect.bisect_right(self.cdf, idx)
+        return relays[i-1]
+
+    def pick_weighted_relay(self):
+        """Select a random relay with probability proportional to its bw
+        weight."""
+        totbw = self.totbw
+        if totbw < 1:
+            raise ValueError("No relays to choose from")
+        idx = random.randint(0, totbw-1)
+        return self.pick_relay_by_uniform_index(idx)
 
 
 class DirAuthNetMsg(network.NetMsg):

+ 18 - 11
relay.py

@@ -474,6 +474,7 @@ class ChannelManager:
         self.myaddr = myaddr
         self.dirauthaddrs = dirauthaddrs
         self.consensus = None
+        self.relaypicker = None
         self.perfstats = perfstats
 
     def terminate(self):
@@ -542,16 +543,20 @@ class RelayChannelManager(ChannelManager):
         self.idpubkey = idpubkey
 
     def get_consensus(self):
-        """Download a fresh consensus from a random dirauth."""
+        """Download a fresh consensus (and ENDIVE if using Walking
+        Onions) from a random dirauth."""
         a = random.choice(self.dirauthaddrs)
         c = network.thenetwork.connect(self, a, self.perfstats)
-        if self.consensus is not None and \
-                len(self.consensus.consdict['relays']) > 0:
-            self.consensus = c.getconsensusdiff()
+        if network.thenetwork.womode == network.WOMode.VANILLA:
+            if self.consensus is not None and \
+                    len(self.consensus.consdict['relays']) > 0:
+                self.consensus = c.getconsensusdiff()
+            else:
+                self.consensus = c.getconsensus()
+            self.relaypicker = dirauth.Consensus.verify(self.consensus,
+                    network.thenetwork.dirauthkeys(), self.perfstats)
         else:
-            self.consensus = c.getconsensus()
-        dirauth.Consensus.verify(self.consensus,
-                network.thenetwork.dirauthkeys(), self.perfstats)
+            raise NotImplementedError("Walking Onions not yet implemented")
         c.close()
 
     def received_msg(self, msg, peeraddr, channel):
@@ -756,10 +761,12 @@ if __name__ == '__main__':
     print(dirauth.DirAuth.consensus)
     print(dirauth.DirAuth.endive)
 
-    dirauth.Consensus.verify(dirauth.DirAuth.consensus,
-            network.thenetwork.dirauthkeys(), perfstats)
-    if network.thenetwork.womode != network.WOMode.VANILLA:
-        dirauth.ENDIVE.verify(dirauth.DirAuth.endive,
+    if network.thenetwork.womode == network.WOMode.VANILLA:
+        relaypicker = dirauth.Consensus.verify(dirauth.DirAuth.consensus,
+                network.thenetwork.dirauthkeys(), perfstats)
+    else:
+        relaypicker = dirauth.ENDIVE.verify(dirauth.DirAuth.endive,
+                dirauth.DirAuth.consensus,
                 network.thenetwork.dirauthkeys(), perfstats)
         for s in dirauth.DirAuth.endive.enddict['snips']:
             dirauth.SNIP.verify(s, dirauth.DirAuth.consensus,