Browse Source

Refactor Traffic.py to use asyncore and asynchat

This refactor probably leaves a lot of needless pieces in place; I
tried to do only what was necessary to move away from a raw select()
loop.  Nonetheless, I think it simplifies things a bit.

(We can't use asyncio yet since we're keeping compatibility with
older pythons.)
Nick Mathewson 5 years ago
parent
commit
b90ff42dfe
1 changed files with 120 additions and 250 deletions
  1. 120 250
      lib/chutney/Traffic.py

+ 120 - 250
lib/chutney/Traffic.py

@@ -24,14 +24,26 @@ from __future__ import print_function
 
 import sys
 import socket
-import select
 import struct
 import errno
 import time
 import os
 
+import asyncore
+import asynchat
+
 from chutney.Debug import debug_flag, debug
 
+def addr_to_family(addr):
+    for family in [socket.AF_INET, socket.AF_INET6]:
+        try:
+            socket.inet_pton(family, addr)
+            return family
+        except (socket.error, OSError):
+            pass
+
+    return socket.AF_INET
+
 def socks_cmd(addr_port):
     """
     Return a SOCKS command for connecting to addr_port.
@@ -54,7 +66,6 @@ def socks_cmd(addr_port):
         dnsname = dnsname.encode("ascii")
     return struct.pack('!BBH', ver, cmd, port) + addr + user + dnsname
 
-
 class TestSuite(object):
 
     """Keep a tab on how many tests are pending, how many have failed
@@ -85,84 +96,50 @@ class TestSuite(object):
     def status(self):
         return('%d/%d/%d' % (self.not_done, self.successes, self.failures))
 
-
-class Peer(object):
-
-    "Base class for Listener, Source and Sink."
-    LISTENER = 1
-    SOURCE = 2
-    SINK = 3
-
-    def __init__(self, ptype, tt, s=None):
-        self.type = ptype
-        self.tt = tt  # TrafficTester
-        if s is not None:
-            self.s = s
-        else:
-            self.s = socket.socket()
-            self.s.setblocking(False)
-
-    def fd(self):
-        return self.s.fileno()
-
-    def is_source(self):
-        return self.type == self.SOURCE
-
-    def is_sink(self):
-        return self.type == self.SINK
-
-
-class Listener(Peer):
-
+class Listener(asyncore.dispatcher):
     "A TCP listener, binding, listening and accepting new connections."
 
     def __init__(self, tt, endpoint):
-        super(Listener, self).__init__(Peer.LISTENER, tt)
-        self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        self.s.bind(endpoint)
-        self.s.listen(0)
-
-    def accept(self):
-        newsock, endpoint = self.s.accept()
-        debug("new client from %s:%s (fd=%d)" %
-              (endpoint[0], endpoint[1], newsock.fileno()))
-        self.tt.add(Sink(self.tt, newsock))
-
-
-class Sink(Peer):
-
+        asyncore.dispatcher.__init__(self)
+        self.create_socket(addr_to_family(endpoint[0]), socket.SOCK_STREAM)
+        self.set_reuse_addr()
+        self.bind(endpoint)
+        self.listen(0)
+        self.tt = tt
+
+    def handle_accept(self):
+        # deprecated in python 3.2
+        pair = self.accept()
+        if pair is not None:
+            newsock, endpoint = pair
+            debug("new client from %s:%s (fd=%d)" %
+                  (endpoint[0], endpoint[1], newsock.fileno()))
+            handler = Sink(newsock, self.tt)
+
+    def fileno(self):
+        return self.socket.fileno()
+
+class Sink(asynchat.async_chat):
     "A data sink, reading from its peer and verifying the data."
-
-    def __init__(self, tt, s):
-        super(Sink, self).__init__(Peer.SINK, tt, s)
-        self.inbuf = b''
-        self.repetitions = self.tt.repetitions
-
-    def on_readable(self):
-        """Invoked when the socket becomes readable.
-        Return 0 on finished, successful verification.
-               -1 on failed verification
-               >0 if more data needs to be read
-        """
-        return self.verify(self.tt.data)
-
-    def verify(self, data):
+    def __init__(self, sock, tt):
+        asynchat.async_chat.__init__(self, sock)
+        self.inbuf = b""
+        self.set_terminator(None)
+        self.tt = tt
+        self.repetitions = tt.repetitions
+
+    def collect_incoming_data(self, inp):
         # shortcut read when we don't ever expect any data
-        if self.repetitions == 0 or len(self.tt.data) == 0:
-            debug("no verification required - no data")
-            return 0
-        inp = self.s.recv(len(data) - len(self.inbuf))
-        debug("Verify: received %d bytes"% len(inp))
-        if len(inp) == 0:
-            debug("EOF on fd %s" % self.fd())
-            return -1
+
         self.inbuf += inp
+        data = self.tt.data
         debug("successfully received (bytes=%d)" % len(self.inbuf))
         while len(self.inbuf) >= len(data):
             assert(len(self.inbuf) <= len(data) or self.repetitions > 1)
             if self.inbuf[:len(data)] != data:
                 debug("receive comparison failed (bytes=%d)" % len(data))
-                return -1  # Failed verification.
+                self.tt.failure()
+                self.close()
             # if we're not debugging, print a dot every dot_repetitions reps
             elif (not debug_flag and self.tt.dot_repetitions > 0 and
                   self.repetitions % self.tt.dot_repetitions == 0):
@@ -176,152 +153,88 @@ class Sink(Peer):
             debug("receive remaining repetitions (reps=%d)" % self.repetitions)
         if self.repetitions == 0 and len(self.inbuf) == 0:
             debug("successful verification")
+            self.close()
+            self.tt.success()
         # calculate the actual length of data remaining, including reps
         debug("receive remaining bytes (bytes=%d)"
               % (self.repetitions*len(data) - len(self.inbuf)))
-        return self.repetitions*len(data) - len(self.inbuf)
 
+    def fileno(self):
+        return self.socket.fileno()
 
-class Source(Peer):
+class CloseSourceProducer:
+    """Helper: when this producer is returned, a source is successful."""
+    def __init__(self, source):
+        self.source = source
 
+    def more(self):
+        self.source.tt.success()
+
+class Source(asynchat.async_chat):
     """A data source, connecting to a TCP server, optionally over a
     SOCKS proxy, sending data."""
-    NOT_CONNECTED = 0
     CONNECTING = 1
     CONNECTING_THROUGH_PROXY = 2
     CONNECTED = 5
 
     def __init__(self, tt, server, buf, proxy=None, repetitions=1):
-        super(Source, self).__init__(Peer.SOURCE, tt)
-        self.state = self.NOT_CONNECTED
+        asynchat.async_chat.__init__(self)
         self.data = buf
         self.outbuf = b''
         self.inbuf = b''
         self.proxy = proxy
+        self.server = server
         self.repetitions = repetitions
         self._sent_no_bytes = 0
+        self.tt = tt
         # sanity checks
         if len(self.data) == 0:
             self.repetitions = 0
         if self.repetitions == 0:
-            self.data = {}
-        self.connect(server)
+            self.data = b""
 
-    def connect(self, endpoint):
-        self.dest = endpoint
+        self.set_terminator(None)
+        dest = (self.proxy or self.server)
+        self.create_socket(addr_to_family(dest[0]), socket.SOCK_STREAM)
+        debug("socket %d connecting to %r..."%(self.fileno(),dest))
         self.state = self.CONNECTING
-        dest = self.proxy or self.dest
-        try:
-            debug("socket %d connecting to %r..."%(self.fd(),dest))
-            self.s.connect(dest)
-        except socket.error as e:
-            if e.errno != errno.EINPROGRESS:
-                raise
-
-    def on_readable(self):
-        """Invoked when the socket becomes readable.
-        Return -1 on failure
-               >0 if more data needs to be read or written
-        """
+        self.connect(dest)
+
+    def handle_connect(self):
+        if self.proxy:
+            self.state = self.CONNECTING_THROUGH_PROXY
+            self.push(socks_cmd(self.server))
+        else:
+            self.state = self.CONNECTED
+            self.push_output()
+
+    def collect_incoming_data(self, data):
+        self.inbuf += data
         if self.state == self.CONNECTING_THROUGH_PROXY:
-            inp = self.s.recv(8 - len(self.inbuf))
-            debug("-- connecting through proxy, got %d bytes"%len(inp))
-            if len(inp) == 0:
-                debug("EOF on fd %d"%self.fd())
-                return -1
-            self.inbuf += inp
-            if len(self.inbuf) == 8:
+            if len(self.inbuf) >= 8:
                 if self.inbuf[:2] == b'\x00\x5a':
-                    debug("proxy handshake successful (fd=%d)" % self.fd())
+                    debug("proxy handshake successful (fd=%d)" % self.fileno())
                     self.state = self.CONNECTED
-                    self.inbuf = b''
-                    debug("successfully connected (fd=%d)" % self.fd())
-                    # if we have no reps or no data, skip sending actual data
-                    if self.want_to_write():
-                        return 1    # Keep us around for writing.
-                    else:
-                        # shortcut write when we don't ever expect any data
-                        debug("no connection required - no data")
-                        return 0
+                    debug("successfully connected (fd=%d)" % self.fileno())
+                    self.inbuf = self.inbuf[8:]
+                    self.push_output()
                 else:
                     debug("proxy handshake failed (0x%x)! (fd=%d)" %
-                          (ord(self.inbuf[1]), self.fd()))
+                          (ord(self.inbuf[1]), self.fileno()))
                     self.state = self.NOT_CONNECTED
-                    return -1
-            assert(8 - len(self.inbuf) > 0)
-            return 8 - len(self.inbuf)
-        return self.want_to_write()  # Keep us around for writing if needed
-
-    def want_to_write(self):
-        if self.state == self.CONNECTING:
-            return True
-        if len(self.outbuf) > 0:
-            return True
-        if (self.state == self.CONNECTED and
-            self.repetitions > 0 and
-            len(self.data) > 0):
-            return True
-        return False
-
-    def on_writable(self):
-        """Invoked when the socket becomes writable.
-        Return 0 when done writing
-               -1 on failure (like connection refused)
-               >0 if more data needs to be written
-        """
-        if self.state == self.CONNECTING:
-            if self.proxy is None:
-                self.state = self.CONNECTED
-                debug("successfully connected (fd=%d)" % self.fd())
-            else:
-                self.state = self.CONNECTING_THROUGH_PROXY
-                self.outbuf = socks_cmd(self.dest)
-                # we write socks_cmd() to the proxy, then read the response
-                # if we get the correct response, we're CONNECTED
-        if self.state == self.CONNECTED:
-            # repeat self.data into self.outbuf if required
-            if (len(self.outbuf) < len(self.data) and self.repetitions > 0):
-                self.outbuf += self.data
-                self.repetitions -= 1
-                debug("adding more data to send (bytes=%d)" % len(self.data))
-                debug("now have data to send (bytes=%d)" % len(self.outbuf))
-                debug("send repetitions remaining (reps=%d)"
-                      % self.repetitions)
-        try:
-            n = self.s.send(self.outbuf)
-        except socket.error as e:
-            if e.errno == errno.ECONNREFUSED:
-                debug("connection refused (fd=%d)" % self.fd())
-                return -1
-            raise
-        # sometimes, this debug statement prints 0
-        # it should print length of the data sent
-        # but the code works as long as this doesn't keep on happening
-        if n > 0:
-            debug("successfully sent (bytes=%d)" % n)
-            self._sent_no_bytes = 0
-        else:
-            debug("BUG: sent no bytes (out of %d; state is %s)"% (len(self.outbuf), self.state))
-            self._sent_no_bytes += 1
-            # We can't retry too fast, otherwise clients burn all their HSDirs
-            if self._sent_no_bytes >= 2:
-                print("Sent no data %d times. Stalled." %
-                      (self._sent_no_bytes))
-                return -1
-            time.sleep(5)
-        self.outbuf = self.outbuf[n:]
-        if self.state == self.CONNECTING_THROUGH_PROXY:
-            return 1  # Keep us around.
-        debug("bytes remaining on outbuf (bytes=%d)" % len(self.outbuf))
-        # calculate the actual length of data remaining, including reps
-        # When 0, we're being removed.
-        debug("bytes remaining overall (bytes=%d)"
-              % (self.repetitions*len(self.data) + len(self.outbuf)))
-        return self.repetitions*len(self.data) + len(self.outbuf)
+                    self.close()
+
+    def push_output(self):
+        for _ in range(self.repetitions):
+            self.push_with_producer(asynchat.simple_producer(self.data))
 
+        self.push_with_producer(CloseSourceProducer(self))
+        self.close_when_done()
 
-class TrafficTester():
+    def fileno(self):
+        return self.socket.fileno()
 
+class TrafficTester(object):
     """
     Hang on select.select() and dispatch to Sources and Sinks.
     Time out after self.timeout seconds.
@@ -332,7 +245,7 @@ class TrafficTester():
 
     def __init__(self,
                  endpoint,
-                 data={},
+                 data=b"",
                  timeout=3,
                  repetitions=1,
                  dot_repetitions=0):
@@ -346,87 +259,44 @@ class TrafficTester():
         if len(self.data) == 0:
             self.repetitions = 0
         if self.repetitions == 0:
-            self.data = {}
+            self.data = b""
         self.dot_repetitions = dot_repetitions
-        debug("listener fd=%d" % self.listener.fd())
-        self.peers = {}  # fd:Peer
+        debug("listener fd=%d" % self.listener.fileno())
 
-    def sinks(self):
-        return self.get_by_ptype(Peer.SINK)
+    def add(self, item):
+        """Register a single item as a test."""
+        # We used to hold on to these items for their fds, but now
+        # asyncore manages them for us.
 
-    def sources(self):
-        return self.get_by_ptype(Peer.SOURCE)
+        self.tests.add()
 
-    def get_by_ptype(self, ptype):
-        return list(filter(lambda p: p.type == ptype, self.peers.values()))
-
-    def add(self, peer):
-        self.peers[peer.fd()] = peer
-        if peer.is_source():
-            self.tests.add()
+    def success(self):
+        """Declare that a single test has passed."""
+        self.tests.success()
 
-    def remove(self, peer):
-        self.peers.pop(peer.fd())
-        self.pending_close.append(peer.s)
+    def failure(self):
+        """Declare that a single test has failed."""
+        self.tests.failure()
 
     def run(self):
-        while not self.tests.all_done() and self.timeout > 0:
-            rset = [self.listener.fd()] + list(self.peers)
-            wset = [p.fd() for p in
-                    filter(lambda x: x.want_to_write(), self.sources())]
-            # debug("rset %s wset %s" % (rset, wset))
-            sets = select.select(rset, wset, [], 1)
-            if all(len(s) == 0 for s in sets):
-                debug("Decrementing timeout.")
-                self.timeout -= 1
-                continue
-
-            for fd in sets[0]:  # readable fd's
-                if fd == self.listener.fd():
-                    self.listener.accept()
-                    continue
-                p = self.peers[fd]
-                n = p.on_readable()
-                debug("On read, fd %d for %s said %d"%(fd, p, n))
-                if n > 0:
-                    # debug("need %d more octets from fd %d" % (n, fd))
-                    pass
-                elif n == 0:  # Success.
-                    self.tests.success()
-                    self.remove(p)
-                else:       # Failure.
-                    debug("Got a failure reading fd %d for %s" % (fd,p))
-                    self.tests.failure()
-                    if p.is_sink():
-                        print("verification failed!")
-                    self.remove(p)
-
-            for fd in sets[1]:  # writable fd's
-                p = self.peers.get(fd)
-                if p is not None:  # Might have been removed above.
-                    n = p.on_writable()
-                    debug("On write, fd %d said %d"%(fd, n))
-                    if n == 0:
-                        self.remove(p)
-                    elif n < 0:
-                        debug("Got a failure writing fd %d for %s" % (fd,p))
-                        self.tests.failure()
-                        self.remove(p)
-
-        for fd in self.peers:
-            peer = self.peers[fd]
-            debug("peer fd=%d never pending close, never read or wrote" % fd)
-            self.pending_close.append(peer.s)
-        self.listener.s.close()
-        for s in self.pending_close:
-            s.close()
+        start = now = time.time()
+        end = time.time() + self.timeout
+        while now < end and not self.tests.all_done():
+            # run only one iteration at a time, with a nice short timeout, so we
+            # can actually detect completion and timeouts.
+            asyncore.loop(0.2, False, None, 1)
+            now = time.time()
+            debug("Test status: %s"%self.tests.status())
+
         if not debug_flag:
             sys.stdout.write('\n')
             sys.stdout.flush()
         debug("Done with run(); all_done == %s and failure_count == %s"
               %(self.tests.all_done(), self.tests.failure_count()))
-        return self.tests.all_done() and self.tests.failure_count() == 0
 
+        self.listener.close()
+
+        return self.tests.all_done() and self.tests.failure_count() == 0
 
 def main():
     """Test the TrafficTester by sending and receiving some data."""