1"""Remote side of a portal session.
  2
  3Runs on the production machine. Connects to the relay, prints a portal
  4code, waits for the local side to connect, then executes commands as
  5they arrive through the encrypted tunnel.
  6"""
  7
  8from __future__ import annotations
  9
 10import ast
 11import asyncio
 12import base64
 13import contextlib
 14import json
 15import os
 16import sys
 17import traceback
 18from contextlib import redirect_stderr, redirect_stdout
 19from datetime import datetime
 20
 21from websockets.asyncio.client import ClientConnection
 22from websockets.asyncio.client import connect as ws_connect
 23
 24from .codegen import generate_code
 25from .crypto import PortalEncryptor, channel_id, perform_key_exchange
 26from .protocol import (
 27    DEFAULT_EXEC_TIMEOUT,
 28    DEFAULT_RELAY_HOST,
 29    FILE_CHUNK_SIZE,
 30    MAX_FILE_SIZE,
 31    chunk_count,
 32    make_error,
 33    make_exec_result,
 34    make_exec_stdout,
 35    make_file_data,
 36    make_file_push_result,
 37    make_ping,
 38    make_pong,
 39    make_relay_url,
 40)
 41
 42_real_stdout = sys.stdout
 43
 44
 45def _log(msg: str) -> None:
 46    ts = datetime.now().strftime("%H:%M:%S")
 47    _real_stdout.write(f"[{ts}] {msg}\n")
 48    _real_stdout.flush()
 49
 50
 51async def _send_error(
 52    ws: ClientConnection,
 53    encryptor: PortalEncryptor,
 54    req_id: int | None,
 55    error_text: str,
 56) -> None:
 57    """Send an error response back through the tunnel."""
 58    msg = make_error(error_text)
 59    msg["_req_id"] = req_id
 60    await ws.send(encryptor.encrypt_message(msg))
 61
 62
 63class _TunnelWriter:
 64    """File-like that streams writes through the tunnel as exec_stdout messages."""
 65
 66    def __init__(
 67        self,
 68        loop: asyncio.AbstractEventLoop,
 69        ws: ClientConnection,
 70        encryptor: PortalEncryptor,
 71        req_id: int | None,
 72    ) -> None:
 73        self._loop = loop
 74        self._ws = ws
 75        self._encryptor = encryptor
 76        self._req_id = req_id
 77        self._buffer = ""
 78
 79    def write(self, s: str) -> int:
 80        self._buffer += s
 81        while "\n" in self._buffer:
 82            line, self._buffer = self._buffer.split("\n", 1)
 83            self._send(line + "\n")
 84        return len(s)
 85
 86    def flush(self) -> None:
 87        if self._buffer:
 88            self._send(self._buffer)
 89            self._buffer = ""
 90
 91    def _send(self, data: str) -> None:
 92        msg = make_exec_stdout(data)
 93        msg["_req_id"] = self._req_id
 94        future = asyncio.run_coroutine_threadsafe(
 95            self._ws.send(self._encryptor.encrypt_message(msg)),
 96            self._loop,
 97        )
 98        try:
 99            future.result(timeout=30)
100        except Exception:
101            pass  # Don't crash exec for a send failure
102
103
104async def run_remote(
105    *,
106    writable: bool = False,
107    timeout_minutes: int = 30,
108    relay_host: str = DEFAULT_RELAY_HOST,
109) -> None:
110    """Start the remote side of a portal session."""
111
112    code = generate_code()
113
114    mode = "writable" if writable else "read-only"
115    print(f"Portal code: {code}")
116    print(f"Session mode: {mode}")
117    print("Waiting for connection...")
118    print()
119
120    cid = channel_id(code)
121    relay_url = make_relay_url(relay_host, cid, "start")
122
123    max_output = (
124        1024 * 1024
125    )  # 1MB โ€” truncate return values to prevent massive relay payloads
126    tmp_prefix = os.path.realpath("/tmp")  # Resolve once (macOS: /tmp โ†’ /private/tmp)
127
128    def execute_code(
129        code_str: str,
130        *,
131        json_output: bool = False,
132        output_writer: _TunnelWriter,
133    ) -> dict:
134        """Execute Python code, streaming stdout through the tunnel.
135
136        Each execution gets a fresh namespace. The last expression's value
137        is captured as the return value (like the interactive REPL).
138        """
139        namespace: dict = {}
140        return_value = None
141        error = None
142
143        try:
144            tree = ast.parse(code_str, mode="exec")
145
146            last_expr: ast.Expr | None = None
147            if tree.body and isinstance(tree.body[-1], ast.Expr):
148                last_expr = tree.body.pop()  # type: ignore[assignment]
149
150            # Process-global redirect โ€” safe because _log() uses _real_stdout
151            ctx = contextlib.ExitStack()
152            ctx.enter_context(redirect_stdout(output_writer))
153            ctx.enter_context(redirect_stderr(output_writer))
154            if not writable:
155                try:
156                    from plain.postgres.connections import read_only
157
158                    ctx.enter_context(read_only())
159                except Exception:
160                    pass  # No DB configured or plain-postgres not installed
161
162            with ctx:
163                if tree.body:
164                    compiled = compile(tree, "<portal>", "exec")
165                    exec(compiled, namespace)  # noqa: S102
166
167                if last_expr is not None:
168                    expr_code = compile(
169                        ast.Expression(last_expr.value), "<portal>", "eval"
170                    )
171                    result = eval(expr_code, namespace)  # noqa: S307
172                    if result is not None:
173                        if json_output:
174                            try:
175                                return_value = json.dumps(result)
176                            except (TypeError, ValueError):
177                                return_value = repr(result)
178                        else:
179                            return_value = repr(result)
180
181        except BaseException:
182            error = traceback.format_exc()
183        finally:
184            # Flush any remaining buffered output
185            output_writer.flush()
186            # Close DB connection to prevent leaks across to_thread calls
187            try:
188                from plain.postgres.connections import get_connection, has_connection
189
190                if has_connection():
191                    get_connection().close()
192            except Exception:
193                pass
194
195        if return_value and len(return_value) > max_output:
196            return_value = (
197                return_value[:max_output]
198                + f"\n... truncated ({len(return_value)} bytes total)"
199            )
200
201        return {
202            "return_value": return_value,
203            "error": error,
204        }
205
206    async def handle_file_pull(remote_path: str, req_id: int | None) -> None:
207        """Read a file from disk and send it in chunks."""
208        try:
209            file_size = os.path.getsize(remote_path)
210            if file_size > MAX_FILE_SIZE:
211                await _send_error(
212                    ws,
213                    encryptor,
214                    req_id,
215                    f"File too large: {file_size} bytes (max {MAX_FILE_SIZE})",
216                )
217                return
218
219            name = os.path.basename(remote_path)
220            chunks = chunk_count(file_size)
221
222            _log(f"       sending {name} ({file_size} bytes, {chunks} chunks)")
223
224            with open(remote_path, "rb") as f:
225                for i in range(chunks):
226                    data = f.read(FILE_CHUNK_SIZE)
227                    msg = make_file_data(name=name, chunk=i, chunks=chunks, data=data)
228                    msg["_req_id"] = req_id
229                    await ws.send(encryptor.encrypt_message(msg))
230
231        except FileNotFoundError:
232            await _send_error(ws, encryptor, req_id, f"File not found: {remote_path}")
233        except (PermissionError, IsADirectoryError, OSError) as e:
234            await _send_error(ws, encryptor, req_id, f"{type(e).__name__}: {e}")
235
236    async def handle_file_push(msg: dict) -> None:
237        """Receive a file chunk and write it to disk."""
238        req_id = msg.get("_req_id")
239        remote_path = msg["remote_path"]
240        chunk_idx = msg["chunk"]
241        chunks = msg["chunks"]
242        data = base64.b64decode(msg["data"])
243
244        resolved = os.path.realpath(remote_path)
245        if not resolved.startswith(tmp_prefix + "/"):
246            await _send_error(
247                ws,
248                encryptor,
249                req_id,
250                f"Push restricted to /tmp/. Got: {remote_path} (resolved: {resolved})",
251            )
252            return
253
254        if chunk_idx == 0:
255            _log(f"push: {remote_path} ({chunks} chunks)")
256
257        try:
258            mode = "wb" if chunk_idx == 0 else "ab"
259            with open(remote_path, mode) as f:
260                f.write(data)
261        except OSError as e:
262            await _send_error(ws, encryptor, req_id, f"{type(e).__name__}: {e}")
263            return
264
265        # Ack every chunk so the sender doesn't block waiting
266        if chunk_idx == chunks - 1:
267            total_bytes = os.path.getsize(remote_path)
268            _log(f"       received {total_bytes} bytes")
269            result = make_file_push_result(path=remote_path, total_bytes=total_bytes)
270        else:
271            result = {"type": "file_push_ack", "chunk": chunk_idx}
272        result["_req_id"] = req_id
273        await ws.send(encryptor.encrypt_message(result))
274
275    async with ws_connect(relay_url) as ws:
276        encryptor = await perform_key_exchange(ws, code, side="start")
277        _log("Connected from remote client.")
278
279        last_activity = asyncio.get_running_loop().time()
280
281        async def check_timeout() -> None:
282            nonlocal last_activity
283            if timeout_minutes <= 0:
284                return
285            while True:
286                await asyncio.sleep(60)
287                idle = asyncio.get_running_loop().time() - last_activity
288                remaining = (timeout_minutes * 60) - idle
289                if remaining <= 60 and remaining > 0:
290                    print(
291                        f"\nWarning: session will timeout in {int(remaining)} seconds due to inactivity.",
292                        flush=True,
293                    )
294                if idle >= timeout_minutes * 60:
295                    print(
296                        "\nSession timed out due to inactivity.",
297                        flush=True,
298                    )
299                    await ws.close()
300                    return
301
302        timeout_task = asyncio.create_task(check_timeout())
303
304        async def send_keepalive_pings() -> None:
305            while True:
306                await asyncio.sleep(30)
307                await ws.send(encryptor.encrypt_message(make_ping()))
308
309        keepalive_task = asyncio.create_task(send_keepalive_pings())
310
311        try:
312            async for raw in ws:
313                last_activity = asyncio.get_running_loop().time()
314
315                if isinstance(raw, str):
316                    continue
317
318                msg = encryptor.decrypt_message(raw)
319                msg_type = msg.get("type")
320
321                if msg_type == "ping":
322                    await ws.send(encryptor.encrypt_message(make_pong()))
323
324                elif msg_type == "pong":
325                    pass
326
327                elif msg_type == "exec":
328                    req_id = msg.get("_req_id")
329                    code_str = msg["code"]
330                    json_output = msg.get("json_output", False)
331                    exec_timeout = msg.get("timeout", DEFAULT_EXEC_TIMEOUT)
332                    _log(
333                        f"exec: {code_str[:200]}{'...' if len(code_str) > 200 else ''}"
334                    )
335                    # Create a writer that streams stdout through the tunnel
336                    tunnel_writer = _TunnelWriter(
337                        asyncio.get_running_loop(), ws, encryptor, req_id
338                    )
339                    try:
340                        result = await asyncio.wait_for(
341                            asyncio.to_thread(
342                                execute_code,
343                                code_str,
344                                json_output=json_output,
345                                output_writer=tunnel_writer,
346                            ),
347                            timeout=exec_timeout,
348                        )
349                    except TimeoutError:
350                        result = {
351                            "return_value": None,
352                            "error": f"Execution timed out ({exec_timeout} seconds). The code may still be running in the background.",
353                        }
354                    return_value = result.get("return_value")
355                    error = result.get("error")
356                    display = return_value or error or ""
357                    if display:
358                        _log(
359                            f"       โ†’ {display[:200]}{'...' if len(display) > 200 else ''}"
360                        )
361                    # Send final result โ€” stdout was already streamed
362                    response = make_exec_result(
363                        return_value=return_value,
364                        error=error,
365                    )
366                    response["_req_id"] = req_id
367                    await ws.send(encryptor.encrypt_message(response))
368
369                elif msg_type == "file_pull":
370                    req_id = msg.get("_req_id")
371                    remote_path = msg["remote_path"]
372                    _log(f"pull: {remote_path}")
373                    await handle_file_pull(remote_path, req_id)
374
375                elif msg_type == "file_push":
376                    await handle_file_push(msg)
377
378                else:
379                    _log(f"Unknown message type: {msg_type}")
380
381        finally:
382            timeout_task.cancel()
383            keepalive_task.cancel()
384
385    _log("Client disconnected.")