1from __future__ import annotations
  2
  3from collections.abc import Callable, Generator
  4from typing import Any, TypeVar
  5
  6from plain.runtime import settings
  7
  8from .results import PreflightResult
  9
 10T = TypeVar("T")
 11
 12
 13class CheckRegistry:
 14    def __init__(self) -> None:
 15        self.checks: dict[
 16            str, tuple[type[Any], bool]
 17        ] = {}  # name -> (check_class, deploy)
 18
 19    def register_check(
 20        self, check_class: type[Any], name: str, deploy: bool = False
 21    ) -> None:
 22        """Register a check class with a unique name."""
 23        if name in self.checks:
 24            raise ValueError(f"Check {name} already registered")
 25        self.checks[name] = (check_class, deploy)
 26
 27    def run_checks(
 28        self,
 29        include_deploy_checks: bool = False,
 30    ) -> Generator[tuple[type[Any], str, list[PreflightResult]]]:
 31        """
 32        Run all registered checks and yield (check_class, name, results) tuples.
 33        """
 34        # Validate silenced check names
 35        silenced_checks = settings.PREFLIGHT_SILENCED_CHECKS
 36        unknown_silenced = set(silenced_checks) - set(self.checks.keys())
 37        if unknown_silenced:
 38            unknown_names = ", ".join(sorted(unknown_silenced))
 39            raise ValueError(
 40                f"Unknown check names in PREFLIGHT_SILENCED_CHECKS: {unknown_names}. "
 41                "Check for typos or remove outdated check names."
 42            )
 43
 44        for name, (check_class, deploy) in sorted(self.checks.items()):
 45            # Skip silenced checks
 46            if name in silenced_checks:
 47                continue
 48
 49            # Skip deployment checks if not requested
 50            if deploy and not include_deploy_checks:
 51                continue
 52
 53            # Instantiate and run check
 54            check = check_class()
 55            results = check.run()
 56            yield check_class, name, results
 57
 58    def get_checks(
 59        self, include_deploy_checks: bool = False
 60    ) -> list[tuple[type[Any], str]]:
 61        """Get list of (check_class, name) tuples."""
 62        result: list[tuple[type[Any], str]] = []
 63        for name, (check_class, deploy) in self.checks.items():
 64            if deploy and not include_deploy_checks:
 65                continue
 66            result.append((check_class, name))
 67        return result
 68
 69
 70checks_registry = CheckRegistry()
 71
 72
 73def register_check(name: str, *, deploy: bool = False) -> Callable[[type[T]], type[T]]:
 74    """
 75    Decorator to register a check class.
 76
 77    Usage:
 78        @register_check("security.secret_key", deploy=True)
 79        class CheckSecretKey(PreflightCheck):
 80            pass
 81
 82        @register_check("files.upload_temp_dir")
 83        class CheckUploadTempDir(PreflightCheck):
 84            pass
 85    """
 86
 87    def wrapper(cls: type[T]) -> type[T]:
 88        checks_registry.register_check(cls, name=name, deploy=deploy)
 89        return cls
 90
 91    return wrapper
 92
 93
 94run_checks = checks_registry.run_checks
 95
 96# Cached error/warning counts — populated on first call, refreshed by
 97# PreflightView when the full page is viewed.
 98_check_counts: dict[str, int] | None = None
 99
100
101def get_check_counts() -> dict[str, int]:
102    """Return ``{"errors": N, "warnings": N}``, caching for the process lifetime."""
103    global _check_counts
104
105    if _check_counts is not None:
106        return _check_counts
107
108    from plain.packages import packages_registry
109
110    packages_registry.autodiscover_modules("preflight", include_app=True)
111
112    include_deploy = not settings.DEBUG
113    warning_count = 0
114    error_count = 0
115
116    for _check_class, _name, results in run_checks(
117        include_deploy_checks=include_deploy
118    ):
119        issues = [r for r in results if not r.is_silenced()]
120        if issues:
121            if any(not issue.warning for issue in issues):
122                error_count += 1
123            else:
124                warning_count += 1
125
126    _check_counts = {"errors": error_count, "warnings": warning_count}
127    return _check_counts
128
129
130def set_check_counts(*, errors: int, warnings: int) -> None:
131    """Update the cached counts (called by PreflightView after running full checks)."""
132    global _check_counts
133    _check_counts = {"errors": errors, "warnings": warnings}