1from __future__ import annotations
2
3import asyncio
4import json
5import logging
6import ssl
7import urllib.error
8import urllib.request
9from typing import Any
10from urllib.parse import urlparse
11
12import click
13import websockets
14
15
16class TunnelClient:
17 def __init__(
18 self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
19 ) -> None:
20 self.destination_url = destination_url
21 self.subdomain = subdomain
22 self.tunnel_host = tunnel_host
23
24 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
25 self.tunnel_websocket_url = f"wss://{subdomain}.{tunnel_host}"
26
27 # Set up logging
28 self.logger = logging.getLogger(__name__)
29 self.logger.setLevel(getattr(logging, log_level.upper()))
30 ch = logging.StreamHandler()
31 ch.setLevel(getattr(logging, log_level.upper()))
32 formatter = logging.Formatter("%(message)s")
33 ch.setFormatter(formatter)
34 self.logger.addHandler(ch)
35
36 # Store incoming requests
37 self.pending_requests = {}
38
39 # Create the event loop
40 self.loop = asyncio.new_event_loop()
41 asyncio.set_event_loop(self.loop)
42 self.stop_event = asyncio.Event()
43
44 async def connect(self) -> None:
45 retry_count = 0
46 max_retries = 5
47 while not self.stop_event.is_set():
48 if retry_count >= max_retries:
49 self.logger.error(
50 f"Failed to connect after {max_retries} retries. Exiting."
51 )
52 break
53 try:
54 self.logger.debug(
55 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
56 )
57 async with websockets.connect(
58 self.tunnel_websocket_url, max_size=None
59 ) as websocket:
60 self.logger.debug("WebSocket connection established")
61 click.secho(
62 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
63 )
64 retry_count = 0 # Reset retry count on successful connection
65 await self.forward_request(websocket)
66 except (websockets.ConnectionClosed, ConnectionError) as e:
67 if self.stop_event.is_set():
68 self.logger.debug("Stopping reconnect attempts due to shutdown")
69 break
70 retry_count += 1
71 self.logger.warning(
72 f"Connection lost: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
73 )
74 await asyncio.sleep(2)
75 except asyncio.CancelledError:
76 self.logger.debug("Connection cancelled")
77 break
78 except Exception as e:
79 if self.stop_event.is_set():
80 self.logger.debug("Stopping reconnect attempts due to shutdown")
81 break
82 retry_count += 1
83 self.logger.error(
84 f"Unexpected error: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
85 )
86 await asyncio.sleep(2)
87
88 async def forward_request(self, websocket: Any) -> None:
89 try:
90 async for message in websocket:
91 if isinstance(message, str):
92 # Received text message (metadata)
93 self.logger.debug("Received metadata from worker")
94 data = json.loads(message)
95 await self.handle_request_metadata(websocket, data)
96 elif isinstance(message, bytes):
97 # Received binary message (body chunk)
98 self.logger.debug("Received binary data from worker")
99 await self.handle_request_body_chunk(websocket, message)
100 else:
101 self.logger.warning("Received unknown message type")
102 except asyncio.CancelledError:
103 self.logger.debug("Forward request cancelled")
104 except Exception as e:
105 self.logger.error(f"Error in forward_request: {e}")
106 raise
107
108 async def handle_request_metadata(
109 self, websocket: Any, data: dict[str, Any]
110 ) -> None:
111 request_id = data["id"]
112 has_body = data.get("has_body", False)
113 total_body_chunks = data.get("totalBodyChunks", 0)
114 self.pending_requests[request_id] = {
115 "metadata": data,
116 "body_chunks": {},
117 "has_body": has_body,
118 "total_body_chunks": total_body_chunks,
119 }
120 self.logger.debug(
121 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
122 )
123 await self.check_and_process_request(websocket, request_id)
124
125 async def handle_request_body_chunk(
126 self, websocket: Any, chunk_data: bytes
127 ) -> None:
128 offset = 0
129
130 # Extract id_length
131 id_length = int.from_bytes(chunk_data[offset : offset + 4], byteorder="little")
132 offset += 4
133
134 # Extract request_id
135 request_id = chunk_data[offset : offset + id_length].decode("utf-8")
136 offset += id_length
137
138 # Extract chunk_index
139 chunk_index = int.from_bytes(
140 chunk_data[offset : offset + 4], byteorder="little"
141 )
142 offset += 4
143
144 # Extract total_chunks
145 total_chunks = int.from_bytes(
146 chunk_data[offset : offset + 4], byteorder="little"
147 )
148 offset += 4
149
150 # Extract body_chunk
151 body_chunk = chunk_data[offset:]
152
153 # Continue processing as before
154
155 if request_id in self.pending_requests:
156 request = self.pending_requests[request_id]
157 if "body_chunks" not in request:
158 request["body_chunks"] = {}
159 request["total_body_chunks"] = total_chunks
160 request["body_chunks"][chunk_index] = body_chunk
161 self.logger.debug(
162 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
163 )
164 await self.check_and_process_request(websocket, request_id)
165 else:
166 self.logger.warning(
167 f"Received body chunk for unknown or completed request ID: {request_id}"
168 )
169
170 async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
171 request_data = self.pending_requests.get(request_id)
172 if request_data and request_data["metadata"]:
173 has_body = request_data["has_body"]
174 total_body_chunks = request_data.get("total_body_chunks", 0)
175 body_chunks = request_data.get("body_chunks", {})
176
177 all_chunks_received = not has_body or (
178 len(body_chunks) == total_body_chunks
179 )
180
181 if all_chunks_received:
182 # Ensure all chunks are present
183 for i in range(total_body_chunks):
184 if i not in body_chunks:
185 self.logger.error(
186 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
187 )
188 return
189
190 self.logger.debug(f"Processing request ID: {request_id}")
191 await self.process_request(
192 websocket, request_data["metadata"], body_chunks, request_id
193 )
194 del self.pending_requests[request_id]
195
196 async def process_request(
197 self,
198 websocket: Any,
199 request_metadata: dict[str, Any],
200 body_chunks: dict[int, bytes],
201 request_id: str,
202 ) -> None:
203 self.logger.debug(
204 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
205 )
206
207 # Reassemble body if present
208 if request_metadata["has_body"]:
209 total_chunks = request_metadata["totalBodyChunks"]
210 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
211 else:
212 body_data = None
213
214 # Parse the original URL to extract the path and query
215 parsed_url = urlparse(request_metadata["url"])
216 path_and_query = parsed_url.path
217 if parsed_url.query:
218 path_and_query += f"?{parsed_url.query}"
219
220 # Construct the new URL by appending path and query to destination_url
221 forward_url = self.destination_url + path_and_query
222
223 self.logger.debug(f"Forwarding request to: {forward_url}")
224
225 # Create a custom SSL context to ignore SSL verification (if needed)
226 ssl_context = ssl.create_default_context()
227 ssl_context.check_hostname = False
228 ssl_context.verify_mode = ssl.CERT_NONE
229
230 # Prepare the request
231 req = urllib.request.Request(
232 url=forward_url,
233 method=request_metadata["method"],
234 data=body_data if body_data else None,
235 headers=request_metadata["headers"], # Headers set directly on the request
236 )
237
238 # Override the HTTPErrorProcessor to stop processing redirects
239 class NoRedirectProcessor(urllib.request.HTTPErrorProcessor):
240 def http_response(
241 self, request: urllib.request.Request, response: Any
242 ) -> Any:
243 return response
244
245 https_response = http_response
246
247 # Create a custom opener that uses the NoRedirectProcessor
248 opener = urllib.request.build_opener(
249 urllib.request.HTTPHandler(),
250 urllib.request.HTTPSHandler(
251 context=ssl_context
252 ), # Pass the SSL context here
253 NoRedirectProcessor(),
254 )
255
256 try:
257 # Make the request using our custom opener
258 with opener.open(req) as response:
259 response_body = response.read()
260 response_headers = dict(response.getheaders())
261 response_status = response.getcode()
262 self.logger.debug(
263 f"Received response with status code: {response_status}"
264 )
265
266 except urllib.error.HTTPError as e:
267 # Non-200 status codes are here (even ones we want)
268 self.logger.debug(f"HTTPError forwarding request: {e}")
269 response_body = e.read()
270 response_headers = dict(e.headers)
271 response_status = e.code
272
273 except urllib.error.URLError as e:
274 self.logger.error(f"URLError forwarding request: {e}")
275 response_body = b""
276 response_headers = {}
277 response_status = 500
278
279 self.logger.info(
280 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
281 )
282
283 has_body = len(response_body) > 0
284 max_chunk_size = 1000000 # 1,000,000 bytes
285 total_body_chunks = (
286 (len(response_body) + max_chunk_size - 1) // max_chunk_size
287 if has_body
288 else 0
289 )
290
291 response_metadata = {
292 "id": request_id,
293 "status": response_status,
294 "headers": list(response_headers.items()),
295 "has_body": has_body,
296 "totalBodyChunks": total_body_chunks,
297 }
298
299 # Send response metadata
300 response_metadata_json = json.dumps(response_metadata)
301 self.logger.debug(
302 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
303 )
304 await websocket.send(response_metadata_json)
305
306 # Send response body chunks if present
307 if has_body:
308 self.logger.debug(
309 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
310 )
311 id_bytes = request_id.encode("utf-8")
312 for i in range(total_body_chunks):
313 chunk_start = i * max_chunk_size
314 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
315 body_chunk = response_body[chunk_start:chunk_end]
316
317 # Prepare the binary message
318 chunk_index_bytes = i.to_bytes(4, byteorder="little")
319 total_chunks_bytes = total_body_chunks.to_bytes(4, byteorder="little")
320 message = id_bytes + chunk_index_bytes + total_chunks_bytes + body_chunk
321 await websocket.send(message)
322 self.logger.debug(
323 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
324 )
325 else:
326 self.logger.debug(f"No body to send for ID: {request_id}")
327
328 async def shutdown(self) -> None:
329 self.stop_event.set()
330 tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
331 if tasks:
332 self.logger.debug(f"Cancelling {len(tasks)} outstanding tasks")
333 for task in tasks:
334 task.cancel()
335 await asyncio.gather(*tasks, return_exceptions=True)
336 await self.loop.shutdown_asyncgens()
337
338 def run(self) -> None:
339 try:
340 self.loop.run_until_complete(self.connect())
341 except KeyboardInterrupt:
342 self.logger.debug("Received exit signal")
343 finally:
344 self.logger.debug("Shutting down...")
345 self.loop.run_until_complete(self.shutdown())
346 self.loop.close()