Browse Source

Support the termination of relays

Ian Goldberg 4 years ago
parent
commit
53d721dad6
2 changed files with 94 additions and 4 deletions
  1. 24 0
      dirauth.py
  2. 70 4
      relay.py

+ 24 - 0
dirauth.py

@@ -100,6 +100,13 @@ class DirAuthUploadDescMsg(DirAuthNetMsg):
     def __init__(self, desc):
         self.desc = desc
 
+class DirAuthDelDescMsg(DirAuthNetMsg):
+    """The subclass of DirAuthNetMsg for deleting a relay
+    descriptor."""
+
+    def __init__(self, desc):
+        self.desc = desc
+
 class DirAuthGetConsensusMsg(DirAuthNetMsg):
     """The subclass of DirAuthNetMsg for fetching the consensus."""
 
@@ -225,6 +232,23 @@ class DirAuth(network.Server):
                 DirAuth.uploadeddescs[epoch][descstr] = \
                     (DirAuth.uploadeddescs[epoch][descstr][0]+1,
                      DirAuth.uploadeddescs[epoch][descstr][1])
+        elif isinstance(msg, DirAuthDelDescMsg):
+            # Check the uploaded descriptor for sanity
+            epoch = msg.desc.descdict['epoch']
+            if epoch != network.thenetwork.getepoch() + 1:
+                return
+            # Remove it from the class-static dict
+            if epoch not in DirAuth.uploadeddescs:
+                return
+            descstr = str(msg.desc)
+            if descstr not in DirAuth.uploadeddescs[epoch]:
+                return
+            elif DirAuth.uploadeddescs[epoch][descstr][0] == 1:
+                del DirAuth.uploadeddescs[epoch][descstr]
+            else:
+                DirAuth.uploadeddescs[epoch][descstr] = \
+                    (DirAuth.uploadeddescs[epoch][descstr][0]-1,
+                     DirAuth.uploadeddescs[epoch][descstr][1])
         elif isinstance(msg, DirAuthGetConsensusMsg):
             client.reply(DirAuthConsensusMsg(DirAuth.consensus))
         elif isinstance(msg, DirAuthGetENDIVEMsg):

+ 70 - 4
relay.py

@@ -47,6 +47,11 @@ class CircuitCellMsg(RelayNetMsg):
         return "C%d:%s" % (self.circid, self.cell)
 
 
+class RelayFallbackTerminationError(Exception):
+    """An exception raised when someone tries to terminate a fallback
+    relay."""
+
+
 class Channel(network.Connection):
     """A class representing a channel between a relay and either a
     client or a relay, transporting cells from various circuits."""
@@ -64,7 +69,7 @@ class Channel(network.Connection):
         self.peer = None
 
     def close(self):
-        if self.peer is not None:
+        if self.peer is not None and self.peer is not self:
             self.peer.closed()
         self.closed()
 
@@ -99,6 +104,14 @@ class CellRelay:
         self.dirauthaddrs = dirauthaddrs
         self.consensus = None
 
+    def terminate(self):
+        """Close all connections we're managing."""
+        while self.channels:
+            channelitems = iter(self.channels.items())
+            addr, channel = next(channelitems)
+            print('closing channel', addr, channel)
+            channel.close()
+
     def get_consensus(self):
         """Download a fresh consensus from a random dirauth."""
         a = random.choice(self.dirauthaddrs)
@@ -116,6 +129,7 @@ class CellRelay:
 
         channel.cellrelay = self
         self.channels[peeraddr] = channel
+        channel.closer = lambda: self.channels.pop(peeraddr)
 
     def get_channel_to(self, addr):
         """Get the Channel connected to the given NetAddr, creating one
@@ -187,6 +201,26 @@ class Relay(network.Server):
 
         self.uploaddesc()
 
+    def terminate(self):
+        """Stop this relay."""
+
+        if self.is_fallbackrelay:
+            # Fallback relays must not (for now) terminate
+            raise RelayFallbackTerminationError(self)
+
+        # Stop listening for epoch ticks
+        network.thenetwork.wantepochticks(self, False, end=True)
+        network.thenetwork.wantepochticks(self, False)
+
+        # Tell the dirauths we're going away
+        self.uploaddesc(False)
+
+        # Close connections to other relays
+        self.cellrelay.terminate()
+
+        # Stop listening to our own bound port
+        self.close()
+
     def set_is_fallbackrelay(self, isfallback = True):
         """Set this relay to be a fallback relay (or unset if passed
         False)."""
@@ -201,8 +235,9 @@ class Relay(network.Server):
     def newepoch(self, epoch):
         self.uploaddesc()
 
-    def uploaddesc(self):
-        # Upload the descriptor for the epoch to come
+    def uploaddesc(self, upload=True):
+        # Upload the descriptor for the epoch to come, or delete a
+        # previous upload if upload=False
         descdict = dict();
         descdict["epoch"] = network.thenetwork.getepoch() + 1
         descdict["idkey"] = self.idkey.verify_key
@@ -214,7 +249,14 @@ class Relay(network.Server):
         desc.sign(self.idkey)
         desc.verify()
 
-        descmsg = dirauth.DirAuthUploadDescMsg(desc)
+        if upload:
+            descmsg = dirauth.DirAuthUploadDescMsg(desc)
+        else:
+            # Note that this relies on signatures being deterministic;
+            # otherwise we'd need to save the descriptor we uploaded
+            # before so we could tell the airauths to delete the exact
+            # one
+            descmsg = dirauth.DirAuthDelDescMsg(desc)
 
         # Upload them
         for a in self.cellrelay.dirauthaddrs:
@@ -281,6 +323,30 @@ if __name__ == '__main__':
 
     relays[3].cellrelay.send_msg(RelayRandomHopMsg(30), relays[5].netaddr)
 
+    # See what channels exist and do a consistency check
+    for r in relays:
+        print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellrelay.channels.keys()]))
+        raddr = r.netaddr
+        for ad, ch in r.cellrelay.channels.items():
+            if ch.peer.cellrelay.myaddr != ad:
+                print('address mismatch:', raddr, ad, ch.peer.cellrelay.myaddr)
+
+            if ch.peer.cellrelay.channels[raddr].peer is not ch:
+                print('asymmetry:', raddr, ad, ch, ch.peer.cellrelay.channels[raddr].peer)
+
+    # Stop some relays
+    relays[3].terminate()
+    relays.remove(relays[3])
+    relays[5].terminate()
+    relays.remove(relays[5])
+    relays[7].terminate()
+    relays.remove(relays[7])
+
+    # Tick the epoch
+    network.thenetwork.nextepoch()
+
+    print(dirauth.DirAuth.consensus)
+
     # See what channels exist and do a consistency check
     for r in relays:
         print("%s: %s" % (r.netaddr, [ str(k) for k in r.cellrelay.channels.keys()]))