1from __future__ import annotations
2
3import asyncio
4import base64
5import json
6import logging
7import math
8import struct
9from typing import Any
10from urllib.parse import urlparse
11
12import click
13import httpx
14import websockets
15
16# Bump this when making breaking changes to the WebSocket protocol.
17# The server will reject clients with a version lower than its minimum.
18PROTOCOL_VERSION = 3
19
20
21class TunnelClient:
22 def __init__(
23 self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
24 ) -> None:
25 self.destination_url = destination_url
26 self.subdomain = subdomain
27 self.tunnel_host = tunnel_host
28
29 if "localhost" in tunnel_host or "127.0.0.1" in tunnel_host:
30 self.tunnel_http_url = f"http://{subdomain}.{tunnel_host}"
31 self.tunnel_websocket_url = (
32 f"ws://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
33 )
34 else:
35 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
36 self.tunnel_websocket_url = (
37 f"wss://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
38 )
39
40 self.logger = logging.getLogger(__name__)
41 level = getattr(logging, log_level.upper())
42 self.logger.setLevel(level)
43 self.logger.propagate = False
44 handler = logging.StreamHandler()
45 handler.setLevel(level)
46 handler.setFormatter(logging.Formatter("%(message)s"))
47 self.logger.addHandler(handler)
48
49 self.pending_requests: dict[str, dict[str, Any]] = {}
50 self.active_streams: dict[str, asyncio.Event] = {}
51 self.proxied_websockets: dict[str, Any] = {}
52 self.ws_pending_queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
53 self.stop_event = asyncio.Event()
54
55 async def connect(self) -> None:
56 retry_delay = 1.0
57 max_retry_delay = 30.0
58 while not self.stop_event.is_set():
59 try:
60 self.logger.debug(
61 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
62 )
63 async with websockets.connect(
64 self.tunnel_websocket_url, max_size=None
65 ) as websocket:
66 self.logger.debug("WebSocket connection established")
67 click.secho(
68 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
69 )
70 retry_delay = 1.0
71 try:
72 await self.handle_messages(websocket)
73 finally:
74 await self._cleanup_proxied_websockets()
75 except asyncio.CancelledError:
76 self.logger.debug("Connection cancelled")
77 break
78 except websockets.InvalidStatus as e:
79 if e.response.status_code == 426:
80 body = e.response.body.decode() if e.response.body else ""
81 click.secho(
82 body or "Client version too old. Please upgrade plain.tunnel.",
83 fg="red",
84 )
85 break
86 raise
87 except (websockets.ConnectionClosed, ConnectionError) as e:
88 if self.stop_event.is_set():
89 self.logger.debug("Stopping reconnect attempts due to shutdown")
90 break
91 click.secho(
92 f"Connection lost: {e}. Retrying in {retry_delay:.0f}s...",
93 fg="yellow",
94 )
95 await asyncio.sleep(retry_delay)
96 retry_delay = min(retry_delay * 2, max_retry_delay)
97 except Exception as e:
98 if self.stop_event.is_set():
99 self.logger.debug("Stopping reconnect attempts due to shutdown")
100 break
101 click.secho(
102 f"Unexpected error: {e}. Retrying in {retry_delay:.0f}s...",
103 fg="yellow",
104 )
105 await asyncio.sleep(retry_delay)
106 retry_delay = min(retry_delay * 2, max_retry_delay)
107
108 async def handle_messages(self, websocket: Any) -> None:
109 try:
110 async for message in websocket:
111 if isinstance(message, str):
112 data = json.loads(message)
113 msg_type = data.get("type")
114 if msg_type == "ping":
115 self.logger.debug("Received heartbeat ping, sending pong")
116 await websocket.send(json.dumps({"type": "pong"}))
117 elif msg_type == "request":
118 self.logger.debug("Received request metadata from worker")
119 await self.handle_request_metadata(websocket, data)
120 elif msg_type == "stream-cancel":
121 request_id = data.get("id")
122 self.logger.debug(
123 f"Received stream-cancel for request ID: {request_id}"
124 )
125 cancel_event = self.active_streams.get(request_id)
126 if cancel_event:
127 cancel_event.set()
128 elif msg_type == "ws-open":
129 self.logger.debug(f"Received ws-open for ID: {data['id']}")
130 self.ws_pending_queues[data["id"]] = asyncio.Queue()
131 task = asyncio.create_task(
132 self._handle_ws_open(websocket, data)
133 )
134 task.add_done_callback(self._handle_task_exception)
135 elif msg_type == "ws-message":
136 await self._handle_ws_message(data)
137 elif msg_type == "ws-close":
138 await self._handle_ws_close(data)
139 else:
140 self.logger.warning(
141 f"Received unknown message type: {msg_type}"
142 )
143 elif isinstance(message, bytes):
144 self.logger.debug("Received binary data from worker")
145 await self.handle_request_body_chunk(websocket, message)
146 else:
147 self.logger.warning("Received unknown message format")
148 except asyncio.CancelledError:
149 self.logger.debug("Message handling cancelled")
150 except Exception as e:
151 self.logger.error(f"Error in handle_messages: {e}")
152 raise
153
154 async def handle_request_metadata(
155 self, websocket: Any, data: dict[str, Any]
156 ) -> None:
157 request_id = data["id"]
158 has_body = data.get("has_body", False)
159 total_body_chunks = data.get("totalBodyChunks", 0)
160 self.pending_requests[request_id] = {
161 "metadata": data,
162 "body_chunks": {},
163 "has_body": has_body,
164 "total_body_chunks": total_body_chunks,
165 }
166 self.logger.debug(
167 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
168 )
169 await self.check_and_process_request(websocket, request_id)
170
171 async def handle_request_body_chunk(
172 self, websocket: Any, chunk_data: bytes
173 ) -> None:
174 (id_length,) = struct.unpack_from("<I", chunk_data, 0)
175 request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
176 header_end = 4 + id_length + 8
177 chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
178 body_chunk = chunk_data[header_end:]
179
180 if request_id in self.pending_requests:
181 request = self.pending_requests[request_id]
182 request["body_chunks"][chunk_index] = body_chunk
183 self.logger.debug(
184 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
185 )
186 await self.check_and_process_request(websocket, request_id)
187 else:
188 self.logger.warning(
189 f"Received body chunk for unknown or completed request ID: {request_id}"
190 )
191
192 async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
193 request_data = self.pending_requests.get(request_id)
194 if not request_data:
195 return
196
197 has_body = request_data["has_body"]
198 total_body_chunks = request_data["total_body_chunks"]
199 body_chunks = request_data["body_chunks"]
200
201 all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
202 if not all_chunks_received:
203 return
204
205 for i in range(total_body_chunks):
206 if i not in body_chunks:
207 self.logger.error(
208 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
209 )
210 return
211
212 self.logger.debug(f"Processing request ID: {request_id}")
213 del self.pending_requests[request_id]
214 task = asyncio.create_task(
215 self.process_request(
216 websocket,
217 request_data["metadata"],
218 body_chunks,
219 request_id,
220 )
221 )
222 task.add_done_callback(self._handle_task_exception)
223
224 def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
225 if not task.cancelled() and task.exception():
226 self.logger.error("Error processing request", exc_info=task.exception())
227
228 async def _cleanup_proxied_websockets(self) -> None:
229 """Close all proxied WebSocket connections on tunnel disconnect."""
230 for ws_id, ws in list(self.proxied_websockets.items()):
231 try:
232 await ws.close()
233 except Exception:
234 pass
235 self.proxied_websockets.clear()
236 self.ws_pending_queues.clear()
237
238 async def process_request(
239 self,
240 websocket: Any,
241 request_metadata: dict[str, Any],
242 body_chunks: dict[int, bytes],
243 request_id: str,
244 ) -> None:
245 self.logger.debug(
246 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
247 )
248
249 if request_metadata["has_body"]:
250 total_chunks = request_metadata["totalBodyChunks"]
251 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
252 else:
253 body_data = None
254
255 parsed = urlparse(request_metadata["url"])
256 path = parsed.path
257 if parsed.query:
258 path = f"{path}?{parsed.query}"
259 forward_url = f"{self.destination_url}{path}"
260
261 self.logger.debug(f"Forwarding request to: {forward_url}")
262
263 async with httpx.AsyncClient(follow_redirects=False, verify=False) as client:
264 try:
265 async with client.stream(
266 method=request_metadata["method"],
267 url=forward_url,
268 headers=request_metadata["headers"],
269 content=body_data,
270 ) as response:
271 response_status = response.status_code
272 response_headers = dict(response.headers)
273
274 self.logger.info(
275 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
276 )
277
278 if self._is_streaming_response(response):
279 await self._handle_streaming_response(
280 websocket,
281 response,
282 request_id,
283 response_status,
284 response_headers,
285 )
286 else:
287 await response.aread()
288 await self._handle_buffered_response(
289 websocket,
290 response.content,
291 request_id,
292 response_status,
293 response_headers,
294 )
295 except httpx.ConnectError as e:
296 self.logger.error(f"Connection error forwarding request: {e}")
297 self.logger.info(
298 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} 502"
299 )
300 await self._handle_buffered_response(
301 websocket, b"", request_id, 502, {}
302 )
303
304 def _is_streaming_response(self, response: httpx.Response) -> bool:
305 content_type = response.headers.get("content-type", "")
306 return "text/event-stream" in content_type
307
308 async def _handle_buffered_response(
309 self,
310 websocket: Any,
311 response_body: bytes,
312 request_id: str,
313 response_status: int,
314 response_headers: dict[str, str],
315 ) -> None:
316 has_body = len(response_body) > 0
317 max_chunk_size = 1_000_000
318 total_body_chunks = (
319 math.ceil(len(response_body) / max_chunk_size) if has_body else 0
320 )
321
322 response_metadata = {
323 "type": "response",
324 "id": request_id,
325 "status": response_status,
326 "headers": list(response_headers.items()),
327 "has_body": has_body,
328 "totalBodyChunks": total_body_chunks,
329 }
330
331 self.logger.debug(
332 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
333 )
334 await websocket.send(json.dumps(response_metadata))
335
336 if has_body:
337 self.logger.debug(
338 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
339 )
340 id_bytes = request_id.encode("utf-8")
341 for i in range(total_body_chunks):
342 chunk_start = i * max_chunk_size
343 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
344 header = id_bytes + struct.pack("<II", i, total_body_chunks)
345 await websocket.send(header + response_body[chunk_start:chunk_end])
346 self.logger.debug(
347 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
348 )
349
350 async def _handle_streaming_response(
351 self,
352 websocket: Any,
353 response: httpx.Response,
354 request_id: str,
355 response_status: int,
356 response_headers: dict[str, str],
357 ) -> None:
358 cancel_event = asyncio.Event()
359 self.active_streams[request_id] = cancel_event
360
361 stream_start = {
362 "type": "stream-start",
363 "id": request_id,
364 "status": response_status,
365 "headers": list(response_headers.items()),
366 }
367
368 self.logger.debug(f"Sending stream-start for ID: {request_id}")
369 await websocket.send(json.dumps(stream_start))
370
371 id_bytes = request_id.encode("utf-8")
372
373 try:
374 async for chunk in response.aiter_bytes():
375 if cancel_event.is_set():
376 self.logger.debug(
377 f"Stream cancelled by browser for request ID: {request_id}"
378 )
379 break
380
381 await websocket.send(id_bytes + chunk)
382 else:
383 # Only send stream-end if the loop completed naturally
384 # (not cancelled by the server via stream-cancel)
385 stream_end = {
386 "type": "stream-end",
387 "id": request_id,
388 }
389 self.logger.debug(f"Sending stream-end for ID: {request_id}")
390 await websocket.send(json.dumps(stream_end))
391 except Exception as e:
392 self.logger.error(f"Error streaming response for ID {request_id}: {e}")
393 stream_error = {
394 "type": "stream-error",
395 "id": request_id,
396 "error": str(e),
397 }
398 try:
399 await websocket.send(json.dumps(stream_error))
400 except Exception:
401 pass
402 finally:
403 self.active_streams.pop(request_id, None)
404
405 async def _handle_ws_open(self, tunnel_ws: Any, data: dict[str, Any]) -> None:
406 ws_id = data["id"]
407 url = data["url"]
408 parsed = urlparse(url)
409 path = parsed.path
410 if parsed.query:
411 path = f"{path}?{parsed.query}"
412
413 # Build local WebSocket URL
414 dest_parsed = urlparse(self.destination_url)
415 if dest_parsed.scheme == "https":
416 ws_scheme = "wss"
417 else:
418 ws_scheme = "ws"
419 local_ws_url = f"{ws_scheme}://{dest_parsed.netloc}{path}"
420
421 self.logger.debug(f"Opening local WebSocket for {ws_id}: {local_ws_url}")
422
423 # Forward safe browser headers (cookies, auth, origin) to the local
424 # server. Skip hop-by-hop and WebSocket handshake headers since
425 # websockets.connect generates its own (including Host from the URL).
426 skip_headers = frozenset(
427 {
428 "host",
429 "connection",
430 "upgrade",
431 "sec-websocket-key",
432 "sec-websocket-version",
433 "sec-websocket-extensions",
434 "sec-websocket-protocol",
435 "host",
436 }
437 )
438 forward_headers = {}
439 for name, value in data.get("headers", {}).items():
440 if name.lower() not in skip_headers:
441 forward_headers[name] = value
442
443 try:
444 import ssl
445
446 ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
447 ssl_context.check_hostname = False
448 ssl_context.verify_mode = ssl.CERT_NONE
449 local_ws = await websockets.connect(
450 local_ws_url,
451 ssl=ssl_context if ws_scheme == "wss" else None,
452 max_size=None,
453 additional_headers=forward_headers,
454 )
455 except Exception as e:
456 self.logger.error(f"Failed to connect local WebSocket for {ws_id}: {e}")
457 self.ws_pending_queues.pop(ws_id, None)
458 try:
459 await tunnel_ws.send(
460 json.dumps(
461 {
462 "type": "ws-close",
463 "id": ws_id,
464 "code": 1011,
465 "reason": str(e),
466 }
467 )
468 )
469 except Exception:
470 pass
471 return
472
473 self.proxied_websockets[ws_id] = local_ws
474
475 # Drain any messages that arrived while connecting
476 queue = self.ws_pending_queues.pop(ws_id, None)
477 if queue is not None:
478 while not queue.empty():
479 queued = queue.get_nowait()
480 try:
481 if queued.get("binary"):
482 await local_ws.send(base64.b64decode(queued["data"]))
483 else:
484 await local_ws.send(queued["data"])
485 except Exception as e:
486 self.logger.error(
487 f"Failed to forward queued message to local WebSocket {ws_id}: {e}"
488 )
489
490 self.logger.info(f"WebSocket proxy opened: {ws_id} -> {local_ws_url}")
491
492 # Relay messages from local server back to the tunnel
493 try:
494 async for message in local_ws:
495 if isinstance(message, str):
496 await tunnel_ws.send(
497 json.dumps({"type": "ws-message", "id": ws_id, "data": message})
498 )
499 elif isinstance(message, bytes):
500 await tunnel_ws.send(
501 json.dumps(
502 {
503 "type": "ws-message",
504 "id": ws_id,
505 "data": base64.b64encode(message).decode("ascii"),
506 "binary": True,
507 }
508 )
509 )
510 except websockets.ConnectionClosed:
511 pass
512 except Exception as e:
513 self.logger.error(f"Error relaying WebSocket {ws_id}: {e}")
514 finally:
515 self.proxied_websockets.pop(ws_id, None)
516 close_code = local_ws.close_code or 1000
517 close_reason = local_ws.close_reason or ""
518 try:
519 await tunnel_ws.send(
520 json.dumps(
521 {
522 "type": "ws-close",
523 "id": ws_id,
524 "code": close_code,
525 "reason": close_reason,
526 }
527 )
528 )
529 except Exception:
530 pass
531
532 async def _handle_ws_message(self, data: dict[str, Any]) -> None:
533 ws_id = data["id"]
534 local_ws = self.proxied_websockets.get(ws_id)
535 if not local_ws:
536 # Connection still being established — buffer for later
537 queue = self.ws_pending_queues.get(ws_id)
538 if queue is not None:
539 await queue.put(data)
540 return
541 self.logger.warning(f"Received ws-message for unknown WebSocket: {ws_id}")
542 return
543 try:
544 if data.get("binary"):
545 await local_ws.send(base64.b64decode(data["data"]))
546 else:
547 await local_ws.send(data["data"])
548 except Exception as e:
549 self.logger.error(
550 f"Failed to forward message to local WebSocket {ws_id}: {e}"
551 )
552
553 async def _handle_ws_close(self, data: dict[str, Any]) -> None:
554 ws_id = data["id"]
555 local_ws = self.proxied_websockets.pop(ws_id, None)
556 if not local_ws:
557 return
558 try:
559 await local_ws.close()
560 except Exception:
561 pass
562
563 def run(self) -> None:
564 try:
565 asyncio.run(self.connect())
566 except KeyboardInterrupt:
567 self.logger.debug("Received exit signal")