1from __future__ import annotations
2
3import asyncio
4import json
5import logging
6import ssl
7import urllib.error
8import urllib.request
9from typing import Any
10from urllib.parse import urlparse
11
12import click
13import websockets
14
15
16class TunnelClient:
17 def __init__(
18 self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
19 ) -> None:
20 self.destination_url = destination_url
21 self.subdomain = subdomain
22 self.tunnel_host = tunnel_host
23
24 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
25 self.tunnel_websocket_url = f"wss://{subdomain}.{tunnel_host}"
26
27 # Set up logging
28 self.logger = logging.getLogger(__name__)
29 self.logger.setLevel(getattr(logging, log_level.upper()))
30 self.logger.propagate = False # Prevent propagation to root logger
31 ch = logging.StreamHandler()
32 ch.setLevel(getattr(logging, log_level.upper()))
33 formatter = logging.Formatter("%(message)s")
34 ch.setFormatter(formatter)
35 self.logger.addHandler(ch)
36
37 # Store incoming requests
38 self.pending_requests = {}
39
40 # Create the event loop
41 self.loop = asyncio.new_event_loop()
42 asyncio.set_event_loop(self.loop)
43 self.stop_event = asyncio.Event()
44
45 async def connect(self) -> None:
46 retry_count = 0
47 max_retries = 5
48 while not self.stop_event.is_set():
49 if retry_count >= max_retries:
50 self.logger.error(
51 f"Failed to connect after {max_retries} retries. Exiting."
52 )
53 break
54 try:
55 self.logger.debug(
56 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
57 )
58 async with websockets.connect(
59 self.tunnel_websocket_url, max_size=None
60 ) as websocket:
61 self.logger.debug("WebSocket connection established")
62 click.secho(
63 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
64 )
65 retry_count = 0 # Reset retry count on successful connection
66 await self.forward_request(websocket)
67 except (websockets.ConnectionClosed, ConnectionError) as e:
68 if self.stop_event.is_set():
69 self.logger.debug("Stopping reconnect attempts due to shutdown")
70 break
71 retry_count += 1
72 self.logger.warning(
73 f"Connection lost: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
74 )
75 await asyncio.sleep(2)
76 except asyncio.CancelledError:
77 self.logger.debug("Connection cancelled")
78 break
79 except Exception as e:
80 if self.stop_event.is_set():
81 self.logger.debug("Stopping reconnect attempts due to shutdown")
82 break
83 retry_count += 1
84 self.logger.error(
85 f"Unexpected error: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
86 )
87 await asyncio.sleep(2)
88
89 async def forward_request(self, websocket: Any) -> None:
90 try:
91 async for message in websocket:
92 if isinstance(message, str):
93 # Received text message (metadata)
94 self.logger.debug("Received metadata from worker")
95 data = json.loads(message)
96 await self.handle_request_metadata(websocket, data)
97 elif isinstance(message, bytes):
98 # Received binary message (body chunk)
99 self.logger.debug("Received binary data from worker")
100 await self.handle_request_body_chunk(websocket, message)
101 else:
102 self.logger.warning("Received unknown message type")
103 except asyncio.CancelledError:
104 self.logger.debug("Forward request cancelled")
105 except Exception as e:
106 self.logger.error(f"Error in forward_request: {e}")
107 raise
108
109 async def handle_request_metadata(
110 self, websocket: Any, data: dict[str, Any]
111 ) -> None:
112 request_id = data["id"]
113 has_body = data.get("has_body", False)
114 total_body_chunks = data.get("totalBodyChunks", 0)
115 self.pending_requests[request_id] = {
116 "metadata": data,
117 "body_chunks": {},
118 "has_body": has_body,
119 "total_body_chunks": total_body_chunks,
120 }
121 self.logger.debug(
122 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
123 )
124 await self.check_and_process_request(websocket, request_id)
125
126 async def handle_request_body_chunk(
127 self, websocket: Any, chunk_data: bytes
128 ) -> None:
129 offset = 0
130
131 # Extract id_length
132 id_length = int.from_bytes(chunk_data[offset : offset + 4], byteorder="little")
133 offset += 4
134
135 # Extract request_id
136 request_id = chunk_data[offset : offset + id_length].decode("utf-8")
137 offset += id_length
138
139 # Extract chunk_index
140 chunk_index = int.from_bytes(
141 chunk_data[offset : offset + 4], byteorder="little"
142 )
143 offset += 4
144
145 # Extract total_chunks
146 total_chunks = int.from_bytes(
147 chunk_data[offset : offset + 4], byteorder="little"
148 )
149 offset += 4
150
151 # Extract body_chunk
152 body_chunk = chunk_data[offset:]
153
154 # Continue processing as before
155
156 if request_id in self.pending_requests:
157 request = self.pending_requests[request_id]
158 if "body_chunks" not in request:
159 request["body_chunks"] = {}
160 request["total_body_chunks"] = total_chunks
161 request["body_chunks"][chunk_index] = body_chunk
162 self.logger.debug(
163 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
164 )
165 await self.check_and_process_request(websocket, request_id)
166 else:
167 self.logger.warning(
168 f"Received body chunk for unknown or completed request ID: {request_id}"
169 )
170
171 async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
172 request_data = self.pending_requests.get(request_id)
173 if request_data and request_data["metadata"]:
174 has_body = request_data["has_body"]
175 total_body_chunks = request_data.get("total_body_chunks", 0)
176 body_chunks = request_data.get("body_chunks", {})
177
178 all_chunks_received = not has_body or (
179 len(body_chunks) == total_body_chunks
180 )
181
182 if all_chunks_received:
183 # Ensure all chunks are present
184 for i in range(total_body_chunks):
185 if i not in body_chunks:
186 self.logger.error(
187 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
188 )
189 return
190
191 self.logger.debug(f"Processing request ID: {request_id}")
192 await self.process_request(
193 websocket, request_data["metadata"], body_chunks, request_id
194 )
195 del self.pending_requests[request_id]
196
197 async def process_request(
198 self,
199 websocket: Any,
200 request_metadata: dict[str, Any],
201 body_chunks: dict[int, bytes],
202 request_id: str,
203 ) -> None:
204 self.logger.debug(
205 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
206 )
207
208 # Reassemble body if present
209 if request_metadata["has_body"]:
210 total_chunks = request_metadata["totalBodyChunks"]
211 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
212 else:
213 body_data = None
214
215 # Parse the original URL to extract the path and query
216 parsed_url = urlparse(request_metadata["url"])
217 path_and_query = parsed_url.path
218 if parsed_url.query:
219 path_and_query += f"?{parsed_url.query}"
220
221 # Construct the new URL by appending path and query to destination_url
222 forward_url = self.destination_url + path_and_query
223
224 self.logger.debug(f"Forwarding request to: {forward_url}")
225
226 # Create a custom SSL context to ignore SSL verification (if needed)
227 ssl_context = ssl.create_default_context()
228 ssl_context.check_hostname = False
229 ssl_context.verify_mode = ssl.CERT_NONE
230
231 # Prepare the request
232 req = urllib.request.Request(
233 url=forward_url,
234 method=request_metadata["method"],
235 data=body_data if body_data else None,
236 headers=request_metadata["headers"], # Headers set directly on the request
237 )
238
239 # Override the HTTPErrorProcessor to stop processing redirects
240 class NoRedirectProcessor(urllib.request.HTTPErrorProcessor):
241 def http_response(
242 self, request: urllib.request.Request, response: Any
243 ) -> Any:
244 return response
245
246 https_response = http_response
247
248 # Create a custom opener that uses the NoRedirectProcessor
249 opener = urllib.request.build_opener(
250 urllib.request.HTTPHandler(),
251 urllib.request.HTTPSHandler(
252 context=ssl_context
253 ), # Pass the SSL context here
254 NoRedirectProcessor(),
255 )
256
257 try:
258 # Make the request using our custom opener
259 with opener.open(req) as response:
260 response_body = response.read()
261 response_headers = dict(response.getheaders())
262 response_status = response.getcode()
263 self.logger.debug(
264 f"Received response with status code: {response_status}"
265 )
266
267 except urllib.error.HTTPError as e:
268 # Non-200 status codes are here (even ones we want)
269 self.logger.debug(f"HTTPError forwarding request: {e}")
270 response_body = e.read()
271 response_headers = dict(e.headers)
272 response_status = e.code
273
274 except urllib.error.URLError as e:
275 self.logger.error(f"URLError forwarding request: {e}")
276 response_body = b""
277 response_headers = {}
278 response_status = 500
279
280 self.logger.info(
281 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
282 )
283
284 has_body = len(response_body) > 0
285 max_chunk_size = 1000000 # 1,000,000 bytes
286 total_body_chunks = (
287 (len(response_body) + max_chunk_size - 1) // max_chunk_size
288 if has_body
289 else 0
290 )
291
292 response_metadata = {
293 "id": request_id,
294 "status": response_status,
295 "headers": list(response_headers.items()),
296 "has_body": has_body,
297 "totalBodyChunks": total_body_chunks,
298 }
299
300 # Send response metadata
301 response_metadata_json = json.dumps(response_metadata)
302 self.logger.debug(
303 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
304 )
305 await websocket.send(response_metadata_json)
306
307 # Send response body chunks if present
308 if has_body:
309 self.logger.debug(
310 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
311 )
312 id_bytes = request_id.encode("utf-8")
313 for i in range(total_body_chunks):
314 chunk_start = i * max_chunk_size
315 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
316 body_chunk = response_body[chunk_start:chunk_end]
317
318 # Prepare the binary message
319 chunk_index_bytes = i.to_bytes(4, byteorder="little")
320 total_chunks_bytes = total_body_chunks.to_bytes(4, byteorder="little")
321 message = id_bytes + chunk_index_bytes + total_chunks_bytes + body_chunk
322 await websocket.send(message)
323 self.logger.debug(
324 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
325 )
326 else:
327 self.logger.debug(f"No body to send for ID: {request_id}")
328
329 async def shutdown(self) -> None:
330 self.stop_event.set()
331 tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
332 if tasks:
333 self.logger.debug(f"Cancelling {len(tasks)} outstanding tasks")
334 for task in tasks:
335 task.cancel()
336 await asyncio.gather(*tasks, return_exceptions=True)
337 await self.loop.shutdown_asyncgens()
338
339 def run(self) -> None:
340 try:
341 self.loop.run_until_complete(self.connect())
342 except KeyboardInterrupt:
343 self.logger.debug("Received exit signal")
344 finally:
345 self.logger.debug("Shutting down...")
346 self.loop.run_until_complete(self.shutdown())
347 self.loop.close()