1from __future__ import annotations
2
3import inspect
4from collections import defaultdict
5from collections.abc import Callable
6from typing import Any
7
8from plain.models.db import db_connection
9from plain.models.registry import ModelsRegistry, models_registry
10from plain.packages import packages_registry
11from plain.preflight import PreflightCheck, PreflightResult, register_check
12
13
14@register_check("models.database_backends")
15class CheckDatabaseBackends(PreflightCheck):
16 """Validates database backend configuration when plain.models is available."""
17
18 def run(self) -> list[PreflightResult]:
19 return db_connection.validation.preflight()
20
21
22@register_check("models.all_models")
23class CheckAllModels(PreflightCheck):
24 """Validates all model definitions for common issues."""
25
26 def run(self) -> list[PreflightResult]:
27 db_table_models = defaultdict(list)
28 indexes = defaultdict(list)
29 constraints = defaultdict(list)
30 errors = []
31 models = models_registry.get_models()
32 for model in models:
33 db_table_models[model.model_options.db_table].append(
34 model.model_options.label
35 )
36 if not inspect.ismethod(model.preflight):
37 errors.append(
38 PreflightResult(
39 fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
40 obj=model,
41 id="models.preflight_method_overridden",
42 )
43 )
44 else:
45 errors.extend(model.preflight())
46 for model_index in model.model_options.indexes:
47 indexes[model_index.name].append(model.model_options.label)
48 for model_constraint in model.model_options.constraints:
49 constraints[model_constraint.name].append(model.model_options.label)
50 for db_table, model_labels in db_table_models.items():
51 if len(model_labels) != 1:
52 model_labels_str = ", ".join(model_labels)
53 errors.append(
54 PreflightResult(
55 fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
56 obj=db_table,
57 id="models.duplicate_db_table",
58 )
59 )
60 for index_name, model_labels in indexes.items():
61 if len(model_labels) > 1:
62 model_labels = set(model_labels)
63 errors.append(
64 PreflightResult(
65 fix="index name '{}' is not unique {} {}.".format(
66 index_name,
67 "for model" if len(model_labels) == 1 else "among models:",
68 ", ".join(sorted(model_labels)),
69 ),
70 id="models.index_name_not_unique_single"
71 if len(model_labels) == 1
72 else "models.index_name_not_unique_multiple",
73 ),
74 )
75 for constraint_name, model_labels in constraints.items():
76 if len(model_labels) > 1:
77 model_labels = set(model_labels)
78 errors.append(
79 PreflightResult(
80 fix="constraint name '{}' is not unique {} {}.".format(
81 constraint_name,
82 "for model" if len(model_labels) == 1 else "among models:",
83 ", ".join(sorted(model_labels)),
84 ),
85 id="models.constraint_name_not_unique_single"
86 if len(model_labels) == 1
87 else "models.constraint_name_not_unique_multiple",
88 ),
89 )
90 return errors
91
92
93def _check_lazy_references(
94 models_registry: ModelsRegistry, packages_registry: Any
95) -> list[PreflightResult]:
96 """
97 Ensure all lazy (i.e. string) model references have been resolved.
98
99 Lazy references are used in various places throughout Plain, primarily in
100 related fields and model signals. Identify those common cases and provide
101 more helpful error messages for them.
102 """
103 pending_models = set(models_registry._pending_operations)
104
105 # Short circuit if there aren't any errors.
106 if not pending_models:
107 return []
108
109 def extract_operation(
110 obj: Any,
111 ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
112 """
113 Take a callable found in Packages._pending_operations and identify the
114 original callable passed to Packages.lazy_model_operation(). If that
115 callable was a partial, return the inner, non-partial function and
116 any arguments and keyword arguments that were supplied with it.
117
118 obj is a callback defined locally in Packages.lazy_model_operation() and
119 annotated there with a `func` attribute so as to imitate a partial.
120 """
121 operation, args, keywords = obj, [], {}
122 while hasattr(operation, "func"):
123 args.extend(getattr(operation, "args", []))
124 keywords.update(getattr(operation, "keywords", {}))
125 operation = operation.func
126 return operation, args, keywords
127
128 def app_model_error(model_key: tuple[str, str]) -> str:
129 try:
130 packages_registry.get_package_config(model_key[0])
131 model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
132 except LookupError:
133 model_error = f"app '{model_key[0]}' isn't installed"
134 return model_error
135
136 # Here are several functions which return CheckMessage instances for the
137 # most common usages of lazy operations throughout Plain. These functions
138 # take the model that was being waited on as an (package_label, modelname)
139 # pair, the original lazy function, and its positional and keyword args as
140 # determined by extract_operation().
141
142 def field_error(
143 model_key: tuple[str, str],
144 func: Callable[..., Any],
145 args: list[Any],
146 keywords: dict[str, Any],
147 ) -> PreflightResult:
148 error_msg = (
149 "The field %(field)s was declared with a lazy reference "
150 "to '%(model)s', but %(model_error)s."
151 )
152 params = {
153 "model": ".".join(model_key),
154 "field": keywords["field"],
155 "model_error": app_model_error(model_key),
156 }
157 return PreflightResult(
158 fix=error_msg % params,
159 obj=keywords["field"],
160 id="fields.lazy_reference_not_resolvable",
161 )
162
163 def default_error(
164 model_key: tuple[str, str],
165 func: Callable[..., Any],
166 args: list[Any],
167 keywords: dict[str, Any],
168 ) -> PreflightResult:
169 error_msg = (
170 "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
171 )
172 params = {
173 "op": func,
174 "model": ".".join(model_key),
175 "model_error": app_model_error(model_key),
176 }
177 return PreflightResult(
178 fix=error_msg % params,
179 obj=func,
180 id="models.lazy_reference_resolution_failed",
181 )
182
183 # Maps common uses of lazy operations to corresponding error functions
184 # defined above. If a key maps to None, no error will be produced.
185 # default_error() will be used for usages that don't appear in this dict.
186 known_lazy = {
187 ("plain.models.fields.related", "resolve_related_class"): field_error,
188 }
189
190 def build_error(
191 model_key: tuple[str, str],
192 func: Callable[..., Any],
193 args: list[Any],
194 keywords: dict[str, Any],
195 ) -> PreflightResult | None:
196 key = (func.__module__, func.__name__) # type: ignore[attr-defined]
197 error_fn = known_lazy.get(key, default_error)
198 return error_fn(model_key, func, args, keywords) if error_fn else None
199
200 return sorted(
201 filter(
202 None,
203 (
204 build_error(model_key, *extract_operation(func))
205 for model_key in pending_models
206 for func in models_registry._pending_operations[model_key]
207 ),
208 ),
209 key=lambda error: error.message,
210 )
211
212
213@register_check("models.lazy_references")
214class CheckLazyReferences(PreflightCheck):
215 """Ensures all lazy (string) model references have been resolved."""
216
217 def run(self) -> list[PreflightResult]:
218 return _check_lazy_references(models_registry, packages_registry)
219
220
221@register_check("models.database_tables")
222class CheckDatabaseTables(PreflightCheck):
223 """Checks for unknown tables in the database when plain.models is available."""
224
225 def run(self) -> list[PreflightResult]:
226 errors = []
227
228 db_tables = db_connection.introspection.table_names()
229 model_tables = db_connection.introspection.plain_table_names()
230 unknown_tables = set(db_tables) - set(model_tables)
231 unknown_tables.discard("plainmigrations") # Know this could be there
232 if unknown_tables:
233 table_names = ", ".join(unknown_tables)
234 specific_fix = f'echo "DROP TABLE IF EXISTS {unknown_tables.pop()}" | plain models db-shell'
235 errors.append(
236 PreflightResult(
237 fix=f"Unknown tables in default database: {table_names}. "
238 "Tables may be from packages/models that have been uninstalled. "
239 "Make sure you have a backup and delete the tables manually "
240 f"(ex. `{specific_fix}`).",
241 id="models.unknown_database_tables",
242 warning=True,
243 )
244 )
245
246 return errors
247
248
249@register_check("models.prunable_migrations")
250class CheckPrunableMigrations(PreflightCheck):
251 """Warns about stale migration records in the database."""
252
253 def run(self) -> list[PreflightResult]:
254 # Import here to avoid circular import issues
255 from plain.models.migrations.loader import MigrationLoader
256 from plain.models.migrations.recorder import MigrationRecorder
257
258 errors = []
259
260 # Load migrations from disk and database
261 loader = MigrationLoader(db_connection, ignore_no_migrations=True)
262 recorder = MigrationRecorder(db_connection)
263 recorded_migrations = recorder.applied_migrations()
264
265 # disk_migrations should not be None after MigrationLoader initialization,
266 # but check to satisfy type checker
267 if loader.disk_migrations is None:
268 return errors
269
270 # Find all prunable migrations (recorded but not on disk)
271 all_prunable = [
272 migration
273 for migration in recorded_migrations
274 if migration not in loader.disk_migrations
275 ]
276
277 if not all_prunable:
278 return errors
279
280 # Separate into existing packages vs orphaned packages
281 existing_packages = set(loader.migrated_packages)
282 prunable_existing: list[tuple[str, str]] = []
283 prunable_orphaned: list[tuple[str, str]] = []
284
285 for migration in all_prunable:
286 package, name = migration
287 if package in existing_packages:
288 prunable_existing.append(migration)
289 else:
290 prunable_orphaned.append(migration)
291
292 # Build the warning message
293 total_count = len(all_prunable)
294 message_parts = [
295 f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
296 ]
297
298 if prunable_existing:
299 existing_list = ", ".join(
300 f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
301 )
302 if len(prunable_existing) > 3:
303 existing_list += f" (and {len(prunable_existing) - 3} more)"
304 message_parts.append(f"From existing packages: {existing_list}.")
305
306 if prunable_orphaned:
307 orphaned_list = ", ".join(
308 f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
309 )
310 if len(prunable_orphaned) > 3:
311 orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
312 message_parts.append(f"From removed packages: {orphaned_list}.")
313
314 message_parts.append(
315 "Run 'plain models prune-migrations' to review and remove them."
316 )
317
318 errors.append(
319 PreflightResult(
320 fix=" ".join(message_parts),
321 id="models.prunable_migrations",
322 warning=True,
323 )
324 )
325
326 return errors