1import asyncio
2import json
3import logging
4import ssl
5import urllib.error
6import urllib.request
7from urllib.parse import urlparse
8
9import click
10import websockets
11
12
13class TunnelClient:
14 def __init__(self, *, destination_url, subdomain, tunnel_host, log_level):
15 self.destination_url = destination_url
16 self.subdomain = subdomain
17 self.tunnel_host = tunnel_host
18
19 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
20 self.tunnel_websocket_url = f"wss://{subdomain}.{tunnel_host}"
21
22 # Set up logging
23 self.logger = logging.getLogger(__name__)
24 self.logger.setLevel(getattr(logging, log_level.upper()))
25 ch = logging.StreamHandler()
26 ch.setLevel(getattr(logging, log_level.upper()))
27 formatter = logging.Formatter("%(message)s")
28 ch.setFormatter(formatter)
29 self.logger.addHandler(ch)
30
31 # Store incoming requests
32 self.pending_requests = {}
33
34 # Create the event loop
35 self.loop = asyncio.new_event_loop()
36 asyncio.set_event_loop(self.loop)
37 self.stop_event = asyncio.Event()
38
39 async def connect(self):
40 retry_count = 0
41 max_retries = 5
42 while not self.stop_event.is_set():
43 if retry_count >= max_retries:
44 self.logger.error(
45 f"Failed to connect after {max_retries} retries. Exiting."
46 )
47 break
48 try:
49 self.logger.debug(
50 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
51 )
52 async with websockets.connect(
53 self.tunnel_websocket_url, max_size=None
54 ) as websocket:
55 self.logger.debug("WebSocket connection established")
56 click.secho(
57 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
58 )
59 retry_count = 0 # Reset retry count on successful connection
60 await self.forward_request(websocket)
61 except (websockets.ConnectionClosed, ConnectionError) as e:
62 if self.stop_event.is_set():
63 self.logger.debug("Stopping reconnect attempts due to shutdown")
64 break
65 retry_count += 1
66 self.logger.warning(
67 f"Connection lost: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
68 )
69 await asyncio.sleep(2)
70 except asyncio.CancelledError:
71 self.logger.debug("Connection cancelled")
72 break
73 except Exception as e:
74 if self.stop_event.is_set():
75 self.logger.debug("Stopping reconnect attempts due to shutdown")
76 break
77 retry_count += 1
78 self.logger.error(
79 f"Unexpected error: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
80 )
81 await asyncio.sleep(2)
82
83 async def forward_request(self, websocket):
84 try:
85 async for message in websocket:
86 if isinstance(message, str):
87 # Received text message (metadata)
88 self.logger.debug("Received metadata from worker")
89 data = json.loads(message)
90 await self.handle_request_metadata(websocket, data)
91 elif isinstance(message, bytes):
92 # Received binary message (body chunk)
93 self.logger.debug("Received binary data from worker")
94 await self.handle_request_body_chunk(websocket, message)
95 else:
96 self.logger.warning("Received unknown message type")
97 except asyncio.CancelledError:
98 self.logger.debug("Forward request cancelled")
99 except Exception as e:
100 self.logger.error(f"Error in forward_request: {e}")
101 raise
102
103 async def handle_request_metadata(self, websocket, data):
104 request_id = data["id"]
105 has_body = data.get("has_body", False)
106 total_body_chunks = data.get("totalBodyChunks", 0)
107 self.pending_requests[request_id] = {
108 "metadata": data,
109 "body_chunks": {},
110 "has_body": has_body,
111 "total_body_chunks": total_body_chunks,
112 }
113 self.logger.debug(
114 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
115 )
116 await self.check_and_process_request(websocket, request_id)
117
118 async def handle_request_body_chunk(self, websocket, chunk_data):
119 offset = 0
120
121 # Extract id_length
122 id_length = int.from_bytes(chunk_data[offset : offset + 4], byteorder="little")
123 offset += 4
124
125 # Extract request_id
126 request_id = chunk_data[offset : offset + id_length].decode("utf-8")
127 offset += id_length
128
129 # Extract chunk_index
130 chunk_index = int.from_bytes(
131 chunk_data[offset : offset + 4], byteorder="little"
132 )
133 offset += 4
134
135 # Extract total_chunks
136 total_chunks = int.from_bytes(
137 chunk_data[offset : offset + 4], byteorder="little"
138 )
139 offset += 4
140
141 # Extract body_chunk
142 body_chunk = chunk_data[offset:]
143
144 # Continue processing as before
145
146 if request_id in self.pending_requests:
147 request = self.pending_requests[request_id]
148 if "body_chunks" not in request:
149 request["body_chunks"] = {}
150 request["total_body_chunks"] = total_chunks
151 request["body_chunks"][chunk_index] = body_chunk
152 self.logger.debug(
153 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
154 )
155 await self.check_and_process_request(websocket, request_id)
156 else:
157 self.logger.warning(
158 f"Received body chunk for unknown or completed request ID: {request_id}"
159 )
160
161 async def check_and_process_request(self, websocket, request_id):
162 request_data = self.pending_requests.get(request_id)
163 if request_data and request_data["metadata"]:
164 has_body = request_data["has_body"]
165 total_body_chunks = request_data.get("total_body_chunks", 0)
166 body_chunks = request_data.get("body_chunks", {})
167
168 all_chunks_received = not has_body or (
169 len(body_chunks) == total_body_chunks
170 )
171
172 if all_chunks_received:
173 # Ensure all chunks are present
174 for i in range(total_body_chunks):
175 if i not in body_chunks:
176 self.logger.error(
177 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
178 )
179 return
180
181 self.logger.debug(f"Processing request ID: {request_id}")
182 await self.process_request(
183 websocket, request_data["metadata"], body_chunks, request_id
184 )
185 del self.pending_requests[request_id]
186
187 async def process_request(
188 self, websocket, request_metadata, body_chunks, request_id
189 ):
190 self.logger.debug(
191 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
192 )
193
194 # Reassemble body if present
195 if request_metadata["has_body"]:
196 total_chunks = request_metadata["totalBodyChunks"]
197 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
198 else:
199 body_data = None
200
201 # Parse the original URL to extract the path and query
202 parsed_url = urlparse(request_metadata["url"])
203 path_and_query = parsed_url.path
204 if parsed_url.query:
205 path_and_query += f"?{parsed_url.query}"
206
207 # Construct the new URL by appending path and query to destination_url
208 forward_url = self.destination_url + path_and_query
209
210 self.logger.debug(f"Forwarding request to: {forward_url}")
211
212 # Create a custom SSL context to ignore SSL verification (if needed)
213 ssl_context = ssl.create_default_context()
214 ssl_context.check_hostname = False
215 ssl_context.verify_mode = ssl.CERT_NONE
216
217 # Prepare the request
218 req = urllib.request.Request(
219 url=forward_url,
220 method=request_metadata["method"],
221 data=body_data if body_data else None,
222 headers=request_metadata["headers"], # Headers set directly on the request
223 )
224
225 # Override the HTTPErrorProcessor to stop processing redirects
226 class NoRedirectProcessor(urllib.request.HTTPErrorProcessor):
227 def http_response(self, request, response):
228 return response
229
230 https_response = http_response
231
232 # Create a custom opener that uses the NoRedirectProcessor
233 opener = urllib.request.build_opener(
234 urllib.request.HTTPHandler(),
235 urllib.request.HTTPSHandler(
236 context=ssl_context
237 ), # Pass the SSL context here
238 NoRedirectProcessor(),
239 )
240
241 try:
242 # Make the request using our custom opener
243 with opener.open(req) as response:
244 response_body = response.read()
245 response_headers = dict(response.getheaders())
246 response_status = response.getcode()
247 self.logger.debug(
248 f"Received response with status code: {response_status}"
249 )
250
251 except urllib.error.HTTPError as e:
252 # Non-200 status codes are here (even ones we want)
253 self.logger.debug(f"HTTPError forwarding request: {e}")
254 response_body = e.read()
255 response_headers = dict(e.headers)
256 response_status = e.code
257
258 except urllib.error.URLError as e:
259 self.logger.error(f"URLError forwarding request: {e}")
260 response_body = b""
261 response_headers = {}
262 response_status = 500
263
264 self.logger.info(
265 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
266 )
267
268 has_body = len(response_body) > 0
269 max_chunk_size = 1000000 # 1,000,000 bytes
270 total_body_chunks = (
271 (len(response_body) + max_chunk_size - 1) // max_chunk_size
272 if has_body
273 else 0
274 )
275
276 response_metadata = {
277 "id": request_id,
278 "status": response_status,
279 "headers": list(response_headers.items()),
280 "has_body": has_body,
281 "totalBodyChunks": total_body_chunks,
282 }
283
284 # Send response metadata
285 response_metadata_json = json.dumps(response_metadata)
286 self.logger.debug(
287 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
288 )
289 await websocket.send(response_metadata_json)
290
291 # Send response body chunks if present
292 if has_body:
293 self.logger.debug(
294 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
295 )
296 id_bytes = request_id.encode("utf-8")
297 for i in range(total_body_chunks):
298 chunk_start = i * max_chunk_size
299 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
300 body_chunk = response_body[chunk_start:chunk_end]
301
302 # Prepare the binary message
303 chunk_index_bytes = i.to_bytes(4, byteorder="little")
304 total_chunks_bytes = total_body_chunks.to_bytes(4, byteorder="little")
305 message = id_bytes + chunk_index_bytes + total_chunks_bytes + body_chunk
306 await websocket.send(message)
307 self.logger.debug(
308 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
309 )
310 else:
311 self.logger.debug(f"No body to send for ID: {request_id}")
312
313 async def shutdown(self):
314 self.stop_event.set()
315 tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
316 if tasks:
317 self.logger.debug(f"Cancelling {len(tasks)} outstanding tasks")
318 for task in tasks:
319 task.cancel()
320 await asyncio.gather(*tasks, return_exceptions=True)
321 await self.loop.shutdown_asyncgens()
322
323 def run(self):
324 try:
325 self.loop.run_until_complete(self.connect())
326 except KeyboardInterrupt:
327 self.logger.debug("Received exit signal")
328 finally:
329 self.logger.debug("Shutting down...")
330 self.loop.run_until_complete(self.shutdown())
331 self.loop.close()