Browse Source

Relays must tick their epoch before clients in Single-Pass Walking Onions

This was the cause of the mysterious VRF failure.  If a client's guard
churned out, and its new one was a relay that started after it did,
then the guard's descriptor (containing its new path selection (VRF)
key) would not yet have been updated to the current epoch by the time
the client asked for it.

Rather than special-casing Single-Pass Walking Onions, we just make it
true always that dirauths and relays get their epoch ticks called before
clients.
Ian Goldberg 4 years ago
parent
commit
5151a304ed
3 changed files with 48 additions and 31 deletions
  1. 1 1
      dirauth.py
  2. 43 26
      network.py
  3. 4 4
      relay.py

+ 1 - 1
dirauth.py

@@ -493,7 +493,7 @@ class DirAuth(network.Server):
         self.netaddr = network.thenetwork.bind(self)
         self.perfstats.name = "DirAuth at %s" % self.netaddr
         network.thenetwork.setdirauthkey(me, self.sigkey.verify_key)
-        network.thenetwork.wantepochticks(self, True, True)
+        network.thenetwork.wantepochticks(self, True, True, True)
 
     def connected(self, client):
         """Callback invoked when a client connects to us. This callback

+ 43 - 26
network.py

@@ -208,6 +208,8 @@ class Network:
     def __init__(self):
         self.servers = dict()
         self.epoch = 1
+        self.epochprioritycallbacks = []
+        self.epochpriorityendingcallbacks = []
         self.epochcallbacks = []
         self.epochendingcallbacks = []
         self.dirauthkeylist = []
@@ -239,47 +241,62 @@ class Network:
     def nextepoch(self):
         """Increment the current epoch, and return it."""
         logging.info("Ending epoch %s", self.epoch)
-        totendingcallbacks = len(self.epochendingcallbacks)
+        totendingcallbacks = len(self.epochpriorityendingcallbacks) + \
+                len(self.epochendingcallbacks)
+        numendingcalled = 0
         lastroundpercent = -1
-        for i, c in enumerate(self.epochendingcallbacks):
-            c.epoch_ending(self.epoch)
-            roundpercent = int(100*(i+1)/totendingcallbacks)
-            if roundpercent != lastroundpercent:
-                logging.info("Ending epoch %s %d%% complete",
-                        self.epoch, roundpercent)
-                lastroundpercent = roundpercent
+        for l in [ self.epochpriorityendingcallbacks,
+                self.epochendingcallbacks ]:
+            for c in l:
+                c.epoch_ending(self.epoch)
+                numendingcalled += 1
+                roundpercent = int(100*numendingcalled/totendingcallbacks)
+                if roundpercent != lastroundpercent:
+                    logging.info("Ending epoch %s %d%% complete",
+                            self.epoch, roundpercent)
+                    lastroundpercent = roundpercent
         self.epoch += 1
         logging.info("Starting epoch %s", self.epoch)
-        totcallbacks = len(self.epochcallbacks)
+        totcallbacks = len(self.epochprioritycallbacks) + \
+                len(self.epochcallbacks)
+        numcalled = 0
         lastroundpercent = -1
-        for i, c in enumerate(self.epochcallbacks):
-            c.newepoch(self.epoch)
-            roundpercent = int(100*(i+1)/totcallbacks)
-            if roundpercent != lastroundpercent:
-                logging.info("Starting epoch %s %d%% complete",
-                        self.epoch, roundpercent)
-                lastroundpercent = roundpercent
+        for l in [ self.epochprioritycallbacks, self.epochcallbacks ]:
+            for c in l:
+                c.newepoch(self.epoch)
+                numcalled += 1
+                roundpercent = int(100*numcalled/totcallbacks)
+                if roundpercent != lastroundpercent:
+                    logging.info("Starting epoch %s %d%% complete",
+                            self.epoch, roundpercent)
+                    lastroundpercent = roundpercent
         logging.info("Epoch %s started", self.epoch)
         return self.epoch
 
-    def wantepochticks(self, callback, want, end=False):
+    def wantepochticks(self, callback, want, priority=False, end=False):
         """Register or deregister an object from receiving epoch change
         callbacks.  If want is True, the callback object's newepoch()
         method will be called at each epoch change, with an argument of
         the new epoch.  If want if False, the callback object will be
-        deregistered.  If end is True, the callback object's
-        epoch_ending() method will be called instead at the end of the
-        epoch, just _before_ the epoch number change."""
+        deregistered.  If priority is True, call back this object before
+        any object with priority=False.  If end is True, the callback
+        object's epoch_ending() method will be called instead at the end
+        of the epoch, just _before_ the epoch number change."""
         if end:
-            if want:
-                self.epochendingcallbacks.append(callback)
+            if priority:
+                l = self.epochpriorityendingcallbacks
             else:
-                self.epochendingcallbacks.remove(callback)
+                l = self.epochendingcallbacks
         else:
-            if want:
-                self.epochcallbacks.append(callback)
+            if priority:
+                l = self.epochprioritycallbacks
             else:
-                self.epochcallbacks.remove(callback)
+                l = self.epochcallbacks
+
+        if want:
+            l.append(callback)
+        else:
+            l.remove(callback)
 
     def bind(self, server):
         """Bind a server to a newly generated NetAddr, returning the

+ 4 - 4
relay.py

@@ -1120,8 +1120,8 @@ class Relay(network.Server):
         self.flags = flags
 
         # Register for epoch change notification
-        network.thenetwork.wantepochticks(self, True, end=True)
-        network.thenetwork.wantepochticks(self, True)
+        network.thenetwork.wantepochticks(self, True, priority=True, end=True)
+        network.thenetwork.wantepochticks(self, True, priority=True)
 
         self.current_desc = None
         self.next_desc = None
@@ -1153,8 +1153,8 @@ class Relay(network.Server):
             raise RelayFallbackTerminationError(self)
 
         # Stop listening for epoch ticks
-        network.thenetwork.wantepochticks(self, False, end=True)
-        network.thenetwork.wantepochticks(self, False)
+        network.thenetwork.wantepochticks(self, False, priority=True, end=True)
+        network.thenetwork.wantepochticks(self, False, priority=True)
 
         # Tell the dirauths we're going away
         self.uploaddesc(False)