v0.150.0
  1"""Preflight checks on the database connection and migration state."""
  2
  3from __future__ import annotations
  4
  5from plain.postgres.db import get_connection
  6from plain.preflight import PreflightCheck, PreflightResult, register_check
  7from plain.runtime import settings
  8
  9
 10@register_check("postgres.middleware_installed")
 11class CheckMiddlewareInstalled(PreflightCheck):
 12    """Errors if `DatabaseConnectionMiddleware` isn't in `MIDDLEWARE`.
 13
 14    Without it, pooled connections are only released by GC at the end of
 15    each request — relying on refcount timing under load is a recipe for
 16    pool exhaustion under cyclic refs or delayed finalization.
 17    """
 18
 19    REQUIRED = "plain.postgres.DatabaseConnectionMiddleware"
 20
 21    def run(self) -> list[PreflightResult]:
 22        if self.REQUIRED in settings.MIDDLEWARE:
 23            return []
 24        return [
 25            PreflightResult(
 26                fix=(
 27                    f"Add '{self.REQUIRED}' to MIDDLEWARE so pooled "
 28                    "database connections are returned at the end of each "
 29                    "request. Place it first so its after_response runs "
 30                    "after any middleware that queries the database."
 31                ),
 32                id="postgres.middleware_not_installed",
 33            )
 34        ]
 35
 36
 37@register_check("postgres.postgres_version")
 38class CheckPostgresVersion(PreflightCheck):
 39    """Checks that the PostgreSQL server meets the minimum version requirement."""
 40
 41    MINIMUM_VERSION = 16
 42
 43    def run(self) -> list[PreflightResult]:
 44        conn = get_connection()
 45        conn.ensure_connection()
 46        assert conn.connection is not None
 47        major, minor = divmod(conn.connection.info.server_version, 10000)
 48        if major < self.MINIMUM_VERSION:
 49            return [
 50                PreflightResult(
 51                    fix=f"PostgreSQL {self.MINIMUM_VERSION} or later is required (found {major}.{minor}).",
 52                    id="postgres.postgres_version_too_old",
 53                )
 54            ]
 55        return []
 56
 57
 58@register_check("postgres.database_tables")
 59class CheckDatabaseTables(PreflightCheck):
 60    """Checks for unknown tables in the database when plain.postgres is available."""
 61
 62    def run(self) -> list[PreflightResult]:
 63        from plain.postgres.introspection import get_unknown_tables
 64
 65        unknown_tables = get_unknown_tables()
 66
 67        if not unknown_tables:
 68            return []
 69
 70        table_names = ", ".join(unknown_tables)
 71        return [
 72            PreflightResult(
 73                fix=f"Unknown tables in default database: {table_names}. "
 74                "Tables may be from packages/models that have been uninstalled. "
 75                "Make sure you have a backup, then run `plain postgres drop-unknown-tables` to remove them.",
 76                id="postgres.unknown_database_tables",
 77                warning=True,
 78            )
 79        ]
 80
 81
 82@register_check("postgres.prunable_migrations")
 83class CheckPrunableMigrations(PreflightCheck):
 84    """Warns about stale migration records in the database."""
 85
 86    def run(self) -> list[PreflightResult]:
 87        # Import here to avoid circular import issues
 88        from plain.postgres.migrations.loader import MigrationLoader
 89        from plain.postgres.migrations.recorder import MigrationRecorder
 90
 91        errors = []
 92
 93        # Load migrations from disk and database
 94        conn = get_connection()
 95        loader = MigrationLoader(conn, ignore_no_migrations=True)
 96        recorder = MigrationRecorder(conn)
 97        recorded_migrations = recorder.applied_migrations()
 98
 99        # disk_migrations should not be None after MigrationLoader initialization,
100        # but check to satisfy type checker
101        if loader.disk_migrations is None:
102            return errors
103
104        # Find all prunable migrations (recorded but not on disk)
105        all_prunable = [
106            migration
107            for migration in recorded_migrations
108            if migration not in loader.disk_migrations
109        ]
110
111        if not all_prunable:
112            return errors
113
114        # Separate into existing packages vs orphaned packages
115        existing_packages = set(loader.migrated_packages)
116        prunable_existing: list[tuple[str, str]] = []
117        prunable_orphaned: list[tuple[str, str]] = []
118
119        for migration in all_prunable:
120            package, name = migration
121            if package in existing_packages:
122                prunable_existing.append(migration)
123            else:
124                prunable_orphaned.append(migration)
125
126        # Build the warning message
127        total_count = len(all_prunable)
128        message_parts = [
129            f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
130        ]
131
132        if prunable_existing:
133            existing_list = ", ".join(
134                f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
135            )
136            if len(prunable_existing) > 3:
137                existing_list += f" (and {len(prunable_existing) - 3} more)"
138            message_parts.append(f"From existing packages: {existing_list}.")
139
140        if prunable_orphaned:
141            orphaned_list = ", ".join(
142                f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
143            )
144            if len(prunable_orphaned) > 3:
145                orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
146            message_parts.append(f"From removed packages: {orphaned_list}.")
147
148        message_parts.append("Run 'plain migrations prune' to review and remove them.")
149
150        errors.append(
151            PreflightResult(
152                fix=" ".join(message_parts),
153                id="postgres.prunable_migrations",
154                warning=True,
155            )
156        )
157
158        return errors