1from __future__ import annotations
2
3import asyncio
4import base64
5import json
6import os
7import sys
8
9import click
10
11from plain.cli import register_cli
12
13from .protocol import (
14 DEFAULT_EXEC_TIMEOUT,
15 DEFAULT_RELAY_HOST,
16 FILE_CHUNK_SIZE,
17 MAX_FILE_SIZE,
18 chunk_count,
19 make_exec,
20 make_file_pull,
21 make_file_push,
22)
23
24
25def _check_response(response: dict) -> None:
26 """Exit with an error message if the response indicates failure."""
27 error = response.get("error")
28 if error:
29 print(error, file=sys.stderr)
30 sys.exit(1)
31
32
33@register_cli("portal")
34@click.group()
35def cli() -> None:
36 """Remote Python shell and file transfer via encrypted tunnel."""
37
38
39@cli.command()
40@click.option(
41 "--writable", is_flag=True, help="Allow database writes (default: read-only)."
42)
43@click.option(
44 "--timeout",
45 default=30,
46 type=int,
47 help="Idle timeout in minutes (0 to disable).",
48)
49@click.option(
50 "--relay-host",
51 envvar="PLAIN_PORTAL_RELAY_HOST",
52 default=DEFAULT_RELAY_HOST,
53 hidden=True,
54)
55def start(writable: bool, timeout: int, relay_host: str) -> None:
56 """Start a portal session on the remote machine."""
57 if writable:
58 if not click.confirm(
59 "This session allows writes to the production database. Continue?"
60 ):
61 return
62
63 from .remote import run_remote
64
65 asyncio.run(
66 run_remote(writable=writable, timeout_minutes=timeout, relay_host=relay_host)
67 )
68
69
70@cli.command()
71@click.argument("code")
72@click.option(
73 "--foreground",
74 is_flag=True,
75 help="Run in foreground instead of backgrounding.",
76)
77@click.option(
78 "--relay-host",
79 envvar="PLAIN_PORTAL_RELAY_HOST",
80 default=DEFAULT_RELAY_HOST,
81 hidden=True,
82)
83def connect(code: str, foreground: bool, relay_host: str) -> None:
84 """Connect to a remote portal session."""
85 from .local import connect as do_connect
86
87 asyncio.run(do_connect(code, relay_host=relay_host, foreground=foreground))
88
89
90@cli.command("exec")
91@click.argument("code")
92@click.option("--json", "json_output", is_flag=True, help="Output as JSON.")
93@click.option(
94 "--timeout",
95 default=DEFAULT_EXEC_TIMEOUT,
96 type=int,
97 help=f"Execution timeout in seconds (default: {DEFAULT_EXEC_TIMEOUT}).",
98)
99def exec_command(code: str, json_output: bool, timeout: int) -> None:
100 """Execute Python code on the remote machine."""
101 from .local import send_exec_streaming
102
103 request = make_exec(code, json_output=json_output, timeout=timeout)
104
105 # Collect streamed stdout for --json mode, print directly otherwise
106 stdout_parts: list[str] = []
107
108 def on_stdout(data: str) -> None:
109 if json_output:
110 stdout_parts.append(data)
111 else:
112 print(data, end="", flush=True)
113
114 response = asyncio.run(send_exec_streaming(request, on_stdout))
115 _check_response(response)
116
117 if json_output:
118 print(
119 json.dumps(
120 {
121 "stdout": "".join(stdout_parts),
122 "return_value": response.get("return_value"),
123 "error": response.get("error"),
124 }
125 )
126 )
127 else:
128 return_value = response.get("return_value")
129 if return_value is not None:
130 print(f"โ {return_value}")
131
132
133@cli.command()
134@click.argument("remote_path")
135@click.argument("local_path")
136def pull(remote_path: str, local_path: str) -> None:
137 """Pull a file from the remote machine."""
138 from .local import send_command
139
140 request = make_file_pull(remote_path)
141 response = asyncio.run(send_command(request))
142 _check_response(response)
143
144 if response.get("type") == "file_data":
145 data = base64.b64decode(response["data"])
146 with open(local_path, "wb") as f:
147 f.write(data)
148 print(f"Pulled {remote_path} โ {local_path} ({len(data)} bytes)")
149 else:
150 print(f"Unexpected response: {response}", file=sys.stderr)
151 sys.exit(1)
152
153
154@cli.command()
155@click.argument("local_path")
156@click.argument("remote_path")
157def push(local_path: str, remote_path: str) -> None:
158 """Push a file to the remote machine."""
159 from .local import send_command
160
161 if not os.path.exists(local_path):
162 print(f"File not found: {local_path}", file=sys.stderr)
163 sys.exit(1)
164
165 file_size = os.path.getsize(local_path)
166 if file_size > MAX_FILE_SIZE:
167 print(
168 f"File too large: {file_size} bytes (max {MAX_FILE_SIZE})",
169 file=sys.stderr,
170 )
171 sys.exit(1)
172
173 async def _push_all() -> dict:
174 chunks = chunk_count(file_size)
175 response = {}
176 with open(local_path, "rb") as f:
177 for i in range(chunks):
178 data = f.read(FILE_CHUNK_SIZE)
179 request = make_file_push(
180 remote_path=remote_path, chunk=i, chunks=chunks, data=data
181 )
182 response = await send_command(request)
183 if response.get("error"):
184 return response
185 return response
186
187 response = asyncio.run(_push_all())
188 _check_response(response)
189
190 total_bytes = response.get("bytes", file_size)
191 print(f"Pushed {local_path} โ {remote_path} ({total_bytes} bytes)")
192
193
194@cli.command()
195def disconnect() -> None:
196 """Disconnect the active portal session."""
197 from .local import disconnect as do_disconnect
198
199 do_disconnect()
200
201
202@cli.command()
203def status() -> None:
204 """Show portal session status."""
205 from .local import status as do_status
206
207 do_status()