Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import asyncio
  4import json
  5import logging
  6import ssl
  7import urllib.error
  8import urllib.request
  9from typing import Any
 10from urllib.parse import urlparse
 11
 12import click
 13import websockets
 14
 15
 16class TunnelClient:
 17    def __init__(
 18        self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
 19    ) -> None:
 20        self.destination_url = destination_url
 21        self.subdomain = subdomain
 22        self.tunnel_host = tunnel_host
 23
 24        self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 25        self.tunnel_websocket_url = f"wss://{subdomain}.{tunnel_host}"
 26
 27        # Set up logging
 28        self.logger = logging.getLogger(__name__)
 29        self.logger.setLevel(getattr(logging, log_level.upper()))
 30        self.logger.propagate = False  # Prevent propagation to root logger
 31        ch = logging.StreamHandler()
 32        ch.setLevel(getattr(logging, log_level.upper()))
 33        formatter = logging.Formatter("%(message)s")
 34        ch.setFormatter(formatter)
 35        self.logger.addHandler(ch)
 36
 37        # Store incoming requests
 38        self.pending_requests = {}
 39
 40        # Create the event loop
 41        self.loop = asyncio.new_event_loop()
 42        asyncio.set_event_loop(self.loop)
 43        self.stop_event = asyncio.Event()
 44
 45    async def connect(self) -> None:
 46        retry_count = 0
 47        max_retries = 5
 48        while not self.stop_event.is_set():
 49            if retry_count >= max_retries:
 50                self.logger.error(
 51                    f"Failed to connect after {max_retries} retries. Exiting."
 52                )
 53                break
 54            try:
 55                self.logger.debug(
 56                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 57                )
 58                async with websockets.connect(
 59                    self.tunnel_websocket_url, max_size=None
 60                ) as websocket:
 61                    self.logger.debug("WebSocket connection established")
 62                    click.secho(
 63                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 64                    )
 65                    retry_count = 0  # Reset retry count on successful connection
 66                    await self.forward_request(websocket)
 67            except (websockets.ConnectionClosed, ConnectionError) as e:
 68                if self.stop_event.is_set():
 69                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 70                    break
 71                retry_count += 1
 72                self.logger.warning(
 73                    f"Connection lost: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
 74                )
 75                await asyncio.sleep(2)
 76            except asyncio.CancelledError:
 77                self.logger.debug("Connection cancelled")
 78                break
 79            except Exception as e:
 80                if self.stop_event.is_set():
 81                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 82                    break
 83                retry_count += 1
 84                self.logger.error(
 85                    f"Unexpected error: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
 86                )
 87                await asyncio.sleep(2)
 88
 89    async def forward_request(self, websocket: Any) -> None:
 90        try:
 91            async for message in websocket:
 92                if isinstance(message, str):
 93                    # Received text message (metadata)
 94                    self.logger.debug("Received metadata from worker")
 95                    data = json.loads(message)
 96                    await self.handle_request_metadata(websocket, data)
 97                elif isinstance(message, bytes):
 98                    # Received binary message (body chunk)
 99                    self.logger.debug("Received binary data from worker")
100                    await self.handle_request_body_chunk(websocket, message)
101                else:
102                    self.logger.warning("Received unknown message type")
103        except asyncio.CancelledError:
104            self.logger.debug("Forward request cancelled")
105        except Exception as e:
106            self.logger.error(f"Error in forward_request: {e}")
107            raise
108
109    async def handle_request_metadata(
110        self, websocket: Any, data: dict[str, Any]
111    ) -> None:
112        request_id = data["id"]
113        has_body = data.get("has_body", False)
114        total_body_chunks = data.get("totalBodyChunks", 0)
115        self.pending_requests[request_id] = {
116            "metadata": data,
117            "body_chunks": {},
118            "has_body": has_body,
119            "total_body_chunks": total_body_chunks,
120        }
121        self.logger.debug(
122            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
123        )
124        await self.check_and_process_request(websocket, request_id)
125
126    async def handle_request_body_chunk(
127        self, websocket: Any, chunk_data: bytes
128    ) -> None:
129        offset = 0
130
131        # Extract id_length
132        id_length = int.from_bytes(chunk_data[offset : offset + 4], byteorder="little")
133        offset += 4
134
135        # Extract request_id
136        request_id = chunk_data[offset : offset + id_length].decode("utf-8")
137        offset += id_length
138
139        # Extract chunk_index
140        chunk_index = int.from_bytes(
141            chunk_data[offset : offset + 4], byteorder="little"
142        )
143        offset += 4
144
145        # Extract total_chunks
146        total_chunks = int.from_bytes(
147            chunk_data[offset : offset + 4], byteorder="little"
148        )
149        offset += 4
150
151        # Extract body_chunk
152        body_chunk = chunk_data[offset:]
153
154        # Continue processing as before
155
156        if request_id in self.pending_requests:
157            request = self.pending_requests[request_id]
158            if "body_chunks" not in request:
159                request["body_chunks"] = {}
160                request["total_body_chunks"] = total_chunks
161            request["body_chunks"][chunk_index] = body_chunk
162            self.logger.debug(
163                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
164            )
165            await self.check_and_process_request(websocket, request_id)
166        else:
167            self.logger.warning(
168                f"Received body chunk for unknown or completed request ID: {request_id}"
169            )
170
171    async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
172        request_data = self.pending_requests.get(request_id)
173        if request_data and request_data["metadata"]:
174            has_body = request_data["has_body"]
175            total_body_chunks = request_data.get("total_body_chunks", 0)
176            body_chunks = request_data.get("body_chunks", {})
177
178            all_chunks_received = not has_body or (
179                len(body_chunks) == total_body_chunks
180            )
181
182            if all_chunks_received:
183                # Ensure all chunks are present
184                for i in range(total_body_chunks):
185                    if i not in body_chunks:
186                        self.logger.error(
187                            f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
188                        )
189                        return
190
191                self.logger.debug(f"Processing request ID: {request_id}")
192                await self.process_request(
193                    websocket, request_data["metadata"], body_chunks, request_id
194                )
195                del self.pending_requests[request_id]
196
197    async def process_request(
198        self,
199        websocket: Any,
200        request_metadata: dict[str, Any],
201        body_chunks: dict[int, bytes],
202        request_id: str,
203    ) -> None:
204        self.logger.debug(
205            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
206        )
207
208        # Reassemble body if present
209        if request_metadata["has_body"]:
210            total_chunks = request_metadata["totalBodyChunks"]
211            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
212        else:
213            body_data = None
214
215        # Parse the original URL to extract the path and query
216        parsed_url = urlparse(request_metadata["url"])
217        path_and_query = parsed_url.path
218        if parsed_url.query:
219            path_and_query += f"?{parsed_url.query}"
220
221        # Construct the new URL by appending path and query to destination_url
222        forward_url = self.destination_url + path_and_query
223
224        self.logger.debug(f"Forwarding request to: {forward_url}")
225
226        # Create a custom SSL context to ignore SSL verification (if needed)
227        ssl_context = ssl.create_default_context()
228        ssl_context.check_hostname = False
229        ssl_context.verify_mode = ssl.CERT_NONE
230
231        # Prepare the request
232        req = urllib.request.Request(
233            url=forward_url,
234            method=request_metadata["method"],
235            data=body_data if body_data else None,
236            headers=request_metadata["headers"],  # Headers set directly on the request
237        )
238
239        # Override the HTTPErrorProcessor to stop processing redirects
240        class NoRedirectProcessor(urllib.request.HTTPErrorProcessor):
241            def http_response(
242                self, request: urllib.request.Request, response: Any
243            ) -> Any:
244                return response
245
246            https_response = http_response
247
248        # Create a custom opener that uses the NoRedirectProcessor
249        opener = urllib.request.build_opener(
250            urllib.request.HTTPHandler(),
251            urllib.request.HTTPSHandler(
252                context=ssl_context
253            ),  # Pass the SSL context here
254            NoRedirectProcessor(),
255        )
256
257        try:
258            # Make the request using our custom opener
259            with opener.open(req) as response:
260                response_body = response.read()
261                response_headers = dict(response.getheaders())
262                response_status = response.getcode()
263                self.logger.debug(
264                    f"Received response with status code: {response_status}"
265                )
266
267        except urllib.error.HTTPError as e:
268            # Non-200 status codes are here (even ones we want)
269            self.logger.debug(f"HTTPError forwarding request: {e}")
270            response_body = e.read()
271            response_headers = dict(e.headers)
272            response_status = e.code
273
274        except urllib.error.URLError as e:
275            self.logger.error(f"URLError forwarding request: {e}")
276            response_body = b""
277            response_headers = {}
278            response_status = 500
279
280        self.logger.info(
281            f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
282        )
283
284        has_body = len(response_body) > 0
285        max_chunk_size = 1000000  # 1,000,000 bytes
286        total_body_chunks = (
287            (len(response_body) + max_chunk_size - 1) // max_chunk_size
288            if has_body
289            else 0
290        )
291
292        response_metadata = {
293            "id": request_id,
294            "status": response_status,
295            "headers": list(response_headers.items()),
296            "has_body": has_body,
297            "totalBodyChunks": total_body_chunks,
298        }
299
300        # Send response metadata
301        response_metadata_json = json.dumps(response_metadata)
302        self.logger.debug(
303            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
304        )
305        await websocket.send(response_metadata_json)
306
307        # Send response body chunks if present
308        if has_body:
309            self.logger.debug(
310                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
311            )
312            id_bytes = request_id.encode("utf-8")
313            for i in range(total_body_chunks):
314                chunk_start = i * max_chunk_size
315                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
316                body_chunk = response_body[chunk_start:chunk_end]
317
318                # Prepare the binary message
319                chunk_index_bytes = i.to_bytes(4, byteorder="little")
320                total_chunks_bytes = total_body_chunks.to_bytes(4, byteorder="little")
321                message = id_bytes + chunk_index_bytes + total_chunks_bytes + body_chunk
322                await websocket.send(message)
323                self.logger.debug(
324                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
325                )
326        else:
327            self.logger.debug(f"No body to send for ID: {request_id}")
328
329    async def shutdown(self) -> None:
330        self.stop_event.set()
331        tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
332        if tasks:
333            self.logger.debug(f"Cancelling {len(tasks)} outstanding tasks")
334            for task in tasks:
335                task.cancel()
336            await asyncio.gather(*tasks, return_exceptions=True)
337        await self.loop.shutdown_asyncgens()
338
339    def run(self) -> None:
340        try:
341            self.loop.run_until_complete(self.connect())
342        except KeyboardInterrupt:
343            self.logger.debug("Received exit signal")
344        finally:
345            self.logger.debug("Shutting down...")
346            self.loop.run_until_complete(self.shutdown())
347            self.loop.close()