Plain is headed towards 1.0! Subscribe for development updates →

  1import importlib
  2import json
  3import multiprocessing
  4import os
  5import platform
  6import signal
  7import socket
  8import subprocess
  9import sys
 10import time
 11import tomllib
 12from importlib.metadata import entry_points
 13from importlib.util import find_spec
 14from pathlib import Path
 15
 16import click
 17from rich.columns import Columns
 18from rich.console import Console
 19from rich.text import Text
 20
 21from plain.cli import register_cli
 22from plain.runtime import APP_PATH, PLAIN_TEMP_PATH
 23
 24from .dev_pid import DevPid
 25from .mkcert import MkcertManager
 26from .poncho.manager import Manager as PonchoManager
 27from .poncho.printer import Printer
 28from .services import Services, ServicesPid
 29from .utils import has_pyproject_toml
 30
 31ENTRYPOINT_GROUP = "plain.dev"
 32
 33
 34@register_cli("dev")
 35@click.group(invoke_without_command=True)
 36@click.pass_context
 37@click.option(
 38    "--port",
 39    "-p",
 40    default="",
 41    type=str,
 42    help=(
 43        "Port to run the web server on. If omitted, tries 8443 and "
 44        "picks the next free port"
 45    ),
 46)
 47@click.option(
 48    "--hostname",
 49    "-h",
 50    default=None,
 51    type=str,
 52    help="Hostname to run the web server on",
 53)
 54@click.option(
 55    "--log-level",
 56    "-l",
 57    default="",
 58    type=click.Choice(["debug", "info", "warning", "error", "critical", ""]),
 59    help="Log level",
 60)
 61def cli(ctx, port, hostname, log_level):
 62    """Start local development"""
 63
 64    if ctx.invoked_subcommand:
 65        return
 66
 67    if DevPid().exists():
 68        click.secho("`plain dev` already running", fg="yellow")
 69        sys.exit(1)
 70
 71    if not hostname:
 72        project_name = os.path.basename(
 73            os.getcwd()
 74        )  # Use the directory name by default
 75
 76        if has_pyproject_toml(APP_PATH.parent):
 77            with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
 78                pyproject = tomllib.load(f)
 79                project_name = pyproject.get("project", {}).get("name", project_name)
 80
 81        hostname = f"{project_name}.localhost"
 82
 83    returncode = Dev(port=port, hostname=hostname, log_level=log_level).run()
 84    if returncode:
 85        sys.exit(returncode)
 86
 87
 88@cli.command()
 89def debug():
 90    """Connect to the remote debugger"""
 91
 92    def _connect():
 93        if subprocess.run(["which", "nc"], capture_output=True).returncode == 0:
 94            return subprocess.run(["nc", "-C", "localhost", "4444"])
 95        else:
 96            raise OSError("nc not found")
 97
 98    result = _connect()
 99
100    # Try again once without a message
101    if result.returncode == 1:
102        time.sleep(1)
103        result = _connect()
104
105    # Keep trying...
106    while result.returncode == 1:
107        click.secho(
108            "Failed to connect. Make sure remote pdb is ready. Retrying...", fg="red"
109        )
110        result = _connect()
111        time.sleep(1)
112
113
114@cli.command()
115def services():
116    """Start additional services defined in pyproject.toml"""
117    _services = Services()
118    if _services.are_running():
119        click.secho("Services already running", fg="yellow")
120        return
121    _services.run()
122
123
124@cli.command()
125@click.option(
126    "--list", "-l", "show_list", is_flag=True, help="List available entrypoints"
127)
128@click.argument("entrypoint", required=False)
129def entrypoint(show_list, entrypoint):
130    """Entrypoints registered under plain.dev"""
131    if not show_list and not entrypoint:
132        click.secho("Please provide an entrypoint name or use --list", fg="red")
133        sys.exit(1)
134
135    for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
136        if show_list:
137            click.echo(entry_point.name)
138        elif entrypoint == entry_point.name:
139            entry_point.load()()
140
141
142class Dev:
143    def __init__(self, *, port, hostname, log_level):
144        self.hostname = hostname
145        self.log_level = log_level
146
147        self.pid = DevPid()
148
149        if port:
150            self.port = int(port)
151            if not self._port_available(self.port):
152                click.secho(f"Port {self.port} in use", fg="red")
153                raise SystemExit(1)
154        else:
155            self.port = self._find_open_port(8443)
156            if self.port != 8443:
157                click.secho(f"Port 8443 in use, using {self.port}", fg="yellow")
158
159        self.ssl_key_path = None
160        self.ssl_cert_path = None
161
162        self.url = f"https://{self.hostname}:{self.port}"
163        self.tunnel_url = os.environ.get("PLAIN_DEV_TUNNEL_URL", "")
164
165        self.plain_env = {
166            "PYTHONUNBUFFERED": "true",
167            "PLAIN_DEV": "true",
168            **os.environ,
169        }
170
171        if log_level:
172            self.plain_env["PLAIN_LOG_LEVEL"] = log_level.upper()
173            self.plain_env["APP_LOG_LEVEL"] = log_level.upper()
174
175        self.custom_process_env = {
176            **self.plain_env,
177            "PORT": str(self.port),
178            "PLAIN_DEV_URL": self.url,
179        }
180
181        if self.tunnel_url:
182            status_bar = Columns(
183                [
184                    Text.from_markup(
185                        f"[bold]Tunnel[/bold] [underline][link={self.tunnel_url}]{self.tunnel_url}[/link][/underline]"
186                    ),
187                    Text.from_markup(
188                        f"[dim][bold]Server[/bold] [link={self.url}]{self.url}[/link][/dim]"
189                    ),
190                    Text.from_markup(
191                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]",
192                        justify="right",
193                    ),
194                ],
195                expand=True,
196            )
197        else:
198            status_bar = Columns(
199                [
200                    Text.from_markup(
201                        f"[bold]Server[/bold] [underline][link={self.url}]{self.url}[/link][/underline]"
202                    ),
203                    Text.from_markup(
204                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]", justify="right"
205                    ),
206                ],
207                expand=True,
208            )
209        self.console = Console(markup=False, highlight=False)
210        self.console_status = self.console.status(status_bar)
211
212        self.poncho = PonchoManager(printer=Printer(lambda s: self.console.out(s)))
213
214    def _find_open_port(self, start_port):
215        port = start_port
216        while not self._port_available(port):
217            port += 1
218        return port
219
220    def _port_available(self, port):
221        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
222            sock.settimeout(0.5)
223            result = sock.connect_ex(("127.0.0.1", port))
224        return result != 0
225
226    def run(self):
227        self.pid.write()
228        mkcert_manager = MkcertManager()
229        mkcert_manager.setup_mkcert(install_path=Path.home() / ".plain" / "dev")
230        self.ssl_cert_path, self.ssl_key_path = mkcert_manager.generate_certs(
231            domain=self.hostname,
232            storage_path=Path(PLAIN_TEMP_PATH) / "dev" / "certs",
233        )
234
235        self.symlink_plain_src()
236        self.modify_hosts_file()
237        self.set_allowed_hosts()
238        self.run_preflight()
239
240        # If we start services ourselves, we should manage the pidfile
241        services_pid = None
242
243        # Services start first (or are already running from a separate command)
244        if Services.are_running():
245            click.secho("Services already running", fg="yellow")
246        elif services := Services.get_services(APP_PATH.parent):
247            click.secho("\nStarting services...", italic=True, dim=True)
248            services_pid = ServicesPid()
249            services_pid.write()
250
251            for name, data in services.items():
252                env = {
253                    **os.environ,
254                    "PYTHONUNBUFFERED": "true",
255                    **data.get("env", {}),
256                }
257                self.poncho.add_process(name, data["cmd"], env=env)
258
259        # If plain.models is installed (common) then we
260        # will do a couple extra things before starting all of the app-related
261        # processes (this way they don't all have to db-wait or anything)
262        process = None
263        if find_spec("plain.models") is not None:
264            # Use a custom signal to tell the main thread to add
265            # the app processes once the db is ready
266            signal.signal(signal.SIGUSR1, self.start_app)
267
268            process = multiprocessing.Process(
269                target=_process_task, args=(self.plain_env,)
270            )
271            process.start()
272
273            # If there are no poncho processes, then let this process finish before
274            # continuing (vs running in parallel)
275            if self.poncho.num_processes() == 0:
276                # Wait for the process to finish
277                process.join()
278        else:
279            # Start the app processes immediately
280            self.start_app(None, None)
281
282        try:
283            # Start processes we know about and block the main thread
284            self.poncho.loop()
285
286            # Remove the status bar
287            self.console_status.stop()
288        finally:
289            self.pid.rm()
290            # Make sure the services pid gets removed if we set it
291            if services_pid:
292                services_pid.rm()
293
294            # Make sure the process is terminated if it is still running
295            if process and process.is_alive():
296                os.killpg(os.getpgid(process.pid), signal.SIGTERM)
297                process.join(timeout=3)
298                if process.is_alive():
299                    os.killpg(os.getpgid(process.pid), signal.SIGKILL)
300                    process.join()
301
302        return self.poncho.returncode
303
304    def start_app(self, signum, frame):
305        # This runs in the main thread when SIGUSR1 is received
306        # (or called directly if no thread).
307        click.secho("\nStarting app...", italic=True, dim=True)
308
309        # Manually start the status bar now so it isn't bungled by
310        # another thread checking db stuff...
311        self.console_status.start()
312
313        self.add_gunicorn()
314        self.add_entrypoints()
315        self.add_pyproject_run()
316
317    def symlink_plain_src(self):
318        """Symlink the plain package into .plain so we can look at it easily"""
319        plain_path = Path(
320            importlib.util.find_spec("plain.runtime").origin
321        ).parent.parent
322        if not PLAIN_TEMP_PATH.exists():
323            PLAIN_TEMP_PATH.mkdir()
324
325        symlink_path = PLAIN_TEMP_PATH / "src"
326
327        # The symlink is broken
328        if symlink_path.is_symlink() and not symlink_path.exists():
329            symlink_path.unlink()
330
331        # The symlink exists but points to the wrong place
332        if (
333            symlink_path.is_symlink()
334            and symlink_path.exists()
335            and symlink_path.resolve() != plain_path
336        ):
337            symlink_path.unlink()
338
339        if plain_path.exists() and not symlink_path.exists():
340            symlink_path.symlink_to(plain_path)
341
342    def modify_hosts_file(self):
343        """Modify the hosts file to map the custom domain to 127.0.0.1."""
344        entry_identifier = "# Added by plain"
345        hosts_entry = f"127.0.0.1 {self.hostname}  {entry_identifier}"
346
347        if platform.system() == "Windows":
348            hosts_path = Path(r"C:\Windows\System32\drivers\etc\hosts")
349            try:
350                with hosts_path.open("r") as f:
351                    content = f.read()
352
353                if hosts_entry in content:
354                    return  # Entry already exists; no action needed
355
356                # Entry does not exist; add it
357                with hosts_path.open("a") as f:
358                    f.write(f"{hosts_entry}\n")
359                click.secho(f"Added {self.hostname} to {hosts_path}", bold=True)
360            except PermissionError:
361                click.secho(
362                    "Permission denied while modifying hosts file. Please run the script as an administrator.",
363                    fg="red",
364                )
365                sys.exit(1)
366        else:
367            # For macOS and Linux
368            hosts_path = Path("/etc/hosts")
369            try:
370                with hosts_path.open("r") as f:
371                    content = f.read()
372
373                if hosts_entry in content:
374                    return  # Entry already exists; no action needed
375
376                # Entry does not exist; append it using sudo
377                click.secho(
378                    f"Adding {self.hostname} to /etc/hosts file. You may be prompted for your password.\n",
379                    bold=True,
380                )
381                cmd = f"echo '{hosts_entry}' | sudo tee -a {hosts_path} >/dev/null"
382                subprocess.run(cmd, shell=True, check=True)
383                click.secho(f"Added {self.hostname} to {hosts_path}\n", bold=True)
384            except PermissionError:
385                click.secho(
386                    "Permission denied while accessing hosts file.",
387                    fg="red",
388                )
389                sys.exit(1)
390            except subprocess.CalledProcessError:
391                click.secho(
392                    "Failed to modify hosts file. Please ensure you have sudo privileges.",
393                    fg="red",
394                )
395                sys.exit(1)
396
397    def set_allowed_hosts(self):
398        if "PLAIN_ALLOWED_HOSTS" not in os.environ:
399            hostnames = [self.hostname]
400            if self.tunnel_url:
401                # Add the tunnel URL to the allowed hosts
402                hostnames.append(self.tunnel_url.split("://")[1])
403            allowed_hosts = json.dumps(hostnames)
404            self.plain_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
405            self.custom_process_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
406            click.secho(
407                f"Automatically set PLAIN_ALLOWED_HOSTS={allowed_hosts}", dim=True
408            )
409
410    def run_preflight(self):
411        click.echo()
412        if subprocess.run(["plain", "preflight"], env=self.plain_env).returncode:
413            click.secho("Preflight check failed!", fg="red")
414            sys.exit(1)
415
416    def add_gunicorn(self):
417        # Watch .env files for reload
418        extra_watch_files = []
419        for f in os.listdir(APP_PATH.parent):
420            if f.startswith(".env"):
421                # Needs to be absolute or "./" for inotify to work on Linux...
422                # https://github.com/dropseed/plain/issues/26
423                extra_watch_files.append(str(Path(APP_PATH.parent) / f))
424
425        reload_extra = " ".join(f"--reload-extra-file {f}" for f in extra_watch_files)
426        gunicorn_cmd = [
427            "gunicorn",
428            "--bind",
429            f"{self.hostname}:{self.port}",
430            "--certfile",
431            str(self.ssl_cert_path),
432            "--keyfile",
433            str(self.ssl_key_path),
434            "--threads",
435            "4",
436            "--reload",
437            "plain.wsgi:app",
438            "--timeout",
439            "60",
440            "--log-level",
441            self.log_level or "info",
442            "--access-logfile",
443            "-",
444            "--error-logfile",
445            "-",
446            *reload_extra.split(),
447            "--access-logformat",
448            "'\"%(r)s\" status=%(s)s length=%(b)s time=%(M)sms'",
449            "--log-config-json",
450            str(Path(__file__).parent / "gunicorn_logging.json"),
451        ]
452        gunicorn = " ".join(gunicorn_cmd)
453
454        self.poncho.add_process("plain", gunicorn, env=self.plain_env)
455
456    def add_entrypoints(self):
457        for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
458            self.poncho.add_process(
459                entry_point.name,
460                f"plain dev entrypoint {entry_point.name}",
461                env=self.plain_env,
462            )
463
464    def add_pyproject_run(self):
465        """Additional processes that only run during `plain dev`."""
466        if not has_pyproject_toml(APP_PATH.parent):
467            return
468
469        with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
470            pyproject = tomllib.load(f)
471
472        run_commands = (
473            pyproject.get("tool", {}).get("plain", {}).get("dev", {}).get("run", {})
474        )
475        for name, data in run_commands.items():
476            env = {
477                **self.custom_process_env,
478                **data.get("env", {}),
479            }
480            self.poncho.add_process(name, data["cmd"], env=env)
481
482
483def _process_task(env):
484    # Make this process the leader of a new group which can be killed together if it doesn't finish
485    os.setsid()
486
487    subprocess.run(["plain", "models", "db-wait"], env=env, check=True)
488    subprocess.run(["plain", "migrate", "--backup"], env=env, check=True)
489
490    # preflight with db?
491
492    # Send SIGUSR1 to the parent process so the parent's handler is invoked
493    os.kill(os.getppid(), signal.SIGUSR1)