Plain is headed towards 1.0! Subscribe for development updates →

plain.tunnel

  1import asyncio
  2import json
  3import logging
  4import ssl
  5import urllib.error
  6import urllib.request
  7from urllib.parse import urlparse
  8
  9import click
 10import websockets
 11
 12
 13class TunnelClient:
 14    def __init__(self, *, destination_url, subdomain, tunnel_host, log_level):
 15        self.destination_url = destination_url
 16        self.subdomain = subdomain
 17        self.tunnel_host = tunnel_host
 18
 19        self.tunnel_http_url = f"https://{subdomain}.{tunnel_host}"
 20        self.tunnel_websocket_url = f"wss://{subdomain}.{tunnel_host}"
 21
 22        # Set up logging
 23        self.logger = logging.getLogger(__name__)
 24        self.logger.setLevel(getattr(logging, log_level.upper()))
 25        ch = logging.StreamHandler()
 26        ch.setLevel(getattr(logging, log_level.upper()))
 27        formatter = logging.Formatter("%(message)s")
 28        ch.setFormatter(formatter)
 29        self.logger.addHandler(ch)
 30
 31        # Store incoming requests
 32        self.pending_requests = {}
 33
 34        # Create the event loop
 35        self.loop = asyncio.new_event_loop()
 36        asyncio.set_event_loop(self.loop)
 37        self.stop_event = asyncio.Event()
 38
 39    async def connect(self):
 40        retry_count = 0
 41        max_retries = 5
 42        while not self.stop_event.is_set():
 43            if retry_count >= max_retries:
 44                self.logger.error(
 45                    f"Failed to connect after {max_retries} retries. Exiting."
 46                )
 47                break
 48            try:
 49                self.logger.debug(
 50                    f"Connecting to WebSocket URL: {self.tunnel_websocket_url}"
 51                )
 52                async with websockets.connect(
 53                    self.tunnel_websocket_url, max_size=None
 54                ) as websocket:
 55                    self.logger.debug("WebSocket connection established")
 56                    click.secho(
 57                        f"Connected to tunnel {self.tunnel_http_url}", fg="green"
 58                    )
 59                    retry_count = 0  # Reset retry count on successful connection
 60                    await self.forward_request(websocket)
 61            except (websockets.ConnectionClosed, ConnectionError) as e:
 62                if self.stop_event.is_set():
 63                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 64                    break
 65                retry_count += 1
 66                self.logger.warning(
 67                    f"Connection lost: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
 68                )
 69                await asyncio.sleep(2)
 70            except asyncio.CancelledError:
 71                self.logger.debug("Connection cancelled")
 72                break
 73            except Exception as e:
 74                if self.stop_event.is_set():
 75                    self.logger.debug("Stopping reconnect attempts due to shutdown")
 76                    break
 77                retry_count += 1
 78                self.logger.error(
 79                    f"Unexpected error: {e}. Retrying in 2 seconds... ({retry_count}/{max_retries})"
 80                )
 81                await asyncio.sleep(2)
 82
 83    async def forward_request(self, websocket):
 84        try:
 85            async for message in websocket:
 86                if isinstance(message, str):
 87                    # Received text message (metadata)
 88                    self.logger.debug("Received metadata from worker")
 89                    data = json.loads(message)
 90                    await self.handle_request_metadata(websocket, data)
 91                elif isinstance(message, bytes):
 92                    # Received binary message (body chunk)
 93                    self.logger.debug("Received binary data from worker")
 94                    await self.handle_request_body_chunk(websocket, message)
 95                else:
 96                    self.logger.warning("Received unknown message type")
 97        except asyncio.CancelledError:
 98            self.logger.debug("Forward request cancelled")
 99        except Exception as e:
100            self.logger.error(f"Error in forward_request: {e}")
101            raise
102
103    async def handle_request_metadata(self, websocket, data):
104        request_id = data["id"]
105        has_body = data.get("has_body", False)
106        total_body_chunks = data.get("totalBodyChunks", 0)
107        self.pending_requests[request_id] = {
108            "metadata": data,
109            "body_chunks": {},
110            "has_body": has_body,
111            "total_body_chunks": total_body_chunks,
112        }
113        self.logger.debug(
114            f"Stored metadata for request ID: {request_id}, has_body: {has_body}"
115        )
116        await self.check_and_process_request(websocket, request_id)
117
118    async def handle_request_body_chunk(self, websocket, chunk_data):
119        offset = 0
120
121        # Extract id_length
122        id_length = int.from_bytes(chunk_data[offset : offset + 4], byteorder="little")
123        offset += 4
124
125        # Extract request_id
126        request_id = chunk_data[offset : offset + id_length].decode("utf-8")
127        offset += id_length
128
129        # Extract chunk_index
130        chunk_index = int.from_bytes(
131            chunk_data[offset : offset + 4], byteorder="little"
132        )
133        offset += 4
134
135        # Extract total_chunks
136        total_chunks = int.from_bytes(
137            chunk_data[offset : offset + 4], byteorder="little"
138        )
139        offset += 4
140
141        # Extract body_chunk
142        body_chunk = chunk_data[offset:]
143
144        # Continue processing as before
145
146        if request_id in self.pending_requests:
147            request = self.pending_requests[request_id]
148            if "body_chunks" not in request:
149                request["body_chunks"] = {}
150                request["total_body_chunks"] = total_chunks
151            request["body_chunks"][chunk_index] = body_chunk
152            self.logger.debug(
153                f"Stored body chunk {chunk_index + 1}/{total_chunks} for request ID: {request_id}"
154            )
155            await self.check_and_process_request(websocket, request_id)
156        else:
157            self.logger.warning(
158                f"Received body chunk for unknown or completed request ID: {request_id}"
159            )
160
161    async def check_and_process_request(self, websocket, request_id):
162        request_data = self.pending_requests.get(request_id)
163        if request_data and request_data["metadata"]:
164            has_body = request_data["has_body"]
165            total_body_chunks = request_data.get("total_body_chunks", 0)
166            body_chunks = request_data.get("body_chunks", {})
167
168            all_chunks_received = not has_body or (
169                len(body_chunks) == total_body_chunks
170            )
171
172            if all_chunks_received:
173                # Ensure all chunks are present
174                for i in range(total_body_chunks):
175                    if i not in body_chunks:
176                        self.logger.error(
177                            f"Missing chunk {i + 1}/{total_body_chunks} for request ID: {request_id}"
178                        )
179                        return
180
181                self.logger.debug(f"Processing request ID: {request_id}")
182                await self.process_request(
183                    websocket, request_data["metadata"], body_chunks, request_id
184                )
185                del self.pending_requests[request_id]
186
187    async def process_request(
188        self, websocket, request_metadata, body_chunks, request_id
189    ):
190        self.logger.debug(
191            f"Processing request: {request_id} {request_metadata['method']} {request_metadata['url']}"
192        )
193
194        # Reassemble body if present
195        if request_metadata["has_body"]:
196            total_chunks = request_metadata["totalBodyChunks"]
197            body_data = b"".join(body_chunks[i] for i in range(total_chunks))
198        else:
199            body_data = None
200
201        # Parse the original URL to extract the path and query
202        parsed_url = urlparse(request_metadata["url"])
203        path_and_query = parsed_url.path
204        if parsed_url.query:
205            path_and_query += f"?{parsed_url.query}"
206
207        # Construct the new URL by appending path and query to destination_url
208        forward_url = self.destination_url + path_and_query
209
210        self.logger.debug(f"Forwarding request to: {forward_url}")
211
212        # Forward the request to the destination URL using urllib
213        try:
214            # Create a custom SSL context to ignore SSL verification (if needed)
215            ssl_context = ssl.create_default_context()
216            ssl_context.check_hostname = False
217            ssl_context.verify_mode = ssl.CERT_NONE
218
219            # Prepare the request
220            req = urllib.request.Request(
221                url=forward_url,
222                method=request_metadata["method"],
223                data=body_data if body_data else None,
224                headers=request_metadata["headers"],
225            )
226
227            # Make the request
228            with urllib.request.urlopen(req, context=ssl_context) as response:
229                response_body = response.read()
230                response_headers = dict(response.getheaders())
231                response_status = response.getcode()
232                self.logger.debug(
233                    f"Received response with status code: {response_status}"
234                )
235
236        except urllib.error.HTTPError as e:
237            # Non-200 status codes are here (even ones we want)
238            self.logger.debug(f"HTTPError forwarding request: {e}")
239            response_body = e.read()
240            response_headers = dict(e.headers)
241            response_status = e.code
242
243        except urllib.error.URLError as e:
244            self.logger.error(f"URLError forwarding request: {e}")
245            response_body = b""
246            response_headers = {}
247            response_status = 500
248
249        self.logger.info(
250            f"{click.style(request_metadata['method'], bold=True)} {request_metadata['url']} {response_status}"
251        )
252
253        has_body = len(response_body) > 0
254        max_chunk_size = 1000000  # 1,000,000 bytes
255        total_body_chunks = (
256            (len(response_body) + max_chunk_size - 1) // max_chunk_size
257            if has_body
258            else 0
259        )
260
261        response_metadata = {
262            "id": request_id,
263            "status": response_status,
264            "headers": list(response_headers.items()),
265            "has_body": has_body,
266            "totalBodyChunks": total_body_chunks,
267        }
268
269        # Send response metadata
270        response_metadata_json = json.dumps(response_metadata)
271        self.logger.debug(
272            f"Sending response metadata for ID: {request_id}, has_body: {has_body}"
273        )
274        await websocket.send(response_metadata_json)
275
276        # Send response body chunks if present
277        if has_body:
278            self.logger.debug(
279                f"Sending {total_body_chunks} body chunks for ID: {request_id}"
280            )
281            id_bytes = request_id.encode("utf-8")
282            for i in range(total_body_chunks):
283                chunk_start = i * max_chunk_size
284                chunk_end = min(chunk_start + max_chunk_size, len(response_body))
285                body_chunk = response_body[chunk_start:chunk_end]
286
287                # Prepare the binary message
288                chunk_index_bytes = i.to_bytes(4, byteorder="little")
289                total_chunks_bytes = total_body_chunks.to_bytes(4, byteorder="little")
290                message = id_bytes + chunk_index_bytes + total_chunks_bytes + body_chunk
291                await websocket.send(message)
292                self.logger.debug(
293                    f"Sent body chunk {i + 1}/{total_body_chunks} for ID: {request_id}"
294                )
295        else:
296            self.logger.debug(f"No body to send for ID: {request_id}")
297
298    async def shutdown(self):
299        self.stop_event.set()
300        tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
301        if tasks:
302            self.logger.debug(f"Cancelling {len(tasks)} outstanding tasks")
303            for task in tasks:
304                task.cancel()
305            await asyncio.gather(*tasks, return_exceptions=True)
306        await self.loop.shutdown_asyncgens()
307
308    def run(self):
309        try:
310            self.loop.run_until_complete(self.connect())
311        except KeyboardInterrupt:
312            self.logger.debug("Received exit signal")
313        finally:
314            self.logger.debug("Shutting down...")
315            self.loop.run_until_complete(self.shutdown())
316            self.loop.close()