Browse Source

Protocols should always return 'True' when completed.

Steven Engler 5 years ago
parent
commit
d943b47e40
2 changed files with 32 additions and 12 deletions
  1. 25 9
      src/basic_protocols.py
  2. 7 3
      src/throughput_protocols.py

+ 25 - 9
src/basic_protocols.py

@@ -76,9 +76,11 @@ class ChainedProtocol(Protocol):
 			#
 			if self.current_protocol >= len(self.protocols):
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #
@@ -114,9 +116,11 @@ class Socks4Protocol(Protocol):
 					raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],))
 				#
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 	def socks_cmd(self, addr_port, username=None):
@@ -215,9 +219,11 @@ class PushDataProtocol(Protocol):
 					raise ProtocolException('Did not receive the expected message: {}'.format(response))
 				#
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #
@@ -289,9 +295,11 @@ class PullDataProtocol(Protocol):
 		if self.state is self.states.SEND_CONFIRMATION:
 			if self.protocol_helper.send(self.socket):
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 	def calc_transfer_rate(self):
@@ -339,9 +347,11 @@ class SendDataProtocol(Protocol):
 					raise ProtocolException('Did not receive the expected message: {}'.format(response))
 				#
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #
@@ -382,9 +392,11 @@ class ReceiveDataProtocol(Protocol):
 		if self.state is self.states.SEND_CONFIRMATION:
 			if self.protocol_helper.send(self.socket):
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #
@@ -446,9 +458,11 @@ class SimpleClientConnectionProtocol(Protocol):
 		if self.state is self.states.PUSH_DATA:
 			if self.sub_protocol.run(block=block):
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #
@@ -470,13 +484,15 @@ class SimpleServerConnectionProtocol(Protocol):
 		#
 		if self.state is self.states.PULL_DATA:
 			if self.sub_protocol.run(block=block):
-				self.state = self.states.DONE
 				if self.bandwidth_callback:
 					self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
 				#
-				return True
+				self.state = self.states.DONE
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #

+ 7 - 3
src/throughput_protocols.py

@@ -41,9 +41,11 @@ class ClientProtocol(basic_protocols.Protocol):
 		if self.state is self.states.PUSH_DATA:
 			if self.sub_protocol.run(block=block):
 				self.state = self.states.DONE
-				return True
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #
@@ -79,13 +81,15 @@ class ServerProtocol(basic_protocols.Protocol):
 		#
 		if self.state is self.states.PULL_DATA:
 			if self.sub_protocol.run(block=block):
-				self.state = self.states.DONE
 				if self.bandwidth_callback:
 					self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
 				#
-				return True
+				self.state = self.states.DONE
 			#
 		#
+		if self.state is self.states.DONE:
+			return True
+		#
 		return False
 	#
 #