#!/usr/bin/python
# Copyright 2005-2006 Nick Mathewson
# See the LICENSE file in the Tor distribution for licensing information.

# Requires Python 2.2 or later.

"""
 exitlist -- Given a Tor directory on stdin, lists the Tor servers
 that accept connections to given addreses.

 example usage:

    cat ~/.tor/cached-descriptors* | python exitlist 18.244.0.188:80

 You should look at the "FetchUselessDescriptors" and "FetchDirInfoEarly"
 config options in the man page.

 Note that this script won't give you a perfect list of IP addresses
 that might connect to you using Tor.
 False negatives:
   - Some Tor servers might exit from other addresses than the one they
     publish in their descriptor.
 False positives:
   - This script just looks at the descriptor lists, so it counts relays
     that were running a day in the past and aren't running now (or are
     now running at a different address).

 See https://check.torproject.org/ for an alternative (more accurate!)
  approach.

"""

#
# Change this to True if you want more verbose output.  By default, we
# only print the IPs of the servers that accept any the listed
# addresses, one per line.
#
VERBOSE = False

#
# Change this to True if you want to reverse the output, and list the
# servers that accept *none* of the listed addresses.
#
INVERSE = False

#
# Change this list to contain all of the target services you are interested
# in.  It must contain one entry per line, each consisting of an IPv4 address,
# a colon, and a port number. This default is only used if we don't learn
# about any addresses from the command-line.
#
ADDRESSES_OF_INTEREST = """
    1.2.3.4:80
"""


#
# YOU DO NOT NEED TO EDIT AFTER THIS POINT.
#

import sys
import re
import getopt
import socket
import struct
import time

assert sys.version_info >= (2,2)


def maskIP(ip,mask):
    return "".join([chr(ord(a) & ord(b)) for a,b in zip(ip,mask)])

def maskFromLong(lng):
    return struct.pack("!L", lng)

def maskByBits(n):
    return maskFromLong(0xffffffffl ^ ((1L<<(32-n))-1))

class Pattern:
    """
    >>> import socket
    >>> ip1 = socket.inet_aton("192.169.64.11")
    >>> ip2 = socket.inet_aton("192.168.64.11")
    >>> ip3 = socket.inet_aton("18.244.0.188")

    >>> print Pattern.parse("18.244.0.188")
    18.244.0.188/255.255.255.255:1-65535
    >>> print Pattern.parse("18.244.0.188/16:*")
    18.244.0.0/255.255.0.0:1-65535
    >>> print Pattern.parse("18.244.0.188/2.2.2.2:80")
    2.0.0.0/2.2.2.2:80-80
    >>> print Pattern.parse("192.168.0.1/255.255.00.0:22-25")
    192.168.0.0/255.255.0.0:22-25
    >>> p1 = Pattern.parse("192.168.0.1/255.255.00.0:22-25")
    >>> import socket
    >>> p1.appliesTo(ip1, 22)
    False
    >>> p1.appliesTo(ip2, 22)
    True
    >>> p1.appliesTo(ip2, 25)
    True
    >>> p1.appliesTo(ip2, 26)
    False
    """
    def __init__(self, ip, mask, portMin, portMax):
        self.ip = maskIP(ip,mask)
        self.mask = mask
        self.portMin = portMin
        self.portMax = portMax

    def __str__(self):
        return "%s/%s:%s-%s"%(socket.inet_ntoa(self.ip),
                              socket.inet_ntoa(self.mask),
                              self.portMin,
                              self.portMax)

    def parse(s):
        if ":" in s:
            addrspec, portspec = s.split(":",1)
        else:
            addrspec, portspec = s, "*"

        if addrspec == '*':
            ip,mask = "\x00\x00\x00\x00","\x00\x00\x00\x00"
        elif '/' not in addrspec:
            ip = socket.inet_aton(addrspec)
            mask = "\xff\xff\xff\xff"
        else:
            ip,mask = addrspec.split("/",1)
            ip = socket.inet_aton(ip)
            if "." in mask:
                mask = socket.inet_aton(mask)
            else:
                mask = maskByBits(int(mask))

        if portspec == '*':
            portMin = 1
            portMax = 65535
        elif '-' not in portspec:
            portMin = portMax = int(portspec)
        else:
            portMin, portMax = map(int,portspec.split("-",1))

        return Pattern(ip,mask,portMin,portMax)

    parse = staticmethod(parse)

    def appliesTo(self, ip, port):
        return ((maskIP(ip,self.mask) == self.ip) and
                (self.portMin <= port <= self.portMax))

class Policy:
    """
    >>> import socket
    >>> ip1 = socket.inet_aton("192.169.64.11")
    >>> ip2 = socket.inet_aton("192.168.64.11")
    >>> ip3 = socket.inet_aton("18.244.0.188")

    >>> pol = Policy.parseLines(["reject *:80","accept 18.244.0.188:*"])
    >>> print str(pol).strip()
    reject 0.0.0.0/0.0.0.0:80-80
    accept 18.244.0.188/255.255.255.255:1-65535
    >>> pol.accepts(ip1,80)
    False
    >>> pol.accepts(ip3,80)
    False
    >>> pol.accepts(ip3,81)
    True
    """

    def __init__(self, lst):
        self.lst = lst

    def parseLines(lines):
        r = []
        for item in lines:
            a,p=item.split(" ",1)
            if a == 'accept':
                a = True
            elif a == 'reject':
                a = False
            else:
                raise ValueError("Unrecognized action %r",a)
            p = Pattern.parse(p)
            r.append((p,a))
        return Policy(r)

    parseLines = staticmethod(parseLines)

    def __str__(self):
        r = []
        for pat, accept in self.lst:
            rule = accept and "accept" or "reject"
            r.append("%s %s\n"%(rule,pat))
        return "".join(r)

    def accepts(self, ip, port):
        for pattern,accept in self.lst:
            if pattern.appliesTo(ip,port):
                return accept
        return True

class Server:
    def __init__(self, name, ip, policy, published, fingerprint):
        self.name = name
        self.ip = ip
        self.policy = policy
        self.published = published
        self.fingerprint = fingerprint

def uniq_sort(lst):
    d = {}
    for item in lst: d[item] = 1
    lst = d.keys()
    lst.sort()
    return lst

def run():
    global VERBOSE
    global INVERSE
    global ADDRESSES_OF_INTEREST

    if len(sys.argv) > 1:
        try:
            opts, pargs = getopt.getopt(sys.argv[1:], "vx")
        except getopt.GetoptError, e:
            print """
usage: cat ~/.tor/cached-routers* | %s [-v] [-x] [host:port [host:port [...]]]
    -v  verbose output
    -x  invert results
""" % sys.argv[0]
            sys.exit(0)

        for o, a in opts:
            if o == "-v":
                VERBOSE = True
            if o == "-x":
                INVERSE = True
        if len(pargs):
            ADDRESSES_OF_INTEREST = "\n".join(pargs)

    servers = []
    policy = []
    name = ip = None
    published = 0
    fp = ""
    for line in sys.stdin.xreadlines():
        if line.startswith('router '):
            if name:
                servers.append(Server(name, ip, Policy.parseLines(policy),
                                      published, fp))
            _, name, ip, rest = line.split(" ", 3)
            policy = []
            published = 0
            fp = ""
        elif line.startswith('fingerprint') or \
                 line.startswith('opt fingerprint'):
            elts = line.strip().split()
            if elts[0] == 'opt': del elts[0]
            assert elts[0] == 'fingerprint'
            del elts[0]
            fp = "".join(elts)
        elif line.startswith('accept ') or line.startswith('reject '):
            policy.append(line.strip())
        elif line.startswith('published '):
            date = time.strptime(line[len('published '):].strip(),
                                 "%Y-%m-%d %H:%M:%S")
            published = time.mktime(date)

    if name:
        servers.append(Server(name, ip, Policy.parseLines(policy), published,
                              fp))

    targets = []
    for line in ADDRESSES_OF_INTEREST.split("\n"):
        line = line.strip()
        if not line: continue
        p = Pattern.parse(line)
        targets.append((p.ip, p.portMin))

    # remove all but the latest server of each IP/Nickname pair.
    latest = {}
    for s in servers:
        if (not latest.has_key((s.fingerprint))
            or s.published > latest[(s.fingerprint)]):
            latest[s.fingerprint] = s
    servers = latest.values()

    accepters, rejecters = {}, {}
    for s in servers:
        for ip,port in targets:
            if s.policy.accepts(ip,port):
                accepters[s.ip] = s
                break
        else:
            rejecters[s.ip] = s

    # If any server at IP foo accepts, the IP does not reject.
    for k in accepters.keys():
        if rejecters.has_key(k):
            del rejecters[k]

    if INVERSE:
        printlist = rejecters.values()
    else:
        printlist = accepters.values()

    ents = []
    if VERBOSE:
        ents = uniq_sort([ "%s\t%s"%(s.ip,s.name) for s in printlist ])
    else:
        ents = uniq_sort([ s.ip for s in printlist ])
    for e in ents:
        print e

def _test():
    import doctest, exitparse
    return doctest.testmod(exitparse)
#_test()

run()