Browse Source

Made ClientProtocol more flexible.

Instead of ClientProtocol containing the proxy protocol, the proxy protocol and ClientProtocol are combined using a ChainedProtocol. This means that the ClientProtocol doesn't need to know anything about the proxy, and we can easily swap the SOCKS proxy out with others in the future.
Steven Engler 5 years ago
parent
commit
32a3d9a996
3 changed files with 53 additions and 38 deletions
  1. 24 0
      src/basic_protocols.py
  2. 23 9
      src/throughput_client.py
  3. 6 29
      src/throughput_protocols.py

+ 24 - 0
src/basic_protocols.py

@@ -58,6 +58,30 @@ class Protocol():
 		#
 	#
 #
+class ChainedProtocol(Protocol):
+	def __init__(self, protocols):
+		self.protocols = protocols
+		self.current_protocol = 0
+		#
+		self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE')
+		self.state = self.states.READY_TO_BEGIN
+	#
+	def _run_iteration(self, block=True):
+		if self.state is self.states.READY_TO_BEGIN:
+			self.state = self.states.RUNNING
+		#
+		if self.state is self.states.RUNNING:
+			if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run(block=block):
+				self.current_protocol += 1
+			#
+			if self.current_protocol >= len(self.protocols):
+				self.state = self.states.DONE
+				return True
+			#
+		#
+		return False
+	#
+#
 class Socks4Protocol(Protocol):
 	def __init__(self, socket, addr_port, username=None):
 		self.socket = socket

+ 23 - 9
src/throughput_client.py

@@ -1,10 +1,12 @@
 #!/usr/bin/python3
 #
 import throughput_protocols
+import basic_protocols
 import useful
 import os
 import argparse
 import logging
+import socket
 #
 if __name__ == '__main__':
 	logging.basicConfig(level=logging.DEBUG)
@@ -22,18 +24,30 @@ if __name__ == '__main__':
 	parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
 	args = parser.parse_args()
 	#
-	endpoint = (args.ip, args.port)
-	proxy = None
 	#
-	if args.proxy is not None:
-		proxy = (args.proxy[0], int(args.proxy[1]))
+	endpoint = (args.ip, args.port)
+	client_socket = socket.socket()
+	protocols = []
 	#
-	username = bytes([x for x in os.urandom(12) if x != 0])
-	#username = None
+	if args.proxy is None:
+		logging.debug('Socket %d connecting to endpoint %r...', client_socket.fileno(), endpoint)
+		client_socket.connect(endpoint)
+	else:
+		proxy_username = bytes([x for x in os.urandom(12) if x != 0])
+		proxy_endpoint = (args.proxy[0], int(args.proxy[1]))
+		#
+		logging.debug('Socket %d connecting to proxy %r...', client_socket.fileno(), proxy_endpoint)
+		client_socket.connect(proxy_endpoint)
+		#
+		proxy_protocol = basic_protocols.Socks4Protocol(client_socket, endpoint, username=proxy_username)
+		protocols.append(proxy_protocol)
 	#
-	client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy,
-	                                             username=username, wait_until=args.wait,
+	throughput_protocol = throughput_protocols.ClientProtocol(client_socket, args.num_bytes,
+	                                             wait_until=args.wait,
 	                                             send_buffer_len=args.buffer_len,
 	                                             use_acceleration=(not args.no_accel))
-	client.run()
+	protocols.append(throughput_protocol)
+	#
+	combined_protocol = basic_protocols.ChainedProtocol(protocols)
+	combined_protocol.run()
 #

+ 6 - 29
src/throughput_protocols.py

@@ -7,48 +7,25 @@ import time
 import socket
 #
 class ClientProtocol(basic_protocols.Protocol):
-	def __init__(self, endpoint, total_bytes, proxy=None, username=None, wait_until=None, send_buffer_len=None, use_acceleration=None):
-		self.endpoint = endpoint
+	def __init__(self, socket, total_bytes, wait_until=None, send_buffer_len=None, use_acceleration=None):
+		self.socket = socket
 		self.total_bytes = total_bytes
-		self.proxy = proxy
-		self.username = username
 		self.wait_until = wait_until
 		self.send_buffer_len = send_buffer_len
 		self.use_acceleration = use_acceleration
 		#
-		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE')
+		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN SEND_GROUP_ID PUSH_DATA DONE')
 		self.state = self.states.READY_TO_BEGIN
 		#
-		self.socket = socket.socket()
 		self.sub_protocol = None
 		self.group_id = int(self.wait_until*1000) if self.wait_until is not None else 0
 		# a group id of 0 means no group
-		#
-		if self.proxy is None:
-			logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
-			self.socket.connect(self.endpoint)
-		else:
-			logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
-			self.socket.connect(self.proxy)
-		#
 	#
 	def _run_iteration(self, block=True):
 		if self.state is self.states.READY_TO_BEGIN:
-			if self.proxy is None:
-				group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
-				self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
-				self.state = self.states.SEND_GROUP_ID
-			else:
-				self.sub_protocol = basic_protocols.Socks4Protocol(self.socket, self.endpoint, username=self.username)
-				self.state = self.states.CONNECT_TO_PROXY
-			#
-		#
-		if self.state is self.states.CONNECT_TO_PROXY:
-			if self.sub_protocol.run(block=block):
-				group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
-				self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
-				self.state = self.states.SEND_GROUP_ID
-			#
+			group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
+			self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
+			self.state = self.states.SEND_GROUP_ID
 		#
 		if self.state is self.states.SEND_GROUP_ID:
 			if block and self.wait_until is not None: