Browse Source

Don't send nacl objects over the wire where reasonable

They pickle unnecessarily large
Ian Goldberg 4 years ago
parent
commit
2b0261df02
2 changed files with 27 additions and 13 deletions
  1. 4 3
      dirauth.py
  2. 23 10
      relay.py

+ 4 - 3
dirauth.py

@@ -30,7 +30,7 @@ class RelayDescriptor:
                     "vrfkey", "sig"]:
                     "vrfkey", "sig"]:
             if k in self.descdict:
             if k in self.descdict:
                 if k == "idkey" or k == "onionkey":
                 if k == "idkey" or k == "onionkey":
-                    res += "  " + k + ": " + self.descdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
+                    res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
                 elif k == "sig":
                 elif k == "sig":
                     if withsig:
                     if withsig:
                         res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
                         res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
@@ -50,7 +50,8 @@ class RelayDescriptor:
         assert(type(desc) is RelayDescriptor)
         assert(type(desc) is RelayDescriptor)
         serialized = desc.__str__(False)
         serialized = desc.__str__(False)
         perfstats.verifs += 1
         perfstats.verifs += 1
-        desc.descdict["idkey"].verify(serialized.encode("ascii"), desc.descdict["sig"])
+        idkey = nacl.signing.VerifyKey(desc.descdict["idkey"])
+        idkey.verify(serialized.encode("ascii"), desc.descdict["sig"])
 
 
 
 
 # A SNIP is a dict containing:
 # A SNIP is a dict containing:
@@ -80,7 +81,7 @@ class SNIP:
                     "vrfkey", "range", "auth"]:
                     "vrfkey", "range", "auth"]:
             if k in self.snipdict:
             if k in self.snipdict:
                 if k == "idkey" or k == "onionkey":
                 if k == "idkey" or k == "onionkey":
-                    res += "  " + k + ": " + self.snipdict[k].encode(encoder=nacl.encoding.HexEncoder).decode("ascii") + "\n"
+                    res += "  " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
                 elif k == "auth":
                 elif k == "auth":
                     if withauth:
                     if withauth:
                         if network.thenetwork.snipauthmode == \
                         if network.thenetwork.snipauthmode == \

+ 23 - 10
relay.py

@@ -343,7 +343,7 @@ class NTor:
         """Create the ntor request message: X = g^x."""
         """Create the ntor request message: X = g^x."""
         self.client_ephem_key = nacl.public.PrivateKey.generate()
         self.client_ephem_key = nacl.public.PrivateKey.generate()
         self.perfstats.keygens += 1
         self.perfstats.keygens += 1
-        return self.client_ephem_key.public_key
+        return bytes(self.client_ephem_key.public_key)
 
 
     @staticmethod
     @staticmethod
     def reply(onion_privkey, idpubkey, client_pubkey, perfstats,
     def reply(onion_privkey, idpubkey, client_pubkey, perfstats,
@@ -353,13 +353,17 @@ class NTor:
         secret S = H(M, "secret") for M = (X^y,X^b,ID,B,X,Y). If
         secret S = H(M, "secret") for M = (X^y,X^b,ID,B,X,Y). If
         sphinx_domainsep is not None, also compute and return the Sphinx
         sphinx_domainsep is not None, also compute and return the Sphinx
         reblinded client request to pass to the next server."""
         reblinded client request to pass to the next server."""
+        if type(idpubkey) is not bytes:
+            idpubkey = bytes(idpubkey)
+        if type(client_pubkey) is bytes:
+            client_pubkey = nacl.public.PublicKey(client_pubkey)
         server_ephem_key = nacl.public.PrivateKey.generate()
         server_ephem_key = nacl.public.PrivateKey.generate()
         perfstats.keygens += 1
         perfstats.keygens += 1
         xykey = nacl.public.Box(server_ephem_key, client_pubkey).shared_key()
         xykey = nacl.public.Box(server_ephem_key, client_pubkey).shared_key()
         xbkey = nacl.public.Box(onion_privkey, client_pubkey).shared_key()
         xbkey = nacl.public.Box(onion_privkey, client_pubkey).shared_key()
         perfstats.dhs += 2
         perfstats.dhs += 2
         M = xykey + xbkey + \
         M = xykey + xbkey + \
-                idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \
+                idpubkey + \
                 onion_privkey.public_key.encode(encoder=nacl.encoding.RawEncoder) + \
                 onion_privkey.public_key.encode(encoder=nacl.encoding.RawEncoder) + \
                 server_ephem_key.public_key.encode(encoder=nacl.encoding.RawEncoder)
                 server_ephem_key.public_key.encode(encoder=nacl.encoding.RawEncoder)
         A = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
         A = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
@@ -369,11 +373,12 @@ class NTor:
             blindkey = Sphinx.makeblindkey(S, sphinx_domainsep, perfstats)
             blindkey = Sphinx.makeblindkey(S, sphinx_domainsep, perfstats)
             blinded_client_pubkey = Sphinx.reblindpubkey(blindkey,
             blinded_client_pubkey = Sphinx.reblindpubkey(blindkey,
                     client_pubkey, perfstats)
                     client_pubkey, perfstats)
-            return ((server_ephem_key.public_key, onion_privkey.public_key, A),
+            return ((bytes(server_ephem_key.public_key),
+                    bytes(onion_privkey.public_key), A),
                     S), blinded_client_pubkey
                     S), blinded_client_pubkey
         else:
         else:
-            return ((server_ephem_key.public_key, onion_privkey.public_key, A),
-                    S)
+            return ((bytes(server_ephem_key.public_key),
+                    bytes(onion_privkey.public_key), A), S)
 
 
     def verify(self, reply, onion_pubkey, idpubkey, sphinx_domainsep=None):
     def verify(self, reply, onion_pubkey, idpubkey, sphinx_domainsep=None):
         """The client calls this method to verify the ntor reply
         """The client calls this method to verify the ntor reply
@@ -383,6 +388,14 @@ class NTor:
         reuse this same NTor object for the next server.  Returns the
         reuse this same NTor object for the next server.  Returns the
         shared secret on success, or raises ValueError on failure."""
         shared secret on success, or raises ValueError on failure."""
         server_ephem_pubkey, server_onion_pubkey, authtag = reply
         server_ephem_pubkey, server_onion_pubkey, authtag = reply
+        if type(idpubkey) is not bytes:
+            idpubkey = bytes(idpubkey)
+        if type(server_ephem_pubkey) is bytes:
+            server_ephem_pubkey = nacl.public.PublicKey(server_ephem_pubkey)
+        if type(server_onion_pubkey) is bytes:
+            server_onion_pubkey = nacl.public.PublicKey(server_onion_pubkey)
+        if type(onion_pubkey) is bytes:
+            onion_pubkey = nacl.public.PublicKey(onion_pubkey)
         if onion_pubkey != server_onion_pubkey:
         if onion_pubkey != server_onion_pubkey:
             raise ValueError("NTor onion pubkey mismatch")
             raise ValueError("NTor onion pubkey mismatch")
         # We use the blinding keys if present; if they're not present
         # We use the blinding keys if present; if they're not present
@@ -403,7 +416,7 @@ class NTor:
                 reblinded_onion_pubkey).shared_key()
                 reblinded_onion_pubkey).shared_key()
         self.perfstats.dhs += 2
         self.perfstats.dhs += 2
         M = xykey + xbkey + \
         M = xykey + xbkey + \
-                idpubkey.encode(encoder=nacl.encoding.RawEncoder) + \
+                idpubkey + \
                 onion_pubkey.encode(encoder=nacl.encoding.RawEncoder) + \
                 onion_pubkey.encode(encoder=nacl.encoding.RawEncoder) + \
                 server_ephem_pubkey.encode(encoder=nacl.encoding.RawEncoder)
                 server_ephem_pubkey.encode(encoder=nacl.encoding.RawEncoder)
         Acheck = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
         Acheck = nacl.hash.sha256(M + b'verify', encoder=nacl.encoding.RawEncoder)
@@ -456,7 +469,7 @@ class TelescopingExtendCircuitHandler:
 
 
     def __init__(self, relaypicker, current_relay_idkey):
     def __init__(self, relaypicker, current_relay_idkey):
         self.relaypicker = relaypicker
         self.relaypicker = relaypicker
-        self.current_relay_idkey = current_relay_idkey
+        self.current_relay_idkey = bytes(current_relay_idkey)
 
 
     def received_cell(self, circhandler, cell):
     def received_cell(self, circhandler, cell):
         # Remove ourselves from handling a second
         # Remove ourselves from handling a second
@@ -980,7 +993,7 @@ class RelayChannelManager(ChannelManager):
             if next_hop == None:
             if next_hop == None:
                 logging.debug("Client requested extending the circuit to a relay index that results in None, aborting. my circid: %s", str(circhandler.circid))
                 logging.debug("Client requested extending the circuit to a relay index that results in None, aborting. my circid: %s", str(circhandler.circid))
                 circhandler.close()
                 circhandler.close()
-            elif next_hop.snipdict["idkey"] == self.idpubkey or next_hop.snipdict["addr"] == peeraddr:
+            elif next_hop.snipdict["idkey"] == bytes(self.idpubkey) or next_hop.snipdict["addr"] == peeraddr:
                 logging.debug("Client requested extending the circuit to a relay already in the path; aborting. my circid: %s", str(circhandler.circid))
                 logging.debug("Client requested extending the circuit to a relay already in the path; aborting. my circid: %s", str(circhandler.circid))
                 circhandler.close()
                 circhandler.close()
 
 
@@ -1092,8 +1105,8 @@ class Relay(network.Server):
         # previous upload if upload=False
         # previous upload if upload=False
         descdict = dict();
         descdict = dict();
         descdict["epoch"] = network.thenetwork.getepoch() + 1
         descdict["epoch"] = network.thenetwork.getepoch() + 1
-        descdict["idkey"] = self.idkey.verify_key
-        descdict["onionkey"] = self.onionkey.public_key
+        descdict["idkey"] = bytes(self.idkey.verify_key)
+        descdict["onionkey"] = bytes(self.onionkey.public_key)
         descdict["addr"] = self.netaddr
         descdict["addr"] = self.netaddr
         descdict["bw"] = self.bw
         descdict["bw"] = self.bw
         descdict["flags"] = self.flags
         descdict["flags"] = self.flags