1"""Local side of a portal session.
2
3Runs on the developer's machine. `connect` establishes the encrypted
4tunnel through the relay and starts a background process that listens
5on a Unix socket. Subsequent commands (exec, pull, push) talk to
6the background process over the socket.
7"""
8
9from __future__ import annotations
10
11import asyncio
12import json
13import os
14import signal
15import struct
16import sys
17import tempfile
18
19import websockets.exceptions
20from websockets.asyncio.client import connect as ws_connect
21
22from .codegen import validate_code
23from .crypto import channel_id, perform_key_exchange
24from .protocol import (
25 DEFAULT_RELAY_HOST,
26 make_ping,
27 make_relay_url,
28)
29
30SOCKET_PATH = os.path.join(tempfile.gettempdir(), "plain-portal.sock")
31PID_PATH = os.path.join(tempfile.gettempdir(), "plain-portal.pid")
32
33
34async def _send_framed(writer: asyncio.StreamWriter, data: bytes) -> None:
35 """Write a length-prefixed message to a stream."""
36 writer.write(struct.pack("!I", len(data)))
37 writer.write(data)
38 await writer.drain()
39
40
41# 75MB — large enough for 50MB files base64-encoded (~67MB), prevents unbounded allocation
42_MAX_FRAME_SIZE = 75 * 1024 * 1024
43
44
45async def _recv_framed(reader: asyncio.StreamReader) -> bytes:
46 """Read a length-prefixed message from a stream."""
47 length_bytes = await reader.readexactly(4)
48 length = struct.unpack("!I", length_bytes)[0]
49 if length > _MAX_FRAME_SIZE:
50 raise ValueError(f"Frame too large: {length} bytes (max {_MAX_FRAME_SIZE})")
51 return await reader.readexactly(length)
52
53
54async def connect(
55 code: str,
56 *,
57 relay_host: str = DEFAULT_RELAY_HOST,
58 foreground: bool = False,
59) -> None:
60 """Connect to a remote portal session and start the local daemon."""
61
62 if not validate_code(code):
63 print(f"Invalid portal code: {code}", file=sys.stderr)
64 sys.exit(1)
65
66 if os.path.exists(SOCKET_PATH):
67 print(
68 "A portal session is already active. Run 'plain portal disconnect' first.",
69 file=sys.stderr,
70 )
71 sys.exit(1)
72
73 cid = channel_id(code)
74 relay_url = make_relay_url(relay_host, cid, "connect")
75
76 try:
77 ws = await ws_connect(relay_url)
78 except Exception as e:
79 print(f"Failed to connect to relay: {e}", file=sys.stderr)
80 sys.exit(1)
81
82 encryptor = await perform_key_exchange(ws, code, side="connect")
83
84 print("Connected to remote. Session active.")
85
86 if not foreground:
87 pid = os.fork()
88 if pid > 0:
89 with open(PID_PATH, "w") as f:
90 f.write(str(pid))
91 return
92 os.setsid()
93
94 if foreground:
95 with open(PID_PATH, "w") as f:
96 f.write(str(os.getpid()))
97
98 # --- daemon logic (runs in child process or foreground) ---
99
100 try:
101 os.unlink(SOCKET_PATH)
102 except FileNotFoundError:
103 pass
104
105 # Exec requests use queues (for streaming exec_stdout + exec_result).
106 # All other request types use single-shot futures.
107 pending_responses: dict[int, asyncio.Future] = {}
108 pending_queues: dict[int, asyncio.Queue] = {}
109 file_data_accumulators: dict[int, dict] = {}
110 request_counter = 0
111
112 async def handle_local_client(
113 reader: asyncio.StreamReader, writer: asyncio.StreamWriter
114 ) -> None:
115 """Handle a command from a local CLI invocation (exec/pull/push)."""
116 nonlocal request_counter
117 req_id = None
118 is_exec = False
119
120 try:
121 data = await _recv_framed(reader)
122 request = json.loads(data.decode("utf-8"))
123
124 request_counter += 1
125 req_id = request_counter
126 request["_req_id"] = req_id
127 is_exec = request.get("type") == "exec"
128
129 if is_exec:
130 # Exec uses a queue so we can stream exec_stdout messages
131 queue: asyncio.Queue = asyncio.Queue()
132 pending_queues[req_id] = queue
133 await ws.send(encryptor.encrypt_message(request))
134
135 # Read from the queue until we get the final exec_result
136 exec_timeout = request.get("timeout", 120) + 30 # extra margin
137 while True:
138 msg = await asyncio.wait_for(queue.get(), timeout=exec_timeout)
139 await _send_framed(writer, json.dumps(msg).encode("utf-8"))
140 if msg.get("type") != "exec_stdout":
141 break
142 else:
143 # Non-exec: single request/response via future
144 future: asyncio.Future = asyncio.get_running_loop().create_future()
145 pending_responses[req_id] = future
146 await ws.send(encryptor.encrypt_message(request))
147 response = await asyncio.wait_for(future, timeout=300)
148 await _send_framed(writer, json.dumps(response).encode("utf-8"))
149
150 except TimeoutError:
151 await _send_framed(
152 writer,
153 json.dumps({"error": "Request timed out"}).encode("utf-8"),
154 )
155 except Exception as e:
156 await _send_framed(writer, json.dumps({"error": str(e)}).encode("utf-8"))
157 finally:
158 if req_id is not None:
159 pending_responses.pop(req_id, None)
160 pending_queues.pop(req_id, None)
161 file_data_accumulators.pop(req_id, None)
162 writer.close()
163 await writer.wait_closed()
164
165 async def relay_listener() -> None:
166 """Listen for messages from the remote side via WebSocket."""
167 try:
168 async for raw in ws:
169 if isinstance(raw, str):
170 continue
171
172 msg = encryptor.decrypt_message(raw)
173 msg_type = msg.get("type")
174
175 if msg_type == "ping":
176 await ws.send(encryptor.encrypt_message({"type": "pong"}))
177 continue
178
179 if msg_type == "pong":
180 continue
181
182 req_id = msg.pop("_req_id", None)
183 if not req_id:
184 continue
185
186 # Streaming exec messages go through the queue
187 if msg_type in ("exec_stdout", "exec_result"):
188 if req_id in pending_queues:
189 await pending_queues[req_id].put(msg)
190 continue
191
192 # File data accumulation (multiple chunks → single response)
193 if msg_type == "file_data":
194 if req_id not in pending_responses:
195 continue
196 if req_id not in file_data_accumulators:
197 file_data_accumulators[req_id] = {
198 "name": msg["name"],
199 "chunks": msg["chunks"],
200 "received": {},
201 }
202 acc = file_data_accumulators[req_id]
203 acc["received"][msg["chunk"]] = msg["data"]
204 if len(acc["received"]) == acc["chunks"]:
205 all_data = "".join(
206 acc["received"][i] for i in range(acc["chunks"])
207 )
208 del file_data_accumulators[req_id]
209 pending_responses[req_id].set_result(
210 {
211 "type": "file_data",
212 "name": acc["name"],
213 "data": all_data,
214 }
215 )
216 continue
217
218 # Everything else resolves the future directly
219 if req_id in pending_responses:
220 pending_responses[req_id].set_result(msg)
221
222 except websockets.exceptions.ConnectionClosed:
223 pass
224 finally:
225 for future in pending_responses.values():
226 if not future.done():
227 future.set_result({"error": "Remote disconnected"})
228 for queue in pending_queues.values():
229 await queue.put({"type": "error", "error": "Remote disconnected"})
230 _cleanup()
231
232 # Set restrictive umask so the socket is created owner-only (no TOCTOU window)
233 old_umask = os.umask(0o177)
234 try:
235 server = await asyncio.start_unix_server(handle_local_client, path=SOCKET_PATH)
236 finally:
237 os.umask(old_umask)
238
239 loop = asyncio.get_running_loop()
240
241 def _handle_signal() -> None:
242 _cleanup()
243 loop.stop()
244
245 loop.add_signal_handler(signal.SIGTERM, _handle_signal)
246 loop.add_signal_handler(signal.SIGINT, _handle_signal)
247
248 async def send_keepalive_pings() -> None:
249 while True:
250 await asyncio.sleep(30)
251 await ws.send(encryptor.encrypt_message(make_ping()))
252
253 keepalive_task = asyncio.create_task(send_keepalive_pings())
254
255 try:
256 await relay_listener()
257 finally:
258 keepalive_task.cancel()
259 server.close()
260 await server.wait_closed()
261 _cleanup()
262
263
264def _cleanup() -> None:
265 """Remove socket and PID files."""
266 for path in (SOCKET_PATH, PID_PATH):
267 try:
268 os.unlink(path)
269 except FileNotFoundError:
270 pass
271
272
273async def send_command(request: dict) -> dict:
274 """Send a command to the background daemon via Unix socket.
275
276 Returns a single response. For streaming exec, use send_exec_streaming instead.
277 """
278 try:
279 reader, writer = await asyncio.open_unix_connection(SOCKET_PATH)
280 except (FileNotFoundError, ConnectionRefusedError):
281 print(
282 "No active portal session. Run 'plain portal connect <code>' first.",
283 file=sys.stderr,
284 )
285 sys.exit(1)
286
287 try:
288 await _send_framed(writer, json.dumps(request).encode("utf-8"))
289 response_data = await _recv_framed(reader)
290 return json.loads(response_data.decode("utf-8"))
291 finally:
292 writer.close()
293 await writer.wait_closed()
294
295
296async def send_exec_streaming(
297 request: dict,
298 on_stdout: callable, # type: ignore[type-arg]
299) -> dict:
300 """Send an exec request and stream stdout chunks as they arrive.
301
302 Calls on_stdout(data) for each exec_stdout chunk.
303 Returns the final exec_result response.
304 """
305 try:
306 reader, writer = await asyncio.open_unix_connection(SOCKET_PATH)
307 except (FileNotFoundError, ConnectionRefusedError):
308 print(
309 "No active portal session. Run 'plain portal connect <code>' first.",
310 file=sys.stderr,
311 )
312 sys.exit(1)
313
314 try:
315 await _send_framed(writer, json.dumps(request).encode("utf-8"))
316 while True:
317 response_data = await _recv_framed(reader)
318 msg = json.loads(response_data.decode("utf-8"))
319 if msg.get("type") == "exec_stdout":
320 on_stdout(msg["data"])
321 else:
322 return msg
323 finally:
324 writer.close()
325 await writer.wait_closed()
326
327
328def disconnect() -> None:
329 """Kill the background daemon and clean up."""
330 if os.path.exists(PID_PATH):
331 try:
332 with open(PID_PATH) as f:
333 pid = int(f.read().strip())
334 os.kill(pid, signal.SIGTERM)
335 print("Portal session disconnected.")
336 except (ProcessLookupError, ValueError):
337 print("Portal daemon not running (stale PID file).")
338 _cleanup()
339 elif os.path.exists(SOCKET_PATH):
340 _cleanup()
341 print("Cleaned up stale socket.")
342 else:
343 print("No active portal session.")
344
345
346def status() -> None:
347 """Show portal session status."""
348 if os.path.exists(PID_PATH):
349 try:
350 with open(PID_PATH) as f:
351 pid = int(f.read().strip())
352 os.kill(pid, 0)
353 print(f"Portal session active (PID {pid})")
354 print(f"Socket: {SOCKET_PATH}")
355 except (ProcessLookupError, ValueError):
356 print("Portal daemon not running (stale PID file).")
357 _cleanup()
358 else:
359 print("No active portal session.")