1from __future__ import annotations
  2
  3import inspect
  4from collections import defaultdict
  5from collections.abc import Callable
  6from typing import Any
  7
  8from plain.packages import packages_registry
  9from plain.postgres.db import get_connection
 10from plain.postgres.migrations.recorder import MIGRATION_TABLE_NAME
 11from plain.postgres.registry import ModelsRegistry, models_registry
 12from plain.preflight import PreflightCheck, PreflightResult, register_check
 13
 14
 15@register_check("postgres.all_models")
 16class CheckAllModels(PreflightCheck):
 17    """Validates all model definitions for common issues."""
 18
 19    def run(self) -> list[PreflightResult]:
 20        db_table_models = defaultdict(list)
 21        indexes = defaultdict(list)
 22        constraints = defaultdict(list)
 23        errors = []
 24        models = models_registry.get_models()
 25        for model in models:
 26            db_table_models[model.model_options.db_table].append(
 27                model.model_options.label
 28            )
 29            if not inspect.ismethod(model.preflight):
 30                errors.append(
 31                    PreflightResult(
 32                        fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
 33                        obj=model,
 34                        id="postgres.preflight_method_overridden",
 35                    )
 36                )
 37            else:
 38                errors.extend(model.preflight())
 39            for model_index in model.model_options.indexes:
 40                indexes[model_index.name].append(model.model_options.label)
 41            for model_constraint in model.model_options.constraints:
 42                constraints[model_constraint.name].append(model.model_options.label)
 43        for db_table, model_labels in db_table_models.items():
 44            if len(model_labels) != 1:
 45                model_labels_str = ", ".join(model_labels)
 46                errors.append(
 47                    PreflightResult(
 48                        fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
 49                        obj=db_table,
 50                        id="postgres.duplicate_db_table",
 51                    )
 52                )
 53        for index_name, model_labels in indexes.items():
 54            if len(model_labels) > 1:
 55                model_labels = set(model_labels)
 56                errors.append(
 57                    PreflightResult(
 58                        fix="index name '{}' is not unique {} {}.".format(
 59                            index_name,
 60                            "for model" if len(model_labels) == 1 else "among models:",
 61                            ", ".join(sorted(model_labels)),
 62                        ),
 63                        id="postgres.index_name_not_unique_single"
 64                        if len(model_labels) == 1
 65                        else "postgres.index_name_not_unique_multiple",
 66                    ),
 67                )
 68        for constraint_name, model_labels in constraints.items():
 69            if len(model_labels) > 1:
 70                model_labels = set(model_labels)
 71                errors.append(
 72                    PreflightResult(
 73                        fix="constraint name '{}' is not unique {} {}.".format(
 74                            constraint_name,
 75                            "for model" if len(model_labels) == 1 else "among models:",
 76                            ", ".join(sorted(model_labels)),
 77                        ),
 78                        id="postgres.constraint_name_not_unique_single"
 79                        if len(model_labels) == 1
 80                        else "postgres.constraint_name_not_unique_multiple",
 81                    ),
 82                )
 83        return errors
 84
 85
 86def _check_lazy_references(
 87    models_registry: ModelsRegistry, packages_registry: Any
 88) -> list[PreflightResult]:
 89    """
 90    Ensure all lazy (i.e. string) model references have been resolved.
 91
 92    Lazy references are used in various places throughout Plain, primarily in
 93    related fields and model signals. Identify those common cases and provide
 94    more helpful error messages for them.
 95    """
 96    pending_models = set(models_registry._pending_operations)
 97
 98    # Short circuit if there aren't any errors.
 99    if not pending_models:
100        return []
101
102    def extract_operation(
103        obj: Any,
104    ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
105        """
106        Take a callable found in Packages._pending_operations and identify the
107        original callable passed to Packages.lazy_model_operation(). If that
108        callable was a partial, return the inner, non-partial function and
109        any arguments and keyword arguments that were supplied with it.
110
111        obj is a callback defined locally in Packages.lazy_model_operation() and
112        annotated there with a `func` attribute so as to imitate a partial.
113        """
114        operation, args, keywords = obj, [], {}
115        while hasattr(operation, "func"):
116            args.extend(getattr(operation, "args", []))
117            keywords.update(getattr(operation, "keywords", {}))
118            operation = operation.func
119        return operation, args, keywords
120
121    def app_model_error(model_key: tuple[str, str]) -> str:
122        try:
123            packages_registry.get_package_config(model_key[0])
124            model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
125        except LookupError:
126            model_error = f"app '{model_key[0]}' isn't installed"
127        return model_error
128
129    # Here are several functions which return CheckMessage instances for the
130    # most common usages of lazy operations throughout Plain. These functions
131    # take the model that was being waited on as an (package_label, modelname)
132    # pair, the original lazy function, and its positional and keyword args as
133    # determined by extract_operation().
134
135    def field_error(
136        model_key: tuple[str, str],
137        func: Callable[..., Any],
138        args: list[Any],
139        keywords: dict[str, Any],
140    ) -> PreflightResult:
141        error_msg = (
142            "The field %(field)s was declared with a lazy reference "
143            "to '%(model)s', but %(model_error)s."
144        )
145        params = {
146            "model": ".".join(model_key),
147            "field": keywords["field"],
148            "model_error": app_model_error(model_key),
149        }
150        return PreflightResult(
151            fix=error_msg % params,
152            obj=keywords["field"],
153            id="fields.lazy_reference_not_resolvable",
154        )
155
156    def default_error(
157        model_key: tuple[str, str],
158        func: Callable[..., Any],
159        args: list[Any],
160        keywords: dict[str, Any],
161    ) -> PreflightResult:
162        error_msg = (
163            "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
164        )
165        params = {
166            "op": func,
167            "model": ".".join(model_key),
168            "model_error": app_model_error(model_key),
169        }
170        return PreflightResult(
171            fix=error_msg % params,
172            obj=func,
173            id="postgres.lazy_reference_resolution_failed",
174        )
175
176    # Maps common uses of lazy operations to corresponding error functions
177    # defined above. If a key maps to None, no error will be produced.
178    # default_error() will be used for usages that don't appear in this dict.
179    known_lazy = {
180        ("plain.postgres.fields.related", "resolve_related_class"): field_error,
181    }
182
183    def build_error(
184        model_key: tuple[str, str],
185        func: Callable[..., Any],
186        args: list[Any],
187        keywords: dict[str, Any],
188    ) -> PreflightResult | None:
189        key = (func.__module__, func.__name__)  # type: ignore[attr-defined]
190        error_fn = known_lazy.get(key, default_error)
191        return error_fn(model_key, func, args, keywords) if error_fn else None
192
193    return sorted(
194        filter(
195            None,
196            (
197                build_error(model_key, *extract_operation(func))
198                for model_key in pending_models
199                for func in models_registry._pending_operations[model_key]
200            ),
201        ),
202        key=lambda error: error.fix,
203    )
204
205
206@register_check("postgres.lazy_references")
207class CheckLazyReferences(PreflightCheck):
208    """Ensures all lazy (string) model references have been resolved."""
209
210    def run(self) -> list[PreflightResult]:
211        return _check_lazy_references(models_registry, packages_registry)
212
213
214@register_check("postgres.postgres_version")
215class CheckPostgresVersion(PreflightCheck):
216    """Checks that the PostgreSQL server meets the minimum version requirement."""
217
218    MINIMUM_VERSION = 16
219
220    def run(self) -> list[PreflightResult]:
221        conn = get_connection()
222        major, minor = divmod(conn.pg_version, 10000)
223        if major < self.MINIMUM_VERSION:
224            return [
225                PreflightResult(
226                    fix=f"PostgreSQL {self.MINIMUM_VERSION} or later is required (found {major}.{minor}).",
227                    id="postgres.postgres_version_too_old",
228                )
229            ]
230        return []
231
232
233@register_check("postgres.database_tables")
234class CheckDatabaseTables(PreflightCheck):
235    """Checks for unknown tables in the database when plain.postgres is available."""
236
237    def run(self) -> list[PreflightResult]:
238        conn = get_connection()
239        unknown_tables = (
240            set(conn.table_names())
241            - set(conn.plain_table_names())
242            - {MIGRATION_TABLE_NAME}
243        )
244
245        if not unknown_tables:
246            return []
247
248        table_names = ", ".join(sorted(unknown_tables))
249        return [
250            PreflightResult(
251                fix=f"Unknown tables in default database: {table_names}. "
252                "Tables may be from packages/models that have been uninstalled. "
253                "Make sure you have a backup, then run `plain db drop-unknown-tables` to remove them.",
254                id="postgres.unknown_database_tables",
255                warning=True,
256            )
257        ]
258
259
260@register_check("postgres.prunable_migrations")
261class CheckPrunableMigrations(PreflightCheck):
262    """Warns about stale migration records in the database."""
263
264    def run(self) -> list[PreflightResult]:
265        # Import here to avoid circular import issues
266        from plain.postgres.migrations.loader import MigrationLoader
267        from plain.postgres.migrations.recorder import MigrationRecorder
268
269        errors = []
270
271        # Load migrations from disk and database
272        conn = get_connection()
273        loader = MigrationLoader(conn, ignore_no_migrations=True)
274        recorder = MigrationRecorder(conn)
275        recorded_migrations = recorder.applied_migrations()
276
277        # disk_migrations should not be None after MigrationLoader initialization,
278        # but check to satisfy type checker
279        if loader.disk_migrations is None:
280            return errors
281
282        # Find all prunable migrations (recorded but not on disk)
283        all_prunable = [
284            migration
285            for migration in recorded_migrations
286            if migration not in loader.disk_migrations
287        ]
288
289        if not all_prunable:
290            return errors
291
292        # Separate into existing packages vs orphaned packages
293        existing_packages = set(loader.migrated_packages)
294        prunable_existing: list[tuple[str, str]] = []
295        prunable_orphaned: list[tuple[str, str]] = []
296
297        for migration in all_prunable:
298            package, name = migration
299            if package in existing_packages:
300                prunable_existing.append(migration)
301            else:
302                prunable_orphaned.append(migration)
303
304        # Build the warning message
305        total_count = len(all_prunable)
306        message_parts = [
307            f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
308        ]
309
310        if prunable_existing:
311            existing_list = ", ".join(
312                f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
313            )
314            if len(prunable_existing) > 3:
315                existing_list += f" (and {len(prunable_existing) - 3} more)"
316            message_parts.append(f"From existing packages: {existing_list}.")
317
318        if prunable_orphaned:
319            orphaned_list = ", ".join(
320                f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
321            )
322            if len(prunable_orphaned) > 3:
323                orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
324            message_parts.append(f"From removed packages: {orphaned_list}.")
325
326        message_parts.append("Run 'plain migrations prune' to review and remove them.")
327
328        errors.append(
329            PreflightResult(
330                fix=" ".join(message_parts),
331                id="postgres.prunable_migrations",
332                warning=True,
333            )
334        )
335
336        return errors