Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2import multiprocessing
  3import os
  4import platform
  5import signal
  6import socket
  7import subprocess
  8import sys
  9import tomllib
 10from importlib.metadata import entry_points
 11from importlib.util import find_spec
 12from pathlib import Path
 13
 14import click
 15from rich.columns import Columns
 16from rich.console import Console
 17from rich.text import Text
 18
 19from plain.runtime import APP_PATH, PLAIN_TEMP_PATH
 20
 21from .mkcert import MkcertManager
 22from .process import ProcessManager
 23from .services import ServicesProcess
 24from .utils import has_pyproject_toml
 25
 26ENTRYPOINT_GROUP = "plain.dev"
 27
 28
 29class DevProcess(ProcessManager):
 30    pidfile = PLAIN_TEMP_PATH / "dev" / "dev.pid"
 31    log_dir = PLAIN_TEMP_PATH / "dev" / "logs" / "run"
 32
 33    def __init__(self, *, port, hostname, log_level):
 34        super().__init__()
 35
 36        if not hostname:
 37            project_name = os.path.basename(
 38                os.getcwd()
 39            )  # Use directory name by default
 40
 41            if has_pyproject_toml(APP_PATH.parent):
 42                with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
 43                    pyproject = tomllib.load(f)
 44                    project_name = pyproject.get("project", {}).get(
 45                        "name", project_name
 46                    )
 47
 48            hostname = f"{project_name}.localhost"
 49
 50        self.hostname = hostname
 51        self.log_level = log_level
 52
 53        self.pid_value = self.pid
 54        self.prepare_log()
 55
 56        if port:
 57            self.port = int(port)
 58            if not self._port_available(self.port):
 59                click.secho(f"Port {self.port} in use", fg="red")
 60                raise SystemExit(1)
 61        else:
 62            self.port = self._find_open_port(8443)
 63            if self.port != 8443:
 64                click.secho(f"Port 8443 in use, using {self.port}", fg="yellow")
 65
 66        self.ssl_key_path = None
 67        self.ssl_cert_path = None
 68
 69        self.url = f"https://{self.hostname}:{self.port}"
 70        self.tunnel_url = os.environ.get("PLAIN_DEV_TUNNEL_URL", "")
 71
 72        self.plain_env = {
 73            "PYTHONUNBUFFERED": "true",
 74            "PLAIN_DEV": "true",
 75            **os.environ,
 76        }
 77
 78        if log_level:
 79            self.plain_env["PLAIN_LOG_LEVEL"] = log_level.upper()
 80            self.plain_env["APP_LOG_LEVEL"] = log_level.upper()
 81
 82        self.custom_process_env = {
 83            **self.plain_env,
 84            "PORT": str(self.port),
 85            "PLAIN_DEV_URL": self.url,
 86        }
 87
 88        if self.tunnel_url:
 89            status_bar = Columns(
 90                [
 91                    Text.from_markup(
 92                        f"[bold]Tunnel[/bold] [underline][link={self.tunnel_url}]{self.tunnel_url}[/link][/underline]"
 93                    ),
 94                    Text.from_markup(
 95                        f"[dim][bold]Server[/bold] [link={self.url}]{self.url}[/link][/dim]"
 96                    ),
 97                    Text.from_markup(
 98                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]",
 99                        justify="right",
100                    ),
101                ],
102                expand=True,
103            )
104        else:
105            status_bar = Columns(
106                [
107                    Text.from_markup(
108                        f"[bold]Server[/bold] [underline][link={self.url}]{self.url}[/link][/underline]"
109                    ),
110                    Text.from_markup(
111                        "[dim][bold]Ctrl+C[/bold] to stop[/dim]", justify="right"
112                    ),
113                ],
114                expand=True,
115            )
116        self.console = Console(markup=False, highlight=False)
117        self.console_status = self.console.status(status_bar)
118
119        self.init_poncho(self.console.out)
120
121    def _find_open_port(self, start_port):
122        port = start_port
123        while not self._port_available(port):
124            port += 1
125        return port
126
127    def _port_available(self, port):
128        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
129            sock.settimeout(0.5)
130            result = sock.connect_ex(("127.0.0.1", port))
131        return result != 0
132
133    def run(self):
134        self.write_pidfile()
135        click.secho(f"PID: {self.pid_value}", dim=True)
136        mkcert_manager = MkcertManager()
137        mkcert_manager.setup_mkcert(install_path=Path.home() / ".plain" / "dev")
138        self.ssl_cert_path, self.ssl_key_path = mkcert_manager.generate_certs(
139            domain=self.hostname,
140            storage_path=Path(PLAIN_TEMP_PATH) / "dev" / "certs",
141        )
142
143        self.symlink_plain_src()
144        self.modify_hosts_file()
145        self.set_allowed_hosts()
146        self.run_preflight()
147
148        # If we start services ourselves, we should manage the pidfile
149        services_proc = None
150
151        # Services start first (or are already running from a separate command)
152        if running_pid := ServicesProcess().running_pid():
153            click.secho(f"Services already running (pid={running_pid})", fg="yellow")
154        elif services := ServicesProcess.get_services(APP_PATH.parent):
155            click.secho("\nStarting services...", italic=True, dim=True)
156            services_proc = ServicesProcess()
157            services_proc.write_pidfile()
158
159            for name, data in services.items():
160                env = {
161                    **os.environ,
162                    "PYTHONUNBUFFERED": "true",
163                    **data.get("env", {}),
164                }
165                self.poncho.add_process(name, data["cmd"], env=env)
166
167        # If plain.models is installed (common) then we
168        # will do a couple extra things before starting all of the app-related
169        # processes (this way they don't all have to db-wait or anything)
170        process = None
171        if find_spec("plain.models") is not None:
172            # Use a custom signal to tell the main thread to add
173            # the app processes once the db is ready
174            signal.signal(signal.SIGUSR1, self.start_app)
175
176            process = multiprocessing.Process(
177                target=_process_task, args=(self.plain_env,)
178            )
179            process.start()
180
181            # If there are no poncho processes, then let this process finish before
182            # continuing (vs running in parallel)
183            if self.poncho.num_processes() == 0:
184                # Wait for the process to finish
185                process.join()
186        else:
187            # Start the app processes immediately
188            self.start_app(None, None)
189
190        try:
191            # Start processes we know about and block the main thread
192            self.poncho.loop()
193
194            # Remove the status bar
195            self.console_status.stop()
196        finally:
197            self.rm_pidfile()
198            # Make sure the services pid gets removed if we set it
199            if services_proc:
200                services_proc.rm_pidfile()
201
202            self.close()
203
204            # Make sure the process is terminated if it is still running
205            if process and process.is_alive():
206                os.killpg(os.getpgid(process.pid), signal.SIGTERM)
207                process.join(timeout=3)
208                if process.is_alive():
209                    os.killpg(os.getpgid(process.pid), signal.SIGKILL)
210                    process.join()
211
212        return self.poncho.returncode
213
214    def start_app(self, signum, frame):
215        # This runs in the main thread when SIGUSR1 is received
216        # (or called directly if no thread).
217        click.secho("\nStarting app...", italic=True, dim=True)
218
219        # Manually start the status bar now so it isn't bungled by
220        # another thread checking db stuff...
221        self.console_status.start()
222
223        self.add_gunicorn()
224        self.add_entrypoints()
225        self.add_pyproject_run()
226
227    def symlink_plain_src(self):
228        """Symlink the plain package into .plain so we can look at it easily"""
229        plain_path = Path(find_spec("plain.runtime").origin).parent.parent
230        if not PLAIN_TEMP_PATH.exists():
231            PLAIN_TEMP_PATH.mkdir()
232
233        symlink_path = PLAIN_TEMP_PATH / "src"
234
235        # The symlink is broken
236        if symlink_path.is_symlink() and not symlink_path.exists():
237            symlink_path.unlink()
238
239        # The symlink exists but points to the wrong place
240        if (
241            symlink_path.is_symlink()
242            and symlink_path.exists()
243            and symlink_path.resolve() != plain_path
244        ):
245            symlink_path.unlink()
246
247        if plain_path.exists() and not symlink_path.exists():
248            symlink_path.symlink_to(plain_path)
249
250    def modify_hosts_file(self):
251        """Modify the hosts file to map the custom domain to 127.0.0.1."""
252        entry_identifier = "# Added by plain"
253        hosts_entry = f"127.0.0.1 {self.hostname}  {entry_identifier}"
254
255        if platform.system() == "Windows":
256            hosts_path = Path(r"C:\Windows\System32\drivers\etc\hosts")
257            try:
258                with hosts_path.open("r") as f:
259                    content = f.read()
260
261                if hosts_entry in content:
262                    return  # Entry already exists; no action needed
263
264                # Entry does not exist; add it
265                with hosts_path.open("a") as f:
266                    f.write(f"{hosts_entry}\n")
267                click.secho(f"Added {self.hostname} to {hosts_path}", bold=True)
268            except PermissionError:
269                click.secho(
270                    "Permission denied while modifying hosts file. Please run the script as an administrator.",
271                    fg="red",
272                )
273                sys.exit(1)
274        else:
275            # For macOS and Linux
276            hosts_path = Path("/etc/hosts")
277            try:
278                with hosts_path.open("r") as f:
279                    content = f.read()
280
281                if hosts_entry in content:
282                    return  # Entry already exists; no action needed
283
284                # Entry does not exist; append it using sudo
285                click.secho(
286                    f"Adding {self.hostname} to /etc/hosts file. You may be prompted for your password.\n",
287                    bold=True,
288                )
289                cmd = f"echo '{hosts_entry}' | sudo tee -a {hosts_path} >/dev/null"
290                subprocess.run(cmd, shell=True, check=True)
291                click.secho(f"Added {self.hostname} to {hosts_path}\n", bold=True)
292            except PermissionError:
293                click.secho(
294                    "Permission denied while accessing hosts file.",
295                    fg="red",
296                )
297                sys.exit(1)
298            except subprocess.CalledProcessError:
299                click.secho(
300                    "Failed to modify hosts file. Please ensure you have sudo privileges.",
301                    fg="red",
302                )
303                sys.exit(1)
304
305    def set_allowed_hosts(self):
306        if "PLAIN_ALLOWED_HOSTS" not in os.environ:
307            hostnames = [self.hostname]
308            if self.tunnel_url:
309                # Add the tunnel URL to the allowed hosts
310                hostnames.append(self.tunnel_url.split("://")[1])
311            allowed_hosts = json.dumps(hostnames)
312            self.plain_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
313            self.custom_process_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
314            click.secho(
315                f"Automatically set PLAIN_ALLOWED_HOSTS={allowed_hosts}", dim=True
316            )
317
318    def run_preflight(self):
319        click.echo()
320        if subprocess.run(["plain", "preflight"], env=self.plain_env).returncode:
321            click.secho("Preflight check failed!", fg="red")
322            sys.exit(1)
323
324    def add_gunicorn(self):
325        # Watch .env files for reload
326        extra_watch_files = []
327        for f in os.listdir(APP_PATH.parent):
328            if f.startswith(".env"):
329                # Needs to be absolute or "./" for inotify to work on Linux...
330                # https://github.com/dropseed/plain/issues/26
331                extra_watch_files.append(str(Path(APP_PATH.parent) / f))
332
333        reload_extra = " ".join(f"--reload-extra-file {f}" for f in extra_watch_files)
334        gunicorn_cmd = [
335            "gunicorn",
336            "--bind",
337            f"{self.hostname}:{self.port}",
338            "--certfile",
339            str(self.ssl_cert_path),
340            "--keyfile",
341            str(self.ssl_key_path),
342            "--threads",
343            "4",
344            "--reload",
345            "plain.wsgi:app",
346            "--timeout",
347            "60",
348            "--log-level",
349            self.log_level or "info",
350            "--access-logfile",
351            "-",
352            "--error-logfile",
353            "-",
354            *reload_extra.split(),
355            "--access-logformat",
356            "'\"%(r)s\" status=%(s)s length=%(b)s time=%(M)sms'",
357            "--log-config-json",
358            str(Path(__file__).parent / "gunicorn_logging.json"),
359        ]
360        gunicorn = " ".join(gunicorn_cmd)
361
362        self.poncho.add_process("plain", gunicorn, env=self.plain_env)
363
364    def add_entrypoints(self):
365        for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
366            self.poncho.add_process(
367                entry_point.name,
368                f"plain dev entrypoint {entry_point.name}",
369                env=self.plain_env,
370            )
371
372    def add_pyproject_run(self):
373        """Additional processes that only run during `plain dev`."""
374        if not has_pyproject_toml(APP_PATH.parent):
375            return
376
377        with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
378            pyproject = tomllib.load(f)
379
380        run_commands = (
381            pyproject.get("tool", {}).get("plain", {}).get("dev", {}).get("run", {})
382        )
383        for name, data in run_commands.items():
384            env = {
385                **self.custom_process_env,
386                **data.get("env", {}),
387            }
388            self.poncho.add_process(name, data["cmd"], env=env)
389
390
391def _process_task(env):
392    # Make this process the leader of a new group which can be killed together if it doesn't finish
393    os.setsid()
394
395    subprocess.run(["plain", "models", "db-wait"], env=env, check=True)
396    subprocess.run(["plain", "migrate", "--backup"], env=env, check=True)
397
398    # preflight with db?
399
400    # Send SIGUSR1 to the parent process so the parent's handler is invoked
401    os.kill(os.getppid(), signal.SIGUSR1)