1from __future__ import annotations
2
3import inspect
4from collections import defaultdict
5from collections.abc import Callable
6from typing import Any
7
8from plain.models.db import db_connection
9from plain.models.registry import ModelsRegistry, models_registry
10from plain.packages import packages_registry
11from plain.preflight import PreflightCheck, PreflightResult, register_check
12
13
14@register_check("models.database_backends")
15class CheckDatabaseBackends(PreflightCheck):
16 """Validates database backend configuration when plain.models is available."""
17
18 def run(self) -> list[PreflightResult]:
19 return db_connection.validation.preflight()
20
21
22@register_check("models.all_models")
23class CheckAllModels(PreflightCheck):
24 """Validates all model definitions for common issues."""
25
26 def run(self) -> list[PreflightResult]:
27 db_table_models = defaultdict(list)
28 indexes = defaultdict(list)
29 constraints = defaultdict(list)
30 errors = []
31 models = models_registry.get_models()
32 for model in models:
33 db_table_models[model.model_options.db_table].append(
34 model.model_options.label
35 )
36 if not inspect.ismethod(model.preflight):
37 errors.append(
38 PreflightResult(
39 fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
40 obj=model,
41 id="models.preflight_method_overridden",
42 )
43 )
44 else:
45 errors.extend(model.preflight())
46 for model_index in model.model_options.indexes:
47 indexes[model_index.name].append(model.model_options.label)
48 for model_constraint in model.model_options.constraints:
49 constraints[model_constraint.name].append(model.model_options.label)
50 for db_table, model_labels in db_table_models.items():
51 if len(model_labels) != 1:
52 model_labels_str = ", ".join(model_labels)
53 errors.append(
54 PreflightResult(
55 fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
56 obj=db_table,
57 id="models.duplicate_db_table",
58 )
59 )
60 for index_name, model_labels in indexes.items():
61 if len(model_labels) > 1:
62 model_labels = set(model_labels)
63 errors.append(
64 PreflightResult(
65 fix="index name '{}' is not unique {} {}.".format(
66 index_name,
67 "for model" if len(model_labels) == 1 else "among models:",
68 ", ".join(sorted(model_labels)),
69 ),
70 id="models.index_name_not_unique_single"
71 if len(model_labels) == 1
72 else "models.index_name_not_unique_multiple",
73 ),
74 )
75 for constraint_name, model_labels in constraints.items():
76 if len(model_labels) > 1:
77 model_labels = set(model_labels)
78 errors.append(
79 PreflightResult(
80 fix="constraint name '{}' is not unique {} {}.".format(
81 constraint_name,
82 "for model" if len(model_labels) == 1 else "among models:",
83 ", ".join(sorted(model_labels)),
84 ),
85 id="models.constraint_name_not_unique_single"
86 if len(model_labels) == 1
87 else "models.constraint_name_not_unique_multiple",
88 ),
89 )
90 return errors
91
92
93def _check_lazy_references(
94 models_registry: ModelsRegistry, packages_registry: Any
95) -> list[PreflightResult]:
96 """
97 Ensure all lazy (i.e. string) model references have been resolved.
98
99 Lazy references are used in various places throughout Plain, primarily in
100 related fields and model signals. Identify those common cases and provide
101 more helpful error messages for them.
102 """
103 pending_models = set(models_registry._pending_operations)
104
105 # Short circuit if there aren't any errors.
106 if not pending_models:
107 return []
108
109 def extract_operation(
110 obj: Any,
111 ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
112 """
113 Take a callable found in Packages._pending_operations and identify the
114 original callable passed to Packages.lazy_model_operation(). If that
115 callable was a partial, return the inner, non-partial function and
116 any arguments and keyword arguments that were supplied with it.
117
118 obj is a callback defined locally in Packages.lazy_model_operation() and
119 annotated there with a `func` attribute so as to imitate a partial.
120 """
121 operation, args, keywords = obj, [], {}
122 while hasattr(operation, "func"):
123 args.extend(getattr(operation, "args", []))
124 keywords.update(getattr(operation, "keywords", {}))
125 operation = operation.func
126 return operation, args, keywords
127
128 def app_model_error(model_key: tuple[str, str]) -> str:
129 try:
130 packages_registry.get_package_config(model_key[0])
131 model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
132 except LookupError:
133 model_error = f"app '{model_key[0]}' isn't installed"
134 return model_error
135
136 # Here are several functions which return CheckMessage instances for the
137 # most common usages of lazy operations throughout Plain. These functions
138 # take the model that was being waited on as an (package_label, modelname)
139 # pair, the original lazy function, and its positional and keyword args as
140 # determined by extract_operation().
141
142 def field_error(
143 model_key: tuple[str, str],
144 func: Callable[..., Any],
145 args: list[Any],
146 keywords: dict[str, Any],
147 ) -> PreflightResult:
148 error_msg = (
149 "The field %(field)s was declared with a lazy reference "
150 "to '%(model)s', but %(model_error)s."
151 )
152 params = {
153 "model": ".".join(model_key),
154 "field": keywords["field"],
155 "model_error": app_model_error(model_key),
156 }
157 return PreflightResult(
158 fix=error_msg % params,
159 obj=keywords["field"],
160 id="fields.lazy_reference_not_resolvable",
161 )
162
163 def default_error(
164 model_key: tuple[str, str],
165 func: Callable[..., Any],
166 args: list[Any],
167 keywords: dict[str, Any],
168 ) -> PreflightResult:
169 error_msg = (
170 "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
171 )
172 params = {
173 "op": func,
174 "model": ".".join(model_key),
175 "model_error": app_model_error(model_key),
176 }
177 return PreflightResult(
178 fix=error_msg % params,
179 obj=func,
180 id="models.lazy_reference_resolution_failed",
181 )
182
183 # Maps common uses of lazy operations to corresponding error functions
184 # defined above. If a key maps to None, no error will be produced.
185 # default_error() will be used for usages that don't appear in this dict.
186 known_lazy = {
187 ("plain.models.fields.related", "resolve_related_class"): field_error,
188 }
189
190 def build_error(
191 model_key: tuple[str, str],
192 func: Callable[..., Any],
193 args: list[Any],
194 keywords: dict[str, Any],
195 ) -> PreflightResult | None:
196 key = (func.__module__, func.__name__) # type: ignore[attr-defined]
197 error_fn = known_lazy.get(key, default_error)
198 return error_fn(model_key, func, args, keywords) if error_fn else None
199
200 return sorted(
201 filter(
202 None,
203 (
204 build_error(model_key, *extract_operation(func))
205 for model_key in pending_models
206 for func in models_registry._pending_operations[model_key]
207 ),
208 ),
209 key=lambda error: error.fix,
210 )
211
212
213@register_check("models.lazy_references")
214class CheckLazyReferences(PreflightCheck):
215 """Ensures all lazy (string) model references have been resolved."""
216
217 def run(self) -> list[PreflightResult]:
218 return _check_lazy_references(models_registry, packages_registry)
219
220
221@register_check("models.database_tables")
222class CheckDatabaseTables(PreflightCheck):
223 """Checks for unknown tables in the database when plain.models is available."""
224
225 def run(self) -> list[PreflightResult]:
226 errors = []
227
228 db_tables = db_connection.introspection.table_names()
229 model_tables = db_connection.introspection.plain_table_names()
230 unknown_tables = set(db_tables) - set(model_tables)
231 unknown_tables.discard("plainmigrations") # Know this could be there
232 if unknown_tables:
233 table_names = ", ".join(unknown_tables)
234 specific_fix = (
235 f'echo "DROP TABLE IF EXISTS {unknown_tables.pop()}" | plain db shell'
236 )
237 errors.append(
238 PreflightResult(
239 fix=f"Unknown tables in default database: {table_names}. "
240 "Tables may be from packages/models that have been uninstalled. "
241 "Make sure you have a backup and delete the tables manually "
242 f"(ex. `{specific_fix}`).",
243 id="models.unknown_database_tables",
244 warning=True,
245 )
246 )
247
248 return errors
249
250
251@register_check("models.prunable_migrations")
252class CheckPrunableMigrations(PreflightCheck):
253 """Warns about stale migration records in the database."""
254
255 def run(self) -> list[PreflightResult]:
256 # Import here to avoid circular import issues
257 from plain.models.migrations.loader import MigrationLoader
258 from plain.models.migrations.recorder import MigrationRecorder
259
260 errors = []
261
262 # Load migrations from disk and database
263 loader = MigrationLoader(db_connection, ignore_no_migrations=True)
264 recorder = MigrationRecorder(db_connection)
265 recorded_migrations = recorder.applied_migrations()
266
267 # disk_migrations should not be None after MigrationLoader initialization,
268 # but check to satisfy type checker
269 if loader.disk_migrations is None:
270 return errors
271
272 # Find all prunable migrations (recorded but not on disk)
273 all_prunable = [
274 migration
275 for migration in recorded_migrations
276 if migration not in loader.disk_migrations
277 ]
278
279 if not all_prunable:
280 return errors
281
282 # Separate into existing packages vs orphaned packages
283 existing_packages = set(loader.migrated_packages)
284 prunable_existing: list[tuple[str, str]] = []
285 prunable_orphaned: list[tuple[str, str]] = []
286
287 for migration in all_prunable:
288 package, name = migration
289 if package in existing_packages:
290 prunable_existing.append(migration)
291 else:
292 prunable_orphaned.append(migration)
293
294 # Build the warning message
295 total_count = len(all_prunable)
296 message_parts = [
297 f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
298 ]
299
300 if prunable_existing:
301 existing_list = ", ".join(
302 f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
303 )
304 if len(prunable_existing) > 3:
305 existing_list += f" (and {len(prunable_existing) - 3} more)"
306 message_parts.append(f"From existing packages: {existing_list}.")
307
308 if prunable_orphaned:
309 orphaned_list = ", ".join(
310 f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
311 )
312 if len(prunable_orphaned) > 3:
313 orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
314 message_parts.append(f"From removed packages: {orphaned_list}.")
315
316 message_parts.append("Run 'plain migrations prune' to review and remove them.")
317
318 errors.append(
319 PreflightResult(
320 fix=" ".join(message_parts),
321 id="models.prunable_migrations",
322 warning=True,
323 )
324 )
325
326 return errors