v0.148.0
  1from __future__ import annotations
  2
  3import asyncio
  4import base64
  5import json
  6import logging
  7import math
  8import struct
  9import time
 10from typing import Any
 11from urllib.parse import urlparse
 12
 13import click
 14import httpx
 15import websockets
 16
 17# Bump this when making breaking changes to the WebSocket protocol.
 18# The server will reject clients with a version lower than its minimum.
 19PROTOCOL_VERSION = 3
 20
 21
 22class TunnelClient:
 23    def __init__(
 24        self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
 25    ) -> None:
 26        self.destination_url = destination_url
 27        self.subdomain = subdomain
 28        self.tunnel_host = tunnel_host
 29
 30        if "localhost" in tunnel_host or "127.0.0.1" in tunnel_host:
 31            self.tunnel_http_url = f"http://{subdomain}.{tunnel_host}"
 32            self.tunnel_websocket_url = (
 33                f"ws://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
 34            )
 35        else:
 36            self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 37            self.tunnel_websocket_url = (
 38                f"wss://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
 39            )
 40
 41        self.logger = logging.getLogger(__name__)
 42        level = getattr(logging, log_level.upper())
 43        self.logger.setLevel(level)
 44        self.logger.propagate = False
 45        handler = logging.StreamHandler()
 46        handler.setLevel(level)
 47        handler.setFormatter(logging.Formatter("%(message)s"))
 48        self.logger.addHandler(handler)
 49
 50        self.pending_requests: dict[str, dict[str, Any]] = {}
 51        self.active_streams: dict[str, asyncio.Event] = {}
 52        self.proxied_websockets: dict[str, Any] = {}
 53        self.ws_pending_queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
 54        self.stop_event = asyncio.Event()
 55
 56    async def connect(self) -> None:
 57        retry_delay = 1.0
 58        max_retry_delay = 30.0
 59        # Connection must stay up at least this long to be considered healthy
 60        # and reset the backoff. Otherwise we keep escalating, which prevents
 61        # tight reconnect loops when something else (e.g. another client
 62        # claiming the same subdomain) keeps closing us right after connect.
 63        healthy_connection_seconds = 5.0
 64        while not self.stop_event.is_set():
 65            connection_duration: float | None = None
 66            try:
 67                self.logger.debug(
 68                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 69                )
 70                async with websockets.connect(
 71                    self.tunnel_websocket_url, max_size=None
 72                ) as websocket:
 73                    self.logger.debug("WebSocket connection established")
 74                    click.secho(
 75                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 76                    )
 77                    connected_at = time.monotonic()
 78                    try:
 79                        await self.handle_messages(websocket)
 80                    finally:
 81                        connection_duration = time.monotonic() - connected_at
 82                        await self._cleanup_proxied_websockets()
 83                if self.stop_event.is_set():
 84                    break
 85                disconnect_message = "Tunnel disconnected by server."
 86            except asyncio.CancelledError:
 87                self.logger.debug("Connection cancelled")
 88                break
 89            except websockets.InvalidStatus as e:
 90                if e.response.status_code == 426:
 91                    body = e.response.body.decode() if e.response.body else ""
 92                    click.secho(
 93                        body or "Client version too old. Please upgrade plain.tunnel.",
 94                        fg="red",
 95                    )
 96                    break
 97                raise
 98            except (websockets.ConnectionClosed, ConnectionError) as e:
 99                if self.stop_event.is_set():
100                    self.logger.debug("Stopping reconnect attempts due to shutdown")
101                    break
102                disconnect_message = f"Connection lost: {e}."
103            except Exception as e:
104                if self.stop_event.is_set():
105                    self.logger.debug("Stopping reconnect attempts due to shutdown")
106                    break
107                disconnect_message = f"Unexpected error: {e}."
108
109            if (
110                connection_duration is not None
111                and connection_duration >= healthy_connection_seconds
112            ):
113                retry_delay = 1.0
114            click.secho(
115                f"{disconnect_message} Retrying in {retry_delay:.0f}s...",
116                fg="yellow",
117            )
118            await asyncio.sleep(retry_delay)
119            retry_delay = min(retry_delay * 2, max_retry_delay)
120
121    async def handle_messages(self, websocket: Any) -> None:
122        try:
123            async for message in websocket:
124                if isinstance(message, str):
125                    data = json.loads(message)
126                    msg_type = data.get("type")
127                    if msg_type == "ping":
128                        self.logger.debug("Received heartbeat ping, sending pong")
129                        await websocket.send(json.dumps({"type": "pong"}))
130                    elif msg_type == "request":
131                        self.logger.debug("Received request metadata from worker")
132                        await self.handle_request_metadata(websocket, data)
133                    elif msg_type == "stream-cancel":
134                        request_id = data.get("id")
135                        self.logger.debug(
136                            f"Received stream-cancel for request ID: {request_id}"
137                        )
138                        cancel_event = self.active_streams.get(request_id)
139                        if cancel_event:
140                            cancel_event.set()
141                    elif msg_type == "ws-open":
142                        self.logger.debug(f"Received ws-open for ID: {data['id']}")
143                        self.ws_pending_queues[data["id"]] = asyncio.Queue()
144                        task = asyncio.create_task(
145                            self._handle_ws_open(websocket, data)
146                        )
147                        task.add_done_callback(self._handle_task_exception)
148                    elif msg_type == "ws-message":
149                        await self._handle_ws_message(data)
150                    elif msg_type == "ws-close":
151                        await self._handle_ws_close(data)
152                    else:
153                        self.logger.warning(
154                            f"Received unknown message type: {msg_type}"
155                        )
156                elif isinstance(message, bytes):
157                    self.logger.debug("Received binary data from worker")
158                    await self.handle_request_body_chunk(websocket, message)
159                else:
160                    self.logger.warning("Received unknown message format")
161        except asyncio.CancelledError:
162            self.logger.debug("Message handling cancelled")
163        except Exception as e:
164            self.logger.error(f"Error in handle_messages: {e}")
165            raise
166
167    async def handle_request_metadata(
168        self, websocket: Any, data: dict[str, Any]
169    ) -> None:
170        request_id = data["id"]
171        has_body = data.get("has_body", False)
172        total_body_chunks = data.get("totalBodyChunks", 0)
173        self.pending_requests[request_id] = {
174            "metadata": data,
175            "body_chunks": {},
176            "has_body": has_body,
177            "total_body_chunks": total_body_chunks,
178        }
179        self.logger.debug(
180            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
181        )
182        await self.check_and_process_request(websocket, request_id)
183
184    async def handle_request_body_chunk(
185        self, websocket: Any, chunk_data: bytes
186    ) -> None:
187        (id_length,) = struct.unpack_from("<I", chunk_data, 0)
188        request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
189        header_end = 4 + id_length + 8
190        chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
191        body_chunk = chunk_data[header_end:]
192
193        if request_id in self.pending_requests:
194            request = self.pending_requests[request_id]
195            request["body_chunks"][chunk_index] = body_chunk
196            self.logger.debug(
197                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
198            )
199            await self.check_and_process_request(websocket, request_id)
200        else:
201            self.logger.warning(
202                f"Received body chunk for unknown or completed request ID: {request_id}"
203            )
204
205    async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
206        request_data = self.pending_requests.get(request_id)
207        if not request_data:
208            return
209
210        has_body = request_data["has_body"]
211        total_body_chunks = request_data["total_body_chunks"]
212        body_chunks = request_data["body_chunks"]
213
214        all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
215        if not all_chunks_received:
216            return
217
218        for i in range(total_body_chunks):
219            if i not in body_chunks:
220                self.logger.error(
221                    f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
222                )
223                return
224
225        self.logger.debug(f"Processing request ID: {request_id}")
226        del self.pending_requests[request_id]
227        task = asyncio.create_task(
228            self.process_request(
229                websocket,
230                request_data["metadata"],
231                body_chunks,
232                request_id,
233            )
234        )
235        task.add_done_callback(self._handle_task_exception)
236
237    def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
238        if not task.cancelled() and task.exception():
239            self.logger.error("Error processing request", exc_info=task.exception())
240
241    async def _cleanup_proxied_websockets(self) -> None:
242        """Close all proxied WebSocket connections on tunnel disconnect."""
243        for ws_id, ws in list(self.proxied_websockets.items()):
244            try:
245                await ws.close()
246            except Exception:
247                pass
248        self.proxied_websockets.clear()
249        self.ws_pending_queues.clear()
250
251    async def process_request(
252        self,
253        websocket: Any,
254        request_metadata: dict[str, Any],
255        body_chunks: dict[int, bytes],
256        request_id: str,
257    ) -> None:
258        self.logger.debug(
259            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
260        )
261
262        if request_metadata["has_body"]:
263            total_chunks = request_metadata["totalBodyChunks"]
264            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
265        else:
266            body_data = None
267
268        parsed = urlparse(request_metadata["url"])
269        path = parsed.path
270        if parsed.query:
271            path = f"{path}?{parsed.query}"
272        forward_url = f"{self.destination_url}{path}"
273
274        self.logger.debug(f"Forwarding request to: {forward_url}")
275
276        async with httpx.AsyncClient(
277            follow_redirects=False, verify=False, timeout=30
278        ) as client:
279            try:
280                async with client.stream(
281                    method=request_metadata["method"],
282                    url=forward_url,
283                    headers=request_metadata["headers"],
284                    content=body_data,
285                ) as response:
286                    response_status = response.status_code
287                    response_headers = dict(response.headers)
288
289                    self.logger.info(
290                        f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
291                    )
292
293                    if self._is_streaming_response(response):
294                        await self._handle_streaming_response(
295                            websocket,
296                            response,
297                            request_id,
298                            response_status,
299                            response_headers,
300                        )
301                    else:
302                        await response.aread()
303                        await self._handle_buffered_response(
304                            websocket,
305                            response.content,
306                            request_id,
307                            response_status,
308                            response_headers,
309                        )
310            except httpx.ConnectError as e:
311                self.logger.error(f"Connection error forwarding request: {e}")
312                self.logger.info(
313                    f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} 502"
314                )
315                await self._handle_buffered_response(
316                    websocket, b"", request_id, 502, {}
317                )
318
319    def _is_streaming_response(self, response: httpx.Response) -> bool:
320        content_type = response.headers.get("content-type", "")
321        return "text/event-stream" in content_type
322
323    async def _handle_buffered_response(
324        self,
325        websocket: Any,
326        response_body: bytes,
327        request_id: str,
328        response_status: int,
329        response_headers: dict[str, str],
330    ) -> None:
331        has_body = len(response_body) > 0
332        max_chunk_size = 1_000_000
333        total_body_chunks = (
334            math.ceil(len(response_body) / max_chunk_size) if has_body else 0
335        )
336
337        response_metadata = {
338            "type": "response",
339            "id": request_id,
340            "status": response_status,
341            "headers": list(response_headers.items()),
342            "has_body": has_body,
343            "totalBodyChunks": total_body_chunks,
344        }
345
346        self.logger.debug(
347            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
348        )
349        await websocket.send(json.dumps(response_metadata))
350
351        if has_body:
352            self.logger.debug(
353                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
354            )
355            id_bytes = request_id.encode("utf-8")
356            for i in range(total_body_chunks):
357                chunk_start = i * max_chunk_size
358                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
359                header = id_bytes + struct.pack("<II", i, total_body_chunks)
360                await websocket.send(header + response_body[chunk_start:chunk_end])
361                self.logger.debug(
362                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
363                )
364
365    async def _handle_streaming_response(
366        self,
367        websocket: Any,
368        response: httpx.Response,
369        request_id: str,
370        response_status: int,
371        response_headers: dict[str, str],
372    ) -> None:
373        cancel_event = asyncio.Event()
374        self.active_streams[request_id] = cancel_event
375
376        stream_start = {
377            "type": "stream-start",
378            "id": request_id,
379            "status": response_status,
380            "headers": list(response_headers.items()),
381        }
382
383        self.logger.debug(f"Sending stream-start for ID: {request_id}")
384        await websocket.send(json.dumps(stream_start))
385
386        id_bytes = request_id.encode("utf-8")
387
388        try:
389            async for chunk in response.aiter_bytes():
390                if cancel_event.is_set():
391                    self.logger.debug(
392                        f"Stream cancelled by browser for request ID: {request_id}"
393                    )
394                    break
395
396                await websocket.send(id_bytes + chunk)
397            else:
398                # Only send stream-end if the loop completed naturally
399                # (not cancelled by the server via stream-cancel)
400                stream_end = {
401                    "type": "stream-end",
402                    "id": request_id,
403                }
404                self.logger.debug(f"Sending stream-end for ID: {request_id}")
405                await websocket.send(json.dumps(stream_end))
406        except Exception as e:
407            self.logger.error(f"Error streaming response for ID {request_id}: {e}")
408            stream_error = {
409                "type": "stream-error",
410                "id": request_id,
411                "error": str(e),
412            }
413            try:
414                await websocket.send(json.dumps(stream_error))
415            except Exception:
416                pass
417        finally:
418            self.active_streams.pop(request_id, None)
419
420    async def _handle_ws_open(self, tunnel_ws: Any, data: dict[str, Any]) -> None:
421        ws_id = data["id"]
422        url = data["url"]
423        parsed = urlparse(url)
424        path = parsed.path
425        if parsed.query:
426            path = f"{path}?{parsed.query}"
427
428        # Build local WebSocket URL
429        dest_parsed = urlparse(self.destination_url)
430        if dest_parsed.scheme == "https":
431            ws_scheme = "wss"
432        else:
433            ws_scheme = "ws"
434        local_ws_url = f"{ws_scheme}://{dest_parsed.netloc}{path}"
435
436        self.logger.debug(f"Opening local WebSocket for {ws_id}: {local_ws_url}")
437
438        # Forward safe browser headers (cookies, auth, origin) to the local
439        # server. Skip hop-by-hop and WebSocket handshake headers since
440        # websockets.connect generates its own (including Host from the URL).
441        skip_headers = frozenset(
442            {
443                "host",
444                "connection",
445                "upgrade",
446                "sec-websocket-key",
447                "sec-websocket-version",
448                "sec-websocket-extensions",
449                "sec-websocket-protocol",
450            }
451        )
452        forward_headers = {}
453        for name, value in data.get("headers", {}).items():
454            if name.lower() not in skip_headers:
455                forward_headers[name] = value
456
457        try:
458            import ssl
459
460            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
461            ssl_context.check_hostname = False
462            ssl_context.verify_mode = ssl.CERT_NONE
463            local_ws = await websockets.connect(
464                local_ws_url,
465                ssl=ssl_context if ws_scheme == "wss" else None,
466                max_size=None,
467                additional_headers=forward_headers,
468            )
469        except Exception as e:
470            self.logger.error(f"Failed to connect local WebSocket for {ws_id}: {e}")
471            self.ws_pending_queues.pop(ws_id, None)
472            try:
473                await tunnel_ws.send(
474                    json.dumps(
475                        {
476                            "type": "ws-close",
477                            "id": ws_id,
478                            "code": 1011,
479                            "reason": str(e),
480                        }
481                    )
482                )
483            except Exception:
484                pass
485            return
486
487        self.proxied_websockets[ws_id] = local_ws
488
489        # Drain any messages that arrived while connecting
490        queue = self.ws_pending_queues.pop(ws_id, None)
491        if queue is not None:
492            while not queue.empty():
493                queued = queue.get_nowait()
494                try:
495                    if queued.get("binary"):
496                        await local_ws.send(base64.b64decode(queued["data"]))
497                    else:
498                        await local_ws.send(queued["data"])
499                except Exception as e:
500                    self.logger.error(
501                        f"Failed to forward queued message to local WebSocket {ws_id}: {e}"
502                    )
503
504        self.logger.info(f"WebSocket proxy opened: {ws_id} -> {local_ws_url}")
505
506        # Relay messages from local server back to the tunnel
507        try:
508            async for message in local_ws:
509                if isinstance(message, str):
510                    await tunnel_ws.send(
511                        json.dumps({"type": "ws-message", "id": ws_id, "data": message})
512                    )
513                elif isinstance(message, bytes):
514                    await tunnel_ws.send(
515                        json.dumps(
516                            {
517                                "type": "ws-message",
518                                "id": ws_id,
519                                "data": base64.b64encode(message).decode("ascii"),
520                                "binary": True,
521                            }
522                        )
523                    )
524        except websockets.ConnectionClosed:
525            pass
526        except Exception as e:
527            self.logger.error(f"Error relaying WebSocket {ws_id}: {e}")
528        finally:
529            self.proxied_websockets.pop(ws_id, None)
530            close_code = local_ws.close_code or 1000
531            close_reason = local_ws.close_reason or ""
532            try:
533                await tunnel_ws.send(
534                    json.dumps(
535                        {
536                            "type": "ws-close",
537                            "id": ws_id,
538                            "code": close_code,
539                            "reason": close_reason,
540                        }
541                    )
542                )
543            except Exception:
544                pass
545
546    async def _handle_ws_message(self, data: dict[str, Any]) -> None:
547        ws_id = data["id"]
548        local_ws = self.proxied_websockets.get(ws_id)
549        if not local_ws:
550            # Connection still being established — buffer for later
551            queue = self.ws_pending_queues.get(ws_id)
552            if queue is not None:
553                await queue.put(data)
554                return
555            self.logger.warning(f"Received ws-message for unknown WebSocket: {ws_id}")
556            return
557        try:
558            if data.get("binary"):
559                await local_ws.send(base64.b64decode(data["data"]))
560            else:
561                await local_ws.send(data["data"])
562        except Exception as e:
563            self.logger.error(
564                f"Failed to forward message to local WebSocket {ws_id}: {e}"
565            )
566
567    async def _handle_ws_close(self, data: dict[str, Any]) -> None:
568        ws_id = data["id"]
569        local_ws = self.proxied_websockets.pop(ws_id, None)
570        if not local_ws:
571            return
572        try:
573            await local_ws.close()
574        except Exception:
575            pass
576
577    def run(self) -> None:
578        try:
579            asyncio.run(self.connect())
580        except KeyboardInterrupt:
581            self.logger.debug("Received exit signal")