Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import asyncio
  4import json
  5import logging
  6import ssl
  7import urllib.error
  8import urllib.request
  9from typing import Any
 10from urllib.parse import urlparse
 11
 12import click
 13import websockets
 14
 15
 16class TunnelClient:
 17    def __init__(
 18        self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
 19    ) -> None:
 20        self.destination_url = destination_url
 21        self.subdomain = subdomain
 22        self.tunnel_host = tunnel_host
 23
 24        self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 25        self.tunnel_websocket_url = f"wss://{subdomain}.{tunnel_host}"
 26
 27        # Set up logging
 28        self.logger = logging.getLogger(__name__)
 29        self.logger.setLevel(getattr(logging, log_level.upper()))
 30        ch = logging.StreamHandler()
 31        ch.setLevel(getattr(logging, log_level.upper()))
 32        formatter = logging.Formatter("%(message)s")
 33        ch.setFormatter(formatter)
 34        self.logger.addHandler(ch)
 35
 36        # Store incoming requests
 37        self.pending_requests = {}
 38
 39        # Create the event loop
 40        self.loop = asyncio.new_event_loop()
 41        asyncio.set_event_loop(self.loop)
 42        self.stop_event = asyncio.Event()
 43
 44    async def connect(self) -> None:
 45        retry_count = 0
 46        max_retries = 5
 47        while not self.stop_event.is_set():
 48            if retry_count >= max_retries:
 49                self.logger.error(
 50                    f"Failed to connect after {max_retries} retries. Exiting."
 51                )
 52                break
 53            try:
 54                self.logger.debug(
 55                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 56                )
 57                async with websockets.connect(
 58                    self.tunnel_websocket_url, max_size=None
 59                ) as websocket:
 60                    self.logger.debug("WebSocket connection established")
 61                    click.secho(
 62                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 63                    )
 64                    retry_count = 0  # Reset retry count on successful connection
 65                    await self.forward_request(websocket)
 66            except (websockets.ConnectionClosed, ConnectionError) as e:
 67                if self.stop_event.is_set():
 68                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 69                    break
 70                retry_count += 1
 71                self.logger.warning(
 72                    f"Connection lost: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
 73                )
 74                await asyncio.sleep(2)
 75            except asyncio.CancelledError:
 76                self.logger.debug("Connection cancelled")
 77                break
 78            except Exception as e:
 79                if self.stop_event.is_set():
 80                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 81                    break
 82                retry_count += 1
 83                self.logger.error(
 84                    f"Unexpected error: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
 85                )
 86                await asyncio.sleep(2)
 87
 88    async def forward_request(self, websocket: Any) -> None:
 89        try:
 90            async for message in websocket:
 91                if isinstance(message, str):
 92                    # Received text message (metadata)
 93                    self.logger.debug("Received metadata from worker")
 94                    data = json.loads(message)
 95                    await self.handle_request_metadata(websocket, data)
 96                elif isinstance(message, bytes):
 97                    # Received binary message (body chunk)
 98                    self.logger.debug("Received binary data from worker")
 99                    await self.handle_request_body_chunk(websocket, message)
100                else:
101                    self.logger.warning("Received unknown message type")
102        except asyncio.CancelledError:
103            self.logger.debug("Forward request cancelled")
104        except Exception as e:
105            self.logger.error(f"Error in forward_request: {e}")
106            raise
107
108    async def handle_request_metadata(
109        self, websocket: Any, data: dict[str, Any]
110    ) -> None:
111        request_id = data["id"]
112        has_body = data.get("has_body", False)
113        total_body_chunks = data.get("totalBodyChunks", 0)
114        self.pending_requests[request_id] = {
115            "metadata": data,
116            "body_chunks": {},
117            "has_body": has_body,
118            "total_body_chunks": total_body_chunks,
119        }
120        self.logger.debug(
121            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
122        )
123        await self.check_and_process_request(websocket, request_id)
124
125    async def handle_request_body_chunk(
126        self, websocket: Any, chunk_data: bytes
127    ) -> None:
128        offset = 0
129
130        # Extract id_length
131        id_length = int.from_bytes(chunk_data[offset : offset + 4], byteorder="little")
132        offset += 4
133
134        # Extract request_id
135        request_id = chunk_data[offset : offset + id_length].decode("utf-8")
136        offset += id_length
137
138        # Extract chunk_index
139        chunk_index = int.from_bytes(
140            chunk_data[offset : offset + 4], byteorder="little"
141        )
142        offset += 4
143
144        # Extract total_chunks
145        total_chunks = int.from_bytes(
146            chunk_data[offset : offset + 4], byteorder="little"
147        )
148        offset += 4
149
150        # Extract body_chunk
151        body_chunk = chunk_data[offset:]
152
153        # Continue processing as before
154
155        if request_id in self.pending_requests:
156            request = self.pending_requests[request_id]
157            if "body_chunks" not in request:
158                request["body_chunks"] = {}
159                request["total_body_chunks"] = total_chunks
160            request["body_chunks"][chunk_index] = body_chunk
161            self.logger.debug(
162                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
163            )
164            await self.check_and_process_request(websocket, request_id)
165        else:
166            self.logger.warning(
167                f"Received body chunk for unknown or completed request ID: {request_id}"
168            )
169
170    async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
171        request_data = self.pending_requests.get(request_id)
172        if request_data and request_data["metadata"]:
173            has_body = request_data["has_body"]
174            total_body_chunks = request_data.get("total_body_chunks", 0)
175            body_chunks = request_data.get("body_chunks", {})
176
177            all_chunks_received = not has_body or (
178                len(body_chunks) == total_body_chunks
179            )
180
181            if all_chunks_received:
182                # Ensure all chunks are present
183                for i in range(total_body_chunks):
184                    if i not in body_chunks:
185                        self.logger.error(
186                            f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
187                        )
188                        return
189
190                self.logger.debug(f"Processing request ID: {request_id}")
191                await self.process_request(
192                    websocket, request_data["metadata"], body_chunks, request_id
193                )
194                del self.pending_requests[request_id]
195
196    async def process_request(
197        self,
198        websocket: Any,
199        request_metadata: dict[str, Any],
200        body_chunks: dict[int, bytes],
201        request_id: str,
202    ) -> None:
203        self.logger.debug(
204            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
205        )
206
207        # Reassemble body if present
208        if request_metadata["has_body"]:
209            total_chunks = request_metadata["totalBodyChunks"]
210            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
211        else:
212            body_data = None
213
214        # Parse the original URL to extract the path and query
215        parsed_url = urlparse(request_metadata["url"])
216        path_and_query = parsed_url.path
217        if parsed_url.query:
218            path_and_query += f"?{parsed_url.query}"
219
220        # Construct the new URL by appending path and query to destination_url
221        forward_url = self.destination_url + path_and_query
222
223        self.logger.debug(f"Forwarding request to: {forward_url}")
224
225        # Create a custom SSL context to ignore SSL verification (if needed)
226        ssl_context = ssl.create_default_context()
227        ssl_context.check_hostname = False
228        ssl_context.verify_mode = ssl.CERT_NONE
229
230        # Prepare the request
231        req = urllib.request.Request(
232            url=forward_url,
233            method=request_metadata["method"],
234            data=body_data if body_data else None,
235            headers=request_metadata["headers"],  # Headers set directly on the request
236        )
237
238        # Override the HTTPErrorProcessor to stop processing redirects
239        class NoRedirectProcessor(urllib.request.HTTPErrorProcessor):
240            def http_response(
241                self, request: urllib.request.Request, response: Any
242            ) -> Any:
243                return response
244
245            https_response = http_response
246
247        # Create a custom opener that uses the NoRedirectProcessor
248        opener = urllib.request.build_opener(
249            urllib.request.HTTPHandler(),
250            urllib.request.HTTPSHandler(
251                context=ssl_context
252            ),  # Pass the SSL context here
253            NoRedirectProcessor(),
254        )
255
256        try:
257            # Make the request using our custom opener
258            with opener.open(req) as response:
259                response_body = response.read()
260                response_headers = dict(response.getheaders())
261                response_status = response.getcode()
262                self.logger.debug(
263                    f"Received response with status code: {response_status}"
264                )
265
266        except urllib.error.HTTPError as e:
267            # Non-200 status codes are here (even ones we want)
268            self.logger.debug(f"HTTPError forwarding request: {e}")
269            response_body = e.read()
270            response_headers = dict(e.headers)
271            response_status = e.code
272
273        except urllib.error.URLError as e:
274            self.logger.error(f"URLError forwarding request: {e}")
275            response_body = b""
276            response_headers = {}
277            response_status = 500
278
279        self.logger.info(
280            f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
281        )
282
283        has_body = len(response_body) > 0
284        max_chunk_size = 1000000  # 1,000,000 bytes
285        total_body_chunks = (
286            (len(response_body) + max_chunk_size - 1) // max_chunk_size
287            if has_body
288            else 0
289        )
290
291        response_metadata = {
292            "id": request_id,
293            "status": response_status,
294            "headers": list(response_headers.items()),
295            "has_body": has_body,
296            "totalBodyChunks": total_body_chunks,
297        }
298
299        # Send response metadata
300        response_metadata_json = json.dumps(response_metadata)
301        self.logger.debug(
302            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
303        )
304        await websocket.send(response_metadata_json)
305
306        # Send response body chunks if present
307        if has_body:
308            self.logger.debug(
309                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
310            )
311            id_bytes = request_id.encode("utf-8")
312            for i in range(total_body_chunks):
313                chunk_start = i * max_chunk_size
314                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
315                body_chunk = response_body[chunk_start:chunk_end]
316
317                # Prepare the binary message
318                chunk_index_bytes = i.to_bytes(4, byteorder="little")
319                total_chunks_bytes = total_body_chunks.to_bytes(4, byteorder="little")
320                message = id_bytes + chunk_index_bytes + total_chunks_bytes + body_chunk
321                await websocket.send(message)
322                self.logger.debug(
323                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
324                )
325        else:
326            self.logger.debug(f"No body to send for ID: {request_id}")
327
328    async def shutdown(self) -> None:
329        self.stop_event.set()
330        tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
331        if tasks:
332            self.logger.debug(f"Cancelling {len(tasks)} outstanding tasks")
333            for task in tasks:
334                task.cancel()
335            await asyncio.gather(*tasks, return_exceptions=True)
336        await self.loop.shutdown_asyncgens()
337
338    def run(self) -> None:
339        try:
340            self.loop.run_until_complete(self.connect())
341        except KeyboardInterrupt:
342            self.logger.debug("Received exit signal")
343        finally:
344            self.logger.debug("Shutting down...")
345            self.loop.run_until_complete(self.shutdown())
346            self.loop.close()