1from __future__ import annotations
2
3import asyncio
4import base64
5import json
6import logging
7import math
8import struct
9from typing import Any
10from urllib.parse import urlparse
11
12import click
13import httpx
14import websockets
15
16# Bump this when making breaking changes to the WebSocket protocol.
17# The server will reject clients with a version lower than its minimum.
18PROTOCOL_VERSION = 3
19
20
21class TunnelClient:
22 def __init__(
23 self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
24 ) -> None:
25 self.destination_url = destination_url
26 self.subdomain = subdomain
27 self.tunnel_host = tunnel_host
28
29 if "localhost" in tunnel_host or "127.0.0.1" in tunnel_host:
30 self.tunnel_http_url = f"http://{subdomain}.{tunnel_host}"
31 self.tunnel_websocket_url = (
32 f"ws://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
33 )
34 else:
35 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
36 self.tunnel_websocket_url = (
37 f"wss://{subdomain}.{tunnel_host}/__tunnel__?v={PROTOCOL_VERSION}"
38 )
39
40 self.logger = logging.getLogger(__name__)
41 level = getattr(logging, log_level.upper())
42 self.logger.setLevel(level)
43 self.logger.propagate = False
44 handler = logging.StreamHandler()
45 handler.setLevel(level)
46 handler.setFormatter(logging.Formatter("%(message)s"))
47 self.logger.addHandler(handler)
48
49 self.pending_requests: dict[str, dict[str, Any]] = {}
50 self.active_streams: dict[str, asyncio.Event] = {}
51 self.proxied_websockets: dict[str, Any] = {}
52 self.ws_pending_queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
53 self.stop_event = asyncio.Event()
54
55 async def connect(self) -> None:
56 retry_delay = 1.0
57 max_retry_delay = 30.0
58 while not self.stop_event.is_set():
59 try:
60 self.logger.debug(
61 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
62 )
63 async with websockets.connect(
64 self.tunnel_websocket_url, max_size=None
65 ) as websocket:
66 self.logger.debug("WebSocket connection established")
67 click.secho(
68 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
69 )
70 retry_delay = 1.0
71 try:
72 await self.handle_messages(websocket)
73 finally:
74 await self._cleanup_proxied_websockets()
75 except asyncio.CancelledError:
76 self.logger.debug("Connection cancelled")
77 break
78 except websockets.InvalidStatus as e:
79 if e.response.status_code == 426:
80 body = e.response.body.decode() if e.response.body else ""
81 click.secho(
82 body or "Client version too old. Please upgrade plain.tunnel.",
83 fg="red",
84 )
85 break
86 raise
87 except (websockets.ConnectionClosed, ConnectionError) as e:
88 if self.stop_event.is_set():
89 self.logger.debug("Stopping reconnect attempts due to shutdown")
90 break
91 click.secho(
92 f"Connection lost: {e}. Retrying in {retry_delay:.0f}s...",
93 fg="yellow",
94 )
95 await asyncio.sleep(retry_delay)
96 retry_delay = min(retry_delay * 2, max_retry_delay)
97 except Exception as e:
98 if self.stop_event.is_set():
99 self.logger.debug("Stopping reconnect attempts due to shutdown")
100 break
101 click.secho(
102 f"Unexpected error: {e}. Retrying in {retry_delay:.0f}s...",
103 fg="yellow",
104 )
105 await asyncio.sleep(retry_delay)
106 retry_delay = min(retry_delay * 2, max_retry_delay)
107
108 async def handle_messages(self, websocket: Any) -> None:
109 try:
110 async for message in websocket:
111 if isinstance(message, str):
112 data = json.loads(message)
113 msg_type = data.get("type")
114 if msg_type == "ping":
115 self.logger.debug("Received heartbeat ping, sending pong")
116 await websocket.send(json.dumps({"type": "pong"}))
117 elif msg_type == "request":
118 self.logger.debug("Received request metadata from worker")
119 await self.handle_request_metadata(websocket, data)
120 elif msg_type == "stream-cancel":
121 request_id = data.get("id")
122 self.logger.debug(
123 f"Received stream-cancel for request ID: {request_id}"
124 )
125 cancel_event = self.active_streams.get(request_id)
126 if cancel_event:
127 cancel_event.set()
128 elif msg_type == "ws-open":
129 self.logger.debug(f"Received ws-open for ID: {data['id']}")
130 self.ws_pending_queues[data["id"]] = asyncio.Queue()
131 task = asyncio.create_task(
132 self._handle_ws_open(websocket, data)
133 )
134 task.add_done_callback(self._handle_task_exception)
135 elif msg_type == "ws-message":
136 await self._handle_ws_message(data)
137 elif msg_type == "ws-close":
138 await self._handle_ws_close(data)
139 else:
140 self.logger.warning(
141 f"Received unknown message type: {msg_type}"
142 )
143 elif isinstance(message, bytes):
144 self.logger.debug("Received binary data from worker")
145 await self.handle_request_body_chunk(websocket, message)
146 else:
147 self.logger.warning("Received unknown message format")
148 except asyncio.CancelledError:
149 self.logger.debug("Message handling cancelled")
150 except Exception as e:
151 self.logger.error(f"Error in handle_messages: {e}")
152 raise
153
154 async def handle_request_metadata(
155 self, websocket: Any, data: dict[str, Any]
156 ) -> None:
157 request_id = data["id"]
158 has_body = data.get("has_body", False)
159 total_body_chunks = data.get("totalBodyChunks", 0)
160 self.pending_requests[request_id] = {
161 "metadata": data,
162 "body_chunks": {},
163 "has_body": has_body,
164 "total_body_chunks": total_body_chunks,
165 }
166 self.logger.debug(
167 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
168 )
169 await self.check_and_process_request(websocket, request_id)
170
171 async def handle_request_body_chunk(
172 self, websocket: Any, chunk_data: bytes
173 ) -> None:
174 (id_length,) = struct.unpack_from("<I", chunk_data, 0)
175 request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
176 header_end = 4 + id_length + 8
177 chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
178 body_chunk = chunk_data[header_end:]
179
180 if request_id in self.pending_requests:
181 request = self.pending_requests[request_id]
182 request["body_chunks"][chunk_index] = body_chunk
183 self.logger.debug(
184 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
185 )
186 await self.check_and_process_request(websocket, request_id)
187 else:
188 self.logger.warning(
189 f"Received body chunk for unknown or completed request ID: {request_id}"
190 )
191
192 async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
193 request_data = self.pending_requests.get(request_id)
194 if not request_data:
195 return
196
197 has_body = request_data["has_body"]
198 total_body_chunks = request_data["total_body_chunks"]
199 body_chunks = request_data["body_chunks"]
200
201 all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
202 if not all_chunks_received:
203 return
204
205 for i in range(total_body_chunks):
206 if i not in body_chunks:
207 self.logger.error(
208 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
209 )
210 return
211
212 self.logger.debug(f"Processing request ID: {request_id}")
213 del self.pending_requests[request_id]
214 task = asyncio.create_task(
215 self.process_request(
216 websocket,
217 request_data["metadata"],
218 body_chunks,
219 request_id,
220 )
221 )
222 task.add_done_callback(self._handle_task_exception)
223
224 def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
225 if not task.cancelled() and task.exception():
226 self.logger.error("Error processing request", exc_info=task.exception())
227
228 async def _cleanup_proxied_websockets(self) -> None:
229 """Close all proxied WebSocket connections on tunnel disconnect."""
230 for ws_id, ws in list(self.proxied_websockets.items()):
231 try:
232 await ws.close()
233 except Exception:
234 pass
235 self.proxied_websockets.clear()
236 self.ws_pending_queues.clear()
237
238 async def process_request(
239 self,
240 websocket: Any,
241 request_metadata: dict[str, Any],
242 body_chunks: dict[int, bytes],
243 request_id: str,
244 ) -> None:
245 self.logger.debug(
246 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
247 )
248
249 if request_metadata["has_body"]:
250 total_chunks = request_metadata["totalBodyChunks"]
251 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
252 else:
253 body_data = None
254
255 parsed = urlparse(request_metadata["url"])
256 path = parsed.path
257 if parsed.query:
258 path = f"{path}?{parsed.query}"
259 forward_url = f"{self.destination_url}{path}"
260
261 self.logger.debug(f"Forwarding request to: {forward_url}")
262
263 async with httpx.AsyncClient(
264 follow_redirects=False, verify=False, timeout=30
265 ) as client:
266 try:
267 async with client.stream(
268 method=request_metadata["method"],
269 url=forward_url,
270 headers=request_metadata["headers"],
271 content=body_data,
272 ) as response:
273 response_status = response.status_code
274 response_headers = dict(response.headers)
275
276 self.logger.info(
277 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
278 )
279
280 if self._is_streaming_response(response):
281 await self._handle_streaming_response(
282 websocket,
283 response,
284 request_id,
285 response_status,
286 response_headers,
287 )
288 else:
289 await response.aread()
290 await self._handle_buffered_response(
291 websocket,
292 response.content,
293 request_id,
294 response_status,
295 response_headers,
296 )
297 except httpx.ConnectError as e:
298 self.logger.error(f"Connection error forwarding request: {e}")
299 self.logger.info(
300 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} 502"
301 )
302 await self._handle_buffered_response(
303 websocket, b"", request_id, 502, {}
304 )
305
306 def _is_streaming_response(self, response: httpx.Response) -> bool:
307 content_type = response.headers.get("content-type", "")
308 return "text/event-stream" in content_type
309
310 async def _handle_buffered_response(
311 self,
312 websocket: Any,
313 response_body: bytes,
314 request_id: str,
315 response_status: int,
316 response_headers: dict[str, str],
317 ) -> None:
318 has_body = len(response_body) > 0
319 max_chunk_size = 1_000_000
320 total_body_chunks = (
321 math.ceil(len(response_body) / max_chunk_size) if has_body else 0
322 )
323
324 response_metadata = {
325 "type": "response",
326 "id": request_id,
327 "status": response_status,
328 "headers": list(response_headers.items()),
329 "has_body": has_body,
330 "totalBodyChunks": total_body_chunks,
331 }
332
333 self.logger.debug(
334 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
335 )
336 await websocket.send(json.dumps(response_metadata))
337
338 if has_body:
339 self.logger.debug(
340 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
341 )
342 id_bytes = request_id.encode("utf-8")
343 for i in range(total_body_chunks):
344 chunk_start = i * max_chunk_size
345 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
346 header = id_bytes + struct.pack("<II", i, total_body_chunks)
347 await websocket.send(header + response_body[chunk_start:chunk_end])
348 self.logger.debug(
349 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
350 )
351
352 async def _handle_streaming_response(
353 self,
354 websocket: Any,
355 response: httpx.Response,
356 request_id: str,
357 response_status: int,
358 response_headers: dict[str, str],
359 ) -> None:
360 cancel_event = asyncio.Event()
361 self.active_streams[request_id] = cancel_event
362
363 stream_start = {
364 "type": "stream-start",
365 "id": request_id,
366 "status": response_status,
367 "headers": list(response_headers.items()),
368 }
369
370 self.logger.debug(f"Sending stream-start for ID: {request_id}")
371 await websocket.send(json.dumps(stream_start))
372
373 id_bytes = request_id.encode("utf-8")
374
375 try:
376 async for chunk in response.aiter_bytes():
377 if cancel_event.is_set():
378 self.logger.debug(
379 f"Stream cancelled by browser for request ID: {request_id}"
380 )
381 break
382
383 await websocket.send(id_bytes + chunk)
384 else:
385 # Only send stream-end if the loop completed naturally
386 # (not cancelled by the server via stream-cancel)
387 stream_end = {
388 "type": "stream-end",
389 "id": request_id,
390 }
391 self.logger.debug(f"Sending stream-end for ID: {request_id}")
392 await websocket.send(json.dumps(stream_end))
393 except Exception as e:
394 self.logger.error(f"Error streaming response for ID {request_id}: {e}")
395 stream_error = {
396 "type": "stream-error",
397 "id": request_id,
398 "error": str(e),
399 }
400 try:
401 await websocket.send(json.dumps(stream_error))
402 except Exception:
403 pass
404 finally:
405 self.active_streams.pop(request_id, None)
406
407 async def _handle_ws_open(self, tunnel_ws: Any, data: dict[str, Any]) -> None:
408 ws_id = data["id"]
409 url = data["url"]
410 parsed = urlparse(url)
411 path = parsed.path
412 if parsed.query:
413 path = f"{path}?{parsed.query}"
414
415 # Build local WebSocket URL
416 dest_parsed = urlparse(self.destination_url)
417 if dest_parsed.scheme == "https":
418 ws_scheme = "wss"
419 else:
420 ws_scheme = "ws"
421 local_ws_url = f"{ws_scheme}://{dest_parsed.netloc}{path}"
422
423 self.logger.debug(f"Opening local WebSocket for {ws_id}: {local_ws_url}")
424
425 # Forward safe browser headers (cookies, auth, origin) to the local
426 # server. Skip hop-by-hop and WebSocket handshake headers since
427 # websockets.connect generates its own (including Host from the URL).
428 skip_headers = frozenset(
429 {
430 "host",
431 "connection",
432 "upgrade",
433 "sec-websocket-key",
434 "sec-websocket-version",
435 "sec-websocket-extensions",
436 "sec-websocket-protocol",
437 "host",
438 }
439 )
440 forward_headers = {}
441 for name, value in data.get("headers", {}).items():
442 if name.lower() not in skip_headers:
443 forward_headers[name] = value
444
445 try:
446 import ssl
447
448 ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
449 ssl_context.check_hostname = False
450 ssl_context.verify_mode = ssl.CERT_NONE
451 local_ws = await websockets.connect(
452 local_ws_url,
453 ssl=ssl_context if ws_scheme == "wss" else None,
454 max_size=None,
455 additional_headers=forward_headers,
456 )
457 except Exception as e:
458 self.logger.error(f"Failed to connect local WebSocket for {ws_id}: {e}")
459 self.ws_pending_queues.pop(ws_id, None)
460 try:
461 await tunnel_ws.send(
462 json.dumps(
463 {
464 "type": "ws-close",
465 "id": ws_id,
466 "code": 1011,
467 "reason": str(e),
468 }
469 )
470 )
471 except Exception:
472 pass
473 return
474
475 self.proxied_websockets[ws_id] = local_ws
476
477 # Drain any messages that arrived while connecting
478 queue = self.ws_pending_queues.pop(ws_id, None)
479 if queue is not None:
480 while not queue.empty():
481 queued = queue.get_nowait()
482 try:
483 if queued.get("binary"):
484 await local_ws.send(base64.b64decode(queued["data"]))
485 else:
486 await local_ws.send(queued["data"])
487 except Exception as e:
488 self.logger.error(
489 f"Failed to forward queued message to local WebSocket {ws_id}: {e}"
490 )
491
492 self.logger.info(f"WebSocket proxy opened: {ws_id} -> {local_ws_url}")
493
494 # Relay messages from local server back to the tunnel
495 try:
496 async for message in local_ws:
497 if isinstance(message, str):
498 await tunnel_ws.send(
499 json.dumps({"type": "ws-message", "id": ws_id, "data": message})
500 )
501 elif isinstance(message, bytes):
502 await tunnel_ws.send(
503 json.dumps(
504 {
505 "type": "ws-message",
506 "id": ws_id,
507 "data": base64.b64encode(message).decode("ascii"),
508 "binary": True,
509 }
510 )
511 )
512 except websockets.ConnectionClosed:
513 pass
514 except Exception as e:
515 self.logger.error(f"Error relaying WebSocket {ws_id}: {e}")
516 finally:
517 self.proxied_websockets.pop(ws_id, None)
518 close_code = local_ws.close_code or 1000
519 close_reason = local_ws.close_reason or ""
520 try:
521 await tunnel_ws.send(
522 json.dumps(
523 {
524 "type": "ws-close",
525 "id": ws_id,
526 "code": close_code,
527 "reason": close_reason,
528 }
529 )
530 )
531 except Exception:
532 pass
533
534 async def _handle_ws_message(self, data: dict[str, Any]) -> None:
535 ws_id = data["id"]
536 local_ws = self.proxied_websockets.get(ws_id)
537 if not local_ws:
538 # Connection still being established — buffer for later
539 queue = self.ws_pending_queues.get(ws_id)
540 if queue is not None:
541 await queue.put(data)
542 return
543 self.logger.warning(f"Received ws-message for unknown WebSocket: {ws_id}")
544 return
545 try:
546 if data.get("binary"):
547 await local_ws.send(base64.b64decode(data["data"]))
548 else:
549 await local_ws.send(data["data"])
550 except Exception as e:
551 self.logger.error(
552 f"Failed to forward message to local WebSocket {ws_id}: {e}"
553 )
554
555 async def _handle_ws_close(self, data: dict[str, Any]) -> None:
556 ws_id = data["id"]
557 local_ws = self.proxied_websockets.pop(ws_id, None)
558 if not local_ws:
559 return
560 try:
561 await local_ws.close()
562 except Exception:
563 pass
564
565 def run(self) -> None:
566 try:
567 asyncio.run(self.connect())
568 except KeyboardInterrupt:
569 self.logger.debug("Received exit signal")