Browse Source

Implement RelayGetDescMsg and RelayDescMsg messages

In Walking Onions (particularly Single-Pass, but also Telescoping
if the onion keys rotate), we need to fetch our guard's current
descriptor at the start of each epoch so that we have its current
onion and (for single-pass) path selection keys.
Ian Goldberg 4 years ago
parent
commit
e2e554939f
2 changed files with 77 additions and 28 deletions
  1. 41 21
      client.py
  2. 36 7
      relay.py

+ 41 - 21
client.py

@@ -71,8 +71,12 @@ class TelescopingCreatedHandler:
     def __init__(self, channelmgr, ntor):
         self.channelmgr = channelmgr
         self.ntor = ntor
-        self.onionkey = self.channelmgr.guard.snipdict["onionkey"]
-        self.idkey = self.channelmgr.guard.snipdict["idkey"]
+        if type(self.channelmgr.guard) is dirauth.RelayDescriptor:
+            guardd = self.channelmgr.guard.descdict
+        else:
+            guardd = self.channelmgr.guard.snipdict
+        self.onionkey = guardd["onionkey"]
+        self.idkey = guardd["idkey"]
 
     def received_cell(self, circhandler, cell):
         logging.debug("Received cell in TelescopingCreatedHandler")
@@ -137,21 +141,10 @@ class TelescopingExtendedHandler:
         # Are we done building the circuit?
         logging.warning("we may need another circhandler structure for snips")
         if len(circhandler.circuit_descs) == 3:
-            logging.debug("Circuit [%s] is long enough; exiting.", [str(x.snipdict['addr']) for x in circhandler.circuit_descs])
             # Yes!
             return
 
-        nexthopidx = None
-        guardrange = circhandler.circuit_descs[0].snipdict["range"]
-        while nexthopidx is None:
-            # Relays make sure that when the extend to a relay, they are not
-            # extending to themselves. So here, we just need to make sure that
-            # this ID is not the same as the guard ID, to protect against the
-            # guard and exit being the same relay
-            nexthopidx = self.channelmgr.relaypicker.pick_weighted_relay_index()
-            if guardrange[0] <= nexthopidx and nexthopidx < guardrange[1]:
-                # We've picked this relay already.  Try again.
-                nexthopidx = None
+        nexthopidx = self.channelmgr.relaypicker.pick_weighted_relay_index()
 
         # Construct the VanillaExtendCircuitCell
         ntor = relay.NTor(self.channelmgr.perfstats)
@@ -184,11 +177,16 @@ class SinglePassCreatedHandler:
         circhandler.circuit_descs.append(self.channelmgr.guard)
 
         # Process each layer of the message
+        blinding_keys = []
         while cell is not None:
             lasthop = circhandler.circuit_descs[-1]
-            onionkey = lasthop.snipdict["onionkey"]
-            idkey = lasthop.snipdict["idkey"]
-            pathselkey = lasthop.snipdict["pathselkey"]
+            if type(lasthop) is dirauth.RelayDescriptor:
+                lasthopd = lasthop.descdict
+            else:
+                lasthopd = lasthop.snipdict
+            onionkey = lasthopd["onionkey"]
+            idkey = lasthopd["idkey"]
+            pathselkey = lasthopd["pathselkey"]
             if cell.enc is None:
                 secret = self.ntor.verify(cell.ntor_reply, onionkey, idkey)
                 enckey = nacl.hash.sha256(secret + b'upstream')
@@ -281,6 +279,17 @@ class ClientChannelManager(relay.ChannelManager):
             if self.guardaddr is not None:
                 break
 
+        # Ensure we have the current descriptor for the guard
+        # Note that self.guard may be a RelayDescriptor or a SNIP,
+        # depending on how we got it
+        if type(self.guard) is dirauth.RelayDescriptor:
+            guardepoch = self.guard.descdict["epoch"]
+        else:
+            guardepoch = self.guard.snipdict["epoch"]
+        if guardepoch != network.thenetwork.getepoch():
+            guardchannel = self.get_channel_to(self.guardaddr)
+            guardchannel.send_msg(relay.RelayGetDescMsg())
+
         logging.debug('chose guard=%s', self.guardaddr)
 
     def ensure_guard(self):
@@ -356,8 +365,12 @@ class ClientChannelManager(relay.ChannelManager):
         # equivalent test is needed here (but should just log a debug,
         # not an error, since the client cannot control the index value
         # selected for the exit.
-        if circhandler.circuit_descs[0].snipdict["addr"] == \
-                circhandler.circuit_descs[2].snipdict["addr"]:
+        guard = circhandler.circuit_descs[0]
+        if type(guard) is dirauth.RelayDescriptor:
+            guardd = guard.descdict
+        else:
+            guardd = guard.snipdict
+        if guardd["addr"] == circhandler.circuit_descs[2].snipdict["addr"]:
             logging.error("CIRCUIT IN A LOOP")
             circhandler.close()
             circhandler = None
@@ -404,8 +417,12 @@ class ClientChannelManager(relay.ChannelManager):
         # circuit got into a loop (guard equals exit); each node will
         # refuse to extend to itself, so this is the only possible loop
         # in a circuit of length 3
-        if circhandler.circuit_descs[0].snipdict["addr"] == \
-                circhandler.circuit_descs[2].snipdict["addr"]:
+        guard = circhandler.circuit_descs[0]
+        if type(guard) is dirauth.RelayDescriptor:
+            guardd = guard.descdict
+        else:
+            guardd = guard.snipdict
+        if guardd["addr"] == circhandler.circuit_descs[2].snipdict["addr"]:
             logging.debug("circuit in a loop")
             circhandler.close()
             circhandler = None
@@ -437,6 +454,9 @@ class ClientChannelManager(relay.ChannelManager):
             self.relaypicker = dirauth.Consensus.verify(msg.consensus,
                     network.thenetwork.dirauthkeys(), self.perfstats)
             self.consensus = msg.consensus
+        elif isinstance(msg, relay.RelayDescMsg):
+            dirauth.RelayDescriptor.verify(msg.desc, self.perfstats)
+            self.guard = msg.desc
         else:
             return super().received_msg(msg, peeraddr, channel)
 

+ 36 - 7
relay.py

@@ -71,6 +71,19 @@ class RelayConsensusMsg(RelayNetMsg):
         self.consensus = consensus
 
 
+class RelayGetDescMsg(RelayNetMsg):
+    """The subclass of RelayNetMsg sent by clients to their guards for
+    retrieving the guard's current descriptor."""
+
+
+class RelayDescMsg(RelayNetMsg):
+    """The subclass of RelayNetMsg sent by guards to clients for
+    reporting their current descriptor."""
+
+    def __init__(self, desc):
+        self.desc = desc
+
+
 class RelayGetConsensusDiffMsg(RelayNetMsg):
     """The subclass of RelayNetMsg for fetching the consensus, if the
     requestor already has the previous consensus. Sent by clients to
@@ -886,15 +899,16 @@ class RelayChannelManager(ChannelManager):
     """The subclass of ChannelManager for relays."""
 
     def __init__(self, myaddr, dirauthaddrs, onionprivkey, idpubkey,
-            path_selection_key, perfstats):
+            desc_getter, path_selection_key_getter, perfstats):
         super().__init__(myaddr, dirauthaddrs, perfstats)
         self.onionkey = onionprivkey
         self.idpubkey = idpubkey
         if network.thenetwork.womode != network.WOMode.VANILLA:
             self.endive = None
+        self.desc_getter = desc_getter
 
         if network.thenetwork.womode == network.WOMode.SINGLEPASS:
-            self.path_selection_key = path_selection_key
+            self.path_selection_key_getter = path_selection_key_getter
 
     def get_consensus(self):
         """Download a fresh consensus (and ENDIVE if using Walking
@@ -938,6 +952,8 @@ class RelayChannelManager(ChannelManager):
             self.send_msg(RelayConsensusMsg(self.consensus), peeraddr)
         elif isinstance(msg, RelayGetConsensusDiffMsg):
             self.send_msg(RelayConsensusDiffMsg(self.consensus), peeraddr)
+        elif isinstance(msg, RelayGetDescMsg):
+            self.send_msg(RelayDescMsg(self.desc_getter()), peeraddr)
         elif isinstance(msg, VanillaCreateCircuitMsg):
             # A new circuit has arrived
             circhandler = channel.new_circuit_with_circid(msg.circid)
@@ -1021,11 +1037,12 @@ class RelayChannelManager(ChannelManager):
                     Sphinx.server(nacl.public.PublicKey(msg.clipathselkey),
                     self.onionkey, b'pathsel', False, self.perfstats)
 
+            pathselkey = self.path_selection_key_getter()
             # Simulate the VRF output for now (but it has the right
             # size, and charges the right number of group operations to
             # the perfstats)
-            vrf_output = VRF.get_output(self.path_selection_key,
-                    pathsel_rand, self.perfstats)
+            vrf_output = VRF.get_output(pathselkey, pathsel_rand,
+                    self.perfstats)
 
             index = int.from_bytes(vrf_output[0][:4], 'big', signed=False)
 
@@ -1096,16 +1113,22 @@ class Relay(network.Server):
         network.thenetwork.wantepochticks(self, True, end=True)
         network.thenetwork.wantepochticks(self, True)
 
+        self.current_desc = None
+        self.next_desc = None
+
         # Create the path selection key for Single-Pass Walking Onions
         if network.thenetwork.womode == network.WOMode.SINGLEPASS:
             self.path_selection_key = nacl.public.PrivateKey.generate()
+            self.next_path_selection_key = self.path_selection_key
             self.perfstats.keygens += 1
         else:
             self.path_selection_key = None
 
         # Create the RelayChannelManager connection manager
         self.channelmgr = RelayChannelManager(self.netaddr, dirauthaddrs,
-                self.onionkey, self.idkey.verify_key, self.path_selection_key, self.perfstats)
+                self.onionkey, self.idkey.verify_key,
+                lambda: self.current_desc, lambda: self.path_selection_key,
+                self.perfstats)
 
         # Initially, we're not a fallback relay
         self.is_fallbackrelay = False
@@ -1146,8 +1169,11 @@ class Relay(network.Server):
     def newepoch(self, epoch):
         # Rotate the path selection key for Single-Pass Walking Onions
         if network.thenetwork.womode == network.WOMode.SINGLEPASS:
-            self.path_selection_key = nacl.public.PrivateKey.generate()
+            self.path_selection_key = self.next_path_selection_key
+            self.next_path_selection_key = nacl.public.PrivateKey.generate()
             self.perfstats.keygens += 1
+        # Upload the descriptor for the *next* epoch (the one after the
+        # one that just started)
         self.uploaddesc()
 
     def uploaddesc(self, upload=True):
@@ -1162,11 +1188,14 @@ class Relay(network.Server):
         descdict["flags"] = self.flags
 
         if network.thenetwork.womode == network.WOMode.SINGLEPASS:
-            descdict["pathselkey"] = bytes(self.path_selection_key)
+            descdict["pathselkey"] = \
+                    bytes(self.next_path_selection_key.public_key)
 
         desc = dirauth.RelayDescriptor(descdict)
         desc.sign(self.idkey, self.perfstats)
         dirauth.RelayDescriptor.verify(desc, self.perfstats)
+        self.current_desc = self.next_desc
+        self.next_desc = desc
 
         if upload:
             descmsg = dirauth.DirAuthUploadDescMsg(desc)