1from __future__ import annotations
2
3import asyncio
4import base64
5import json
6import logging
7import math
8import struct
9import time
10from typing import Any
11from urllib.parse import urlparse
12
13import click
14import httpx
15import websockets
16
17# Bump this when making breaking changes to the WebSocket protocol.
18# The server will reject clients with a version lower than its minimum.
19PROTOCOL_VERSION = 3
20
21
22class TunnelClient:
23 def __init__(
24 self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
25 ) -> None:
26 self.destination_url = destination_url
27 self.subdomain = subdomain
28 self.tunnel_host = tunnel_host
29
30 if "localhost" in tunnel_host or "127.0.0.1" in tunnel_host:
31 self.tunnel_http_url = f"http://{subdomain}.{tunnel_host}"
32 self.tunnel_websocket_url = (
33 f"ws://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
34 )
35 else:
36 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
37 self.tunnel_websocket_url = (
38 f"wss://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
39 )
40
41 self.logger = logging.getLogger(__name__)
42 level = getattr(logging, log_level.upper())
43 self.logger.setLevel(level)
44 self.logger.propagate = False
45 handler = logging.StreamHandler()
46 handler.setLevel(level)
47 handler.setFormatter(logging.Formatter("%(message)s"))
48 self.logger.addHandler(handler)
49
50 self.pending_requests: dict[str, dict[str, Any]] = {}
51 self.active_streams: dict[str, asyncio.Event] = {}
52 self.proxied_websockets: dict[str, Any] = {}
53 self.ws_pending_queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
54 self.stop_event = asyncio.Event()
55
56 async def connect(self) -> None:
57 retry_delay = 1.0
58 max_retry_delay = 30.0
59 # Connection must stay up at least this long to be considered healthy
60 # and reset the backoff. Otherwise we keep escalating, which prevents
61 # tight reconnect loops when something else (e.g. another client
62 # claiming the same subdomain) keeps closing us right after connect.
63 healthy_connection_seconds = 5.0
64 while not self.stop_event.is_set():
65 connection_duration: float | None = None
66 try:
67 self.logger.debug(
68 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
69 )
70 async with websockets.connect(
71 self.tunnel_websocket_url, max_size=None
72 ) as websocket:
73 self.logger.debug("WebSocket connection established")
74 click.secho(
75 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
76 )
77 connected_at = time.monotonic()
78 try:
79 await self.handle_messages(websocket)
80 finally:
81 connection_duration = time.monotonic() - connected_at
82 await self._cleanup_proxied_websockets()
83 if self.stop_event.is_set():
84 break
85 disconnect_message = "Tunnel disconnected by server."
86 except asyncio.CancelledError:
87 self.logger.debug("Connection cancelled")
88 break
89 except websockets.InvalidStatus as e:
90 if e.response.status_code == 426:
91 body = e.response.body.decode() if e.response.body else ""
92 click.secho(
93 body or "Client version too old. Please upgrade plain.tunnel.",
94 fg="red",
95 )
96 break
97 raise
98 except (websockets.ConnectionClosed, ConnectionError) as e:
99 if self.stop_event.is_set():
100 self.logger.debug("Stopping reconnect attempts due to shutdown")
101 break
102 disconnect_message = f"Connection lost: {e}."
103 except Exception as e:
104 if self.stop_event.is_set():
105 self.logger.debug("Stopping reconnect attempts due to shutdown")
106 break
107 disconnect_message = f"Unexpected error: {e}."
108
109 if (
110 connection_duration is not None
111 and connection_duration >= healthy_connection_seconds
112 ):
113 retry_delay = 1.0
114 click.secho(
115 f"{disconnect_message} Retrying in {retry_delay:.0f}s...",
116 fg="yellow",
117 )
118 await asyncio.sleep(retry_delay)
119 retry_delay = min(retry_delay * 2, max_retry_delay)
120
121 async def handle_messages(self, websocket: Any) -> None:
122 try:
123 async for message in websocket:
124 if isinstance(message, str):
125 data = json.loads(message)
126 msg_type = data.get("type")
127 if msg_type == "ping":
128 self.logger.debug("Received heartbeat ping, sending pong")
129 await websocket.send(json.dumps({"type": "pong"}))
130 elif msg_type == "request":
131 self.logger.debug("Received request metadata from worker")
132 await self.handle_request_metadata(websocket, data)
133 elif msg_type == "stream-cancel":
134 request_id = data.get("id")
135 self.logger.debug(
136 f"Received stream-cancel for request ID: {request_id}"
137 )
138 cancel_event = self.active_streams.get(request_id)
139 if cancel_event:
140 cancel_event.set()
141 elif msg_type == "ws-open":
142 self.logger.debug(f"Received ws-open for ID: {data['id']}")
143 self.ws_pending_queues[data["id"]] = asyncio.Queue()
144 task = asyncio.create_task(
145 self._handle_ws_open(websocket, data)
146 )
147 task.add_done_callback(self._handle_task_exception)
148 elif msg_type == "ws-message":
149 await self._handle_ws_message(data)
150 elif msg_type == "ws-close":
151 await self._handle_ws_close(data)
152 else:
153 self.logger.warning(
154 f"Received unknown message type: {msg_type}"
155 )
156 elif isinstance(message, bytes):
157 self.logger.debug("Received binary data from worker")
158 await self.handle_request_body_chunk(websocket, message)
159 else:
160 self.logger.warning("Received unknown message format")
161 except asyncio.CancelledError:
162 self.logger.debug("Message handling cancelled")
163 except Exception as e:
164 self.logger.error(f"Error in handle_messages: {e}")
165 raise
166
167 async def handle_request_metadata(
168 self, websocket: Any, data: dict[str, Any]
169 ) -> None:
170 request_id = data["id"]
171 has_body = data.get("has_body", False)
172 total_body_chunks = data.get("totalBodyChunks", 0)
173 self.pending_requests[request_id] = {
174 "metadata": data,
175 "body_chunks": {},
176 "has_body": has_body,
177 "total_body_chunks": total_body_chunks,
178 }
179 self.logger.debug(
180 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
181 )
182 await self.check_and_process_request(websocket, request_id)
183
184 async def handle_request_body_chunk(
185 self, websocket: Any, chunk_data: bytes
186 ) -> None:
187 (id_length,) = struct.unpack_from("<I", chunk_data, 0)
188 request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
189 header_end = 4 + id_length + 8
190 chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
191 body_chunk = chunk_data[header_end:]
192
193 if request_id in self.pending_requests:
194 request = self.pending_requests[request_id]
195 request["body_chunks"][chunk_index] = body_chunk
196 self.logger.debug(
197 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
198 )
199 await self.check_and_process_request(websocket, request_id)
200 else:
201 self.logger.warning(
202 f"Received body chunk for unknown or completed request ID: {request_id}"
203 )
204
205 async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
206 request_data = self.pending_requests.get(request_id)
207 if not request_data:
208 return
209
210 has_body = request_data["has_body"]
211 total_body_chunks = request_data["total_body_chunks"]
212 body_chunks = request_data["body_chunks"]
213
214 all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
215 if not all_chunks_received:
216 return
217
218 for i in range(total_body_chunks):
219 if i not in body_chunks:
220 self.logger.error(
221 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
222 )
223 return
224
225 self.logger.debug(f"Processing request ID: {request_id}")
226 del self.pending_requests[request_id]
227 task = asyncio.create_task(
228 self.process_request(
229 websocket,
230 request_data["metadata"],
231 body_chunks,
232 request_id,
233 )
234 )
235 task.add_done_callback(self._handle_task_exception)
236
237 def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
238 if not task.cancelled() and task.exception():
239 self.logger.error("Error processing request", exc_info=task.exception())
240
241 async def _cleanup_proxied_websockets(self) -> None:
242 """Close all proxied WebSocket connections on tunnel disconnect."""
243 for ws_id, ws in list(self.proxied_websockets.items()):
244 try:
245 await ws.close()
246 except Exception:
247 pass
248 self.proxied_websockets.clear()
249 self.ws_pending_queues.clear()
250
251 async def process_request(
252 self,
253 websocket: Any,
254 request_metadata: dict[str, Any],
255 body_chunks: dict[int, bytes],
256 request_id: str,
257 ) -> None:
258 self.logger.debug(
259 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
260 )
261
262 if request_metadata["has_body"]:
263 total_chunks = request_metadata["totalBodyChunks"]
264 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
265 else:
266 body_data = None
267
268 parsed = urlparse(request_metadata["url"])
269 path = parsed.path
270 if parsed.query:
271 path = f"{path}?{parsed.query}"
272 forward_url = f"{self.destination_url}{path}"
273
274 self.logger.debug(f"Forwarding request to: {forward_url}")
275
276 async with httpx.AsyncClient(
277 follow_redirects=False, verify=False, timeout=30
278 ) as client:
279 try:
280 async with client.stream(
281 method=request_metadata["method"],
282 url=forward_url,
283 headers=request_metadata["headers"],
284 content=body_data,
285 ) as response:
286 response_status = response.status_code
287 response_headers = dict(response.headers)
288
289 self.logger.info(
290 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
291 )
292
293 if self._is_streaming_response(response):
294 await self._handle_streaming_response(
295 websocket,
296 response,
297 request_id,
298 response_status,
299 response_headers,
300 )
301 else:
302 await response.aread()
303 await self._handle_buffered_response(
304 websocket,
305 response.content,
306 request_id,
307 response_status,
308 response_headers,
309 )
310 except httpx.ConnectError as e:
311 self.logger.error(f"Connection error forwarding request: {e}")
312 self.logger.info(
313 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} 502"
314 )
315 await self._handle_buffered_response(
316 websocket, b"", request_id, 502, {}
317 )
318
319 def _is_streaming_response(self, response: httpx.Response) -> bool:
320 content_type = response.headers.get("content-type", "")
321 return "text/event-stream" in content_type
322
323 async def _handle_buffered_response(
324 self,
325 websocket: Any,
326 response_body: bytes,
327 request_id: str,
328 response_status: int,
329 response_headers: dict[str, str],
330 ) -> None:
331 has_body = len(response_body) > 0
332 max_chunk_size = 1_000_000
333 total_body_chunks = (
334 math.ceil(len(response_body) / max_chunk_size) if has_body else 0
335 )
336
337 response_metadata = {
338 "type": "response",
339 "id": request_id,
340 "status": response_status,
341 "headers": list(response_headers.items()),
342 "has_body": has_body,
343 "totalBodyChunks": total_body_chunks,
344 }
345
346 self.logger.debug(
347 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
348 )
349 await websocket.send(json.dumps(response_metadata))
350
351 if has_body:
352 self.logger.debug(
353 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
354 )
355 id_bytes = request_id.encode("utf-8")
356 for i in range(total_body_chunks):
357 chunk_start = i * max_chunk_size
358 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
359 header = id_bytes + struct.pack("<II", i, total_body_chunks)
360 await websocket.send(header + response_body[chunk_start:chunk_end])
361 self.logger.debug(
362 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
363 )
364
365 async def _handle_streaming_response(
366 self,
367 websocket: Any,
368 response: httpx.Response,
369 request_id: str,
370 response_status: int,
371 response_headers: dict[str, str],
372 ) -> None:
373 cancel_event = asyncio.Event()
374 self.active_streams[request_id] = cancel_event
375
376 stream_start = {
377 "type": "stream-start",
378 "id": request_id,
379 "status": response_status,
380 "headers": list(response_headers.items()),
381 }
382
383 self.logger.debug(f"Sending stream-start for ID: {request_id}")
384 await websocket.send(json.dumps(stream_start))
385
386 id_bytes = request_id.encode("utf-8")
387
388 try:
389 async for chunk in response.aiter_bytes():
390 if cancel_event.is_set():
391 self.logger.debug(
392 f"Stream cancelled by browser for request ID: {request_id}"
393 )
394 break
395
396 await websocket.send(id_bytes + chunk)
397 else:
398 # Only send stream-end if the loop completed naturally
399 # (not cancelled by the server via stream-cancel)
400 stream_end = {
401 "type": "stream-end",
402 "id": request_id,
403 }
404 self.logger.debug(f"Sending stream-end for ID: {request_id}")
405 await websocket.send(json.dumps(stream_end))
406 except Exception as e:
407 self.logger.error(f"Error streaming response for ID {request_id}: {e}")
408 stream_error = {
409 "type": "stream-error",
410 "id": request_id,
411 "error": str(e),
412 }
413 try:
414 await websocket.send(json.dumps(stream_error))
415 except Exception:
416 pass
417 finally:
418 self.active_streams.pop(request_id, None)
419
420 async def _handle_ws_open(self, tunnel_ws: Any, data: dict[str, Any]) -> None:
421 ws_id = data["id"]
422 url = data["url"]
423 parsed = urlparse(url)
424 path = parsed.path
425 if parsed.query:
426 path = f"{path}?{parsed.query}"
427
428 # Build local WebSocket URL
429 dest_parsed = urlparse(self.destination_url)
430 if dest_parsed.scheme == "https":
431 ws_scheme = "wss"
432 else:
433 ws_scheme = "ws"
434 local_ws_url = f"{ws_scheme}://{dest_parsed.netloc}{path}"
435
436 self.logger.debug(f"Opening local WebSocket for {ws_id}: {local_ws_url}")
437
438 # Forward safe browser headers (cookies, auth, origin) to the local
439 # server. Skip hop-by-hop and WebSocket handshake headers since
440 # websockets.connect generates its own (including Host from the URL).
441 skip_headers = frozenset(
442 {
443 "host",
444 "connection",
445 "upgrade",
446 "sec-websocket-key",
447 "sec-websocket-version",
448 "sec-websocket-extensions",
449 "sec-websocket-protocol",
450 }
451 )
452 forward_headers = {}
453 for name, value in data.get("headers", {}).items():
454 if name.lower() not in skip_headers:
455 forward_headers[name] = value
456
457 try:
458 import ssl
459
460 ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
461 ssl_context.check_hostname = False
462 ssl_context.verify_mode = ssl.CERT_NONE
463 local_ws = await websockets.connect(
464 local_ws_url,
465 ssl=ssl_context if ws_scheme == "wss" else None,
466 max_size=None,
467 additional_headers=forward_headers,
468 )
469 except Exception as e:
470 self.logger.error(f"Failed to connect local WebSocket for {ws_id}: {e}")
471 self.ws_pending_queues.pop(ws_id, None)
472 try:
473 await tunnel_ws.send(
474 json.dumps(
475 {
476 "type": "ws-close",
477 "id": ws_id,
478 "code": 1011,
479 "reason": str(e),
480 }
481 )
482 )
483 except Exception:
484 pass
485 return
486
487 self.proxied_websockets[ws_id] = local_ws
488
489 # Drain any messages that arrived while connecting
490 queue = self.ws_pending_queues.pop(ws_id, None)
491 if queue is not None:
492 while not queue.empty():
493 queued = queue.get_nowait()
494 try:
495 if queued.get("binary"):
496 await local_ws.send(base64.b64decode(queued["data"]))
497 else:
498 await local_ws.send(queued["data"])
499 except Exception as e:
500 self.logger.error(
501 f"Failed to forward queued message to local WebSocket {ws_id}: {e}"
502 )
503
504 self.logger.info(f"WebSocket proxy opened: {ws_id} -> {local_ws_url}")
505
506 # Relay messages from local server back to the tunnel
507 try:
508 async for message in local_ws:
509 if isinstance(message, str):
510 await tunnel_ws.send(
511 json.dumps({"type": "ws-message", "id": ws_id, "data": message})
512 )
513 elif isinstance(message, bytes):
514 await tunnel_ws.send(
515 json.dumps(
516 {
517 "type": "ws-message",
518 "id": ws_id,
519 "data": base64.b64encode(message).decode("ascii"),
520 "binary": True,
521 }
522 )
523 )
524 except websockets.ConnectionClosed:
525 pass
526 except Exception as e:
527 self.logger.error(f"Error relaying WebSocket {ws_id}: {e}")
528 finally:
529 self.proxied_websockets.pop(ws_id, None)
530 close_code = local_ws.close_code or 1000
531 close_reason = local_ws.close_reason or ""
532 try:
533 await tunnel_ws.send(
534 json.dumps(
535 {
536 "type": "ws-close",
537 "id": ws_id,
538 "code": close_code,
539 "reason": close_reason,
540 }
541 )
542 )
543 except Exception:
544 pass
545
546 async def _handle_ws_message(self, data: dict[str, Any]) -> None:
547 ws_id = data["id"]
548 local_ws = self.proxied_websockets.get(ws_id)
549 if not local_ws:
550 # Connection still being established — buffer for later
551 queue = self.ws_pending_queues.get(ws_id)
552 if queue is not None:
553 await queue.put(data)
554 return
555 self.logger.warning(f"Received ws-message for unknown WebSocket: {ws_id}")
556 return
557 try:
558 if data.get("binary"):
559 await local_ws.send(base64.b64decode(data["data"]))
560 else:
561 await local_ws.send(data["data"])
562 except Exception as e:
563 self.logger.error(
564 f"Failed to forward message to local WebSocket {ws_id}: {e}"
565 )
566
567 async def _handle_ws_close(self, data: dict[str, Any]) -> None:
568 ws_id = data["id"]
569 local_ws = self.proxied_websockets.pop(ws_id, None)
570 if not local_ws:
571 return
572 try:
573 await local_ws.close()
574 except Exception:
575 pass
576
577 def run(self) -> None:
578 try:
579 asyncio.run(self.connect())
580 except KeyboardInterrupt:
581 self.logger.debug("Received exit signal")