Plain is headed towards 1.0! Subscribe for development updates →

  1import importlib
  2import json
  3import multiprocessing
  4import os
  5import platform
  6import signal
  7import subprocess
  8import sys
  9import time
 10import tomllib
 11from importlib.metadata import entry_points
 12from importlib.util import find_spec
 13from pathlib import Path
 14
 15import click
 16from rich.columns import Columns
 17from rich.console import Console
 18from rich.text import Text
 19
 20from plain.cli import register_cli
 21from plain.runtime import APP_PATH, settings
 22
 23from .mkcert import MkcertManager
 24from .poncho.manager import Manager as PonchoManager
 25from .poncho.printer import Printer
 26from .services import Services, ServicesPid
 27from .utils import has_pyproject_toml
 28
 29ENTRYPOINT_GROUP = "plain.dev"
 30
 31
 32@register_cli("dev")
 33@click.group(invoke_without_command=True)
 34@click.pass_context
 35@click.option(
 36    "--port",
 37    "-p",
 38    default=8443,
 39    type=int,
 40    help="Port to run the web server on",
 41)
 42@click.option(
 43    "--hostname",
 44    "-h",
 45    default=None,
 46    type=str,
 47    help="Hostname to run the web server on",
 48)
 49@click.option(
 50    "--log-level",
 51    "-l",
 52    default="",
 53    type=click.Choice(["debug", "info", "warning", "error", "critical", ""]),
 54    help="Log level",
 55)
 56def cli(ctx, port, hostname, log_level):
 57    """Start local development"""
 58
 59    if ctx.invoked_subcommand:
 60        return
 61
 62    if not hostname:
 63        project_name = os.path.basename(
 64            os.getcwd()
 65        )  # Use the directory name by default
 66
 67        if has_pyproject_toml(APP_PATH.parent):
 68            with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
 69                pyproject = tomllib.load(f)
 70                project_name = pyproject.get("project", {}).get("name", project_name)
 71
 72        hostname = f"{project_name}.localhost"
 73
 74    returncode = Dev(port=port, hostname=hostname, log_level=log_level).run()
 75    if returncode:
 76        sys.exit(returncode)
 77
 78
 79@cli.command()
 80def debug():
 81    """Connect to the remote debugger"""
 82
 83    def _connect():
 84        if subprocess.run(["which", "nc"], capture_output=True).returncode == 0:
 85            return subprocess.run(["nc", "-C", "localhost", "4444"])
 86        else:
 87            raise OSError("nc not found")
 88
 89    result = _connect()
 90
 91    # Try again once without a message
 92    if result.returncode == 1:
 93        time.sleep(1)
 94        result = _connect()
 95
 96    # Keep trying...
 97    while result.returncode == 1:
 98        click.secho(
 99            "Failed to connect. Make sure remote pdb is ready. Retrying...", fg="red"
100        )
101        result = _connect()
102        time.sleep(1)
103
104
105@cli.command()
106def services():
107    """Start additional services defined in pyproject.toml"""
108    _services = Services()
109    if _services.are_running():
110        click.secho("Services already running", fg="yellow")
111        return
112    _services.run()
113
114
115@cli.command()
116@click.option(
117    "--list", "-l", "show_list", is_flag=True, help="List available entrypoints"
118)
119@click.argument("entrypoint", required=False)
120def entrypoint(show_list, entrypoint):
121    """Entrypoints registered under plain.dev"""
122    if not show_list and not entrypoint:
123        click.secho("Please provide an entrypoint name or use --list", fg="red")
124        sys.exit(1)
125
126    for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
127        if show_list:
128            click.echo(entry_point.name)
129        elif entrypoint == entry_point.name:
130            entry_point.load()()
131
132
133class Dev:
134    def __init__(self, *, port, hostname, log_level):
135        self.port = port
136        self.hostname = hostname
137        self.log_level = log_level
138
139        self.ssl_key_path = None
140        self.ssl_cert_path = None
141
142        self.url = f"https://{self.hostname}:{self.port}"
143        self.tunnel_url = os.environ.get("PLAIN_DEV_TUNNEL_URL", "")
144
145        self.plain_env = {
146            "PYTHONUNBUFFERED": "true",
147            "PLAIN_DEV": "true",
148            **os.environ,
149        }
150
151        if log_level:
152            self.plain_env["PLAIN_LOG_LEVEL"] = log_level.upper()
153            self.plain_env["APP_LOG_LEVEL"] = log_level.upper()
154
155        self.custom_process_env = {
156            **self.plain_env,
157            "PORT": str(self.port),
158            "PLAIN_DEV_URL": self.url,
159        }
160
161        if self.tunnel_url:
162            status_bar = Columns(
163                [
164                    Text.from_markup(
165                        f"[bold]Tunnel[/bold] [underline][link={self.tunnel_url}]{self.tunnel_url}[/link][/underline]"
166                    ),
167                    Text.from_markup(
168                        f"[dim][bold]Server[/bold] [link={self.url}]{self.url}[/link][/dim]"
169                    ),
170                    Text.from_markup(
171                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]",
172                        justify="right",
173                    ),
174                ],
175                expand=True,
176            )
177        else:
178            status_bar = Columns(
179                [
180                    Text.from_markup(
181                        f"[bold]Server[/bold] [underline][link={self.url}]{self.url}[/link][/underline]"
182                    ),
183                    Text.from_markup(
184                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]", justify="right"
185                    ),
186                ],
187                expand=True,
188            )
189        self.console = Console(markup=False, highlight=False)
190        self.console_status = self.console.status(status_bar)
191
192        self.poncho = PonchoManager(printer=Printer(lambda s: self.console.out(s)))
193
194    def run(self):
195        mkcert_manager = MkcertManager()
196        mkcert_manager.setup_mkcert(install_path=Path.home() / ".plain" / "dev")
197        self.ssl_cert_path, self.ssl_key_path = mkcert_manager.generate_certs(
198            domain=self.hostname,
199            storage_path=Path(settings.PLAIN_TEMP_PATH) / "dev" / "certs",
200        )
201
202        self.symlink_plain_src()
203        self.modify_hosts_file()
204        self.set_allowed_hosts()
205        self.run_preflight()
206
207        # If we start services ourselves, we should manage the pidfile
208        services_pid = None
209
210        # Services start first (or are already running from a separate command)
211        if Services.are_running():
212            click.secho("Services already running", fg="yellow")
213        elif services := Services.get_services(APP_PATH.parent):
214            click.secho("\nStarting services...", italic=True, dim=True)
215            services_pid = ServicesPid()
216            services_pid.write()
217
218            for name, data in services.items():
219                env = {
220                    **os.environ,
221                    "PYTHONUNBUFFERED": "true",
222                    **data.get("env", {}),
223                }
224                self.poncho.add_process(name, data["cmd"], env=env)
225
226        # If plain.models is installed (common) then we
227        # will do a couple extra things before starting all of the app-related
228        # processes (this way they don't all have to db-wait or anything)
229        process = None
230        if find_spec("plain.models") is not None:
231            # Use a custom signal to tell the main thread to add
232            # the app processes once the db is ready
233            signal.signal(signal.SIGUSR1, self.start_app)
234
235            process = multiprocessing.Process(
236                target=_process_task, args=(self.plain_env,)
237            )
238            process.start()
239        else:
240            # Start the app processes immediately
241            self.start_app(None, None)
242
243        try:
244            # Start processes we know about and block the main thread
245            self.poncho.loop()
246
247            # Remove the status bar
248            self.console_status.stop()
249        finally:
250            # Make sure the services pid gets removed if we set it
251            if services_pid:
252                services_pid.rm()
253
254            # Make sure the process is terminated if it is still running
255            if process and process.is_alive():
256                os.killpg(os.getpgid(process.pid), signal.SIGTERM)
257                process.join(timeout=3)
258                if process.is_alive():
259                    os.killpg(os.getpgid(process.pid), signal.SIGKILL)
260                    process.join()
261
262        return self.poncho.returncode
263
264    def start_app(self, signum, frame):
265        # This runs in the main thread when SIGUSR1 is received
266        # (or called directly if no thread).
267        click.secho("\nStarting app...", italic=True, dim=True)
268
269        # Manually start the status bar now so it isn't bungled by
270        # another thread checking db stuff...
271        self.console_status.start()
272
273        self.add_gunicorn()
274        self.add_entrypoints()
275        self.add_pyproject_run()
276
277    def symlink_plain_src(self):
278        """Symlink the plain package into .plain so we can look at it easily"""
279        plain_path = Path(
280            importlib.util.find_spec("plain.runtime").origin
281        ).parent.parent
282        if not settings.PLAIN_TEMP_PATH.exists():
283            settings.PLAIN_TEMP_PATH.mkdir()
284
285        symlink_path = settings.PLAIN_TEMP_PATH / "src"
286
287        # The symlink is broken
288        if symlink_path.is_symlink() and not symlink_path.exists():
289            symlink_path.unlink()
290
291        # The symlink exists but points to the wrong place
292        if (
293            symlink_path.is_symlink()
294            and symlink_path.exists()
295            and symlink_path.resolve() != plain_path
296        ):
297            symlink_path.unlink()
298
299        if plain_path.exists() and not symlink_path.exists():
300            symlink_path.symlink_to(plain_path)
301
302    def modify_hosts_file(self):
303        """Modify the hosts file to map the custom domain to 127.0.0.1."""
304        entry_identifier = "# Added by plain"
305        hosts_entry = f"127.0.0.1 {self.hostname}  {entry_identifier}"
306
307        if platform.system() == "Windows":
308            hosts_path = Path(r"C:\Windows\System32\drivers\etc\hosts")
309            try:
310                with hosts_path.open("r") as f:
311                    content = f.read()
312
313                if hosts_entry in content:
314                    return  # Entry already exists; no action needed
315
316                # Entry does not exist; add it
317                with hosts_path.open("a") as f:
318                    f.write(f"{hosts_entry}\n")
319                click.secho(f"Added {self.hostname} to {hosts_path}", bold=True)
320            except PermissionError:
321                click.secho(
322                    "Permission denied while modifying hosts file. Please run the script as an administrator.",
323                    fg="red",
324                )
325                sys.exit(1)
326        else:
327            # For macOS and Linux
328            hosts_path = Path("/etc/hosts")
329            try:
330                with hosts_path.open("r") as f:
331                    content = f.read()
332
333                if hosts_entry in content:
334                    return  # Entry already exists; no action needed
335
336                # Entry does not exist; append it using sudo
337                click.secho(
338                    f"Adding {self.hostname} to /etc/hosts file. You may be prompted for your password.\n",
339                    bold=True,
340                )
341                cmd = f"echo '{hosts_entry}' | sudo tee -a {hosts_path} >/dev/null"
342                subprocess.run(cmd, shell=True, check=True)
343                click.secho(f"Added {self.hostname} to {hosts_path}\n", bold=True)
344            except PermissionError:
345                click.secho(
346                    "Permission denied while accessing hosts file.",
347                    fg="red",
348                )
349                sys.exit(1)
350            except subprocess.CalledProcessError:
351                click.secho(
352                    "Failed to modify hosts file. Please ensure you have sudo privileges.",
353                    fg="red",
354                )
355                sys.exit(1)
356
357    def set_allowed_hosts(self):
358        if "PLAIN_ALLOWED_HOSTS" not in os.environ:
359            hostnames = [self.hostname]
360            if self.tunnel_url:
361                # Add the tunnel URL to the allowed hosts
362                hostnames.append(self.tunnel_url.split("://")[1])
363            allowed_hosts = json.dumps(hostnames)
364            self.plain_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
365            self.custom_process_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
366            click.secho(
367                f"Automatically set PLAIN_ALLOWED_HOSTS={allowed_hosts}", dim=True
368            )
369
370    def run_preflight(self):
371        click.echo()
372        if subprocess.run(["plain", "preflight"], env=self.plain_env).returncode:
373            click.secho("Preflight check failed!", fg="red")
374            sys.exit(1)
375
376    def add_gunicorn(self):
377        # Watch .env files for reload
378        extra_watch_files = []
379        for f in os.listdir(APP_PATH.parent):
380            if f.startswith(".env"):
381                # Needs to be absolute or "./" for inotify to work on Linux...
382                # https://github.com/dropseed/plain/issues/26
383                extra_watch_files.append(str(Path(APP_PATH.parent) / f))
384
385        reload_extra = " ".join(f"--reload-extra-file {f}" for f in extra_watch_files)
386        gunicorn_cmd = [
387            "gunicorn",
388            "--bind",
389            f"{self.hostname}:{self.port}",
390            "--certfile",
391            str(self.ssl_cert_path),
392            "--keyfile",
393            str(self.ssl_key_path),
394            "--threads",
395            "4",
396            "--reload",
397            "plain.wsgi:app",
398            "--timeout",
399            "60",
400            "--log-level",
401            self.log_level or "info",
402            "--access-logfile",
403            "-",
404            "--error-logfile",
405            "-",
406            *reload_extra.split(),
407            "--access-logformat",
408            "'\"%(r)s\" status=%(s)s length=%(b)s time=%(M)sms'",
409            "--log-config-json",
410            str(Path(__file__).parent / "gunicorn_logging.json"),
411        ]
412        gunicorn = " ".join(gunicorn_cmd)
413
414        self.poncho.add_process("plain", gunicorn, env=self.plain_env)
415
416    def add_entrypoints(self):
417        for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
418            self.poncho.add_process(
419                entry_point.name,
420                f"plain dev entrypoint {entry_point.name}",
421                env=self.plain_env,
422            )
423
424    def add_pyproject_run(self):
425        """Additional processes that only run during `plain dev`."""
426        if not has_pyproject_toml(APP_PATH.parent):
427            return
428
429        with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
430            pyproject = tomllib.load(f)
431
432        run_commands = (
433            pyproject.get("tool", {}).get("plain", {}).get("dev", {}).get("run", {})
434        )
435        for name, data in run_commands.items():
436            env = {
437                **self.custom_process_env,
438                **data.get("env", {}),
439            }
440            self.poncho.add_process(name, data["cmd"], env=env)
441
442
443def _process_task(env):
444    # Make this process the leader of a new group which can be killed together if it doesn't finish
445    os.setsid()
446
447    subprocess.run(["plain", "models", "db-wait"], env=env, check=True)
448    subprocess.run(["plain", "migrate", "--backup"], env=env, check=True)
449
450    # preflight with db?
451
452    # Send SIGUSR1 to the parent process so the parent's handler is invoked
453    os.kill(os.getppid(), signal.SIGUSR1)