1from __future__ import annotations
  2
  3import pkgutil
  4import sys
  5from importlib import import_module, reload
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.models.migrations.graph import MigrationGraph
  9from plain.models.migrations.recorder import MigrationRecorder
 10from plain.packages import packages_registry
 11
 12from .exceptions import (
 13    AmbiguityError,
 14    BadMigrationError,
 15    InconsistentMigrationHistory,
 16    NodeNotFoundError,
 17)
 18
 19if TYPE_CHECKING:
 20    from plain.models.migrations.migration import Migration
 21    from plain.models.postgres.connection import DatabaseConnection
 22
 23MIGRATIONS_MODULE_NAME = "migrations"
 24
 25
 26class MigrationLoader:
 27    """
 28    Load migration files from disk and their status from the database.
 29
 30    Migration files are expected to live in the "migrations" directory of
 31    an app. Their names are entirely unimportant from a code perspective,
 32    but will probably follow the 1234_name.py convention.
 33
 34    On initialization, this class will scan those directories, and open and
 35    read the Python files, looking for a class called Migration, which should
 36    inherit from plain.models.migrations.Migration. See
 37    plain.models.migrations.migration for what that looks like.
 38
 39    Some migrations will be marked as "replacing" another set of migrations.
 40    These are loaded into a separate set of migrations away from the main ones.
 41    If all the migrations they replace are either unapplied or missing from
 42    disk, then they are injected into the main set, replacing the named migrations.
 43    Any dependency pointers to the replaced migrations are re-pointed to the
 44    new migration.
 45
 46    This does mean that this class MUST also talk to the database as well as
 47    to disk, but this is probably fine. We're already not just operating
 48    in memory.
 49    """
 50
 51    def __init__(
 52        self,
 53        connection: DatabaseConnection | None,
 54        load: bool = True,
 55        ignore_no_migrations: bool = False,
 56        replace_migrations: bool = True,
 57    ):
 58        self.connection = connection
 59        self.disk_migrations: dict[tuple[str, str], Migration] | None = None
 60        self.applied_migrations: dict[tuple[str, str], Any] | None = None
 61        self.ignore_no_migrations = ignore_no_migrations
 62        self.replace_migrations = replace_migrations
 63        self.unmigrated_packages: set[str]
 64        self.migrated_packages: set[str]
 65        self.graph: MigrationGraph
 66        self.replacements: dict[tuple[str, str], Migration]
 67        if load:
 68            self.build_graph()
 69
 70    @classmethod
 71    def migrations_module(cls, package_label: str) -> tuple[str | None, bool]:
 72        """
 73        Return the path to the migrations module for the specified package_label
 74        and a boolean indicating if the module is specified in
 75        settings.MIGRATION_MODULE.
 76        """
 77
 78        # This package (plain-models) has different code under migrations/
 79        if package_label == "plainmodels":
 80            return None, True
 81
 82        app = packages_registry.get_package_config(package_label)
 83        return f"{app.name}.{MIGRATIONS_MODULE_NAME}", False
 84
 85    def load_disk(self) -> None:
 86        """Load the migrations from all INSTALLED_PACKAGES from disk."""
 87        self.disk_migrations = {}
 88        self.unmigrated_packages = set()
 89        self.migrated_packages = set()
 90        for package_config in packages_registry.get_package_configs():
 91            # Get the migrations module directory
 92            module_name, explicit = self.migrations_module(package_config.package_label)
 93            if module_name is None:
 94                self.unmigrated_packages.add(package_config.package_label)
 95                continue
 96            was_loaded = module_name in sys.modules
 97            try:
 98                module = import_module(module_name)
 99            except ModuleNotFoundError as e:
100                if (explicit and self.ignore_no_migrations) or (
101                    not explicit
102                    and e.name is not None
103                    and MIGRATIONS_MODULE_NAME in e.name.split(".")
104                ):
105                    self.unmigrated_packages.add(package_config.package_label)
106                    continue
107                raise
108            else:
109                # Module is not a package (e.g. migrations.py).
110                if not hasattr(module, "__path__"):
111                    self.unmigrated_packages.add(package_config.package_label)
112                    continue
113                # Empty directories are namespaces. Namespace packages have no
114                # __file__ and don't use a list for __path__. See
115                # https://docs.python.org/3/reference/import.html#namespace-packages
116                if getattr(module, "__file__", None) is None and not isinstance(
117                    module.__path__, list
118                ):
119                    self.unmigrated_packages.add(package_config.package_label)
120                    continue
121                # Force a reload if it's already loaded (tests need this)
122                if was_loaded:
123                    reload(module)
124            self.migrated_packages.add(package_config.package_label)
125            migration_names = {
126                name
127                for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
128                if not is_pkg and name[0] not in "_~"
129            }
130            # Load migrations
131            for migration_name in migration_names:
132                migration_path = f"{module_name}.{migration_name}"
133                try:
134                    migration_module = import_module(migration_path)
135                except ImportError as e:
136                    if "bad magic number" in str(e):
137                        raise ImportError(
138                            f"Couldn't import {migration_path!r} as it appears to be a stale "
139                            ".pyc file."
140                        ) from e
141                    else:
142                        raise
143                if not hasattr(migration_module, "Migration"):
144                    raise BadMigrationError(
145                        f"Migration {migration_name} in app {package_config.package_label} has no Migration class"
146                    )
147                self.disk_migrations[package_config.package_label, migration_name] = (
148                    migration_module.Migration(
149                        migration_name,
150                        package_config.package_label,
151                    )
152                )
153
154    def get_migration(self, package_label: str, name_prefix: str) -> Migration | None:
155        """Return the named migration or raise NodeNotFoundError."""
156        return self.graph.nodes[package_label, name_prefix]
157
158    def get_migration_by_prefix(
159        self, package_label: str, name_prefix: str
160    ) -> Migration:
161        """
162        Return the migration(s) which match the given app label and name_prefix.
163        """
164        # Do the search
165        assert self.disk_migrations is not None, "load_disk() must be called first"
166        results = []
167        for migration_package_label, migration_name in self.disk_migrations:
168            if migration_package_label == package_label and migration_name.startswith(
169                name_prefix
170            ):
171                results.append((migration_package_label, migration_name))
172        if len(results) > 1:
173            raise AmbiguityError(
174                f"There is more than one migration for '{package_label}' with the prefix '{name_prefix}'"
175            )
176        elif not results:
177            raise KeyError(
178                f"There is no migration for '{package_label}' with the prefix "
179                f"'{name_prefix}'"
180            )
181        else:
182            return self.disk_migrations[results[0]]
183
184    def check_key(
185        self, key: tuple[str, str], current_package: str
186    ) -> tuple[str, str] | None:
187        if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
188            return key
189        # Special-case __first__, which means "the first migration" for
190        # migrated packages, and is ignored for unmigrated packages. It allows
191        # makemigrations to declare dependencies on packages before they even have
192        # migrations.
193        if key[0] == current_package:
194            # Ignore __first__ references to the same app (#22325)
195            return None
196        if key[0] in self.unmigrated_packages:
197            # This app isn't migrated, but something depends on it.
198            # The models will get auto-added into the state, though
199            # so we're fine.
200            return None
201        if key[0] in self.migrated_packages:
202            try:
203                if key[1] == "__first__":
204                    return self.graph.root_nodes(key[0])[0]
205                else:  # "__latest__"
206                    return self.graph.leaf_nodes(key[0])[0]
207            except IndexError:
208                if self.ignore_no_migrations:
209                    return None
210                else:
211                    raise ValueError(f"Dependency on app with no migrations: {key[0]}")
212        raise ValueError(f"Dependency on unknown app: {key[0]}")
213
214    def add_internal_dependencies(
215        self, key: tuple[str, str], migration: Migration
216    ) -> None:
217        """
218        Internal dependencies need to be added first to ensure `__first__`
219        dependencies find the correct root node.
220        """
221        for parent in migration.dependencies:
222            # Ignore __first__ references to the same app.
223            if parent[0] == key[0] and parent[1] != "__first__":
224                # Migration object is used only for error messages in add_dependency
225                self.graph.add_dependency(migration, key, parent, skip_validation=True)
226
227    def add_external_dependencies(
228        self, key: tuple[str, str], migration: Migration
229    ) -> None:
230        for parent in migration.dependencies:
231            # Skip internal dependencies
232            if key[0] == parent[0]:
233                continue
234            parent = self.check_key(parent, key[0])
235            if parent is not None:
236                # Migration object is used only for error messages in add_dependency
237                self.graph.add_dependency(migration, key, parent, skip_validation=True)
238
239    def build_graph(self) -> None:
240        """
241        Build a migration dependency graph using both the disk and database.
242        You'll need to rebuild the graph if you apply migrations. This isn't
243        usually a problem as generally migration stuff runs in a one-shot process.
244        """
245        # Load disk data
246        self.load_disk()
247        assert self.disk_migrations is not None  # load_disk() ensures this
248        # Load database data
249        if self.connection is None:
250            self.applied_migrations = {}
251        else:
252            recorder = MigrationRecorder(self.connection)
253            self.applied_migrations = recorder.applied_migrations()
254        # To start, populate the migration graph with nodes for ALL migrations
255        # and their dependencies. Also make note of replacing migrations at this step.
256        self.graph = MigrationGraph()
257        self.replacements = {}
258        for key, migration in self.disk_migrations.items():
259            self.graph.add_node(key, migration)
260            # Replacing migrations.
261            if migration.replaces:
262                self.replacements[key] = migration
263        for key, migration in self.disk_migrations.items():
264            # Internal (same app) dependencies.
265            self.add_internal_dependencies(key, migration)
266        # Add external dependencies now that the internal ones have been resolved.
267        for key, migration in self.disk_migrations.items():
268            self.add_external_dependencies(key, migration)
269        # Carry out replacements where possible and if enabled.
270        if self.replace_migrations:
271            for key, migration in self.replacements.items():
272                # Get applied status of each of this migration's replacement
273                # targets.
274                applied_statuses = [
275                    (target in self.applied_migrations) for target in migration.replaces
276                ]
277                # The replacing migration is only marked as applied if all of
278                # its replacement targets are.
279                if all(applied_statuses):
280                    self.applied_migrations[key] = migration
281                else:
282                    self.applied_migrations.pop(key, None)
283                # A replacing migration can be used if either all or none of
284                # its replacement targets have been applied.
285                if all(applied_statuses) or (not any(applied_statuses)):
286                    self.graph.remove_replaced_nodes(key, migration.replaces)
287                else:
288                    # This replacing migration cannot be used because it is
289                    # partially applied. Remove it from the graph and remap
290                    # dependencies to it (#25945).
291                    self.graph.remove_replacement_node(key, migration.replaces)
292        # Ensure the graph is consistent.
293        try:
294            self.graph.validate_consistency()
295        except NodeNotFoundError as exc:
296            # Check if the missing node could have been replaced by any squash
297            # migration but wasn't because the squash migration was partially
298            # applied before. In that case raise a more understandable exception
299            # (#23556).
300            # Get reverse replacements.
301            reverse_replacements = {}
302            for key, migration in self.replacements.items():
303                for replaced in migration.replaces:
304                    reverse_replacements.setdefault(replaced, set()).add(key)
305            # Try to reraise exception with more detail.
306            if exc.node in reverse_replacements:
307                candidates = reverse_replacements.get(exc.node, set())
308                is_replaced = any(
309                    candidate in self.graph.nodes for candidate in candidates
310                )
311                if not is_replaced:
312                    tries = ", ".join("{}.{}".format(*c) for c in candidates)
313                    raise NodeNotFoundError(
314                        f"Migration {exc.origin} depends on nonexistent node ('{exc.node[0]}', '{exc.node[1]}'). "
315                        f"Plain tried to replace migration {exc.node[0]}.{exc.node[1]} with any of [{tries}] "
316                        "but wasn't able to because some of the replaced migrations "
317                        "are already applied.",
318                        exc.node,
319                    ) from exc
320            raise
321        self.graph.ensure_not_cyclic()
322
323    def check_consistent_history(self, connection: DatabaseConnection) -> None:
324        """
325        Raise InconsistentMigrationHistory if any applied migrations have
326        unapplied dependencies.
327        """
328        recorder = MigrationRecorder(connection)
329        applied = recorder.applied_migrations()
330        for migration in applied:
331            # If the migration is unknown, skip it.
332            if migration not in self.graph.nodes:
333                continue
334            for parent in self.graph.node_map[migration].parents:
335                if parent not in applied:
336                    # Skip unapplied squashed migrations that have all of their
337                    # `replaces` applied.
338                    # Use parent.key for dict lookup (Node.__eq__ allows `in` check)
339                    if parent.key in self.replacements:
340                        if all(
341                            m in applied for m in self.replacements[parent.key].replaces
342                        ):
343                            continue
344                    raise InconsistentMigrationHistory(
345                        f"Migration {migration[0]}.{migration[1]} is applied before its dependency "
346                        f"{parent[0]}.{parent[1]} on the database."
347                    )
348
349    def detect_conflicts(self) -> dict[str, list[str]]:
350        """
351        Look through the loaded graph and detect any conflicts - packages
352        with more than one leaf migration. Return a dict of the app labels
353        that conflict with the migration names that conflict.
354        """
355        seen_packages = {}
356        conflicting_packages = set()
357        for package_label, migration_name in self.graph.leaf_nodes():
358            if package_label in seen_packages:
359                conflicting_packages.add(package_label)
360            seen_packages.setdefault(package_label, set()).add(migration_name)
361        return {
362            package_label: sorted(seen_packages[package_label])
363            for package_label in conflicting_packages
364        }
365
366    def project_state(
367        self, nodes: tuple[str, str] | None = None, at_end: bool = True
368    ) -> Any:
369        """
370        Return a ProjectState object representing the most recent state
371        that the loaded migrations represent.
372
373        See graph.make_state() for the meaning of "nodes" and "at_end".
374        """
375        return self.graph.make_state(
376            nodes=nodes, at_end=at_end, real_packages=self.unmigrated_packages
377        )