瀏覽代碼

Start creating the circuit

The first CREATE/CREATED pair (to the guard) is working, and both sides
end up with the same shared secret.
Ian Goldberg 5 年之前
父節點
當前提交
16793ed8fa
共有 2 個文件被更改,包括 127 次插入34 次删除
  1. 76 10
      client.py
  2. 51 24
      relay.py

+ 76 - 10
client.py

@@ -7,12 +7,28 @@ import network
 import dirauth
 import relay
 
+class CreatedHandler:
+    """A handler for VanillaCreatedCircuitMsg cells."""
+
+    def __init__(self, ntor, expecteddesc):
+        self.ntor = ntor
+        self.expecteddesc = expecteddesc
+        self.onionkey = expecteddesc.descdict['onionkey']
+        self.idkey = expecteddesc.descdict['idkey']
+
+    def received_cell(self, circhandler, cell):
+        secret = self.ntor.verify(cell.ntor_reply, self.onionkey, self.idkey)
+        circhandler.circuit_descs.append(self.expecteddesc)
+        print('client secret=',secret)
+
+
 class CellClient(relay.CellHandler):
     """The subclass of CellHandler for clients."""
 
     def __init__(self, myaddr, dirauthaddrs, perfstats):
         super().__init__(myaddr, dirauthaddrs, perfstats)
         self.guardaddr = None
+        self.guard = None
         if network.thenetwork.womode == network.WOMode.VANILLA:
             self.consensus_cdf = []
 
@@ -27,8 +43,8 @@ class CellClient(relay.CellHandler):
         while True:
             if self.guardaddr is None:
                 # Pick a guard from the consensus
-                guard = self.consensus.select_weighted_relay(self.consensus_cdf)
-                self.guardaddr = guard.descdict['addr']
+                self.guard = self.consensus.select_weighted_relay(self.consensus_cdf)
+                self.guardaddr = self.guard.descdict['addr']
 
             # Connect to the guard
             try:
@@ -36,6 +52,7 @@ class CellClient(relay.CellHandler):
             except network.NetNoServer:
                 # Our guard is gone
                 self.guardaddr = None
+                self.guard = None
 
             if self.guardaddr is not None:
                 break
@@ -51,28 +68,62 @@ class CellClient(relay.CellHandler):
         """Create a new circuit from this client. (Vanilla Onion Routing
         version)"""
 
+        # Get our channel to the guard
+        guardchannel = self.get_channel_to(self.guardaddr)
+
+        # Allocate a new circuit id on it
+        circid, circhandler = guardchannel.new_circuit()
+
+        # Construct the VanillaCreateCircuitMsg
+        ntor = relay.NTor(self.perfstats)
+        ntor_request = ntor.request()
+        circcreatemsg = relay.VanillaCreateCircuitMsg(circid, ntor_request)
+
+        # Set up the reply handler
+        circhandler.cell_dispatch_table[relay.VanillaCreatedCircuitMsg] = \
+                CreatedHandler(ntor, self.guard)
+
+        # Send the message
+        guardchannel.send_msg(circcreatemsg)
+
+        # We have a guard already at this point, so choose a middle and
+        # an exit.  They must all be different.
+        middle = None
+        while middle is None:
+            middle = self.consensus.select_weighted_relay(self.consensus_cdf)
+            if middle.descdict['addr'] == self.guardaddr:
+                middle = None
+
+        exit = None
+        while exit is None:
+            exit = self.consensus.select_weighted_relay(self.consensus_cdf)
+            if exit.descdict['addr'] == self.guardaddr or \
+                    exit.descdict['addr'] == middle.descdict['addr']:
+                middle = None
+
+
     def new_circuit(self):
         """Create a new circuit from this client."""
         if network.thenetwork.womode == network.WOMode.VANILLA:
             self.new_circuit_vanilla()
 
-    def received_msg(self, msg, peeraddr, peer):
+    def received_msg(self, msg, peeraddr, channel):
         """Callback when a NetMsg not specific to a circuit is
         received."""
         print("Client %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
         if isinstance(msg, relay.RelayConsensusMsg):
-            self.consensus = msg.consensus
-            dirauth.Consensus.verify(self.consensus, \
+            dirauth.Consensus.verify(msg.consensus, \
                     network.thenetwork.dirauthkeys(), self.perfstats)
+            self.consensus = msg.consensus
             if network.thenetwork.womode == network.WOMode.VANILLA:
                 self.consensus_cdf = self.consensus.bw_cdf()
         else:
-            return super().received_msg(msg, peeraddr, peer)
+            return super().received_msg(msg, peeraddr, channel)
 
-    def received_cell(self, circid, cell, peeraddr, peer):
+    def received_cell(self, circid, cell, peeraddr, channel):
         """Callback with a circuit-specific cell is received."""
         print("Client %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
-        return super().received_cell(circid, cell, peeraddr, peer)
+        return super().received_cell(circid, cell, peeraddr, channel)
 
 
 class Client:
@@ -216,9 +267,22 @@ if __name__ == '__main__':
     # Tick the epoch
     network.thenetwork.nextepoch()
 
+    clients[0].cellhandler.new_circuit()
+
+    # See what channels exist and do a consistency check
+    for r in relays:
+        print("%s: %s" % (r.netaddr, [ str(k) + str([ck for ck in r.cellhandler.channels[k].circuithandlers.keys()]) for k in r.cellhandler.channels.keys()]))
+        raddr = r.netaddr
+        for ad, ch in r.cellhandler.channels.items():
+            if ch.peer.cellhandler.myaddr != ad:
+                print('address mismatch:', raddr, ad, ch.peer.cellhandler.myaddr)
+
+            if ch.peer.cellhandler.channels[raddr].peer is not ch:
+                print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer)
+
     # See what channels exist and do a consistency check
     for c in clients:
-        print("%s: %s" % (c.netaddr, [ str(k) for k in c.cellhandler.channels.keys()]))
+        print("%s: %s" % (c.netaddr, [ str(k) + str([ck for ck in c.cellhandler.channels[k].circuithandlers.keys()]) for k in c.cellhandler.channels.keys()]))
         caddr = c.netaddr
         for ad, ch in c.cellhandler.channels.items():
             if ch.peer.cellhandler.myaddr != ad:
@@ -227,7 +291,9 @@ if __name__ == '__main__':
             if ch.peer.cellhandler.channels[caddr].peer is not ch:
                 print('asymmetry:', caddr, ad, ch, ch.peer.cellhandler.channels[caddr].peer)
 
-    clients[0].cellhandler.new_circuit()
+            if ch.circuithandlers.keys() != \
+                    ch.peer.cellhandler.channels[caddr].circuithandlers.keys():
+                print('circuit asymmetry:', caddr, ad, ch.peer.cellhandler.myaddr)
 
     for d in dirauths:
         print(d.perfstats)

+ 51 - 24
relay.py

@@ -51,9 +51,8 @@ class VanillaCreatedCircuitMsg(RelayNetMsg):
     """The message for responding to circuit creation in Vanilla Onion
     Routing."""
 
-    def __init__(self, circid, ntor_response):
-        self.circid = circid
-        self.ntor_response = ntor_response
+    def __init__(self, ntor_reply):
+        self.ntor_reply = ntor_reply
 
 
 class CircuitCellMsg(RelayNetMsg):
@@ -65,6 +64,10 @@ class CircuitCellMsg(RelayNetMsg):
     def __str__(self):
         return "C%d:%s" % (self.circid, self.cell)
 
+    def size(self):
+        # circuitids are 4 bytes
+        return 4 + self.cell.size()
+
 
 class RelayFallbackTerminationError(Exception):
     """An exception raised when someone tries to terminate a fallback
@@ -134,16 +137,23 @@ class CircuitHandler:
         self.channel = channel
         self.circid = circid
         self.send_cell = self.channel_send_cell
-        self.received_cell = self.channel_received_cell
+        # The list of relay descriptors that form the circuit so far
+        # (client side only)
+        self.circuit_descs = []
+        # The dispatch table is indexed by type, and the values are
+        # objects with received_cell(circhandler, cell) methods.
+        self.cell_dispatch_table = dict()
 
     def channel_send_cell(self, cell):
         """Send a cell on this circuit."""
         self.channel.send_msg(CircuitCellMsg(self.circid, cell))
 
-    def channel_received_cell(self, cell, peeraddr, peer):
-        """A cell has been received on this circuit.  Forward it to the
-        channel's received_cell callback."""
-        self.channel.cellhandler.received_cell(self.circid, cell, peeraddr, peer)
+    def received_cell(self, cell):
+        """A cell has been received on this circuit.  Dispatch it
+        according to its type."""
+        celltype = type(cell)
+        if celltype in self.cell_dispatch_table:
+            self.cell_dispatch_table[celltype].received_cell(self, cell)
 
 
 class Channel(network.Connection):
@@ -176,16 +186,19 @@ class Channel(network.Connection):
 
     def new_circuit(self):
         """Allocate a new circuit on this channel, returning the new
-        circuit's id."""
+        circuit's id and the new CircuitHandler."""
         circid = self.next_circid
         self.next_circid += 2
-        self.circuithandlers[circid] = CircuitHandler(self, circid)
-        return circid
+        circuithandler = CircuitHandler(self, circid)
+        self.circuithandlers[circid] = circuithandler
+        return circid, circuithandler
 
     def new_circuit_with_circid(self, circid):
         """Allocate a new circuit on this channel, with the circuit id
-        received from our peer."""
-        self.circuithandlers[circid] = CircuitHandler(self, circid)
+        received from our peer.  Return the new CircuitHandler"""
+        circuithandler = CircuitHandler(self, circid)
+        self.circuithandlers[circid] = circuithandler
+        return circuithandler
 
     def send_cell(self, circid, cell):
         """Send the given message on the given circuit, encrypting or
@@ -207,9 +220,9 @@ class Channel(network.Connection):
         self.cellhandler.perfstats.bytes_received += msg.size()
         if isinstance(msg, CircuitCellMsg):
             circid, cell = msg.circid, msg.cell
-            self.circuithandlers[circid].received_cell(cell, peeraddr, self.peer)
+            self.circuithandlers[circid].received_cell(cell)
         else:
-            self.cellhandler.received_msg(msg, peeraddr, self.peer)
+            self.cellhandler.received_msg(msg, peeraddr, self)
 
 
 class CellHandler:
@@ -260,12 +273,12 @@ class CellHandler:
 
         return newchannel
 
-    def received_msg(self, msg, peeraddr, peer):
+    def received_msg(self, msg, peeraddr, channel):
         """Callback when a NetMsg not specific to a circuit is
         received."""
         print("CellHandler: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
 
-    def received_cell(self, circid, cell, peeraddr, peer):
+    def received_cell(self, circid, cell, peeraddr, channel):
         """Callback with a circuit-specific cell is received."""
         print("CellHandler: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
 
@@ -284,8 +297,10 @@ class CellHandler:
 class CellRelay(CellHandler):
     """The subclass of CellHandler for relays."""
 
-    def __init__(self, myaddr, dirauthaddrs, perfstats):
+    def __init__(self, myaddr, dirauthaddrs, onionprivkey, idpubkey, perfstats):
         super().__init__(myaddr, dirauthaddrs, perfstats)
+        self.onionkey = onionprivkey
+        self.idpubkey = idpubkey
 
     def get_consensus(self):
         """Download a fresh consensus from a random dirauth."""
@@ -296,7 +311,7 @@ class CellRelay(CellHandler):
                 network.thenetwork.dirauthkeys(), self.perfstats)
         c.close()
 
-    def received_msg(self, msg, peeraddr, peer):
+    def received_msg(self, msg, peeraddr, channel):
         """Callback when a NetMsg not specific to a circuit is
         received."""
         print("CellRelay: Node %s received msg %s from %s" % (self.myaddr, msg, peeraddr))
@@ -308,13 +323,24 @@ class CellRelay(CellHandler):
                 self.send_msg(RelayRandomHopMsg(msg.ttl-1), nextaddr)
         elif isinstance(msg, RelayGetConsensusMsg):
             self.send_msg(RelayConsensusMsg(self.consensus), peeraddr)
+        elif isinstance(msg, VanillaCreateCircuitMsg):
+            # A new circuit has arrived
+            circhandler = channel.new_circuit_with_circid(msg.circid)
+            # Create the ntor reply
+            reply, secret = NTor.reply(self.onionkey, self.idpubkey, \
+                    msg.ntor_request, self.perfstats)
+            # Set up the circuit to use the shared secret
+            # TODO
+            print('relay secret=', secret)
+            # Send the ntor reply
+            self.send_msg(CircuitCellMsg(msg.circid, VanillaCreatedCircuitMsg(reply)), peeraddr)
         else:
-            return super().received_msg(msg, peeraddr, peer)
+            return super().received_msg(msg, peeraddr, channel)
 
-    def received_cell(self, circid, cell, peeraddr, peer):
+    def received_cell(self, circid, cell, peeraddr, channel):
         """Callback with a circuit-specific cell is received."""
         print("CellRelay: Node %s received cell on circ %d: %s from %s" % (self.myaddr, circid, cell, peeraddr))
-        return super().received_cell(circid, cell, peeraddr, peer)
+        return super().received_cell(circid, cell, peeraddr, channel)
 
 
 class Relay(network.Server):
@@ -344,7 +370,8 @@ class Relay(network.Server):
         network.thenetwork.wantepochticks(self, True)
 
         # Create the CellRelay connection manager
-        self.cellhandler = CellRelay(self.netaddr, dirauthaddrs, self.perfstats)
+        self.cellhandler = CellRelay(self.netaddr, dirauthaddrs, \
+                self.onionkey, self.idkey.verify_key, self.perfstats)
 
         # Initially, we're not a fallback relay
         self.is_fallbackrelay = False
@@ -515,7 +542,7 @@ if __name__ == '__main__':
                 print('asymmetry:', raddr, ad, ch, ch.peer.cellhandler.channels[raddr].peer)
 
     channel = relays[3].cellhandler.get_channel_to(relays[5].netaddr)
-    circid = channel.new_circuit()
+    circid, circhandler = channel.new_circuit()
     peerchannel = relays[5].cellhandler.get_channel_to(relays[3].netaddr)
     peerchannel.new_circuit_with_circid(circid)
     relays[3].cellhandler.send_cell(circid, network.StringNetMsg("test"), relays[5].netaddr)