1from __future__ import annotations
  2
  3import asyncio
  4import base64
  5import json
  6import logging
  7import math
  8import struct
  9from typing import Any
 10from urllib.parse import urlparse
 11
 12import click
 13import httpx
 14import websockets
 15
 16# Bump this when making breaking changes to the WebSocket protocol.
 17# The server will reject clients with a version lower than its minimum.
 18PROTOCOL_VERSION = 3
 19
 20
 21class TunnelClient:
 22    def __init__(
 23        self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
 24    ) -> None:
 25        self.destination_url = destination_url
 26        self.subdomain = subdomain
 27        self.tunnel_host = tunnel_host
 28
 29        if "localhost" in tunnel_host or "127.0.0.1" in tunnel_host:
 30            self.tunnel_http_url = f"http://{subdomain}.{tunnel_host}"
 31            self.tunnel_websocket_url = (
 32                f"ws://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
 33            )
 34        else:
 35            self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 36            self.tunnel_websocket_url = (
 37                f"wss://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
 38            )
 39
 40        self.logger = logging.getLogger(__name__)
 41        level = getattr(logging, log_level.upper())
 42        self.logger.setLevel(level)
 43        self.logger.propagate = False
 44        handler = logging.StreamHandler()
 45        handler.setLevel(level)
 46        handler.setFormatter(logging.Formatter("%(message)s"))
 47        self.logger.addHandler(handler)
 48
 49        self.pending_requests: dict[str, dict[str, Any]] = {}
 50        self.active_streams: dict[str, asyncio.Event] = {}
 51        self.proxied_websockets: dict[str, Any] = {}
 52        self.ws_pending_queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
 53        self.stop_event = asyncio.Event()
 54
 55    async def connect(self) -> None:
 56        retry_delay = 1.0
 57        max_retry_delay = 30.0
 58        while not self.stop_event.is_set():
 59            try:
 60                self.logger.debug(
 61                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 62                )
 63                async with websockets.connect(
 64                    self.tunnel_websocket_url, max_size=None
 65                ) as websocket:
 66                    self.logger.debug("WebSocket connection established")
 67                    click.secho(
 68                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 69                    )
 70                    retry_delay = 1.0
 71                    try:
 72                        await self.handle_messages(websocket)
 73                    finally:
 74                        await self._cleanup_proxied_websockets()
 75            except asyncio.CancelledError:
 76                self.logger.debug("Connection cancelled")
 77                break
 78            except websockets.InvalidStatus as e:
 79                if e.response.status_code == 426:
 80                    body = e.response.body.decode() if e.response.body else ""
 81                    click.secho(
 82                        body or "Client version too old. Please upgrade plain.tunnel.",
 83                        fg="red",
 84                    )
 85                    break
 86                raise
 87            except (websockets.ConnectionClosed, ConnectionError) as e:
 88                if self.stop_event.is_set():
 89                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 90                    break
 91                click.secho(
 92                    f"Connection lost: {e}. Retrying in {retry_delay:.0f}s...",
 93                    fg="yellow",
 94                )
 95                await asyncio.sleep(retry_delay)
 96                retry_delay = min(retry_delay * 2, max_retry_delay)
 97            except Exception as e:
 98                if self.stop_event.is_set():
 99                    self.logger.debug("Stopping reconnect attempts due to shutdown")
100                    break
101                click.secho(
102                    f"Unexpected error: {e}. Retrying in {retry_delay:.0f}s...",
103                    fg="yellow",
104                )
105                await asyncio.sleep(retry_delay)
106                retry_delay = min(retry_delay * 2, max_retry_delay)
107
108    async def handle_messages(self, websocket: Any) -> None:
109        try:
110            async for message in websocket:
111                if isinstance(message, str):
112                    data = json.loads(message)
113                    msg_type = data.get("type")
114                    if msg_type == "ping":
115                        self.logger.debug("Received heartbeat ping, sending pong")
116                        await websocket.send(json.dumps({"type": "pong"}))
117                    elif msg_type == "request":
118                        self.logger.debug("Received request metadata from worker")
119                        await self.handle_request_metadata(websocket, data)
120                    elif msg_type == "stream-cancel":
121                        request_id = data.get("id")
122                        self.logger.debug(
123                            f"Received stream-cancel for request ID: {request_id}"
124                        )
125                        cancel_event = self.active_streams.get(request_id)
126                        if cancel_event:
127                            cancel_event.set()
128                    elif msg_type == "ws-open":
129                        self.logger.debug(f"Received ws-open for ID: {data['id']}")
130                        self.ws_pending_queues[data["id"]] = asyncio.Queue()
131                        task = asyncio.create_task(
132                            self._handle_ws_open(websocket, data)
133                        )
134                        task.add_done_callback(self._handle_task_exception)
135                    elif msg_type == "ws-message":
136                        await self._handle_ws_message(data)
137                    elif msg_type == "ws-close":
138                        await self._handle_ws_close(data)
139                    else:
140                        self.logger.warning(
141                            f"Received unknown message type: {msg_type}"
142                        )
143                elif isinstance(message, bytes):
144                    self.logger.debug("Received binary data from worker")
145                    await self.handle_request_body_chunk(websocket, message)
146                else:
147                    self.logger.warning("Received unknown message format")
148        except asyncio.CancelledError:
149            self.logger.debug("Message handling cancelled")
150        except Exception as e:
151            self.logger.error(f"Error in handle_messages: {e}")
152            raise
153
154    async def handle_request_metadata(
155        self, websocket: Any, data: dict[str, Any]
156    ) -> None:
157        request_id = data["id"]
158        has_body = data.get("has_body", False)
159        total_body_chunks = data.get("totalBodyChunks", 0)
160        self.pending_requests[request_id] = {
161            "metadata": data,
162            "body_chunks": {},
163            "has_body": has_body,
164            "total_body_chunks": total_body_chunks,
165        }
166        self.logger.debug(
167            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
168        )
169        await self.check_and_process_request(websocket, request_id)
170
171    async def handle_request_body_chunk(
172        self, websocket: Any, chunk_data: bytes
173    ) -> None:
174        (id_length,) = struct.unpack_from("<I", chunk_data, 0)
175        request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
176        header_end = 4 + id_length + 8
177        chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
178        body_chunk = chunk_data[header_end:]
179
180        if request_id in self.pending_requests:
181            request = self.pending_requests[request_id]
182            request["body_chunks"][chunk_index] = body_chunk
183            self.logger.debug(
184                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
185            )
186            await self.check_and_process_request(websocket, request_id)
187        else:
188            self.logger.warning(
189                f"Received body chunk for unknown or completed request ID: {request_id}"
190            )
191
192    async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
193        request_data = self.pending_requests.get(request_id)
194        if not request_data:
195            return
196
197        has_body = request_data["has_body"]
198        total_body_chunks = request_data["total_body_chunks"]
199        body_chunks = request_data["body_chunks"]
200
201        all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
202        if not all_chunks_received:
203            return
204
205        for i in range(total_body_chunks):
206            if i not in body_chunks:
207                self.logger.error(
208                    f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
209                )
210                return
211
212        self.logger.debug(f"Processing request ID: {request_id}")
213        del self.pending_requests[request_id]
214        task = asyncio.create_task(
215            self.process_request(
216                websocket,
217                request_data["metadata"],
218                body_chunks,
219                request_id,
220            )
221        )
222        task.add_done_callback(self._handle_task_exception)
223
224    def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
225        if not task.cancelled() and task.exception():
226            self.logger.error("Error processing request", exc_info=task.exception())
227
228    async def _cleanup_proxied_websockets(self) -> None:
229        """Close all proxied WebSocket connections on tunnel disconnect."""
230        for ws_id, ws in list(self.proxied_websockets.items()):
231            try:
232                await ws.close()
233            except Exception:
234                pass
235        self.proxied_websockets.clear()
236        self.ws_pending_queues.clear()
237
238    async def process_request(
239        self,
240        websocket: Any,
241        request_metadata: dict[str, Any],
242        body_chunks: dict[int, bytes],
243        request_id: str,
244    ) -> None:
245        self.logger.debug(
246            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
247        )
248
249        if request_metadata["has_body"]:
250            total_chunks = request_metadata["totalBodyChunks"]
251            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
252        else:
253            body_data = None
254
255        parsed = urlparse(request_metadata["url"])
256        path = parsed.path
257        if parsed.query:
258            path = f"{path}?{parsed.query}"
259        forward_url = f"{self.destination_url}{path}"
260
261        self.logger.debug(f"Forwarding request to: {forward_url}")
262
263        async with httpx.AsyncClient(
264            follow_redirects=False, verify=False, timeout=30
265        ) as client:
266            try:
267                async with client.stream(
268                    method=request_metadata["method"],
269                    url=forward_url,
270                    headers=request_metadata["headers"],
271                    content=body_data,
272                ) as response:
273                    response_status = response.status_code
274                    response_headers = dict(response.headers)
275
276                    self.logger.info(
277                        f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
278                    )
279
280                    if self._is_streaming_response(response):
281                        await self._handle_streaming_response(
282                            websocket,
283                            response,
284                            request_id,
285                            response_status,
286                            response_headers,
287                        )
288                    else:
289                        await response.aread()
290                        await self._handle_buffered_response(
291                            websocket,
292                            response.content,
293                            request_id,
294                            response_status,
295                            response_headers,
296                        )
297            except httpx.ConnectError as e:
298                self.logger.error(f"Connection error forwarding request: {e}")
299                self.logger.info(
300                    f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} 502"
301                )
302                await self._handle_buffered_response(
303                    websocket, b"", request_id, 502, {}
304                )
305
306    def _is_streaming_response(self, response: httpx.Response) -> bool:
307        content_type = response.headers.get("content-type", "")
308        return "text/event-stream" in content_type
309
310    async def _handle_buffered_response(
311        self,
312        websocket: Any,
313        response_body: bytes,
314        request_id: str,
315        response_status: int,
316        response_headers: dict[str, str],
317    ) -> None:
318        has_body = len(response_body) > 0
319        max_chunk_size = 1_000_000
320        total_body_chunks = (
321            math.ceil(len(response_body) / max_chunk_size) if has_body else 0
322        )
323
324        response_metadata = {
325            "type": "response",
326            "id": request_id,
327            "status": response_status,
328            "headers": list(response_headers.items()),
329            "has_body": has_body,
330            "totalBodyChunks": total_body_chunks,
331        }
332
333        self.logger.debug(
334            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
335        )
336        await websocket.send(json.dumps(response_metadata))
337
338        if has_body:
339            self.logger.debug(
340                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
341            )
342            id_bytes = request_id.encode("utf-8")
343            for i in range(total_body_chunks):
344                chunk_start = i * max_chunk_size
345                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
346                header = id_bytes + struct.pack("<II", i, total_body_chunks)
347                await websocket.send(header + response_body[chunk_start:chunk_end])
348                self.logger.debug(
349                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
350                )
351
352    async def _handle_streaming_response(
353        self,
354        websocket: Any,
355        response: httpx.Response,
356        request_id: str,
357        response_status: int,
358        response_headers: dict[str, str],
359    ) -> None:
360        cancel_event = asyncio.Event()
361        self.active_streams[request_id] = cancel_event
362
363        stream_start = {
364            "type": "stream-start",
365            "id": request_id,
366            "status": response_status,
367            "headers": list(response_headers.items()),
368        }
369
370        self.logger.debug(f"Sending stream-start for ID: {request_id}")
371        await websocket.send(json.dumps(stream_start))
372
373        id_bytes = request_id.encode("utf-8")
374
375        try:
376            async for chunk in response.aiter_bytes():
377                if cancel_event.is_set():
378                    self.logger.debug(
379                        f"Stream cancelled by browser for request ID: {request_id}"
380                    )
381                    break
382
383                await websocket.send(id_bytes + chunk)
384            else:
385                # Only send stream-end if the loop completed naturally
386                # (not cancelled by the server via stream-cancel)
387                stream_end = {
388                    "type": "stream-end",
389                    "id": request_id,
390                }
391                self.logger.debug(f"Sending stream-end for ID: {request_id}")
392                await websocket.send(json.dumps(stream_end))
393        except Exception as e:
394            self.logger.error(f"Error streaming response for ID {request_id}: {e}")
395            stream_error = {
396                "type": "stream-error",
397                "id": request_id,
398                "error": str(e),
399            }
400            try:
401                await websocket.send(json.dumps(stream_error))
402            except Exception:
403                pass
404        finally:
405            self.active_streams.pop(request_id, None)
406
407    async def _handle_ws_open(self, tunnel_ws: Any, data: dict[str, Any]) -> None:
408        ws_id = data["id"]
409        url = data["url"]
410        parsed = urlparse(url)
411        path = parsed.path
412        if parsed.query:
413            path = f"{path}?{parsed.query}"
414
415        # Build local WebSocket URL
416        dest_parsed = urlparse(self.destination_url)
417        if dest_parsed.scheme == "https":
418            ws_scheme = "wss"
419        else:
420            ws_scheme = "ws"
421        local_ws_url = f"{ws_scheme}://{dest_parsed.netloc}{path}"
422
423        self.logger.debug(f"Opening local WebSocket for {ws_id}: {local_ws_url}")
424
425        # Forward safe browser headers (cookies, auth, origin) to the local
426        # server. Skip hop-by-hop and WebSocket handshake headers since
427        # websockets.connect generates its own (including Host from the URL).
428        skip_headers = frozenset(
429            {
430                "host",
431                "connection",
432                "upgrade",
433                "sec-websocket-key",
434                "sec-websocket-version",
435                "sec-websocket-extensions",
436                "sec-websocket-protocol",
437                "host",
438            }
439        )
440        forward_headers = {}
441        for name, value in data.get("headers", {}).items():
442            if name.lower() not in skip_headers:
443                forward_headers[name] = value
444
445        try:
446            import ssl
447
448            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
449            ssl_context.check_hostname = False
450            ssl_context.verify_mode = ssl.CERT_NONE
451            local_ws = await websockets.connect(
452                local_ws_url,
453                ssl=ssl_context if ws_scheme == "wss" else None,
454                max_size=None,
455                additional_headers=forward_headers,
456            )
457        except Exception as e:
458            self.logger.error(f"Failed to connect local WebSocket for {ws_id}: {e}")
459            self.ws_pending_queues.pop(ws_id, None)
460            try:
461                await tunnel_ws.send(
462                    json.dumps(
463                        {
464                            "type": "ws-close",
465                            "id": ws_id,
466                            "code": 1011,
467                            "reason": str(e),
468                        }
469                    )
470                )
471            except Exception:
472                pass
473            return
474
475        self.proxied_websockets[ws_id] = local_ws
476
477        # Drain any messages that arrived while connecting
478        queue = self.ws_pending_queues.pop(ws_id, None)
479        if queue is not None:
480            while not queue.empty():
481                queued = queue.get_nowait()
482                try:
483                    if queued.get("binary"):
484                        await local_ws.send(base64.b64decode(queued["data"]))
485                    else:
486                        await local_ws.send(queued["data"])
487                except Exception as e:
488                    self.logger.error(
489                        f"Failed to forward queued message to local WebSocket {ws_id}: {e}"
490                    )
491
492        self.logger.info(f"WebSocket proxy opened: {ws_id} -> {local_ws_url}")
493
494        # Relay messages from local server back to the tunnel
495        try:
496            async for message in local_ws:
497                if isinstance(message, str):
498                    await tunnel_ws.send(
499                        json.dumps({"type": "ws-message", "id": ws_id, "data": message})
500                    )
501                elif isinstance(message, bytes):
502                    await tunnel_ws.send(
503                        json.dumps(
504                            {
505                                "type": "ws-message",
506                                "id": ws_id,
507                                "data": base64.b64encode(message).decode("ascii"),
508                                "binary": True,
509                            }
510                        )
511                    )
512        except websockets.ConnectionClosed:
513            pass
514        except Exception as e:
515            self.logger.error(f"Error relaying WebSocket {ws_id}: {e}")
516        finally:
517            self.proxied_websockets.pop(ws_id, None)
518            close_code = local_ws.close_code or 1000
519            close_reason = local_ws.close_reason or ""
520            try:
521                await tunnel_ws.send(
522                    json.dumps(
523                        {
524                            "type": "ws-close",
525                            "id": ws_id,
526                            "code": close_code,
527                            "reason": close_reason,
528                        }
529                    )
530                )
531            except Exception:
532                pass
533
534    async def _handle_ws_message(self, data: dict[str, Any]) -> None:
535        ws_id = data["id"]
536        local_ws = self.proxied_websockets.get(ws_id)
537        if not local_ws:
538            # Connection still being established — buffer for later
539            queue = self.ws_pending_queues.get(ws_id)
540            if queue is not None:
541                await queue.put(data)
542                return
543            self.logger.warning(f"Received ws-message for unknown WebSocket: {ws_id}")
544            return
545        try:
546            if data.get("binary"):
547                await local_ws.send(base64.b64decode(data["data"]))
548            else:
549                await local_ws.send(data["data"])
550        except Exception as e:
551            self.logger.error(
552                f"Failed to forward message to local WebSocket {ws_id}: {e}"
553            )
554
555    async def _handle_ws_close(self, data: dict[str, Any]) -> None:
556        ws_id = data["id"]
557        local_ws = self.proxied_websockets.pop(ws_id, None)
558        if not local_ws:
559            return
560        try:
561            await local_ws.close()
562        except Exception:
563            pass
564
565    def run(self) -> None:
566        try:
567            asyncio.run(self.connect())
568        except KeyboardInterrupt:
569            self.logger.debug("Received exit signal")