1from __future__ import annotations
2
3import asyncio
4import json
5import logging
6import math
7import struct
8from typing import Any
9from urllib.parse import urlparse
10
11import click
12import httpx
13import websockets
14
15# Bump this when making breaking changes to the WebSocket protocol.
16# The server will reject clients with a version lower than its minimum.
17PROTOCOL_VERSION = 1
18
19
20class TunnelClient:
21 def __init__(
22 self, *, destination_url: str, subdomain: str, tunnel_host: str, log_level: str
23 ) -> None:
24 self.destination_url = destination_url
25 self.subdomain = subdomain
26 self.tunnel_host = tunnel_host
27
28 self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
29 self.tunnel_websocket_url = (
30 f"wss://{subdomain}.{tunnel_host}?v={PROTOCOL_VERSION}"
31 )
32
33 self.logger = logging.getLogger(__name__)
34 level = getattr(logging, log_level.upper())
35 self.logger.setLevel(level)
36 self.logger.propagate = False
37 handler = logging.StreamHandler()
38 handler.setLevel(level)
39 handler.setFormatter(logging.Formatter("%(message)s"))
40 self.logger.addHandler(handler)
41
42 self.pending_requests: dict[str, dict[str, Any]] = {}
43 self.stop_event = asyncio.Event()
44
45 async def connect(self) -> None:
46 retry_delay = 1.0
47 max_retry_delay = 30.0
48 while not self.stop_event.is_set():
49 try:
50 self.logger.debug(
51 f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
52 )
53 async with websockets.connect(
54 self.tunnel_websocket_url, max_size=None
55 ) as websocket:
56 self.logger.debug("WebSocket connection established")
57 click.secho(
58 f"Connected to tunnel {self.tunnel_http_url}", fg="green"
59 )
60 retry_delay = 1.0
61 await self.handle_messages(websocket)
62 except asyncio.CancelledError:
63 self.logger.debug("Connection cancelled")
64 break
65 except websockets.InvalidStatus as e:
66 if e.response.status_code == 426:
67 body = e.response.body.decode() if e.response.body else ""
68 click.secho(
69 body or "Client version too old. Please upgrade plain.tunnel.",
70 fg="red",
71 )
72 break
73 raise
74 except (websockets.ConnectionClosed, ConnectionError) as e:
75 if self.stop_event.is_set():
76 self.logger.debug("Stopping reconnect attempts due to shutdown")
77 break
78 click.secho(
79 f"Connection lost: {e}. Retrying in {retry_delay:.0f}s...",
80 fg="yellow",
81 )
82 await asyncio.sleep(retry_delay)
83 retry_delay = min(retry_delay * 2, max_retry_delay)
84 except Exception as e:
85 if self.stop_event.is_set():
86 self.logger.debug("Stopping reconnect attempts due to shutdown")
87 break
88 click.secho(
89 f"Unexpected error: {e}. Retrying in {retry_delay:.0f}s...",
90 fg="yellow",
91 )
92 await asyncio.sleep(retry_delay)
93 retry_delay = min(retry_delay * 2, max_retry_delay)
94
95 async def handle_messages(self, websocket: Any) -> None:
96 try:
97 async for message in websocket:
98 if isinstance(message, str):
99 data = json.loads(message)
100 msg_type = data.get("type")
101 if msg_type == "ping":
102 self.logger.debug("Received heartbeat ping, sending pong")
103 await websocket.send(json.dumps({"type": "pong"}))
104 elif msg_type == "request":
105 self.logger.debug("Received request metadata from worker")
106 await self.handle_request_metadata(websocket, data)
107 else:
108 self.logger.warning(
109 f"Received unknown message type: {msg_type}"
110 )
111 elif isinstance(message, bytes):
112 self.logger.debug("Received binary data from worker")
113 await self.handle_request_body_chunk(websocket, message)
114 else:
115 self.logger.warning("Received unknown message format")
116 except asyncio.CancelledError:
117 self.logger.debug("Message handling cancelled")
118 except Exception as e:
119 self.logger.error(f"Error in handle_messages: {e}")
120 raise
121
122 async def handle_request_metadata(
123 self, websocket: Any, data: dict[str, Any]
124 ) -> None:
125 request_id = data["id"]
126 has_body = data.get("has_body", False)
127 total_body_chunks = data.get("totalBodyChunks", 0)
128 self.pending_requests[request_id] = {
129 "metadata": data,
130 "body_chunks": {},
131 "has_body": has_body,
132 "total_body_chunks": total_body_chunks,
133 }
134 self.logger.debug(
135 f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
136 )
137 await self.check_and_process_request(websocket, request_id)
138
139 async def handle_request_body_chunk(
140 self, websocket: Any, chunk_data: bytes
141 ) -> None:
142 (id_length,) = struct.unpack_from("<I", chunk_data, 0)
143 request_id = chunk_data[4 : 4 + id_length].decode("utf-8")
144 header_end = 4 + id_length + 8
145 chunk_index, total_chunks = struct.unpack_from("<II", chunk_data, 4 + id_length)
146 body_chunk = chunk_data[header_end:]
147
148 if request_id in self.pending_requests:
149 request = self.pending_requests[request_id]
150 request["body_chunks"][chunk_index] = body_chunk
151 self.logger.debug(
152 f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
153 )
154 await self.check_and_process_request(websocket, request_id)
155 else:
156 self.logger.warning(
157 f"Received body chunk for unknown or completed request ID: {request_id}"
158 )
159
160 async def check_and_process_request(self, websocket: Any, request_id: str) -> None:
161 request_data = self.pending_requests.get(request_id)
162 if not request_data:
163 return
164
165 has_body = request_data["has_body"]
166 total_body_chunks = request_data["total_body_chunks"]
167 body_chunks = request_data["body_chunks"]
168
169 all_chunks_received = not has_body or len(body_chunks) == total_body_chunks
170 if not all_chunks_received:
171 return
172
173 for i in range(total_body_chunks):
174 if i not in body_chunks:
175 self.logger.error(
176 f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
177 )
178 return
179
180 self.logger.debug(f"Processing request ID: {request_id}")
181 del self.pending_requests[request_id]
182 task = asyncio.create_task(
183 self.process_request(
184 websocket,
185 request_data["metadata"],
186 body_chunks,
187 request_id,
188 )
189 )
190 task.add_done_callback(self._handle_task_exception)
191
192 def _handle_task_exception(self, task: asyncio.Task[None]) -> None:
193 if not task.cancelled() and task.exception():
194 self.logger.error(f"Error processing request: {task.exception()}")
195
196 async def process_request(
197 self,
198 websocket: Any,
199 request_metadata: dict[str, Any],
200 body_chunks: dict[int, bytes],
201 request_id: str,
202 ) -> None:
203 self.logger.debug(
204 f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
205 )
206
207 if request_metadata["has_body"]:
208 total_chunks = request_metadata["totalBodyChunks"]
209 body_data = b"".join(body_chunks[i] for i in range(total_chunks))
210 else:
211 body_data = None
212
213 parsed = urlparse(request_metadata["url"])
214 path = parsed.path
215 if parsed.query:
216 path = f"{path}?{parsed.query}"
217 forward_url = f"{self.destination_url}{path}"
218
219 self.logger.debug(f"Forwarding request to: {forward_url}")
220
221 async with httpx.AsyncClient(follow_redirects=False, verify=False) as client:
222 try:
223 response = await client.request(
224 method=request_metadata["method"],
225 url=forward_url,
226 headers=request_metadata["headers"],
227 content=body_data,
228 )
229 response_body = response.content
230 response_headers = dict(response.headers)
231 response_status = response.status_code
232 self.logger.debug(
233 f"Received response with status code: {response_status}"
234 )
235 except httpx.ConnectError as e:
236 self.logger.error(f"Connection error forwarding request: {e}")
237 response_body = b""
238 response_headers = {}
239 response_status = 502
240
241 self.logger.info(
242 f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
243 )
244
245 has_body = len(response_body) > 0
246 max_chunk_size = 1_000_000
247 total_body_chunks = (
248 math.ceil(len(response_body) / max_chunk_size) if has_body else 0
249 )
250
251 response_metadata = {
252 "type": "response",
253 "id": request_id,
254 "status": response_status,
255 "headers": list(response_headers.items()),
256 "has_body": has_body,
257 "totalBodyChunks": total_body_chunks,
258 }
259
260 self.logger.debug(
261 f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
262 )
263 await websocket.send(json.dumps(response_metadata))
264
265 if has_body:
266 self.logger.debug(
267 f"Sending {total_body_chunks} body chunks for ID: {request_id}"
268 )
269 id_bytes = request_id.encode("utf-8")
270 for i in range(total_body_chunks):
271 chunk_start = i * max_chunk_size
272 chunk_end = min(chunk_start + max_chunk_size, len(response_body))
273 header = id_bytes + struct.pack("<II", i, total_body_chunks)
274 await websocket.send(header + response_body[chunk_start:chunk_end])
275 self.logger.debug(
276 f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
277 )
278
279 def run(self) -> None:
280 try:
281 asyncio.run(self.connect())
282 except KeyboardInterrupt:
283 self.logger.debug("Received exit signal")