1from __future__ import annotations
2
3import inspect
4from collections import defaultdict
5from collections.abc import Callable
6from typing import Any
7
8from plain.packages import packages_registry
9from plain.postgres.db import get_connection
10from plain.postgres.migrations.recorder import MIGRATION_TABLE_NAME
11from plain.postgres.registry import ModelsRegistry, models_registry
12from plain.preflight import PreflightCheck, PreflightResult, register_check
13
14
15@register_check("postgres.all_models")
16class CheckAllModels(PreflightCheck):
17 """Validates all model definitions for common issues."""
18
19 def run(self) -> list[PreflightResult]:
20 db_table_models = defaultdict(list)
21 indexes = defaultdict(list)
22 constraints = defaultdict(list)
23 errors = []
24 models = models_registry.get_models()
25 for model in models:
26 db_table_models[model.model_options.db_table].append(
27 model.model_options.label
28 )
29 if not inspect.ismethod(model.preflight):
30 errors.append(
31 PreflightResult(
32 fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
33 obj=model,
34 id="postgres.preflight_method_overridden",
35 )
36 )
37 else:
38 errors.extend(model.preflight())
39 for model_index in model.model_options.indexes:
40 indexes[model_index.name].append(model.model_options.label)
41 for model_constraint in model.model_options.constraints:
42 constraints[model_constraint.name].append(model.model_options.label)
43 for db_table, model_labels in db_table_models.items():
44 if len(model_labels) != 1:
45 model_labels_str = ", ".join(model_labels)
46 errors.append(
47 PreflightResult(
48 fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
49 obj=db_table,
50 id="postgres.duplicate_db_table",
51 )
52 )
53 for index_name, model_labels in indexes.items():
54 if len(model_labels) > 1:
55 model_labels = set(model_labels)
56 errors.append(
57 PreflightResult(
58 fix="index name '{}' is not unique {} {}.".format(
59 index_name,
60 "for model" if len(model_labels) == 1 else "among models:",
61 ", ".join(sorted(model_labels)),
62 ),
63 id="postgres.index_name_not_unique_single"
64 if len(model_labels) == 1
65 else "postgres.index_name_not_unique_multiple",
66 ),
67 )
68 for constraint_name, model_labels in constraints.items():
69 if len(model_labels) > 1:
70 model_labels = set(model_labels)
71 errors.append(
72 PreflightResult(
73 fix="constraint name '{}' is not unique {} {}.".format(
74 constraint_name,
75 "for model" if len(model_labels) == 1 else "among models:",
76 ", ".join(sorted(model_labels)),
77 ),
78 id="postgres.constraint_name_not_unique_single"
79 if len(model_labels) == 1
80 else "postgres.constraint_name_not_unique_multiple",
81 ),
82 )
83 return errors
84
85
86def _check_lazy_references(
87 models_registry: ModelsRegistry, packages_registry: Any
88) -> list[PreflightResult]:
89 """
90 Ensure all lazy (i.e. string) model references have been resolved.
91
92 Lazy references are used in various places throughout Plain, primarily in
93 related fields and model signals. Identify those common cases and provide
94 more helpful error messages for them.
95 """
96 pending_models = set(models_registry._pending_operations)
97
98 # Short circuit if there aren't any errors.
99 if not pending_models:
100 return []
101
102 def extract_operation(
103 obj: Any,
104 ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
105 """
106 Take a callable found in Packages._pending_operations and identify the
107 original callable passed to Packages.lazy_model_operation(). If that
108 callable was a partial, return the inner, non-partial function and
109 any arguments and keyword arguments that were supplied with it.
110
111 obj is a callback defined locally in Packages.lazy_model_operation() and
112 annotated there with a `func` attribute so as to imitate a partial.
113 """
114 operation, args, keywords = obj, [], {}
115 while hasattr(operation, "func"):
116 args.extend(getattr(operation, "args", []))
117 keywords.update(getattr(operation, "keywords", {}))
118 operation = operation.func
119 return operation, args, keywords
120
121 def app_model_error(model_key: tuple[str, str]) -> str:
122 try:
123 packages_registry.get_package_config(model_key[0])
124 model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
125 except LookupError:
126 model_error = f"app '{model_key[0]}' isn't installed"
127 return model_error
128
129 # Here are several functions which return CheckMessage instances for the
130 # most common usages of lazy operations throughout Plain. These functions
131 # take the model that was being waited on as an (package_label, modelname)
132 # pair, the original lazy function, and its positional and keyword args as
133 # determined by extract_operation().
134
135 def field_error(
136 model_key: tuple[str, str],
137 func: Callable[..., Any],
138 args: list[Any],
139 keywords: dict[str, Any],
140 ) -> PreflightResult:
141 error_msg = (
142 "The field %(field)s was declared with a lazy reference "
143 "to '%(model)s', but %(model_error)s."
144 )
145 params = {
146 "model": ".".join(model_key),
147 "field": keywords["field"],
148 "model_error": app_model_error(model_key),
149 }
150 return PreflightResult(
151 fix=error_msg % params,
152 obj=keywords["field"],
153 id="fields.lazy_reference_not_resolvable",
154 )
155
156 def default_error(
157 model_key: tuple[str, str],
158 func: Callable[..., Any],
159 args: list[Any],
160 keywords: dict[str, Any],
161 ) -> PreflightResult:
162 error_msg = (
163 "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
164 )
165 params = {
166 "op": func,
167 "model": ".".join(model_key),
168 "model_error": app_model_error(model_key),
169 }
170 return PreflightResult(
171 fix=error_msg % params,
172 obj=func,
173 id="postgres.lazy_reference_resolution_failed",
174 )
175
176 # Maps common uses of lazy operations to corresponding error functions
177 # defined above. If a key maps to None, no error will be produced.
178 # default_error() will be used for usages that don't appear in this dict.
179 known_lazy = {
180 ("plain.postgres.fields.related", "resolve_related_class"): field_error,
181 }
182
183 def build_error(
184 model_key: tuple[str, str],
185 func: Callable[..., Any],
186 args: list[Any],
187 keywords: dict[str, Any],
188 ) -> PreflightResult | None:
189 key = (func.__module__, func.__name__) # type: ignore[attr-defined]
190 error_fn = known_lazy.get(key, default_error)
191 return error_fn(model_key, func, args, keywords) if error_fn else None
192
193 return sorted(
194 filter(
195 None,
196 (
197 build_error(model_key, *extract_operation(func))
198 for model_key in pending_models
199 for func in models_registry._pending_operations[model_key]
200 ),
201 ),
202 key=lambda error: error.fix,
203 )
204
205
206@register_check("postgres.lazy_references")
207class CheckLazyReferences(PreflightCheck):
208 """Ensures all lazy (string) model references have been resolved."""
209
210 def run(self) -> list[PreflightResult]:
211 return _check_lazy_references(models_registry, packages_registry)
212
213
214@register_check("postgres.postgres_version")
215class CheckPostgresVersion(PreflightCheck):
216 """Checks that the PostgreSQL server meets the minimum version requirement."""
217
218 MINIMUM_VERSION = 16
219
220 def run(self) -> list[PreflightResult]:
221 conn = get_connection()
222 major, minor = divmod(conn.pg_version, 10000)
223 if major < self.MINIMUM_VERSION:
224 return [
225 PreflightResult(
226 fix=f"PostgreSQL {self.MINIMUM_VERSION} or later is required (found {major}.{minor}).",
227 id="postgres.postgres_version_too_old",
228 )
229 ]
230 return []
231
232
233@register_check("postgres.database_tables")
234class CheckDatabaseTables(PreflightCheck):
235 """Checks for unknown tables in the database when plain.postgres is available."""
236
237 def run(self) -> list[PreflightResult]:
238 conn = get_connection()
239 unknown_tables = (
240 set(conn.table_names())
241 - set(conn.plain_table_names())
242 - {MIGRATION_TABLE_NAME}
243 )
244
245 if not unknown_tables:
246 return []
247
248 table_names = ", ".join(sorted(unknown_tables))
249 return [
250 PreflightResult(
251 fix=f"Unknown tables in default database: {table_names}. "
252 "Tables may be from packages/models that have been uninstalled. "
253 "Make sure you have a backup, then run `plain db drop-unknown-tables` to remove them.",
254 id="postgres.unknown_database_tables",
255 warning=True,
256 )
257 ]
258
259
260@register_check("postgres.prunable_migrations")
261class CheckPrunableMigrations(PreflightCheck):
262 """Warns about stale migration records in the database."""
263
264 def run(self) -> list[PreflightResult]:
265 # Import here to avoid circular import issues
266 from plain.postgres.migrations.loader import MigrationLoader
267 from plain.postgres.migrations.recorder import MigrationRecorder
268
269 errors = []
270
271 # Load migrations from disk and database
272 conn = get_connection()
273 loader = MigrationLoader(conn, ignore_no_migrations=True)
274 recorder = MigrationRecorder(conn)
275 recorded_migrations = recorder.applied_migrations()
276
277 # disk_migrations should not be None after MigrationLoader initialization,
278 # but check to satisfy type checker
279 if loader.disk_migrations is None:
280 return errors
281
282 # Find all prunable migrations (recorded but not on disk)
283 all_prunable = [
284 migration
285 for migration in recorded_migrations
286 if migration not in loader.disk_migrations
287 ]
288
289 if not all_prunable:
290 return errors
291
292 # Separate into existing packages vs orphaned packages
293 existing_packages = set(loader.migrated_packages)
294 prunable_existing: list[tuple[str, str]] = []
295 prunable_orphaned: list[tuple[str, str]] = []
296
297 for migration in all_prunable:
298 package, name = migration
299 if package in existing_packages:
300 prunable_existing.append(migration)
301 else:
302 prunable_orphaned.append(migration)
303
304 # Build the warning message
305 total_count = len(all_prunable)
306 message_parts = [
307 f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
308 ]
309
310 if prunable_existing:
311 existing_list = ", ".join(
312 f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
313 )
314 if len(prunable_existing) > 3:
315 existing_list += f" (and {len(prunable_existing) - 3} more)"
316 message_parts.append(f"From existing packages: {existing_list}.")
317
318 if prunable_orphaned:
319 orphaned_list = ", ".join(
320 f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
321 )
322 if len(prunable_orphaned) > 3:
323 orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
324 message_parts.append(f"From removed packages: {orphaned_list}.")
325
326 message_parts.append("Run 'plain migrations prune' to review and remove them.")
327
328 errors.append(
329 PreflightResult(
330 fix=" ".join(message_parts),
331 id="postgres.prunable_migrations",
332 warning=True,
333 )
334 )
335
336 return errors