1from __future__ import annotations
  2
  3import asyncio
  4import base64
  5import json
  6import logging
  7import math
  8import struct
  9from typing import Any
 10from urllib.parse import urlparse
 11
 12import click
 13import httpx
 14import websockets
 15
 16# Bump this when making breaking changes to the WebSocket protocol.
 17# The server will reject clients with a version lower than its minimum.
 18PROTOCOL_VERSION = 3
 19
 20
 21class TunnelClient:
 22    def __init__(
 23        self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
 24    ) -> None:
 25        self.destination_url = destination_url
 26        self.subdomain = subdomain
 27        self.tunnel_host = tunnel_host
 28
 29        if "localhost" in tunnel_host or "127.0.0.1" in tunnel_host:
 30            self.tunnel_http_url = f"http://{subdomain}.{tunnel_host}"
 31            self.tunnel_websocket_url = (
 32                f"ws://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
 33            )
 34        else:
 35            self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 36            self.tunnel_websocket_url = (
 37                f"wss://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
 38            )
 39
 40        self.logger = logging.getLogger(__name__)
 41        level = getattr(logging, log_level.upper())
 42        self.logger.setLevel(level)
 43        self.logger.propagate = False
 44        handler = logging.StreamHandler()
 45        handler.setLevel(level)
 46        handler.setFormatter(logging.Formatter("%(message)s"))
 47        self.logger.addHandler(handler)
 48
 49        self.pending_requests: dict[str, dict[str, Any]] = {}
 50        self.active_streams: dict[str, asyncio.Event] = {}
 51        self.proxied_websockets: dict[str, Any] = {}
 52        self.ws_pending_queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
 53        self.stop_event = asyncio.Event()
 54
 55    async def connect(self) -> None:
 56        retry_delay = 1.0
 57        max_retry_delay = 30.0
 58        while not self.stop_event.is_set():
 59            try:
 60                self.logger.debug(
 61                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 62                )
 63                async with websockets.connect(
 64                    self.tunnel_websocket_url, max_size=None
 65                ) as websocket:
 66                    self.logger.debug("WebSocket connection established")
 67                    click.secho(
 68                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 69                    )
 70                    retry_delay = 1.0
 71                    try:
 72                        await self.handle_messages(websocket)
 73                    finally:
 74                        await self._cleanup_proxied_websockets()
 75            except asyncio.CancelledError:
 76                self.logger.debug("Connection cancelled")
 77                break
 78            except websockets.InvalidStatus as e:
 79                if e.response.status_code == 426:
 80                    body = e.response.body.decode() if e.response.body else ""
 81                    click.secho(
 82                        body or "Client version too old. Please upgrade plain.tunnel.",
 83                        fg="red",
 84                    )
 85                    break
 86                raise
 87            except (websockets.ConnectionClosed, ConnectionError) as e:
 88                if self.stop_event.is_set():
 89                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 90                    break
 91                click.secho(
 92                    f"Connection lost: {e}. Retrying in {retry_delay:.0f}s...",
 93                    fg="yellow",
 94                )
 95                await asyncio.sleep(retry_delay)
 96                retry_delay = min(retry_delay * 2, max_retry_delay)
 97            except Exception as e:
 98                if self.stop_event.is_set():
 99                    self.logger.debug("Stopping reconnect attempts due to shutdown")
100                    break
101                click.secho(
102                    f"Unexpected error: {e}. Retrying in {retry_delay:.0f}s...",
103                    fg="yellow",
104                )
105                await asyncio.sleep(retry_delay)
106                retry_delay = min(retry_delay * 2, max_retry_delay)
107
108    async def handle_messages(self, websocket: Any) -> None:
109        try:
110            async for message in websocket:
111                if isinstance(message, str):
112                    data = json.loads(message)
113                    msg_type = data.get("type")
114                    if msg_type == "ping":
115                        self.logger.debug("Received heartbeat ping, sending pong")
116                        await websocket.send(json.dumps({"type": "pong"}))
117                    elif msg_type == "request":
118                        self.logger.debug("Received request metadata from worker")
119                        await self.handle_request_metadata(websocket, data)
120                    elif msg_type == "stream-cancel":
121                        request_id = data.get("id")
122                        self.logger.debug(
123                            f"Received stream-cancel for request ID: {request_id}"
124                        )
125                        cancel_event = self.active_streams.get(request_id)
126                        if cancel_event:
127                            cancel_event.set()
128                    elif msg_type == "ws-open":
129                        self.logger.debug(f"Received ws-open for ID: {data['id']}")
130                        self.ws_pending_queues[data["id"]] = asyncio.Queue()
131                        task = asyncio.create_task(
132                            self._handle_ws_open(websocket, data)
133                        )
134                        task.add_done_callback(self._handle_task_exception)
135                    elif msg_type == "ws-message":
136                        await self._handle_ws_message(data)
137                    elif msg_type == "ws-close":
138                        await self._handle_ws_close(data)
139                    else:
140                        self.logger.warning(
141                            f"Received unknown message type: {msg_type}"
142                        )
143                elif isinstance(message, bytes):
144                    self.logger.debug("Received binary data from worker")
145                    await self.handle_request_body_chunk(websocket, message)
146                else:
147                    self.logger.warning("Received unknown message format")
148        except asyncio.CancelledError:
149            self.logger.debug("Message handling cancelled")
150        except Exception as e:
151            self.logger.error(f"Error in handle_messages: {e}")
152            raise
153
154    async def handle_request_metadata(
155        self, websocket: Any, data: dict[str, Any]
156    ) -> None:
157        request_id = data["id"]
158        has_body = data.get("has_body", False)
159        total_body_chunks = data.get("totalBodyChunks", 0)
160        self.pending_requests[request_id] = {
161            "metadata": data,
162            "body_chunks": {},
163            "has_body": has_body,
164            "total_body_chunks": total_body_chunks,
165        }
166        self.logger.debug(
167            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
168        )
169        await self.check_and_process_request(websocket, request_id)
170
171    async def handle_request_body_chunk(
172        self, websocket: Any, chunk_data: bytes
173    ) -> None:
174        (id_length,) = struct.unpack_from("<I", chunk_data, 0)
175        request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
176        header_end = 4 + id_length + 8
177        chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
178        body_chunk = chunk_data[header_end:]
179
180        if request_id in self.pending_requests:
181            request = self.pending_requests[request_id]
182            request["body_chunks"][chunk_index] = body_chunk
183            self.logger.debug(
184                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
185            )
186            await self.check_and_process_request(websocket, request_id)
187        else:
188            self.logger.warning(
189                f"Received body chunk for unknown or completed request ID: {request_id}"
190            )
191
192    async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
193        request_data = self.pending_requests.get(request_id)
194        if not request_data:
195            return
196
197        has_body = request_data["has_body"]
198        total_body_chunks = request_data["total_body_chunks"]
199        body_chunks = request_data["body_chunks"]
200
201        all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
202        if not all_chunks_received:
203            return
204
205        for i in range(total_body_chunks):
206            if i not in body_chunks:
207                self.logger.error(
208                    f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
209                )
210                return
211
212        self.logger.debug(f"Processing request ID: {request_id}")
213        del self.pending_requests[request_id]
214        task = asyncio.create_task(
215            self.process_request(
216                websocket,
217                request_data["metadata"],
218                body_chunks,
219                request_id,
220            )
221        )
222        task.add_done_callback(self._handle_task_exception)
223
224    def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
225        if not task.cancelled() and task.exception():
226            self.logger.error("Error processing request", exc_info=task.exception())
227
228    async def _cleanup_proxied_websockets(self) -> None:
229        """Close all proxied WebSocket connections on tunnel disconnect."""
230        for ws_id, ws in list(self.proxied_websockets.items()):
231            try:
232                await ws.close()
233            except Exception:
234                pass
235        self.proxied_websockets.clear()
236        self.ws_pending_queues.clear()
237
238    async def process_request(
239        self,
240        websocket: Any,
241        request_metadata: dict[str, Any],
242        body_chunks: dict[int, bytes],
243        request_id: str,
244    ) -> None:
245        self.logger.debug(
246            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
247        )
248
249        if request_metadata["has_body"]:
250            total_chunks = request_metadata["totalBodyChunks"]
251            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
252        else:
253            body_data = None
254
255        parsed = urlparse(request_metadata["url"])
256        path = parsed.path
257        if parsed.query:
258            path = f"{path}?{parsed.query}"
259        forward_url = f"{self.destination_url}{path}"
260
261        self.logger.debug(f"Forwarding request to: {forward_url}")
262
263        async with httpx.AsyncClient(follow_redirects=False, verify=False) as client:
264            try:
265                async with client.stream(
266                    method=request_metadata["method"],
267                    url=forward_url,
268                    headers=request_metadata["headers"],
269                    content=body_data,
270                ) as response:
271                    response_status = response.status_code
272                    response_headers = dict(response.headers)
273
274                    self.logger.info(
275                        f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
276                    )
277
278                    if self._is_streaming_response(response):
279                        await self._handle_streaming_response(
280                            websocket,
281                            response,
282                            request_id,
283                            response_status,
284                            response_headers,
285                        )
286                    else:
287                        await response.aread()
288                        await self._handle_buffered_response(
289                            websocket,
290                            response.content,
291                            request_id,
292                            response_status,
293                            response_headers,
294                        )
295            except httpx.ConnectError as e:
296                self.logger.error(f"Connection error forwarding request: {e}")
297                self.logger.info(
298                    f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} 502"
299                )
300                await self._handle_buffered_response(
301                    websocket, b"", request_id, 502, {}
302                )
303
304    def _is_streaming_response(self, response: httpx.Response) -> bool:
305        content_type = response.headers.get("content-type", "")
306        return "text/event-stream" in content_type
307
308    async def _handle_buffered_response(
309        self,
310        websocket: Any,
311        response_body: bytes,
312        request_id: str,
313        response_status: int,
314        response_headers: dict[str, str],
315    ) -> None:
316        has_body = len(response_body) > 0
317        max_chunk_size = 1_000_000
318        total_body_chunks = (
319            math.ceil(len(response_body) / max_chunk_size) if has_body else 0
320        )
321
322        response_metadata = {
323            "type": "response",
324            "id": request_id,
325            "status": response_status,
326            "headers": list(response_headers.items()),
327            "has_body": has_body,
328            "totalBodyChunks": total_body_chunks,
329        }
330
331        self.logger.debug(
332            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
333        )
334        await websocket.send(json.dumps(response_metadata))
335
336        if has_body:
337            self.logger.debug(
338                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
339            )
340            id_bytes = request_id.encode("utf-8")
341            for i in range(total_body_chunks):
342                chunk_start = i * max_chunk_size
343                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
344                header = id_bytes + struct.pack("<II", i, total_body_chunks)
345                await websocket.send(header + response_body[chunk_start:chunk_end])
346                self.logger.debug(
347                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
348                )
349
350    async def _handle_streaming_response(
351        self,
352        websocket: Any,
353        response: httpx.Response,
354        request_id: str,
355        response_status: int,
356        response_headers: dict[str, str],
357    ) -> None:
358        cancel_event = asyncio.Event()
359        self.active_streams[request_id] = cancel_event
360
361        stream_start = {
362            "type": "stream-start",
363            "id": request_id,
364            "status": response_status,
365            "headers": list(response_headers.items()),
366        }
367
368        self.logger.debug(f"Sending stream-start for ID: {request_id}")
369        await websocket.send(json.dumps(stream_start))
370
371        id_bytes = request_id.encode("utf-8")
372
373        try:
374            async for chunk in response.aiter_bytes():
375                if cancel_event.is_set():
376                    self.logger.debug(
377                        f"Stream cancelled by browser for request ID: {request_id}"
378                    )
379                    break
380
381                await websocket.send(id_bytes + chunk)
382            else:
383                # Only send stream-end if the loop completed naturally
384                # (not cancelled by the server via stream-cancel)
385                stream_end = {
386                    "type": "stream-end",
387                    "id": request_id,
388                }
389                self.logger.debug(f"Sending stream-end for ID: {request_id}")
390                await websocket.send(json.dumps(stream_end))
391        except Exception as e:
392            self.logger.error(f"Error streaming response for ID {request_id}: {e}")
393            stream_error = {
394                "type": "stream-error",
395                "id": request_id,
396                "error": str(e),
397            }
398            try:
399                await websocket.send(json.dumps(stream_error))
400            except Exception:
401                pass
402        finally:
403            self.active_streams.pop(request_id, None)
404
405    async def _handle_ws_open(self, tunnel_ws: Any, data: dict[str, Any]) -> None:
406        ws_id = data["id"]
407        url = data["url"]
408        parsed = urlparse(url)
409        path = parsed.path
410        if parsed.query:
411            path = f"{path}?{parsed.query}"
412
413        # Build local WebSocket URL
414        dest_parsed = urlparse(self.destination_url)
415        if dest_parsed.scheme == "https":
416            ws_scheme = "wss"
417        else:
418            ws_scheme = "ws"
419        local_ws_url = f"{ws_scheme}://{dest_parsed.netloc}{path}"
420
421        self.logger.debug(f"Opening local WebSocket for {ws_id}: {local_ws_url}")
422
423        # Forward safe browser headers (cookies, auth, origin) to the local
424        # server. Skip hop-by-hop and WebSocket handshake headers since
425        # websockets.connect generates its own (including Host from the URL).
426        skip_headers = frozenset(
427            {
428                "host",
429                "connection",
430                "upgrade",
431                "sec-websocket-key",
432                "sec-websocket-version",
433                "sec-websocket-extensions",
434                "sec-websocket-protocol",
435                "host",
436            }
437        )
438        forward_headers = {}
439        for name, value in data.get("headers", {}).items():
440            if name.lower() not in skip_headers:
441                forward_headers[name] = value
442
443        try:
444            import ssl
445
446            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
447            ssl_context.check_hostname = False
448            ssl_context.verify_mode = ssl.CERT_NONE
449            local_ws = await websockets.connect(
450                local_ws_url,
451                ssl=ssl_context if ws_scheme == "wss" else None,
452                max_size=None,
453                additional_headers=forward_headers,
454            )
455        except Exception as e:
456            self.logger.error(f"Failed to connect local WebSocket for {ws_id}: {e}")
457            self.ws_pending_queues.pop(ws_id, None)
458            try:
459                await tunnel_ws.send(
460                    json.dumps(
461                        {
462                            "type": "ws-close",
463                            "id": ws_id,
464                            "code": 1011,
465                            "reason": str(e),
466                        }
467                    )
468                )
469            except Exception:
470                pass
471            return
472
473        self.proxied_websockets[ws_id] = local_ws
474
475        # Drain any messages that arrived while connecting
476        queue = self.ws_pending_queues.pop(ws_id, None)
477        if queue is not None:
478            while not queue.empty():
479                queued = queue.get_nowait()
480                try:
481                    if queued.get("binary"):
482                        await local_ws.send(base64.b64decode(queued["data"]))
483                    else:
484                        await local_ws.send(queued["data"])
485                except Exception as e:
486                    self.logger.error(
487                        f"Failed to forward queued message to local WebSocket {ws_id}: {e}"
488                    )
489
490        self.logger.info(f"WebSocket proxy opened: {ws_id} -> {local_ws_url}")
491
492        # Relay messages from local server back to the tunnel
493        try:
494            async for message in local_ws:
495                if isinstance(message, str):
496                    await tunnel_ws.send(
497                        json.dumps({"type": "ws-message", "id": ws_id, "data": message})
498                    )
499                elif isinstance(message, bytes):
500                    await tunnel_ws.send(
501                        json.dumps(
502                            {
503                                "type": "ws-message",
504                                "id": ws_id,
505                                "data": base64.b64encode(message).decode("ascii"),
506                                "binary": True,
507                            }
508                        )
509                    )
510        except websockets.ConnectionClosed:
511            pass
512        except Exception as e:
513            self.logger.error(f"Error relaying WebSocket {ws_id}: {e}")
514        finally:
515            self.proxied_websockets.pop(ws_id, None)
516            close_code = local_ws.close_code or 1000
517            close_reason = local_ws.close_reason or ""
518            try:
519                await tunnel_ws.send(
520                    json.dumps(
521                        {
522                            "type": "ws-close",
523                            "id": ws_id,
524                            "code": close_code,
525                            "reason": close_reason,
526                        }
527                    )
528                )
529            except Exception:
530                pass
531
532    async def _handle_ws_message(self, data: dict[str, Any]) -> None:
533        ws_id = data["id"]
534        local_ws = self.proxied_websockets.get(ws_id)
535        if not local_ws:
536            # Connection still being established — buffer for later
537            queue = self.ws_pending_queues.get(ws_id)
538            if queue is not None:
539                await queue.put(data)
540                return
541            self.logger.warning(f"Received ws-message for unknown WebSocket: {ws_id}")
542            return
543        try:
544            if data.get("binary"):
545                await local_ws.send(base64.b64decode(data["data"]))
546            else:
547                await local_ws.send(data["data"])
548        except Exception as e:
549            self.logger.error(
550                f"Failed to forward message to local WebSocket {ws_id}: {e}"
551            )
552
553    async def _handle_ws_close(self, data: dict[str, Any]) -> None:
554        ws_id = data["id"]
555        local_ws = self.proxied_websockets.pop(ws_id, None)
556        if not local_ws:
557            return
558        try:
559            await local_ws.close()
560        except Exception:
561            pass
562
563    def run(self) -> None:
564        try:
565            asyncio.run(self.connect())
566        except KeyboardInterrupt:
567            self.logger.debug("Received exit signal")