1from __future__ import annotations
  2
  3import inspect
  4from collections import defaultdict
  5from collections.abc import Callable
  6from typing import Any
  7
  8from plain.models.db import db_connection
  9from plain.models.migrations.recorder import MIGRATION_TABLE_NAME
 10from plain.models.registry import ModelsRegistry, models_registry
 11from plain.packages import packages_registry
 12from plain.preflight import PreflightCheck, PreflightResult, register_check
 13
 14
 15@register_check("models.all_models")
 16class CheckAllModels(PreflightCheck):
 17    """Validates all model definitions for common issues."""
 18
 19    def run(self) -> list[PreflightResult]:
 20        db_table_models = defaultdict(list)
 21        indexes = defaultdict(list)
 22        constraints = defaultdict(list)
 23        errors = []
 24        models = models_registry.get_models()
 25        for model in models:
 26            db_table_models[model.model_options.db_table].append(
 27                model.model_options.label
 28            )
 29            if not inspect.ismethod(model.preflight):
 30                errors.append(
 31                    PreflightResult(
 32                        fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
 33                        obj=model,
 34                        id="models.preflight_method_overridden",
 35                    )
 36                )
 37            else:
 38                errors.extend(model.preflight())
 39            for model_index in model.model_options.indexes:
 40                indexes[model_index.name].append(model.model_options.label)
 41            for model_constraint in model.model_options.constraints:
 42                constraints[model_constraint.name].append(model.model_options.label)
 43        for db_table, model_labels in db_table_models.items():
 44            if len(model_labels) != 1:
 45                model_labels_str = ", ".join(model_labels)
 46                errors.append(
 47                    PreflightResult(
 48                        fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
 49                        obj=db_table,
 50                        id="models.duplicate_db_table",
 51                    )
 52                )
 53        for index_name, model_labels in indexes.items():
 54            if len(model_labels) > 1:
 55                model_labels = set(model_labels)
 56                errors.append(
 57                    PreflightResult(
 58                        fix="index name '{}' is not unique {} {}.".format(
 59                            index_name,
 60                            "for model" if len(model_labels) == 1 else "among models:",
 61                            ", ".join(sorted(model_labels)),
 62                        ),
 63                        id="models.index_name_not_unique_single"
 64                        if len(model_labels) == 1
 65                        else "models.index_name_not_unique_multiple",
 66                    ),
 67                )
 68        for constraint_name, model_labels in constraints.items():
 69            if len(model_labels) > 1:
 70                model_labels = set(model_labels)
 71                errors.append(
 72                    PreflightResult(
 73                        fix="constraint name '{}' is not unique {} {}.".format(
 74                            constraint_name,
 75                            "for model" if len(model_labels) == 1 else "among models:",
 76                            ", ".join(sorted(model_labels)),
 77                        ),
 78                        id="models.constraint_name_not_unique_single"
 79                        if len(model_labels) == 1
 80                        else "models.constraint_name_not_unique_multiple",
 81                    ),
 82                )
 83        return errors
 84
 85
 86def _check_lazy_references(
 87    models_registry: ModelsRegistry, packages_registry: Any
 88) -> list[PreflightResult]:
 89    """
 90    Ensure all lazy (i.e. string) model references have been resolved.
 91
 92    Lazy references are used in various places throughout Plain, primarily in
 93    related fields and model signals. Identify those common cases and provide
 94    more helpful error messages for them.
 95    """
 96    pending_models = set(models_registry._pending_operations)
 97
 98    # Short circuit if there aren't any errors.
 99    if not pending_models:
100        return []
101
102    def extract_operation(
103        obj: Any,
104    ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
105        """
106        Take a callable found in Packages._pending_operations and identify the
107        original callable passed to Packages.lazy_model_operation(). If that
108        callable was a partial, return the inner, non-partial function and
109        any arguments and keyword arguments that were supplied with it.
110
111        obj is a callback defined locally in Packages.lazy_model_operation() and
112        annotated there with a `func` attribute so as to imitate a partial.
113        """
114        operation, args, keywords = obj, [], {}
115        while hasattr(operation, "func"):
116            args.extend(getattr(operation, "args", []))
117            keywords.update(getattr(operation, "keywords", {}))
118            operation = operation.func
119        return operation, args, keywords
120
121    def app_model_error(model_key: tuple[str, str]) -> str:
122        try:
123            packages_registry.get_package_config(model_key[0])
124            model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
125        except LookupError:
126            model_error = f"app '{model_key[0]}' isn't installed"
127        return model_error
128
129    # Here are several functions which return CheckMessage instances for the
130    # most common usages of lazy operations throughout Plain. These functions
131    # take the model that was being waited on as an (package_label, modelname)
132    # pair, the original lazy function, and its positional and keyword args as
133    # determined by extract_operation().
134
135    def field_error(
136        model_key: tuple[str, str],
137        func: Callable[..., Any],
138        args: list[Any],
139        keywords: dict[str, Any],
140    ) -> PreflightResult:
141        error_msg = (
142            "The field %(field)s was declared with a lazy reference "
143            "to '%(model)s', but %(model_error)s."
144        )
145        params = {
146            "model": ".".join(model_key),
147            "field": keywords["field"],
148            "model_error": app_model_error(model_key),
149        }
150        return PreflightResult(
151            fix=error_msg % params,
152            obj=keywords["field"],
153            id="fields.lazy_reference_not_resolvable",
154        )
155
156    def default_error(
157        model_key: tuple[str, str],
158        func: Callable[..., Any],
159        args: list[Any],
160        keywords: dict[str, Any],
161    ) -> PreflightResult:
162        error_msg = (
163            "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
164        )
165        params = {
166            "op": func,
167            "model": ".".join(model_key),
168            "model_error": app_model_error(model_key),
169        }
170        return PreflightResult(
171            fix=error_msg % params,
172            obj=func,
173            id="models.lazy_reference_resolution_failed",
174        )
175
176    # Maps common uses of lazy operations to corresponding error functions
177    # defined above. If a key maps to None, no error will be produced.
178    # default_error() will be used for usages that don't appear in this dict.
179    known_lazy = {
180        ("plain.models.fields.related", "resolve_related_class"): field_error,
181    }
182
183    def build_error(
184        model_key: tuple[str, str],
185        func: Callable[..., Any],
186        args: list[Any],
187        keywords: dict[str, Any],
188    ) -> PreflightResult | None:
189        key = (func.__module__, func.__name__)  # type: ignore[attr-defined]
190        error_fn = known_lazy.get(key, default_error)
191        return error_fn(model_key, func, args, keywords) if error_fn else None
192
193    return sorted(
194        filter(
195            None,
196            (
197                build_error(model_key, *extract_operation(func))
198                for model_key in pending_models
199                for func in models_registry._pending_operations[model_key]
200            ),
201        ),
202        key=lambda error: error.fix,
203    )
204
205
206@register_check("models.lazy_references")
207class CheckLazyReferences(PreflightCheck):
208    """Ensures all lazy (string) model references have been resolved."""
209
210    def run(self) -> list[PreflightResult]:
211        return _check_lazy_references(models_registry, packages_registry)
212
213
214@register_check("models.database_tables")
215class CheckDatabaseTables(PreflightCheck):
216    """Checks for unknown tables in the database when plain.models is available."""
217
218    def run(self) -> list[PreflightResult]:
219        unknown_tables = (
220            set(db_connection.table_names())
221            - set(db_connection.plain_table_names())
222            - {MIGRATION_TABLE_NAME}
223        )
224
225        if not unknown_tables:
226            return []
227
228        table_names = ", ".join(sorted(unknown_tables))
229        return [
230            PreflightResult(
231                fix=f"Unknown tables in default database: {table_names}. "
232                "Tables may be from packages/models that have been uninstalled. "
233                "Make sure you have a backup, then run `plain db drop-unknown-tables` to remove them.",
234                id="models.unknown_database_tables",
235                warning=True,
236            )
237        ]
238
239
240@register_check("models.prunable_migrations")
241class CheckPrunableMigrations(PreflightCheck):
242    """Warns about stale migration records in the database."""
243
244    def run(self) -> list[PreflightResult]:
245        # Import here to avoid circular import issues
246        from plain.models.migrations.loader import MigrationLoader
247        from plain.models.migrations.recorder import MigrationRecorder
248
249        errors = []
250
251        # Load migrations from disk and database
252        loader = MigrationLoader(db_connection, ignore_no_migrations=True)
253        recorder = MigrationRecorder(db_connection)
254        recorded_migrations = recorder.applied_migrations()
255
256        # disk_migrations should not be None after MigrationLoader initialization,
257        # but check to satisfy type checker
258        if loader.disk_migrations is None:
259            return errors
260
261        # Find all prunable migrations (recorded but not on disk)
262        all_prunable = [
263            migration
264            for migration in recorded_migrations
265            if migration not in loader.disk_migrations
266        ]
267
268        if not all_prunable:
269            return errors
270
271        # Separate into existing packages vs orphaned packages
272        existing_packages = set(loader.migrated_packages)
273        prunable_existing: list[tuple[str, str]] = []
274        prunable_orphaned: list[tuple[str, str]] = []
275
276        for migration in all_prunable:
277            package, name = migration
278            if package in existing_packages:
279                prunable_existing.append(migration)
280            else:
281                prunable_orphaned.append(migration)
282
283        # Build the warning message
284        total_count = len(all_prunable)
285        message_parts = [
286            f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
287        ]
288
289        if prunable_existing:
290            existing_list = ", ".join(
291                f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
292            )
293            if len(prunable_existing) > 3:
294                existing_list += f" (and {len(prunable_existing) - 3} more)"
295            message_parts.append(f"From existing packages: {existing_list}.")
296
297        if prunable_orphaned:
298            orphaned_list = ", ".join(
299                f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
300            )
301            if len(prunable_orphaned) > 3:
302                orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
303            message_parts.append(f"From removed packages: {orphaned_list}.")
304
305        message_parts.append("Run 'plain migrations prune' to review and remove them.")
306
307        errors.append(
308            PreflightResult(
309                fix=" ".join(message_parts),
310                id="models.prunable_migrations",
311                warning=True,
312            )
313        )
314
315        return errors