1from __future__ import annotations
  2
  3import asyncio
  4import json
  5import logging
  6import math
  7import struct
  8from typing import Any
  9from urllib.parse import urlparse
 10
 11import click
 12import httpx
 13import websockets
 14
 15# Bump this when making breaking changes to the WebSocket protocol.
 16# The server will reject clients with a version lower than its minimum.
 17PROTOCOL_VERSION = 1
 18
 19
 20class TunnelClient:
 21    def __init__(
 22        self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
 23    ) -> None:
 24        self.destination_url = destination_url
 25        self.subdomain = subdomain
 26        self.tunnel_host = tunnel_host
 27
 28        self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 29        self.tunnel_websocket_url = (
 30            f"wss://{subdomain}.{tunnel_host}?v={PROTOCOL_VERSION}"
 31        )
 32
 33        self.logger = logging.getLogger(__name__)
 34        level = getattr(logging, log_level.upper())
 35        self.logger.setLevel(level)
 36        self.logger.propagate = False
 37        handler = logging.StreamHandler()
 38        handler.setLevel(level)
 39        handler.setFormatter(logging.Formatter("%(message)s"))
 40        self.logger.addHandler(handler)
 41
 42        self.pending_requests: dict[str, dict[str, Any]] = {}
 43        self.stop_event = asyncio.Event()
 44
 45    async def connect(self) -> None:
 46        retry_delay = 1.0
 47        max_retry_delay = 30.0
 48        while not self.stop_event.is_set():
 49            try:
 50                self.logger.debug(
 51                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 52                )
 53                async with websockets.connect(
 54                    self.tunnel_websocket_url, max_size=None
 55                ) as websocket:
 56                    self.logger.debug("WebSocket connection established")
 57                    click.secho(
 58                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 59                    )
 60                    retry_delay = 1.0
 61                    await self.handle_messages(websocket)
 62            except asyncio.CancelledError:
 63                self.logger.debug("Connection cancelled")
 64                break
 65            except websockets.InvalidStatus as e:
 66                if e.response.status_code == 426:
 67                    body = e.response.body.decode() if e.response.body else ""
 68                    click.secho(
 69                        body or "Client version too old. Please upgrade plain.tunnel.",
 70                        fg="red",
 71                    )
 72                    break
 73                raise
 74            except (websockets.ConnectionClosed, ConnectionError) as e:
 75                if self.stop_event.is_set():
 76                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 77                    break
 78                click.secho(
 79                    f"Connection lost: {e}. Retrying in {retry_delay:.0f}s...",
 80                    fg="yellow",
 81                )
 82                await asyncio.sleep(retry_delay)
 83                retry_delay = min(retry_delay * 2, max_retry_delay)
 84            except Exception as e:
 85                if self.stop_event.is_set():
 86                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 87                    break
 88                click.secho(
 89                    f"Unexpected error: {e}. Retrying in {retry_delay:.0f}s...",
 90                    fg="yellow",
 91                )
 92                await asyncio.sleep(retry_delay)
 93                retry_delay = min(retry_delay * 2, max_retry_delay)
 94
 95    async def handle_messages(self, websocket: Any) -> None:
 96        try:
 97            async for message in websocket:
 98                if isinstance(message, str):
 99                    data = json.loads(message)
100                    msg_type = data.get("type")
101                    if msg_type == "ping":
102                        self.logger.debug("Received heartbeat ping, sending pong")
103                        await websocket.send(json.dumps({"type": "pong"}))
104                    elif msg_type == "request":
105                        self.logger.debug("Received request metadata from worker")
106                        await self.handle_request_metadata(websocket, data)
107                    else:
108                        self.logger.warning(
109                            f"Received unknown message type: {msg_type}"
110                        )
111                elif isinstance(message, bytes):
112                    self.logger.debug("Received binary data from worker")
113                    await self.handle_request_body_chunk(websocket, message)
114                else:
115                    self.logger.warning("Received unknown message format")
116        except asyncio.CancelledError:
117            self.logger.debug("Message handling cancelled")
118        except Exception as e:
119            self.logger.error(f"Error in handle_messages: {e}")
120            raise
121
122    async def handle_request_metadata(
123        self, websocket: Any, data: dict[str, Any]
124    ) -> None:
125        request_id = data["id"]
126        has_body = data.get("has_body", False)
127        total_body_chunks = data.get("totalBodyChunks", 0)
128        self.pending_requests[request_id] = {
129            "metadata": data,
130            "body_chunks": {},
131            "has_body": has_body,
132            "total_body_chunks": total_body_chunks,
133        }
134        self.logger.debug(
135            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
136        )
137        await self.check_and_process_request(websocket, request_id)
138
139    async def handle_request_body_chunk(
140        self, websocket: Any, chunk_data: bytes
141    ) -> None:
142        (id_length,) = struct.unpack_from("<I", chunk_data, 0)
143        request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
144        header_end = 4 + id_length + 8
145        chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
146        body_chunk = chunk_data[header_end:]
147
148        if request_id in self.pending_requests:
149            request = self.pending_requests[request_id]
150            request["body_chunks"][chunk_index] = body_chunk
151            self.logger.debug(
152                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
153            )
154            await self.check_and_process_request(websocket, request_id)
155        else:
156            self.logger.warning(
157                f"Received body chunk for unknown or completed request ID: {request_id}"
158            )
159
160    async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
161        request_data = self.pending_requests.get(request_id)
162        if not request_data:
163            return
164
165        has_body = request_data["has_body"]
166        total_body_chunks = request_data["total_body_chunks"]
167        body_chunks = request_data["body_chunks"]
168
169        all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
170        if not all_chunks_received:
171            return
172
173        for i in range(total_body_chunks):
174            if i not in body_chunks:
175                self.logger.error(
176                    f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
177                )
178                return
179
180        self.logger.debug(f"Processing request ID: {request_id}")
181        del self.pending_requests[request_id]
182        task = asyncio.create_task(
183            self.process_request(
184                websocket,
185                request_data["metadata"],
186                body_chunks,
187                request_id,
188            )
189        )
190        task.add_done_callback(self._handle_task_exception)
191
192    def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
193        if not task.cancelled() and task.exception():
194            self.logger.error(f"Error processing request: {task.exception()}")
195
196    async def process_request(
197        self,
198        websocket: Any,
199        request_metadata: dict[str, Any],
200        body_chunks: dict[int, bytes],
201        request_id: str,
202    ) -> None:
203        self.logger.debug(
204            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
205        )
206
207        if request_metadata["has_body"]:
208            total_chunks = request_metadata["totalBodyChunks"]
209            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
210        else:
211            body_data = None
212
213        parsed = urlparse(request_metadata["url"])
214        path = parsed.path
215        if parsed.query:
216            path = f"{path}?{parsed.query}"
217        forward_url = f"{self.destination_url}{path}"
218
219        self.logger.debug(f"Forwarding request to: {forward_url}")
220
221        async with httpx.AsyncClient(follow_redirects=False, verify=False) as client:
222            try:
223                response = await client.request(
224                    method=request_metadata["method"],
225                    url=forward_url,
226                    headers=request_metadata["headers"],
227                    content=body_data,
228                )
229                response_body = response.content
230                response_headers = dict(response.headers)
231                response_status = response.status_code
232                self.logger.debug(
233                    f"Received response with status code: {response_status}"
234                )
235            except httpx.ConnectError as e:
236                self.logger.error(f"Connection error forwarding request: {e}")
237                response_body = b""
238                response_headers = {}
239                response_status = 502
240
241        self.logger.info(
242            f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
243        )
244
245        has_body = len(response_body) > 0
246        max_chunk_size = 1_000_000
247        total_body_chunks = (
248            math.ceil(len(response_body) / max_chunk_size) if has_body else 0
249        )
250
251        response_metadata = {
252            "type": "response",
253            "id": request_id,
254            "status": response_status,
255            "headers": list(response_headers.items()),
256            "has_body": has_body,
257            "totalBodyChunks": total_body_chunks,
258        }
259
260        self.logger.debug(
261            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
262        )
263        await websocket.send(json.dumps(response_metadata))
264
265        if has_body:
266            self.logger.debug(
267                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
268            )
269            id_bytes = request_id.encode("utf-8")
270            for i in range(total_body_chunks):
271                chunk_start = i * max_chunk_size
272                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
273                header = id_bytes + struct.pack("<II", i, total_body_chunks)
274                await websocket.send(header + response_body[chunk_start:chunk_end])
275                self.logger.debug(
276                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
277                )
278
279    def run(self) -> None:
280        try:
281            asyncio.run(self.connect())
282        except KeyboardInterrupt:
283            self.logger.debug("Received exit signal")