1import importlib
2import json
3import multiprocessing
4import os
5import platform
6import signal
7import socket
8import subprocess
9import sys
10import time
11import tomllib
12from importlib.metadata import entry_points
13from importlib.util import find_spec
14from pathlib import Path
15
16import click
17from rich.columns import Columns
18from rich.console import Console
19from rich.text import Text
20
21from plain.cli import register_cli
22from plain.runtime import APP_PATH, PLAIN_TEMP_PATH
23
24from .dev_pid import DevPid
25from .mkcert import MkcertManager
26from .poncho.manager import Manager as PonchoManager
27from .poncho.printer import Printer
28from .services import Services, ServicesPid
29from .utils import has_pyproject_toml
30
31ENTRYPOINT_GROUP = "plain.dev"
32
33
34@register_cli("dev")
35@click.group(invoke_without_command=True)
36@click.pass_context
37@click.option(
38 "--port",
39 "-p",
40 default="",
41 type=str,
42 help=(
43 "Port to run the web server on. If omitted, tries 8443 and "
44 "picks the next free port"
45 ),
46)
47@click.option(
48 "--hostname",
49 "-h",
50 default=None,
51 type=str,
52 help="Hostname to run the web server on",
53)
54@click.option(
55 "--log-level",
56 "-l",
57 default="",
58 type=click.Choice(["debug", "info", "warning", "error", "critical", ""]),
59 help="Log level",
60)
61def cli(ctx, port, hostname, log_level):
62 """Start local development"""
63
64 if ctx.invoked_subcommand:
65 return
66
67 if DevPid().exists():
68 click.secho("`plain dev` already running", fg="yellow")
69 sys.exit(1)
70
71 if not hostname:
72 project_name = os.path.basename(
73 os.getcwd()
74 ) # Use the directory name by default
75
76 if has_pyproject_toml(APP_PATH.parent):
77 with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
78 pyproject = tomllib.load(f)
79 project_name = pyproject.get("project", {}).get("name", project_name)
80
81 hostname = f"{project_name}.localhost"
82
83 returncode = Dev(port=port, hostname=hostname, log_level=log_level).run()
84 if returncode:
85 sys.exit(returncode)
86
87
88@cli.command()
89def debug():
90 """Connect to the remote debugger"""
91
92 def _connect():
93 if subprocess.run(["which", "nc"], capture_output=True).returncode == 0:
94 return subprocess.run(["nc", "-C", "localhost", "4444"])
95 else:
96 raise OSError("nc not found")
97
98 result = _connect()
99
100 # Try again once without a message
101 if result.returncode == 1:
102 time.sleep(1)
103 result = _connect()
104
105 # Keep trying...
106 while result.returncode == 1:
107 click.secho(
108 "Failed to connect. Make sure remote pdb is ready. Retrying...", fg="red"
109 )
110 result = _connect()
111 time.sleep(1)
112
113
114@cli.command()
115def services():
116 """Start additional services defined in pyproject.toml"""
117 _services = Services()
118 if _services.are_running():
119 click.secho("Services already running", fg="yellow")
120 return
121 _services.run()
122
123
124@cli.command()
125@click.option(
126 "--list", "-l", "show_list", is_flag=True, help="List available entrypoints"
127)
128@click.argument("entrypoint", required=False)
129def entrypoint(show_list, entrypoint):
130 """Entrypoints registered under plain.dev"""
131 if not show_list and not entrypoint:
132 click.secho("Please provide an entrypoint name or use --list", fg="red")
133 sys.exit(1)
134
135 for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
136 if show_list:
137 click.echo(entry_point.name)
138 elif entrypoint == entry_point.name:
139 entry_point.load()()
140
141
142class Dev:
143 def __init__(self, *, port, hostname, log_level):
144 self.hostname = hostname
145 self.log_level = log_level
146
147 self.pid = DevPid()
148
149 if port:
150 self.port = int(port)
151 if not self._port_available(self.port):
152 click.secho(f"Port {self.port} in use", fg="red")
153 raise SystemExit(1)
154 else:
155 self.port = self._find_open_port(8443)
156 if self.port != 8443:
157 click.secho(f"Port 8443 in use, using {self.port}", fg="yellow")
158
159 self.ssl_key_path = None
160 self.ssl_cert_path = None
161
162 self.url = f"https://{self.hostname}:{self.port}"
163 self.tunnel_url = os.environ.get("PLAIN_DEV_TUNNEL_URL", "")
164
165 self.plain_env = {
166 "PYTHONUNBUFFERED": "true",
167 "PLAIN_DEV": "true",
168 **os.environ,
169 }
170
171 if log_level:
172 self.plain_env["PLAIN_LOG_LEVEL"] = log_level.upper()
173 self.plain_env["APP_LOG_LEVEL"] = log_level.upper()
174
175 self.custom_process_env = {
176 **self.plain_env,
177 "PORT": str(self.port),
178 "PLAIN_DEV_URL": self.url,
179 }
180
181 if self.tunnel_url:
182 status_bar = Columns(
183 [
184 Text.from_markup(
185 f"[bold]Tunnel[/bold] [underline][link={self.tunnel_url}]{self.tunnel_url}[/link][/underline]"
186 ),
187 Text.from_markup(
188 f"[dim][bold]Server[/bold] [link={self.url}]{self.url}[/link][/dim]"
189 ),
190 Text.from_markup(
191 "[dim][bold]Ctrl+C[/bold] to stop[/dim]",
192 justify="right",
193 ),
194 ],
195 expand=True,
196 )
197 else:
198 status_bar = Columns(
199 [
200 Text.from_markup(
201 f"[bold]Server[/bold] [underline][link={self.url}]{self.url}[/link][/underline]"
202 ),
203 Text.from_markup(
204 "[dim][bold]Ctrl+C[/bold] to stop[/dim]", justify="right"
205 ),
206 ],
207 expand=True,
208 )
209 self.console = Console(markup=False, highlight=False)
210 self.console_status = self.console.status(status_bar)
211
212 self.poncho = PonchoManager(printer=Printer(lambda s: self.console.out(s)))
213
214 def _find_open_port(self, start_port):
215 port = start_port
216 while not self._port_available(port):
217 port += 1
218 return port
219
220 def _port_available(self, port):
221 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
222 sock.settimeout(0.5)
223 result = sock.connect_ex(("127.0.0.1", port))
224 return result != 0
225
226 def run(self):
227 self.pid.write()
228 mkcert_manager = MkcertManager()
229 mkcert_manager.setup_mkcert(install_path=Path.home() / ".plain" / "dev")
230 self.ssl_cert_path, self.ssl_key_path = mkcert_manager.generate_certs(
231 domain=self.hostname,
232 storage_path=Path(PLAIN_TEMP_PATH) / "dev" / "certs",
233 )
234
235 self.symlink_plain_src()
236 self.modify_hosts_file()
237 self.set_allowed_hosts()
238 self.run_preflight()
239
240 # If we start services ourselves, we should manage the pidfile
241 services_pid = None
242
243 # Services start first (or are already running from a separate command)
244 if Services.are_running():
245 click.secho("Services already running", fg="yellow")
246 elif services := Services.get_services(APP_PATH.parent):
247 click.secho("\nStarting services...", italic=True, dim=True)
248 services_pid = ServicesPid()
249 services_pid.write()
250
251 for name, data in services.items():
252 env = {
253 **os.environ,
254 "PYTHONUNBUFFERED": "true",
255 **data.get("env", {}),
256 }
257 self.poncho.add_process(name, data["cmd"], env=env)
258
259 # If plain.models is installed (common) then we
260 # will do a couple extra things before starting all of the app-related
261 # processes (this way they don't all have to db-wait or anything)
262 process = None
263 if find_spec("plain.models") is not None:
264 # Use a custom signal to tell the main thread to add
265 # the app processes once the db is ready
266 signal.signal(signal.SIGUSR1, self.start_app)
267
268 process = multiprocessing.Process(
269 target=_process_task, args=(self.plain_env,)
270 )
271 process.start()
272
273 # If there are no poncho processes, then let this process finish before
274 # continuing (vs running in parallel)
275 if self.poncho.num_processes() == 0:
276 # Wait for the process to finish
277 process.join()
278 else:
279 # Start the app processes immediately
280 self.start_app(None, None)
281
282 try:
283 # Start processes we know about and block the main thread
284 self.poncho.loop()
285
286 # Remove the status bar
287 self.console_status.stop()
288 finally:
289 self.pid.rm()
290 # Make sure the services pid gets removed if we set it
291 if services_pid:
292 services_pid.rm()
293
294 # Make sure the process is terminated if it is still running
295 if process and process.is_alive():
296 os.killpg(os.getpgid(process.pid), signal.SIGTERM)
297 process.join(timeout=3)
298 if process.is_alive():
299 os.killpg(os.getpgid(process.pid), signal.SIGKILL)
300 process.join()
301
302 return self.poncho.returncode
303
304 def start_app(self, signum, frame):
305 # This runs in the main thread when SIGUSR1 is received
306 # (or called directly if no thread).
307 click.secho("\nStarting app...", italic=True, dim=True)
308
309 # Manually start the status bar now so it isn't bungled by
310 # another thread checking db stuff...
311 self.console_status.start()
312
313 self.add_gunicorn()
314 self.add_entrypoints()
315 self.add_pyproject_run()
316
317 def symlink_plain_src(self):
318 """Symlink the plain package into .plain so we can look at it easily"""
319 plain_path = Path(
320 importlib.util.find_spec("plain.runtime").origin
321 ).parent.parent
322 if not PLAIN_TEMP_PATH.exists():
323 PLAIN_TEMP_PATH.mkdir()
324
325 symlink_path = PLAIN_TEMP_PATH / "src"
326
327 # The symlink is broken
328 if symlink_path.is_symlink() and not symlink_path.exists():
329 symlink_path.unlink()
330
331 # The symlink exists but points to the wrong place
332 if (
333 symlink_path.is_symlink()
334 and symlink_path.exists()
335 and symlink_path.resolve() != plain_path
336 ):
337 symlink_path.unlink()
338
339 if plain_path.exists() and not symlink_path.exists():
340 symlink_path.symlink_to(plain_path)
341
342 def modify_hosts_file(self):
343 """Modify the hosts file to map the custom domain to 127.0.0.1."""
344 entry_identifier = "# Added by plain"
345 hosts_entry = f"127.0.0.1 {self.hostname} {entry_identifier}"
346
347 if platform.system() == "Windows":
348 hosts_path = Path(r"C:\Windows\System32\drivers\etc\hosts")
349 try:
350 with hosts_path.open("r") as f:
351 content = f.read()
352
353 if hosts_entry in content:
354 return # Entry already exists; no action needed
355
356 # Entry does not exist; add it
357 with hosts_path.open("a") as f:
358 f.write(f"{hosts_entry}\n")
359 click.secho(f"Added {self.hostname} to {hosts_path}", bold=True)
360 except PermissionError:
361 click.secho(
362 "Permission denied while modifying hosts file. Please run the script as an administrator.",
363 fg="red",
364 )
365 sys.exit(1)
366 else:
367 # For macOS and Linux
368 hosts_path = Path("/etc/hosts")
369 try:
370 with hosts_path.open("r") as f:
371 content = f.read()
372
373 if hosts_entry in content:
374 return # Entry already exists; no action needed
375
376 # Entry does not exist; append it using sudo
377 click.secho(
378 f"Adding {self.hostname} to /etc/hosts file. You may be prompted for your password.\n",
379 bold=True,
380 )
381 cmd = f"echo '{hosts_entry}' | sudo tee -a {hosts_path} >/dev/null"
382 subprocess.run(cmd, shell=True, check=True)
383 click.secho(f"Added {self.hostname} to {hosts_path}\n", bold=True)
384 except PermissionError:
385 click.secho(
386 "Permission denied while accessing hosts file.",
387 fg="red",
388 )
389 sys.exit(1)
390 except subprocess.CalledProcessError:
391 click.secho(
392 "Failed to modify hosts file. Please ensure you have sudo privileges.",
393 fg="red",
394 )
395 sys.exit(1)
396
397 def set_allowed_hosts(self):
398 if "PLAIN_ALLOWED_HOSTS" not in os.environ:
399 hostnames = [self.hostname]
400 if self.tunnel_url:
401 # Add the tunnel URL to the allowed hosts
402 hostnames.append(self.tunnel_url.split("://")[1])
403 allowed_hosts = json.dumps(hostnames)
404 self.plain_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
405 self.custom_process_env["PLAIN_ALLOWED_HOSTS"] = allowed_hosts
406 click.secho(
407 f"Automatically set PLAIN_ALLOWED_HOSTS={allowed_hosts}", dim=True
408 )
409
410 def run_preflight(self):
411 click.echo()
412 if subprocess.run(["plain", "preflight"], env=self.plain_env).returncode:
413 click.secho("Preflight check failed!", fg="red")
414 sys.exit(1)
415
416 def add_gunicorn(self):
417 # Watch .env files for reload
418 extra_watch_files = []
419 for f in os.listdir(APP_PATH.parent):
420 if f.startswith(".env"):
421 # Needs to be absolute or "./" for inotify to work on Linux...
422 # https://github.com/dropseed/plain/issues/26
423 extra_watch_files.append(str(Path(APP_PATH.parent) / f))
424
425 reload_extra = " ".join(f"--reload-extra-file {f}" for f in extra_watch_files)
426 gunicorn_cmd = [
427 "gunicorn",
428 "--bind",
429 f"{self.hostname}:{self.port}",
430 "--certfile",
431 str(self.ssl_cert_path),
432 "--keyfile",
433 str(self.ssl_key_path),
434 "--threads",
435 "4",
436 "--reload",
437 "plain.wsgi:app",
438 "--timeout",
439 "60",
440 "--log-level",
441 self.log_level or "info",
442 "--access-logfile",
443 "-",
444 "--error-logfile",
445 "-",
446 *reload_extra.split(),
447 "--access-logformat",
448 "'\"%(r)s\" status=%(s)s length=%(b)s time=%(M)sms'",
449 "--log-config-json",
450 str(Path(__file__).parent / "gunicorn_logging.json"),
451 ]
452 gunicorn = " ".join(gunicorn_cmd)
453
454 self.poncho.add_process("plain", gunicorn, env=self.plain_env)
455
456 def add_entrypoints(self):
457 for entry_point in entry_points().select(group=ENTRYPOINT_GROUP):
458 self.poncho.add_process(
459 entry_point.name,
460 f"plain dev entrypoint {entry_point.name}",
461 env=self.plain_env,
462 )
463
464 def add_pyproject_run(self):
465 """Additional processes that only run during `plain dev`."""
466 if not has_pyproject_toml(APP_PATH.parent):
467 return
468
469 with open(Path(APP_PATH.parent, "pyproject.toml"), "rb") as f:
470 pyproject = tomllib.load(f)
471
472 run_commands = (
473 pyproject.get("tool", {}).get("plain", {}).get("dev", {}).get("run", {})
474 )
475 for name, data in run_commands.items():
476 env = {
477 **self.custom_process_env,
478 **data.get("env", {}),
479 }
480 self.poncho.add_process(name, data["cmd"], env=env)
481
482
483def _process_task(env):
484 # Make this process the leader of a new group which can be killed together if it doesn't finish
485 os.setsid()
486
487 subprocess.run(["plain", "models", "db-wait"], env=env, check=True)
488 subprocess.run(["plain", "migrate", "--backup"], env=env, check=True)
489
490 # preflight with db?
491
492 # Send SIGUSR1 to the parent process so the parent's handler is invoked
493 os.kill(os.getppid(), signal.SIGUSR1)