1from __future__ import annotations
2
3import inspect
4from collections import defaultdict
5from collections.abc import Callable
6from typing import Any
7
8from plain.models.db import db_connection
9from plain.models.migrations.recorder import MIGRATION_TABLE_NAME
10from plain.models.registry import ModelsRegistry, models_registry
11from plain.packages import packages_registry
12from plain.preflight import PreflightCheck, PreflightResult, register_check
13
14
15@register_check("models.all_models")
16class CheckAllModels(PreflightCheck):
17 """Validates all model definitions for common issues."""
18
19 def run(self) -> list[PreflightResult]:
20 db_table_models = defaultdict(list)
21 indexes = defaultdict(list)
22 constraints = defaultdict(list)
23 errors = []
24 models = models_registry.get_models()
25 for model in models:
26 db_table_models[model.model_options.db_table].append(
27 model.model_options.label
28 )
29 if not inspect.ismethod(model.preflight):
30 errors.append(
31 PreflightResult(
32 fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
33 obj=model,
34 id="models.preflight_method_overridden",
35 )
36 )
37 else:
38 errors.extend(model.preflight())
39 for model_index in model.model_options.indexes:
40 indexes[model_index.name].append(model.model_options.label)
41 for model_constraint in model.model_options.constraints:
42 constraints[model_constraint.name].append(model.model_options.label)
43 for db_table, model_labels in db_table_models.items():
44 if len(model_labels) != 1:
45 model_labels_str = ", ".join(model_labels)
46 errors.append(
47 PreflightResult(
48 fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
49 obj=db_table,
50 id="models.duplicate_db_table",
51 )
52 )
53 for index_name, model_labels in indexes.items():
54 if len(model_labels) > 1:
55 model_labels = set(model_labels)
56 errors.append(
57 PreflightResult(
58 fix="index name '{}' is not unique {} {}.".format(
59 index_name,
60 "for model" if len(model_labels) == 1 else "among models:",
61 ", ".join(sorted(model_labels)),
62 ),
63 id="models.index_name_not_unique_single"
64 if len(model_labels) == 1
65 else "models.index_name_not_unique_multiple",
66 ),
67 )
68 for constraint_name, model_labels in constraints.items():
69 if len(model_labels) > 1:
70 model_labels = set(model_labels)
71 errors.append(
72 PreflightResult(
73 fix="constraint name '{}' is not unique {} {}.".format(
74 constraint_name,
75 "for model" if len(model_labels) == 1 else "among models:",
76 ", ".join(sorted(model_labels)),
77 ),
78 id="models.constraint_name_not_unique_single"
79 if len(model_labels) == 1
80 else "models.constraint_name_not_unique_multiple",
81 ),
82 )
83 return errors
84
85
86def _check_lazy_references(
87 models_registry: ModelsRegistry, packages_registry: Any
88) -> list[PreflightResult]:
89 """
90 Ensure all lazy (i.e. string) model references have been resolved.
91
92 Lazy references are used in various places throughout Plain, primarily in
93 related fields and model signals. Identify those common cases and provide
94 more helpful error messages for them.
95 """
96 pending_models = set(models_registry._pending_operations)
97
98 # Short circuit if there aren't any errors.
99 if not pending_models:
100 return []
101
102 def extract_operation(
103 obj: Any,
104 ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
105 """
106 Take a callable found in Packages._pending_operations and identify the
107 original callable passed to Packages.lazy_model_operation(). If that
108 callable was a partial, return the inner, non-partial function and
109 any arguments and keyword arguments that were supplied with it.
110
111 obj is a callback defined locally in Packages.lazy_model_operation() and
112 annotated there with a `func` attribute so as to imitate a partial.
113 """
114 operation, args, keywords = obj, [], {}
115 while hasattr(operation, "func"):
116 args.extend(getattr(operation, "args", []))
117 keywords.update(getattr(operation, "keywords", {}))
118 operation = operation.func
119 return operation, args, keywords
120
121 def app_model_error(model_key: tuple[str, str]) -> str:
122 try:
123 packages_registry.get_package_config(model_key[0])
124 model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
125 except LookupError:
126 model_error = f"app '{model_key[0]}' isn't installed"
127 return model_error
128
129 # Here are several functions which return CheckMessage instances for the
130 # most common usages of lazy operations throughout Plain. These functions
131 # take the model that was being waited on as an (package_label, modelname)
132 # pair, the original lazy function, and its positional and keyword args as
133 # determined by extract_operation().
134
135 def field_error(
136 model_key: tuple[str, str],
137 func: Callable[..., Any],
138 args: list[Any],
139 keywords: dict[str, Any],
140 ) -> PreflightResult:
141 error_msg = (
142 "The field %(field)s was declared with a lazy reference "
143 "to '%(model)s', but %(model_error)s."
144 )
145 params = {
146 "model": ".".join(model_key),
147 "field": keywords["field"],
148 "model_error": app_model_error(model_key),
149 }
150 return PreflightResult(
151 fix=error_msg % params,
152 obj=keywords["field"],
153 id="fields.lazy_reference_not_resolvable",
154 )
155
156 def default_error(
157 model_key: tuple[str, str],
158 func: Callable[..., Any],
159 args: list[Any],
160 keywords: dict[str, Any],
161 ) -> PreflightResult:
162 error_msg = (
163 "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
164 )
165 params = {
166 "op": func,
167 "model": ".".join(model_key),
168 "model_error": app_model_error(model_key),
169 }
170 return PreflightResult(
171 fix=error_msg % params,
172 obj=func,
173 id="models.lazy_reference_resolution_failed",
174 )
175
176 # Maps common uses of lazy operations to corresponding error functions
177 # defined above. If a key maps to None, no error will be produced.
178 # default_error() will be used for usages that don't appear in this dict.
179 known_lazy = {
180 ("plain.models.fields.related", "resolve_related_class"): field_error,
181 }
182
183 def build_error(
184 model_key: tuple[str, str],
185 func: Callable[..., Any],
186 args: list[Any],
187 keywords: dict[str, Any],
188 ) -> PreflightResult | None:
189 key = (func.__module__, func.__name__) # type: ignore[attr-defined]
190 error_fn = known_lazy.get(key, default_error)
191 return error_fn(model_key, func, args, keywords) if error_fn else None
192
193 return sorted(
194 filter(
195 None,
196 (
197 build_error(model_key, *extract_operation(func))
198 for model_key in pending_models
199 for func in models_registry._pending_operations[model_key]
200 ),
201 ),
202 key=lambda error: error.fix,
203 )
204
205
206@register_check("models.lazy_references")
207class CheckLazyReferences(PreflightCheck):
208 """Ensures all lazy (string) model references have been resolved."""
209
210 def run(self) -> list[PreflightResult]:
211 return _check_lazy_references(models_registry, packages_registry)
212
213
214@register_check("models.database_tables")
215class CheckDatabaseTables(PreflightCheck):
216 """Checks for unknown tables in the database when plain.models is available."""
217
218 def run(self) -> list[PreflightResult]:
219 unknown_tables = (
220 set(db_connection.table_names())
221 - set(db_connection.plain_table_names())
222 - {MIGRATION_TABLE_NAME}
223 )
224
225 if not unknown_tables:
226 return []
227
228 table_names = ", ".join(sorted(unknown_tables))
229 return [
230 PreflightResult(
231 fix=f"Unknown tables in default database: {table_names}. "
232 "Tables may be from packages/models that have been uninstalled. "
233 "Make sure you have a backup, then run `plain db drop-unknown-tables` to remove them.",
234 id="models.unknown_database_tables",
235 warning=True,
236 )
237 ]
238
239
240@register_check("models.prunable_migrations")
241class CheckPrunableMigrations(PreflightCheck):
242 """Warns about stale migration records in the database."""
243
244 def run(self) -> list[PreflightResult]:
245 # Import here to avoid circular import issues
246 from plain.models.migrations.loader import MigrationLoader
247 from plain.models.migrations.recorder import MigrationRecorder
248
249 errors = []
250
251 # Load migrations from disk and database
252 loader = MigrationLoader(db_connection, ignore_no_migrations=True)
253 recorder = MigrationRecorder(db_connection)
254 recorded_migrations = recorder.applied_migrations()
255
256 # disk_migrations should not be None after MigrationLoader initialization,
257 # but check to satisfy type checker
258 if loader.disk_migrations is None:
259 return errors
260
261 # Find all prunable migrations (recorded but not on disk)
262 all_prunable = [
263 migration
264 for migration in recorded_migrations
265 if migration not in loader.disk_migrations
266 ]
267
268 if not all_prunable:
269 return errors
270
271 # Separate into existing packages vs orphaned packages
272 existing_packages = set(loader.migrated_packages)
273 prunable_existing: list[tuple[str, str]] = []
274 prunable_orphaned: list[tuple[str, str]] = []
275
276 for migration in all_prunable:
277 package, name = migration
278 if package in existing_packages:
279 prunable_existing.append(migration)
280 else:
281 prunable_orphaned.append(migration)
282
283 # Build the warning message
284 total_count = len(all_prunable)
285 message_parts = [
286 f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
287 ]
288
289 if prunable_existing:
290 existing_list = ", ".join(
291 f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
292 )
293 if len(prunable_existing) > 3:
294 existing_list += f" (and {len(prunable_existing) - 3} more)"
295 message_parts.append(f"From existing packages: {existing_list}.")
296
297 if prunable_orphaned:
298 orphaned_list = ", ".join(
299 f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
300 )
301 if len(prunable_orphaned) > 3:
302 orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
303 message_parts.append(f"From removed packages: {orphaned_list}.")
304
305 message_parts.append("Run 'plain migrations prune' to review and remove them.")
306
307 errors.append(
308 PreflightResult(
309 fix=" ".join(message_parts),
310 id="models.prunable_migrations",
311 warning=True,
312 )
313 )
314
315 return errors