Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import inspect
  4from collections import defaultdict
  5from collections.abc import Callable
  6from typing import Any
  7
  8from plain.models.db import db_connection
  9from plain.models.registry import ModelsRegistry, models_registry
 10from plain.packages import packages_registry
 11from plain.preflight import PreflightCheck, PreflightResult, register_check
 12
 13
 14@register_check("models.database_backends")
 15class CheckDatabaseBackends(PreflightCheck):
 16    """Validates database backend configuration when plain.models is available."""
 17
 18    def run(self) -> list[PreflightResult]:
 19        return db_connection.validation.preflight()
 20
 21
 22@register_check("models.all_models")
 23class CheckAllModels(PreflightCheck):
 24    """Validates all model definitions for common issues."""
 25
 26    def run(self) -> list[PreflightResult]:
 27        db_table_models = defaultdict(list)
 28        indexes = defaultdict(list)
 29        constraints = defaultdict(list)
 30        errors = []
 31        models = models_registry.get_models()
 32        for model in models:
 33            db_table_models[model.model_options.db_table].append(
 34                model.model_options.label
 35            )
 36            if not inspect.ismethod(model.preflight):
 37                errors.append(
 38                    PreflightResult(
 39                        fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
 40                        obj=model,
 41                        id="models.preflight_method_overridden",
 42                    )
 43                )
 44            else:
 45                errors.extend(model.preflight())
 46            for model_index in model.model_options.indexes:
 47                indexes[model_index.name].append(model.model_options.label)
 48            for model_constraint in model.model_options.constraints:
 49                constraints[model_constraint.name].append(model.model_options.label)
 50        for db_table, model_labels in db_table_models.items():
 51            if len(model_labels) != 1:
 52                model_labels_str = ", ".join(model_labels)
 53                errors.append(
 54                    PreflightResult(
 55                        fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
 56                        obj=db_table,
 57                        id="models.duplicate_db_table",
 58                    )
 59                )
 60        for index_name, model_labels in indexes.items():
 61            if len(model_labels) > 1:
 62                model_labels = set(model_labels)
 63                errors.append(
 64                    PreflightResult(
 65                        fix="index name '{}' is not unique {} {}.".format(
 66                            index_name,
 67                            "for model" if len(model_labels) == 1 else "among models:",
 68                            ", ".join(sorted(model_labels)),
 69                        ),
 70                        id="models.index_name_not_unique_single"
 71                        if len(model_labels) == 1
 72                        else "models.index_name_not_unique_multiple",
 73                    ),
 74                )
 75        for constraint_name, model_labels in constraints.items():
 76            if len(model_labels) > 1:
 77                model_labels = set(model_labels)
 78                errors.append(
 79                    PreflightResult(
 80                        fix="constraint name '{}' is not unique {} {}.".format(
 81                            constraint_name,
 82                            "for model" if len(model_labels) == 1 else "among models:",
 83                            ", ".join(sorted(model_labels)),
 84                        ),
 85                        id="models.constraint_name_not_unique_single"
 86                        if len(model_labels) == 1
 87                        else "models.constraint_name_not_unique_multiple",
 88                    ),
 89                )
 90        return errors
 91
 92
 93def _check_lazy_references(
 94    models_registry: ModelsRegistry, packages_registry: Any
 95) -> list[PreflightResult]:
 96    """
 97    Ensure all lazy (i.e. string) model references have been resolved.
 98
 99    Lazy references are used in various places throughout Plain, primarily in
100    related fields and model signals. Identify those common cases and provide
101    more helpful error messages for them.
102    """
103    pending_models = set(models_registry._pending_operations)
104
105    # Short circuit if there aren't any errors.
106    if not pending_models:
107        return []
108
109    def extract_operation(
110        obj: Any,
111    ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
112        """
113        Take a callable found in Packages._pending_operations and identify the
114        original callable passed to Packages.lazy_model_operation(). If that
115        callable was a partial, return the inner, non-partial function and
116        any arguments and keyword arguments that were supplied with it.
117
118        obj is a callback defined locally in Packages.lazy_model_operation() and
119        annotated there with a `func` attribute so as to imitate a partial.
120        """
121        operation, args, keywords = obj, [], {}
122        while hasattr(operation, "func"):
123            args.extend(getattr(operation, "args", []))
124            keywords.update(getattr(operation, "keywords", {}))
125            operation = operation.func
126        return operation, args, keywords
127
128    def app_model_error(model_key: tuple[str, str]) -> str:
129        try:
130            packages_registry.get_package_config(model_key[0])
131            model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
132        except LookupError:
133            model_error = f"app '{model_key[0]}' isn't installed"
134        return model_error
135
136    # Here are several functions which return CheckMessage instances for the
137    # most common usages of lazy operations throughout Plain. These functions
138    # take the model that was being waited on as an (package_label, modelname)
139    # pair, the original lazy function, and its positional and keyword args as
140    # determined by extract_operation().
141
142    def field_error(
143        model_key: tuple[str, str],
144        func: Callable[..., Any],
145        args: list[Any],
146        keywords: dict[str, Any],
147    ) -> PreflightResult:
148        error_msg = (
149            "The field %(field)s was declared with a lazy reference "
150            "to '%(model)s', but %(model_error)s."
151        )
152        params = {
153            "model": ".".join(model_key),
154            "field": keywords["field"],
155            "model_error": app_model_error(model_key),
156        }
157        return PreflightResult(
158            fix=error_msg % params,
159            obj=keywords["field"],
160            id="fields.lazy_reference_not_resolvable",
161        )
162
163    def default_error(
164        model_key: tuple[str, str],
165        func: Callable[..., Any],
166        args: list[Any],
167        keywords: dict[str, Any],
168    ) -> PreflightResult:
169        error_msg = (
170            "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
171        )
172        params = {
173            "op": func,
174            "model": ".".join(model_key),
175            "model_error": app_model_error(model_key),
176        }
177        return PreflightResult(
178            fix=error_msg % params,
179            obj=func,
180            id="models.lazy_reference_resolution_failed",
181        )
182
183    # Maps common uses of lazy operations to corresponding error functions
184    # defined above. If a key maps to None, no error will be produced.
185    # default_error() will be used for usages that don't appear in this dict.
186    known_lazy = {
187        ("plain.models.fields.related", "resolve_related_class"): field_error,
188    }
189
190    def build_error(
191        model_key: tuple[str, str],
192        func: Callable[..., Any],
193        args: list[Any],
194        keywords: dict[str, Any],
195    ) -> PreflightResult | None:
196        key = (func.__module__, func.__name__)  # type: ignore[attr-defined]
197        error_fn = known_lazy.get(key, default_error)
198        return error_fn(model_key, func, args, keywords) if error_fn else None
199
200    return sorted(
201        filter(
202            None,
203            (
204                build_error(model_key, *extract_operation(func))
205                for model_key in pending_models
206                for func in models_registry._pending_operations[model_key]
207            ),
208        ),
209        key=lambda error: error.fix,
210    )
211
212
213@register_check("models.lazy_references")
214class CheckLazyReferences(PreflightCheck):
215    """Ensures all lazy (string) model references have been resolved."""
216
217    def run(self) -> list[PreflightResult]:
218        return _check_lazy_references(models_registry, packages_registry)
219
220
221@register_check("models.database_tables")
222class CheckDatabaseTables(PreflightCheck):
223    """Checks for unknown tables in the database when plain.models is available."""
224
225    def run(self) -> list[PreflightResult]:
226        errors = []
227
228        db_tables = db_connection.introspection.table_names()
229        model_tables = db_connection.introspection.plain_table_names()
230        unknown_tables = set(db_tables) - set(model_tables)
231        unknown_tables.discard("plainmigrations")  # Know this could be there
232        if unknown_tables:
233            table_names = ", ".join(unknown_tables)
234            specific_fix = (
235                f'echo "DROP TABLE IF EXISTS {unknown_tables.pop()}" | plain db shell'
236            )
237            errors.append(
238                PreflightResult(
239                    fix=f"Unknown tables in default database: {table_names}. "
240                    "Tables may be from packages/models that have been uninstalled. "
241                    "Make sure you have a backup and delete the tables manually "
242                    f"(ex. `{specific_fix}`).",
243                    id="models.unknown_database_tables",
244                    warning=True,
245                )
246            )
247
248        return errors
249
250
251@register_check("models.prunable_migrations")
252class CheckPrunableMigrations(PreflightCheck):
253    """Warns about stale migration records in the database."""
254
255    def run(self) -> list[PreflightResult]:
256        # Import here to avoid circular import issues
257        from plain.models.migrations.loader import MigrationLoader
258        from plain.models.migrations.recorder import MigrationRecorder
259
260        errors = []
261
262        # Load migrations from disk and database
263        loader = MigrationLoader(db_connection, ignore_no_migrations=True)
264        recorder = MigrationRecorder(db_connection)
265        recorded_migrations = recorder.applied_migrations()
266
267        # disk_migrations should not be None after MigrationLoader initialization,
268        # but check to satisfy type checker
269        if loader.disk_migrations is None:
270            return errors
271
272        # Find all prunable migrations (recorded but not on disk)
273        all_prunable = [
274            migration
275            for migration in recorded_migrations
276            if migration not in loader.disk_migrations
277        ]
278
279        if not all_prunable:
280            return errors
281
282        # Separate into existing packages vs orphaned packages
283        existing_packages = set(loader.migrated_packages)
284        prunable_existing: list[tuple[str, str]] = []
285        prunable_orphaned: list[tuple[str, str]] = []
286
287        for migration in all_prunable:
288            package, name = migration
289            if package in existing_packages:
290                prunable_existing.append(migration)
291            else:
292                prunable_orphaned.append(migration)
293
294        # Build the warning message
295        total_count = len(all_prunable)
296        message_parts = [
297            f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
298        ]
299
300        if prunable_existing:
301            existing_list = ", ".join(
302                f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
303            )
304            if len(prunable_existing) > 3:
305                existing_list += f" (and {len(prunable_existing) - 3} more)"
306            message_parts.append(f"From existing packages: {existing_list}.")
307
308        if prunable_orphaned:
309            orphaned_list = ", ".join(
310                f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
311            )
312            if len(prunable_orphaned) > 3:
313                orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
314            message_parts.append(f"From removed packages: {orphaned_list}.")
315
316        message_parts.append("Run 'plain migrations prune' to review and remove them.")
317
318        errors.append(
319            PreflightResult(
320                fix=" ".join(message_parts),
321                id="models.prunable_migrations",
322                warning=True,
323            )
324        )
325
326        return errors