1"""Local side of a portal session.
  2
  3Runs on the developer's machine. `connect` establishes the encrypted
  4tunnel through the relay and starts a background process that listens
  5on a Unix socket. Subsequent commands (exec, pull, push) talk to
  6the background process over the socket.
  7"""
  8
  9from __future__ import annotations
 10
 11import asyncio
 12import json
 13import os
 14import signal
 15import struct
 16import sys
 17import tempfile
 18
 19import websockets.exceptions
 20from websockets.asyncio.client import connect as ws_connect
 21
 22from .codegen import validate_code
 23from .crypto import channel_id, perform_key_exchange
 24from .protocol import (
 25    DEFAULT_RELAY_HOST,
 26    make_ping,
 27    make_relay_url,
 28)
 29
 30SOCKET_PATH = os.path.join(tempfile.gettempdir(), "plain-portal.sock")
 31PID_PATH = os.path.join(tempfile.gettempdir(), "plain-portal.pid")
 32
 33
 34async def _send_framed(writer: asyncio.StreamWriter, data: bytes) -> None:
 35    """Write a length-prefixed message to a stream."""
 36    writer.write(struct.pack("!I", len(data)))
 37    writer.write(data)
 38    await writer.drain()
 39
 40
 41# 75MB — large enough for 50MB files base64-encoded (~67MB), prevents unbounded allocation
 42_MAX_FRAME_SIZE = 75 * 1024 * 1024
 43
 44
 45async def _recv_framed(reader: asyncio.StreamReader) -> bytes:
 46    """Read a length-prefixed message from a stream."""
 47    length_bytes = await reader.readexactly(4)
 48    length = struct.unpack("!I", length_bytes)[0]
 49    if length > _MAX_FRAME_SIZE:
 50        raise ValueError(f"Frame too large: {length} bytes (max {_MAX_FRAME_SIZE})")
 51    return await reader.readexactly(length)
 52
 53
 54async def connect(
 55    code: str,
 56    *,
 57    relay_host: str = DEFAULT_RELAY_HOST,
 58    foreground: bool = False,
 59) -> None:
 60    """Connect to a remote portal session and start the local daemon."""
 61
 62    if not validate_code(code):
 63        print(f"Invalid portal code: {code}", file=sys.stderr)
 64        sys.exit(1)
 65
 66    if os.path.exists(SOCKET_PATH):
 67        print(
 68            "A portal session is already active. Run 'plain portal disconnect' first.",
 69            file=sys.stderr,
 70        )
 71        sys.exit(1)
 72
 73    cid = channel_id(code)
 74    relay_url = make_relay_url(relay_host, cid, "connect")
 75
 76    try:
 77        ws = await ws_connect(relay_url)
 78    except Exception as e:
 79        print(f"Failed to connect to relay: {e}", file=sys.stderr)
 80        sys.exit(1)
 81
 82    encryptor = await perform_key_exchange(ws, code, side="connect")
 83
 84    print("Connected to remote. Session active.")
 85
 86    if not foreground:
 87        pid = os.fork()
 88        if pid > 0:
 89            with open(PID_PATH, "w") as f:
 90                f.write(str(pid))
 91            return
 92        os.setsid()
 93
 94    if foreground:
 95        with open(PID_PATH, "w") as f:
 96            f.write(str(os.getpid()))
 97
 98    # --- daemon logic (runs in child process or foreground) ---
 99
100    try:
101        os.unlink(SOCKET_PATH)
102    except FileNotFoundError:
103        pass
104
105    # Exec requests use queues (for streaming exec_stdout + exec_result).
106    # All other request types use single-shot futures.
107    pending_responses: dict[int, asyncio.Future] = {}
108    pending_queues: dict[int, asyncio.Queue] = {}
109    file_data_accumulators: dict[int, dict] = {}
110    request_counter = 0
111
112    async def handle_local_client(
113        reader: asyncio.StreamReader, writer: asyncio.StreamWriter
114    ) -> None:
115        """Handle a command from a local CLI invocation (exec/pull/push)."""
116        nonlocal request_counter
117        req_id = None
118        is_exec = False
119
120        try:
121            data = await _recv_framed(reader)
122            request = json.loads(data.decode("utf-8"))
123
124            request_counter += 1
125            req_id = request_counter
126            request["_req_id"] = req_id
127            is_exec = request.get("type") == "exec"
128
129            if is_exec:
130                # Exec uses a queue so we can stream exec_stdout messages
131                queue: asyncio.Queue = asyncio.Queue()
132                pending_queues[req_id] = queue
133                await ws.send(encryptor.encrypt_message(request))
134
135                # Read from the queue until we get the final exec_result
136                exec_timeout = request.get("timeout", 120) + 30  # extra margin
137                while True:
138                    msg = await asyncio.wait_for(queue.get(), timeout=exec_timeout)
139                    await _send_framed(writer, json.dumps(msg).encode("utf-8"))
140                    if msg.get("type") != "exec_stdout":
141                        break
142            else:
143                # Non-exec: single request/response via future
144                future: asyncio.Future = asyncio.get_running_loop().create_future()
145                pending_responses[req_id] = future
146                await ws.send(encryptor.encrypt_message(request))
147                response = await asyncio.wait_for(future, timeout=300)
148                await _send_framed(writer, json.dumps(response).encode("utf-8"))
149
150        except TimeoutError:
151            await _send_framed(
152                writer,
153                json.dumps({"error": "Request timed out"}).encode("utf-8"),
154            )
155        except Exception as e:
156            await _send_framed(writer, json.dumps({"error": str(e)}).encode("utf-8"))
157        finally:
158            if req_id is not None:
159                pending_responses.pop(req_id, None)
160                pending_queues.pop(req_id, None)
161                file_data_accumulators.pop(req_id, None)
162            writer.close()
163            await writer.wait_closed()
164
165    async def relay_listener() -> None:
166        """Listen for messages from the remote side via WebSocket."""
167        try:
168            async for raw in ws:
169                if isinstance(raw, str):
170                    continue
171
172                msg = encryptor.decrypt_message(raw)
173                msg_type = msg.get("type")
174
175                if msg_type == "ping":
176                    await ws.send(encryptor.encrypt_message({"type": "pong"}))
177                    continue
178
179                if msg_type == "pong":
180                    continue
181
182                req_id = msg.pop("_req_id", None)
183                if not req_id:
184                    continue
185
186                # Streaming exec messages go through the queue
187                if msg_type in ("exec_stdout", "exec_result"):
188                    if req_id in pending_queues:
189                        await pending_queues[req_id].put(msg)
190                    continue
191
192                # File data accumulation (multiple chunks → single response)
193                if msg_type == "file_data":
194                    if req_id not in pending_responses:
195                        continue
196                    if req_id not in file_data_accumulators:
197                        file_data_accumulators[req_id] = {
198                            "name": msg["name"],
199                            "chunks": msg["chunks"],
200                            "received": {},
201                        }
202                    acc = file_data_accumulators[req_id]
203                    acc["received"][msg["chunk"]] = msg["data"]
204                    if len(acc["received"]) == acc["chunks"]:
205                        all_data = "".join(
206                            acc["received"][i] for i in range(acc["chunks"])
207                        )
208                        del file_data_accumulators[req_id]
209                        pending_responses[req_id].set_result(
210                            {
211                                "type": "file_data",
212                                "name": acc["name"],
213                                "data": all_data,
214                            }
215                        )
216                    continue
217
218                # Everything else resolves the future directly
219                if req_id in pending_responses:
220                    pending_responses[req_id].set_result(msg)
221
222        except websockets.exceptions.ConnectionClosed:
223            pass
224        finally:
225            for future in pending_responses.values():
226                if not future.done():
227                    future.set_result({"error": "Remote disconnected"})
228            for queue in pending_queues.values():
229                await queue.put({"type": "error", "error": "Remote disconnected"})
230            _cleanup()
231
232    # Set restrictive umask so the socket is created owner-only (no TOCTOU window)
233    old_umask = os.umask(0o177)
234    try:
235        server = await asyncio.start_unix_server(handle_local_client, path=SOCKET_PATH)
236    finally:
237        os.umask(old_umask)
238
239    loop = asyncio.get_running_loop()
240
241    def _handle_signal() -> None:
242        _cleanup()
243        loop.stop()
244
245    loop.add_signal_handler(signal.SIGTERM, _handle_signal)
246    loop.add_signal_handler(signal.SIGINT, _handle_signal)
247
248    async def send_keepalive_pings() -> None:
249        while True:
250            await asyncio.sleep(30)
251            await ws.send(encryptor.encrypt_message(make_ping()))
252
253    keepalive_task = asyncio.create_task(send_keepalive_pings())
254
255    try:
256        await relay_listener()
257    finally:
258        keepalive_task.cancel()
259        server.close()
260        await server.wait_closed()
261        _cleanup()
262
263
264def _cleanup() -> None:
265    """Remove socket and PID files."""
266    for path in (SOCKET_PATH, PID_PATH):
267        try:
268            os.unlink(path)
269        except FileNotFoundError:
270            pass
271
272
273async def send_command(request: dict) -> dict:
274    """Send a command to the background daemon via Unix socket.
275
276    Returns a single response. For streaming exec, use send_exec_streaming instead.
277    """
278    try:
279        reader, writer = await asyncio.open_unix_connection(SOCKET_PATH)
280    except (FileNotFoundError, ConnectionRefusedError):
281        print(
282            "No active portal session. Run 'plain portal connect <code>' first.",
283            file=sys.stderr,
284        )
285        sys.exit(1)
286
287    try:
288        await _send_framed(writer, json.dumps(request).encode("utf-8"))
289        response_data = await _recv_framed(reader)
290        return json.loads(response_data.decode("utf-8"))
291    finally:
292        writer.close()
293        await writer.wait_closed()
294
295
296async def send_exec_streaming(
297    request: dict,
298    on_stdout: callable,  # type: ignore[type-arg]
299) -> dict:
300    """Send an exec request and stream stdout chunks as they arrive.
301
302    Calls on_stdout(data) for each exec_stdout chunk.
303    Returns the final exec_result response.
304    """
305    try:
306        reader, writer = await asyncio.open_unix_connection(SOCKET_PATH)
307    except (FileNotFoundError, ConnectionRefusedError):
308        print(
309            "No active portal session. Run 'plain portal connect <code>' first.",
310            file=sys.stderr,
311        )
312        sys.exit(1)
313
314    try:
315        await _send_framed(writer, json.dumps(request).encode("utf-8"))
316        while True:
317            response_data = await _recv_framed(reader)
318            msg = json.loads(response_data.decode("utf-8"))
319            if msg.get("type") == "exec_stdout":
320                on_stdout(msg["data"])
321            else:
322                return msg
323    finally:
324        writer.close()
325        await writer.wait_closed()
326
327
328def disconnect() -> None:
329    """Kill the background daemon and clean up."""
330    if os.path.exists(PID_PATH):
331        try:
332            with open(PID_PATH) as f:
333                pid = int(f.read().strip())
334            os.kill(pid, signal.SIGTERM)
335            print("Portal session disconnected.")
336        except (ProcessLookupError, ValueError):
337            print("Portal daemon not running (stale PID file).")
338        _cleanup()
339    elif os.path.exists(SOCKET_PATH):
340        _cleanup()
341        print("Cleaned up stale socket.")
342    else:
343        print("No active portal session.")
344
345
346def status() -> None:
347    """Show portal session status."""
348    if os.path.exists(PID_PATH):
349        try:
350            with open(PID_PATH) as f:
351                pid = int(f.read().strip())
352            os.kill(pid, 0)
353            print(f"Portal session active (PID {pid})")
354            print(f"Socket: {SOCKET_PATH}")
355        except (ProcessLookupError, ValueError):
356            print("Portal daemon not running (stale PID file).")
357            _cleanup()
358    else:
359        print("No active portal session.")