Browse Source

Don't send nacl objects over the wire where reasonable

They pickle unnecessarily large
Ian Goldberg 4 years ago
parent
commit
2b0261df02
2 changed files with 27 additions and 13 deletions
  1. 4 3
      dirauth.py
  2. 23 10
      relay.py

+ 4 - 3
dirauth.py

@@ -30,7 +30,7 @@ class RelayDescriptor:
                     "vrfkey", "sig"]:
             if k in self.descdict:
                 if k == "idkey" or k == "onionkey":
-                    res += "  " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
+                    res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
                 elif k == "sig":
                     if withsig:
                         res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
@@ -50,7 +50,8 @@ class RelayDescriptor:
         assert(type(desc) is RelayDescriptor)
         serialized = desc.__str__(False)
         perfstats.verifs += 1
-        desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
+        idkey = nacl.signing.VerifyKey(desc.descdict["idkey"])
+        idkey.verify(serialized.encode("ascii"), desc.descdict["sig"])
 
 
 # A SNIP is a dict containing:
@@ -80,7 +81,7 @@ class SNIP:
                     "vrfkey", "range", "auth"]:
             if k in self.snipdict:
                 if k == "idkey" or k == "onionkey":
-                    res += "  " + k + ": " + self.snipdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
+                    res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
                 elif k == "auth":
                     if withauth:
                         if network.thenetwork.snipauthmode == \

+ 23 - 10
relay.py

@@ -343,7 +343,7 @@ class NTor:
         """Create the ntor request message: X = g^x."""
         self.client_ephem_key = nacl.public.PrivateKey.generate()
         self.perfstats.keygens += 1
-        return self.client_ephem_key.public_key
+        return bytes(self.client_ephem_key.public_key)
 
     @staticmethod
     def reply(onion_privkey, idpubkey, client_pubkey, perfstats,
@@ -353,13 +353,17 @@ class NTor:
         secret S = H(M, "secret") for M = (X^y,X^b,ID,B,X,Y). If
         sphinx_domainsep is not None, also compute and return the Sphinx
         reblinded client request to pass to the next server."""
+        if type(idpubkey) is not bytes:
+            idpubkey = bytes(idpubkey)
+        if type(client_pubkey) is bytes:
+            client_pubkey = nacl.public.PublicKey(client_pubkey)
         server_ephem_key = nacl.public.PrivateKey.generate()
         perfstats.keygens += 1
         xykey = nacl.public.Box(server_ephem_key, client_pubkey).shared_key()
         xbkey = nacl.public.Box(onion_privkey, client_pubkey).shared_key()
         perfstats.dhs += 2
         M = xykey + xbkey + \
-                idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \
+                idpubkey + \
                 onion_privkey.public_key.encode(encoder=nacl.encoding.RawEncoder) + \
                 server_ephem_key.public_key.encode(encoder=nacl.encoding.RawEncoder)
         A = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
@@ -369,11 +373,12 @@ class NTor:
             blindkey = Sphinx.makeblindkey(S, sphinx_domainsep, perfstats)
             blinded_client_pubkey = Sphinx.reblindpubkey(blindkey,
                     client_pubkey, perfstats)
-            return ((server_ephem_key.public_key, onion_privkey.public_key, A),
+            return ((bytes(server_ephem_key.public_key),
+                    bytes(onion_privkey.public_key), A),
                     S), blinded_client_pubkey
         else:
-            return ((server_ephem_key.public_key, onion_privkey.public_key, A),
-                    S)
+            return ((bytes(server_ephem_key.public_key),
+                    bytes(onion_privkey.public_key), A), S)
 
     def verify(self, reply, onion_pubkey, idpubkey, sphinx_domainsep=None):
         """The client calls this method to verify the ntor reply
@@ -383,6 +388,14 @@ class NTor:
         reuse this same NTor object for the next server.  Returns the
         shared secret on success, or raises ValueError on failure."""
         server_ephem_pubkey, server_onion_pubkey, authtag = reply
+        if type(idpubkey) is not bytes:
+            idpubkey = bytes(idpubkey)
+        if type(server_ephem_pubkey) is bytes:
+            server_ephem_pubkey = nacl.public.PublicKey(server_ephem_pubkey)
+        if type(server_onion_pubkey) is bytes:
+            server_onion_pubkey = nacl.public.PublicKey(server_onion_pubkey)
+        if type(onion_pubkey) is bytes:
+            onion_pubkey = nacl.public.PublicKey(onion_pubkey)
         if onion_pubkey != server_onion_pubkey:
             raise ValueError("NTor onion pubkey mismatch")
         # We use the blinding keys if present; if they're not present
@@ -403,7 +416,7 @@ class NTor:
                 reblinded_onion_pubkey).shared_key()
         self.perfstats.dhs += 2
         M = xykey + xbkey + \
-                idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \
+                idpubkey + \
                 onion_pubkey.encode(encoder=nacl.encoding.RawEncoder) + \
                 server_ephem_pubkey.encode(encoder=nacl.encoding.RawEncoder)
         Acheck = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
@@ -456,7 +469,7 @@ class TelescopingExtendCircuitHandler:
 
     def __init__(self, relaypicker, current_relay_idkey):
         self.relaypicker = relaypicker
-        self.current_relay_idkey = current_relay_idkey
+        self.current_relay_idkey = bytes(current_relay_idkey)
 
     def received_cell(self, circhandler, cell):
         # Remove ourselves from handling a second
@@ -980,7 +993,7 @@ class RelayChannelManager(ChannelManager):
             if next_hop == None:
                 logging.debug("Client requested extending the circuit to a relay index that results in None, aborting. my circid: %s", str(circhandler.circid))
                 circhandler.close()
-            elif next_hop.snipdict["idkey"] == self.idpubkey or next_hop.snipdict["addr"] == peeraddr:
+            elif next_hop.snipdict["idkey"] == bytes(self.idpubkey) or next_hop.snipdict["addr"] == peeraddr:
                 logging.debug("Client requested extending the circuit to a relay already in the path; aborting. my circid: %s", str(circhandler.circid))
                 circhandler.close()
 
@@ -1092,8 +1105,8 @@ class Relay(network.Server):
         # previous upload if upload=False
         descdict = dict();
         descdict["epoch"] = network.thenetwork.getepoch() + 1
-        descdict["idkey"] = self.idkey.verify_key
-        descdict["onionkey"] = self.onionkey.public_key
+        descdict["idkey"] = bytes(self.idkey.verify_key)
+        descdict["onionkey"] = bytes(self.onionkey.public_key)
         descdict["addr"] = self.netaddr
         descdict["bw"] = self.bw
         descdict["flags"] = self.flags