1from __future__ import annotations
  2
  3import asyncio
  4import base64
  5import json
  6import os
  7import sys
  8
  9import click
 10
 11from plain.cli import register_cli
 12
 13from .protocol import (
 14    DEFAULT_EXEC_TIMEOUT,
 15    DEFAULT_RELAY_HOST,
 16    FILE_CHUNK_SIZE,
 17    MAX_FILE_SIZE,
 18    chunk_count,
 19    make_exec,
 20    make_file_pull,
 21    make_file_push,
 22)
 23
 24
 25def _check_response(response: dict) -> None:
 26    """Exit with an error message if the response indicates failure."""
 27    error = response.get("error")
 28    if error:
 29        print(error, file=sys.stderr)
 30        sys.exit(1)
 31
 32
 33@register_cli("portal")
 34@click.group()
 35def cli() -> None:
 36    """Remote Python shell and file transfer via encrypted tunnel."""
 37
 38
 39@cli.command()
 40@click.option(
 41    "--writable", is_flag=True, help="Allow database writes (default: read-only)."
 42)
 43@click.option(
 44    "--timeout",
 45    default=30,
 46    type=int,
 47    help="Idle timeout in minutes (0 to disable).",
 48)
 49@click.option(
 50    "--relay-host",
 51    envvar="PLAIN_PORTAL_RELAY_HOST",
 52    default=DEFAULT_RELAY_HOST,
 53    hidden=True,
 54)
 55def start(writable: bool, timeout: int, relay_host: str) -> None:
 56    """Start a portal session on the remote machine."""
 57    if writable:
 58        if not click.confirm(
 59            "This session allows writes to the production database. Continue?"
 60        ):
 61            return
 62
 63    from .remote import run_remote
 64
 65    asyncio.run(
 66        run_remote(writable=writable, timeout_minutes=timeout, relay_host=relay_host)
 67    )
 68
 69
 70@cli.command()
 71@click.argument("code")
 72@click.option(
 73    "--foreground",
 74    is_flag=True,
 75    help="Run in foreground instead of backgrounding.",
 76)
 77@click.option(
 78    "--relay-host",
 79    envvar="PLAIN_PORTAL_RELAY_HOST",
 80    default=DEFAULT_RELAY_HOST,
 81    hidden=True,
 82)
 83def connect(code: str, foreground: bool, relay_host: str) -> None:
 84    """Connect to a remote portal session."""
 85    from .local import connect as do_connect
 86
 87    asyncio.run(do_connect(code, relay_host=relay_host, foreground=foreground))
 88
 89
 90@cli.command("exec")
 91@click.argument("code")
 92@click.option("--json", "json_output", is_flag=True, help="Output as JSON.")
 93@click.option(
 94    "--timeout",
 95    default=DEFAULT_EXEC_TIMEOUT,
 96    type=int,
 97    help=f"Execution timeout in seconds (default: {DEFAULT_EXEC_TIMEOUT}).",
 98)
 99def exec_command(code: str, json_output: bool, timeout: int) -> None:
100    """Execute Python code on the remote machine."""
101    from .local import send_exec_streaming
102
103    request = make_exec(code, json_output=json_output, timeout=timeout)
104
105    # Collect streamed stdout for --json mode, print directly otherwise
106    stdout_parts: list[str] = []
107
108    def on_stdout(data: str) -> None:
109        if json_output:
110            stdout_parts.append(data)
111        else:
112            print(data, end="", flush=True)
113
114    response = asyncio.run(send_exec_streaming(request, on_stdout))
115    _check_response(response)
116
117    if json_output:
118        print(
119            json.dumps(
120                {
121                    "stdout": "".join(stdout_parts),
122                    "return_value": response.get("return_value"),
123                    "error": response.get("error"),
124                }
125            )
126        )
127    else:
128        return_value = response.get("return_value")
129        if return_value is not None:
130            print(f"โ†’ {return_value}")
131
132
133@cli.command()
134@click.argument("remote_path")
135@click.argument("local_path")
136def pull(remote_path: str, local_path: str) -> None:
137    """Pull a file from the remote machine."""
138    from .local import send_command
139
140    request = make_file_pull(remote_path)
141    response = asyncio.run(send_command(request))
142    _check_response(response)
143
144    if response.get("type") == "file_data":
145        data = base64.b64decode(response["data"])
146        with open(local_path, "wb") as f:
147            f.write(data)
148        print(f"Pulled {remote_path} โ†’ {local_path} ({len(data)} bytes)")
149    else:
150        print(f"Unexpected response: {response}", file=sys.stderr)
151        sys.exit(1)
152
153
154@cli.command()
155@click.argument("local_path")
156@click.argument("remote_path")
157def push(local_path: str, remote_path: str) -> None:
158    """Push a file to the remote machine."""
159    from .local import send_command
160
161    if not os.path.exists(local_path):
162        print(f"File not found: {local_path}", file=sys.stderr)
163        sys.exit(1)
164
165    file_size = os.path.getsize(local_path)
166    if file_size > MAX_FILE_SIZE:
167        print(
168            f"File too large: {file_size} bytes (max {MAX_FILE_SIZE})",
169            file=sys.stderr,
170        )
171        sys.exit(1)
172
173    async def _push_all() -> dict:
174        chunks = chunk_count(file_size)
175        response = {}
176        with open(local_path, "rb") as f:
177            for i in range(chunks):
178                data = f.read(FILE_CHUNK_SIZE)
179                request = make_file_push(
180                    remote_path=remote_path, chunk=i, chunks=chunks, data=data
181                )
182                response = await send_command(request)
183                if response.get("error"):
184                    return response
185        return response
186
187    response = asyncio.run(_push_all())
188    _check_response(response)
189
190    total_bytes = response.get("bytes", file_size)
191    print(f"Pushed {local_path} โ†’ {remote_path} ({total_bytes} bytes)")
192
193
194@cli.command()
195def disconnect() -> None:
196    """Disconnect the active portal session."""
197    from .local import disconnect as do_disconnect
198
199    do_disconnect()
200
201
202@cli.command()
203def status() -> None:
204    """Show portal session status."""
205    from .local import status as do_status
206
207    do_status()