v0.150.0
  1import os
  2import platform
  3import socket
  4import subprocess
  5import sys
  6import time
  7import tomllib
  8from importlib.metadata import entry_points
  9from importlib.util import find_spec
 10from pathlib import Path
 11
 12import click
 13from rich.columns import Columns
 14from rich.console import Console
 15from rich.text import Text
 16
 17from plain.cli.print import print_event
 18from plain.runtime import APP_PATH, PLAIN_TEMP_PATH
 19
 20from .backups.core import DatabaseBackups
 21from .mkcert import MkcertManager
 22from .process import Supervisor
 23from .utils import has_pyproject_toml
 24
 25ENTRYPOINT_GROUP = "plain.dev"
 26
 27
 28class DevSupervisor(Supervisor):
 29    pidfile = PLAIN_TEMP_PATH / "dev" / "dev.pid"
 30    log_dir = PLAIN_TEMP_PATH / "dev" / "logs" / "run"
 31    background_command = ["dev"]
 32    display_name = "`plain dev`"
 33
 34    def setup(
 35        self, *, port: int | None, hostname: str | None, log_level: str | None
 36    ) -> None:
 37        if not hostname:
 38            project_name = os.path.basename(
 39                os.getcwd()
 40            )  # Use directory name by default
 41
 42            if has_pyproject_toml(APP_PATH.parent):
 43                with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
 44                    pyproject = tomllib.load(f)
 45                    project_name = pyproject.get("project", {}).get(
 46                        "name", project_name
 47                    )
 48
 49            hostname = f"{project_name.lower()}.localhost"
 50
 51        self.hostname = hostname
 52        self.log_level = log_level
 53
 54        self.prepare_log()
 55
 56        if port:
 57            self.port = int(port)
 58            if not self._port_available(self.port):
 59                click.secho(f"Port {self.port} in use", fg="red")
 60                raise SystemExit(1)
 61        else:
 62            self.port = self._find_open_port(8443)
 63            if self.port != 8443:
 64                click.secho(f"Port 8443 in use, using {self.port}", fg="yellow")
 65
 66        self.ssl_key_path = None
 67        self.ssl_cert_path = None
 68
 69        self.url = f"https://{self.hostname}:{self.port}"
 70        self.tunnel_url = os.environ.get("DEV_TUNNEL_URL", "")
 71
 72        self.plain_env = {
 73            "PYTHONUNBUFFERED": "true",
 74            "PLAIN_SERVER_ACCESS_LOG_FIELDS": '["method", "url", "status", "duration_ms", "size"]',
 75            "FORCE_COLOR": "1",
 76            "PYTHONWARNINGS": "default::DeprecationWarning,default::PendingDeprecationWarning",
 77            **os.environ,
 78        }
 79
 80        if log_level:
 81            self.plain_env["PLAIN_FRAMEWORK_LOG_LEVEL"] = log_level.upper()
 82            self.plain_env["PLAIN_LOG_LEVEL"] = log_level.upper()
 83
 84        self.custom_process_env = {
 85            **self.plain_env,
 86            "PORT": str(self.port),
 87            "DEV_URL": self.url,
 88        }
 89
 90        if self.tunnel_url:
 91            status_bar = Columns(
 92                [
 93                    Text.from_markup(
 94                        f"[bold]Tunnel[/bold] [underline][link={self.tunnel_url}]{self.tunnel_url}[/link][/underline]"
 95                    ),
 96                    Text.from_markup(
 97                        f"[dim][bold]Server[/bold] [link={self.url}]{self.url}[/link][/dim]"
 98                    ),
 99                    Text.from_markup(
100                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]",
101                        justify="right",
102                    ),
103                ],
104                expand=True,
105            )
106        else:
107            status_bar = Columns(
108                [
109                    Text.from_markup(
110                        f"[bold]Server[/bold] [underline][link={self.url}]{self.url}[/link][/underline]"
111                    ),
112                    Text.from_markup(
113                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]", justify="right"
114                    ),
115                ],
116                expand=True,
117            )
118        self.console = Console(markup=False, highlight=False)
119        self.console_status = self.console.status(status_bar)
120
121        self.init_poncho(self.console.out)
122
123    def _find_open_port(self, start_port: int) -> int:
124        port = start_port
125        while not self._port_available(port):
126            port += 1
127        return port
128
129    def _port_available(self, port: int) -> bool:
130        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
131            sock.settimeout(0.5)
132            result = sock.connect_ex(("127.0.0.1", port))
133        return result != 0
134
135    def run(self, *, reinstall_ssl: bool = False) -> int:
136        if not self.acquire():
137            click.secho(self.already_running_message(self.read_pidfile()), fg="yellow")
138            return 1
139
140        mkcert_manager = MkcertManager()
141        mkcert_manager.setup_mkcert(
142            install_path=Path.home() / ".plain" / "dev",
143            force_reinstall=reinstall_ssl,
144        )
145        self.ssl_cert_path, self.ssl_key_path = mkcert_manager.generate_certs(
146            domain=self.hostname,
147            storage_path=Path(PLAIN_TEMP_PATH) / "dev" / "certs",
148            force_regenerate=reinstall_ssl,
149        )
150
151        self.symlink_plain_src()
152        self.install_agent()
153        self.modify_hosts_file()
154
155        print_event("Running preflight checks...", newline=False)
156        self.run_preflight()
157
158        if find_spec("plain.postgres"):
159            print_event("Waiting for database...", newline=False)
160            subprocess.run(
161                [sys.executable, "-m", "plain", "postgres", "wait"],
162                env=self.plain_env,
163                check=True,
164            )
165
166            # Backup before syncing if sync would make changes
167            check_result = subprocess.run(
168                [sys.executable, "-m", "plain", "postgres", "sync", "--check"],
169                env=self.plain_env,
170                capture_output=True,
171            )
172
173            if check_result.returncode != 0:
174                backup_name = time.strftime("%Y%m%d_%H%M%S")
175                print_event(f"Backing up database ({backup_name})...", newline=False)
176                try:
177                    DatabaseBackups().create(
178                        backup_name,
179                        source="dev",
180                        pg_dump=os.environ.get("PG_DUMP", "pg_dump"),
181                    )
182                    click.secho("✔", fg="green")
183                except Exception:
184                    click.secho("skipped", dim=True)
185
186            print_event("Syncing database...")
187            subprocess.run(
188                [sys.executable, "-m", "plain", "postgres", "sync"],
189                env=self.plain_env,
190                check=True,
191            )
192
193        print_event("Starting app...")
194
195        # Manually start the status bar now so it isn't bungled by
196        # another thread checking db stuff...
197        self.console_status.start()
198
199        assert self.poncho is not None, "poncho should be initialized"
200
201        self.add_server()
202        self.add_entrypoints()
203        self.add_pyproject_run()
204
205        try:
206            # Start processes we know about and block the main thread
207            self.poncho.loop()
208
209            # Remove the status bar
210            self.console_status.stop()
211        finally:
212            self.release()
213            self.close()
214
215        assert self.poncho.returncode is not None, "returncode should be set after loop"
216        return self.poncho.returncode
217
218    def symlink_plain_src(self) -> None:
219        """Symlink the plain package into .plain so we can look at it easily"""
220        spec = find_spec("plain.runtime")
221        if spec is None or spec.origin is None:
222            return None
223        plain_path = Path(spec.origin).parent.parent
224        if not PLAIN_TEMP_PATH.exists():
225            PLAIN_TEMP_PATH.mkdir()
226
227        symlink_path = PLAIN_TEMP_PATH / "src"
228
229        # The symlink is broken
230        if symlink_path.is_symlink() and not symlink_path.exists():
231            symlink_path.unlink()
232
233        # The symlink exists but points to the wrong place
234        if (
235            symlink_path.is_symlink()
236            and symlink_path.exists()
237            and symlink_path.resolve() != plain_path
238        ):
239            symlink_path.unlink()
240
241        if plain_path.exists() and not symlink_path.exists():
242            symlink_path.symlink_to(plain_path)
243
244    def install_agent(self) -> None:
245        """Install AI agent skills and hooks."""
246        try:
247            result = subprocess.run(
248                [sys.executable, "-m", "plain", "agent", "install"],
249                check=False,
250                capture_output=True,
251                text=True,
252            )
253            if result.returncode != 0 and result.stderr:
254                click.secho(
255                    f"Warning: Failed to install agent: {result.stderr}",
256                    fg="yellow",
257                    err=True,
258                )
259        except Exception as e:
260            click.secho(
261                f"Warning: Failed to install agent: {e}",
262                fg="yellow",
263                err=True,
264            )
265
266    def modify_hosts_file(self) -> None:
267        """Modify the hosts file to map the custom domain to 127.0.0.1."""
268        # Check if the hostname already resolves to loopback (e.g., *.localhost on modern OS)
269        try:
270            results = socket.getaddrinfo(self.hostname, None)
271            addrs = {r[4][0] for r in results}
272            if addrs <= {"127.0.0.1", "::1"}:
273                return
274        except socket.gaierror:
275            pass  # Doesn't resolve; fall through to modify hosts file
276
277        entry_identifier = "# Added by plain"
278        hosts_entry = f"127.0.0.1 {self.hostname}  {entry_identifier}"
279
280        if platform.system() == "Windows":
281            hosts_path = Path(r"C:\Windows\System32\drivers\etc\hosts")
282            try:
283                with hosts_path.open("r") as f:
284                    content = f.read()
285
286                if hosts_entry in content:
287                    return  # Entry already exists; no action needed
288
289                # Entry does not exist; add it
290                with hosts_path.open("a") as f:
291                    f.write(f"{hosts_entry}\n")
292                click.secho(f"Added {self.hostname} to {hosts_path}", bold=True)
293            except PermissionError:
294                click.secho(
295                    "Permission denied while modifying hosts file. Please run the script as an administrator.",
296                    fg="red",
297                )
298                sys.exit(1)
299        else:
300            # For macOS and Linux
301            hosts_path = Path("/etc/hosts")
302            try:
303                with hosts_path.open("r") as f:
304                    content = f.read()
305
306                if hosts_entry in content:
307                    return  # Entry already exists; no action needed
308
309                # Entry does not exist; append it using sudo
310                click.secho(
311                    f"Adding {self.hostname} to /etc/hosts file. You may be prompted for your password.\n",
312                    bold=True,
313                )
314                cmd = f"echo '{hosts_entry}' | sudo tee -a {hosts_path} >/dev/null"
315                subprocess.run(cmd, shell=True, check=True)
316                click.secho(f"Added {self.hostname} to {hosts_path}\n", bold=True)
317            except PermissionError:
318                click.secho(
319                    "Permission denied while accessing hosts file.",
320                    fg="red",
321                )
322                sys.exit(1)
323            except subprocess.CalledProcessError:
324                click.secho(
325                    "Failed to modify hosts file. Please ensure you have sudo privileges.",
326                    fg="red",
327                )
328                sys.exit(1)
329
330    def run_preflight(self) -> None:
331        if subprocess.run(
332            ["plain", "preflight", "--quiet"], env=self.plain_env
333        ).returncode:
334            click.secho("Preflight check failed!", fg="red")
335            sys.exit(1)
336
337    def add_server(self) -> None:
338        """Add the Plain HTTP server process."""
339        assert self.poncho is not None
340        server_cmd = [
341            sys.executable,
342            "-m",
343            "plain",
344            "server",
345            "--bind",
346            f"{self.hostname}:{self.port}",
347            "--certfile",
348            str(self.ssl_cert_path),
349            "--keyfile",
350            str(self.ssl_key_path),
351            "--threads",
352            "4",
353            "--timeout",
354            "60",
355            "--workers",
356            "1",
357            "--reload",  # Enable auto-reload for development
358        ]
359
360        server = " ".join(server_cmd)
361        self.poncho.add_process("plain", server, env=self.plain_env)
362
363    def add_entrypoints(self) -> None:
364        assert self.poncho is not None
365        for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
366            self.poncho.add_process(
367                entry_point.name,
368                f"plain dev entrypoint {entry_point.name}",
369                env=self.plain_env,
370            )
371
372    def add_pyproject_run(self) -> None:
373        """Additional processes that only run during `plain dev`."""
374        assert self.poncho is not None
375        if not has_pyproject_toml(APP_PATH.parent):
376            return
377
378        with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
379            pyproject = tomllib.load(f)
380
381        run_commands = (
382            pyproject.get("tool", {}).get("plain", {}).get("dev", {}).get("run", {})
383        )
384        for name, data in run_commands.items():
385            env = {
386                **self.custom_process_env,
387                **data.get("env", {}),
388            }
389            self.poncho.add_process(name, data["cmd"], env=env)