1"""Remote side of a portal session.
2
3Runs on the production machine. Connects to the relay, prints a portal
4code, waits for the local side to connect, then executes commands as
5they arrive through the encrypted tunnel.
6"""
7
8from __future__ import annotations
9
10import ast
11import asyncio
12import base64
13import contextlib
14import json
15import os
16import sys
17import traceback
18from contextlib import redirect_stderr, redirect_stdout
19from datetime import datetime
20
21from websockets.asyncio.client import ClientConnection
22from websockets.asyncio.client import connect as ws_connect
23
24from .codegen import generate_code
25from .crypto import PortalEncryptor, channel_id, perform_key_exchange
26from .protocol import (
27 DEFAULT_EXEC_TIMEOUT,
28 DEFAULT_RELAY_HOST,
29 FILE_CHUNK_SIZE,
30 MAX_FILE_SIZE,
31 chunk_count,
32 make_error,
33 make_exec_result,
34 make_exec_stdout,
35 make_file_data,
36 make_file_push_result,
37 make_ping,
38 make_pong,
39 make_relay_url,
40)
41
42_real_stdout = sys.stdout
43
44
45def _log(msg: str) -> None:
46 ts = datetime.now().strftime("%H:%M:%S")
47 _real_stdout.write(f"[{ts}] {msg}\n")
48 _real_stdout.flush()
49
50
51async def _send_error(
52 ws: ClientConnection,
53 encryptor: PortalEncryptor,
54 req_id: int | None,
55 error_text: str,
56) -> None:
57 """Send an error response back through the tunnel."""
58 msg = make_error(error_text)
59 msg["_req_id"] = req_id
60 await ws.send(encryptor.encrypt_message(msg))
61
62
63class _TunnelWriter:
64 """File-like that streams writes through the tunnel as exec_stdout messages."""
65
66 def __init__(
67 self,
68 loop: asyncio.AbstractEventLoop,
69 ws: ClientConnection,
70 encryptor: PortalEncryptor,
71 req_id: int | None,
72 ) -> None:
73 self._loop = loop
74 self._ws = ws
75 self._encryptor = encryptor
76 self._req_id = req_id
77 self._buffer = ""
78
79 def write(self, s: str) -> int:
80 self._buffer += s
81 while "\n" in self._buffer:
82 line, self._buffer = self._buffer.split("\n", 1)
83 self._send(line + "\n")
84 return len(s)
85
86 def flush(self) -> None:
87 if self._buffer:
88 self._send(self._buffer)
89 self._buffer = ""
90
91 def _send(self, data: str) -> None:
92 msg = make_exec_stdout(data)
93 msg["_req_id"] = self._req_id
94 future = asyncio.run_coroutine_threadsafe(
95 self._ws.send(self._encryptor.encrypt_message(msg)),
96 self._loop,
97 )
98 try:
99 future.result(timeout=30)
100 except Exception:
101 pass # Don't crash exec for a send failure
102
103
104async def run_remote(
105 *,
106 writable: bool = False,
107 timeout_minutes: int = 30,
108 relay_host: str = DEFAULT_RELAY_HOST,
109) -> None:
110 """Start the remote side of a portal session."""
111
112 code = generate_code()
113
114 mode = "writable" if writable else "read-only"
115 print(f"Portal code: {code}")
116 print(f"Session mode: {mode}")
117 print("Waiting for connection...")
118 print()
119
120 cid = channel_id(code)
121 relay_url = make_relay_url(relay_host, cid, "start")
122
123 max_output = (
124 1024 * 1024
125 ) # 1MB โ truncate return values to prevent massive relay payloads
126 tmp_prefix = os.path.realpath("/tmp") # Resolve once (macOS: /tmp โ /private/tmp)
127
128 def execute_code(
129 code_str: str,
130 *,
131 json_output: bool = False,
132 output_writer: _TunnelWriter,
133 ) -> dict:
134 """Execute Python code, streaming stdout through the tunnel.
135
136 Each execution gets a fresh namespace. The last expression's value
137 is captured as the return value (like the interactive REPL).
138 """
139 namespace: dict = {}
140 return_value = None
141 error = None
142
143 try:
144 tree = ast.parse(code_str, mode="exec")
145
146 last_expr: ast.Expr | None = None
147 if tree.body and isinstance(tree.body[-1], ast.Expr):
148 last_expr = tree.body.pop() # type: ignore[assignment]
149
150 # Process-global redirect โ safe because _log() uses _real_stdout
151 ctx = contextlib.ExitStack()
152 ctx.enter_context(redirect_stdout(output_writer))
153 ctx.enter_context(redirect_stderr(output_writer))
154 if not writable:
155 try:
156 from plain.postgres.connections import read_only
157
158 ctx.enter_context(read_only())
159 except Exception:
160 pass # No DB configured or plain-postgres not installed
161
162 with ctx:
163 if tree.body:
164 compiled = compile(tree, "<portal>", "exec")
165 exec(compiled, namespace) # noqa: S102
166
167 if last_expr is not None:
168 expr_code = compile(
169 ast.Expression(last_expr.value), "<portal>", "eval"
170 )
171 result = eval(expr_code, namespace) # noqa: S307
172 if result is not None:
173 if json_output:
174 try:
175 return_value = json.dumps(result)
176 except (TypeError, ValueError):
177 return_value = repr(result)
178 else:
179 return_value = repr(result)
180
181 except BaseException:
182 error = traceback.format_exc()
183 finally:
184 # Flush any remaining buffered output
185 output_writer.flush()
186 # Close DB connection to prevent leaks across to_thread calls
187 try:
188 from plain.postgres.connections import get_connection, has_connection
189
190 if has_connection():
191 get_connection().close()
192 except Exception:
193 pass
194
195 if return_value and len(return_value) > max_output:
196 return_value = (
197 return_value[:max_output]
198 + f"\n... truncated ({len(return_value)} bytes total)"
199 )
200
201 return {
202 "return_value": return_value,
203 "error": error,
204 }
205
206 async def handle_file_pull(remote_path: str, req_id: int | None) -> None:
207 """Read a file from disk and send it in chunks."""
208 try:
209 file_size = os.path.getsize(remote_path)
210 if file_size > MAX_FILE_SIZE:
211 await _send_error(
212 ws,
213 encryptor,
214 req_id,
215 f"File too large: {file_size} bytes (max {MAX_FILE_SIZE})",
216 )
217 return
218
219 name = os.path.basename(remote_path)
220 chunks = chunk_count(file_size)
221
222 _log(f" sending {name} ({file_size} bytes, {chunks} chunks)")
223
224 with open(remote_path, "rb") as f:
225 for i in range(chunks):
226 data = f.read(FILE_CHUNK_SIZE)
227 msg = make_file_data(name=name, chunk=i, chunks=chunks, data=data)
228 msg["_req_id"] = req_id
229 await ws.send(encryptor.encrypt_message(msg))
230
231 except FileNotFoundError:
232 await _send_error(ws, encryptor, req_id, f"File not found: {remote_path}")
233 except (PermissionError, IsADirectoryError, OSError) as e:
234 await _send_error(ws, encryptor, req_id, f"{type(e).__name__}: {e}")
235
236 async def handle_file_push(msg: dict) -> None:
237 """Receive a file chunk and write it to disk."""
238 req_id = msg.get("_req_id")
239 remote_path = msg["remote_path"]
240 chunk_idx = msg["chunk"]
241 chunks = msg["chunks"]
242 data = base64.b64decode(msg["data"])
243
244 resolved = os.path.realpath(remote_path)
245 if not resolved.startswith(tmp_prefix + "/"):
246 await _send_error(
247 ws,
248 encryptor,
249 req_id,
250 f"Push restricted to /tmp/. Got: {remote_path} (resolved: {resolved})",
251 )
252 return
253
254 if chunk_idx == 0:
255 _log(f"push: {remote_path} ({chunks} chunks)")
256
257 try:
258 mode = "wb" if chunk_idx == 0 else "ab"
259 with open(remote_path, mode) as f:
260 f.write(data)
261 except OSError as e:
262 await _send_error(ws, encryptor, req_id, f"{type(e).__name__}: {e}")
263 return
264
265 # Ack every chunk so the sender doesn't block waiting
266 if chunk_idx == chunks - 1:
267 total_bytes = os.path.getsize(remote_path)
268 _log(f" received {total_bytes} bytes")
269 result = make_file_push_result(path=remote_path, total_bytes=total_bytes)
270 else:
271 result = {"type": "file_push_ack", "chunk": chunk_idx}
272 result["_req_id"] = req_id
273 await ws.send(encryptor.encrypt_message(result))
274
275 async with ws_connect(relay_url) as ws:
276 encryptor = await perform_key_exchange(ws, code, side="start")
277 _log("Connected from remote client.")
278
279 last_activity = asyncio.get_running_loop().time()
280
281 async def check_timeout() -> None:
282 nonlocal last_activity
283 if timeout_minutes <= 0:
284 return
285 while True:
286 await asyncio.sleep(60)
287 idle = asyncio.get_running_loop().time() - last_activity
288 remaining = (timeout_minutes * 60) - idle
289 if remaining <= 60 and remaining > 0:
290 print(
291 f"\nWarning: session will timeout in {int(remaining)} seconds due to inactivity.",
292 flush=True,
293 )
294 if idle >= timeout_minutes * 60:
295 print(
296 "\nSession timed out due to inactivity.",
297 flush=True,
298 )
299 await ws.close()
300 return
301
302 timeout_task = asyncio.create_task(check_timeout())
303
304 async def send_keepalive_pings() -> None:
305 while True:
306 await asyncio.sleep(30)
307 await ws.send(encryptor.encrypt_message(make_ping()))
308
309 keepalive_task = asyncio.create_task(send_keepalive_pings())
310
311 try:
312 async for raw in ws:
313 last_activity = asyncio.get_running_loop().time()
314
315 if isinstance(raw, str):
316 continue
317
318 msg = encryptor.decrypt_message(raw)
319 msg_type = msg.get("type")
320
321 if msg_type == "ping":
322 await ws.send(encryptor.encrypt_message(make_pong()))
323
324 elif msg_type == "pong":
325 pass
326
327 elif msg_type == "exec":
328 req_id = msg.get("_req_id")
329 code_str = msg["code"]
330 json_output = msg.get("json_output", False)
331 exec_timeout = msg.get("timeout", DEFAULT_EXEC_TIMEOUT)
332 _log(
333 f"exec: {code_str[:200]}{'...' if len(code_str) > 200 else ''}"
334 )
335 # Create a writer that streams stdout through the tunnel
336 tunnel_writer = _TunnelWriter(
337 asyncio.get_running_loop(), ws, encryptor, req_id
338 )
339 try:
340 result = await asyncio.wait_for(
341 asyncio.to_thread(
342 execute_code,
343 code_str,
344 json_output=json_output,
345 output_writer=tunnel_writer,
346 ),
347 timeout=exec_timeout,
348 )
349 except TimeoutError:
350 result = {
351 "return_value": None,
352 "error": f"Execution timed out ({exec_timeout} seconds). The code may still be running in the background.",
353 }
354 return_value = result.get("return_value")
355 error = result.get("error")
356 display = return_value or error or ""
357 if display:
358 _log(
359 f" โ {display[:200]}{'...' if len(display) > 200 else ''}"
360 )
361 # Send final result โ stdout was already streamed
362 response = make_exec_result(
363 return_value=return_value,
364 error=error,
365 )
366 response["_req_id"] = req_id
367 await ws.send(encryptor.encrypt_message(response))
368
369 elif msg_type == "file_pull":
370 req_id = msg.get("_req_id")
371 remote_path = msg["remote_path"]
372 _log(f"pull: {remote_path}")
373 await handle_file_pull(remote_path, req_id)
374
375 elif msg_type == "file_push":
376 await handle_file_push(msg)
377
378 else:
379 _log(f"Unknown message type: {msg_type}")
380
381 finally:
382 timeout_task.cancel()
383 keepalive_task.cancel()
384
385 _log("Client disconnected.")