1"""Preflight checks on the database connection and migration state."""
2
3from __future__ import annotations
4
5from plain.postgres.db import get_connection
6from plain.preflight import PreflightCheck, PreflightResult, register_check
7from plain.runtime import settings
8
9
10@register_check("postgres.middleware_installed")
11class CheckMiddlewareInstalled(PreflightCheck):
12 """Errors if `DatabaseConnectionMiddleware` isn't in `MIDDLEWARE`.
13
14 Without it, pooled connections are only released by GC at the end of
15 each request — relying on refcount timing under load is a recipe for
16 pool exhaustion under cyclic refs or delayed finalization.
17 """
18
19 REQUIRED = "plain.postgres.DatabaseConnectionMiddleware"
20
21 def run(self) -> list[PreflightResult]:
22 if self.REQUIRED in settings.MIDDLEWARE:
23 return []
24 return [
25 PreflightResult(
26 fix=(
27 f"Add '{self.REQUIRED}' to MIDDLEWARE so pooled "
28 "database connections are returned at the end of each "
29 "request. Place it first so its after_response runs "
30 "after any middleware that queries the database."
31 ),
32 id="postgres.middleware_not_installed",
33 )
34 ]
35
36
37@register_check("postgres.postgres_version")
38class CheckPostgresVersion(PreflightCheck):
39 """Checks that the PostgreSQL server meets the minimum version requirement."""
40
41 MINIMUM_VERSION = 16
42
43 def run(self) -> list[PreflightResult]:
44 conn = get_connection()
45 conn.ensure_connection()
46 assert conn.connection is not None
47 major, minor = divmod(conn.connection.info.server_version, 10000)
48 if major < self.MINIMUM_VERSION:
49 return [
50 PreflightResult(
51 fix=f"PostgreSQL {self.MINIMUM_VERSION} or later is required (found {major}.{minor}).",
52 id="postgres.postgres_version_too_old",
53 )
54 ]
55 return []
56
57
58@register_check("postgres.database_tables")
59class CheckDatabaseTables(PreflightCheck):
60 """Checks for unknown tables in the database when plain.postgres is available."""
61
62 def run(self) -> list[PreflightResult]:
63 from plain.postgres.introspection import get_unknown_tables
64
65 unknown_tables = get_unknown_tables()
66
67 if not unknown_tables:
68 return []
69
70 table_names = ", ".join(unknown_tables)
71 return [
72 PreflightResult(
73 fix=f"Unknown tables in default database: {table_names}. "
74 "Tables may be from packages/models that have been uninstalled. "
75 "Make sure you have a backup, then run `plain postgres drop-unknown-tables` to remove them.",
76 id="postgres.unknown_database_tables",
77 warning=True,
78 )
79 ]
80
81
82@register_check("postgres.prunable_migrations")
83class CheckPrunableMigrations(PreflightCheck):
84 """Warns about stale migration records in the database."""
85
86 def run(self) -> list[PreflightResult]:
87 # Import here to avoid circular import issues
88 from plain.postgres.migrations.loader import MigrationLoader
89 from plain.postgres.migrations.recorder import MigrationRecorder
90
91 errors = []
92
93 # Load migrations from disk and database
94 conn = get_connection()
95 loader = MigrationLoader(conn, ignore_no_migrations=True)
96 recorder = MigrationRecorder(conn)
97 recorded_migrations = recorder.applied_migrations()
98
99 # disk_migrations should not be None after MigrationLoader initialization,
100 # but check to satisfy type checker
101 if loader.disk_migrations is None:
102 return errors
103
104 # Find all prunable migrations (recorded but not on disk)
105 all_prunable = [
106 migration
107 for migration in recorded_migrations
108 if migration not in loader.disk_migrations
109 ]
110
111 if not all_prunable:
112 return errors
113
114 # Separate into existing packages vs orphaned packages
115 existing_packages = set(loader.migrated_packages)
116 prunable_existing: list[tuple[str, str]] = []
117 prunable_orphaned: list[tuple[str, str]] = []
118
119 for migration in all_prunable:
120 package, name = migration
121 if package in existing_packages:
122 prunable_existing.append(migration)
123 else:
124 prunable_orphaned.append(migration)
125
126 # Build the warning message
127 total_count = len(all_prunable)
128 message_parts = [
129 f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
130 ]
131
132 if prunable_existing:
133 existing_list = ", ".join(
134 f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
135 )
136 if len(prunable_existing) > 3:
137 existing_list += f" (and {len(prunable_existing) - 3} more)"
138 message_parts.append(f"From existing packages: {existing_list}.")
139
140 if prunable_orphaned:
141 orphaned_list = ", ".join(
142 f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
143 )
144 if len(prunable_orphaned) > 3:
145 orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
146 message_parts.append(f"From removed packages: {orphaned_list}.")
147
148 message_parts.append("Run 'plain migrations prune' to review and remove them.")
149
150 errors.append(
151 PreflightResult(
152 fix=" ".join(message_parts),
153 id="postgres.prunable_migrations",
154 warning=True,
155 )
156 )
157
158 return errors