1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from contextlib import nullcontext
  5from typing import TYPE_CHECKING, Any
  6
  7from ..transaction import atomic
  8from .loader import MigrationLoader
  9from .migration import Migration
 10from .recorder import MigrationRecorder
 11from .state import ProjectState
 12
 13if TYPE_CHECKING:
 14    from plain.postgres.connection import DatabaseConnection
 15
 16
 17class MigrationExecutor:
 18    """
 19    End-to-end migration execution - load migrations and run them up or down
 20    to a specified set of targets.
 21    """
 22
 23    def __init__(
 24        self,
 25        connection: DatabaseConnection,
 26        progress_callback: Callable[..., Any] | None = None,
 27    ) -> None:
 28        self.connection = connection
 29        self.loader = MigrationLoader(self.connection)
 30        self.recorder = MigrationRecorder(self.connection)
 31        self.progress_callback = progress_callback
 32
 33    def migration_plan(
 34        self, targets: list[tuple[str, str]], clean_start: bool = False
 35    ) -> list[Migration]:
 36        """
 37        Given a set of targets, return a list of Migration instances.
 38        """
 39        plan = []
 40        if clean_start:
 41            applied = {}
 42        else:
 43            applied_source = self.loader.applied_migrations or {}
 44            applied = dict(applied_source)
 45        for target in targets:
 46            for migration in self.loader.graph.forwards_plan(target):
 47                if migration not in applied:
 48                    plan.append(self.loader.graph.nodes[migration])
 49                    applied[migration] = self.loader.graph.nodes[migration]
 50        return plan
 51
 52    def _create_project_state(
 53        self, with_applied_migrations: bool = False
 54    ) -> ProjectState:
 55        """
 56        Create a project state including all the applications without
 57        migrations and applied migrations if with_applied_migrations=True.
 58        """
 59        state = ProjectState(real_packages=self.loader.unmigrated_packages)
 60        if with_applied_migrations:
 61            # Create the forwards plan Plain would follow on an empty database
 62            full_plan = self.migration_plan(
 63                self.loader.graph.leaf_nodes(), clean_start=True
 64            )
 65            applied_source = self.loader.applied_migrations or {}
 66            applied_migrations = {
 67                self.loader.graph.nodes[key]
 68                for key in applied_source
 69                if key in self.loader.graph.nodes
 70            }
 71            for migration in full_plan:
 72                if migration in applied_migrations:
 73                    migration.mutate_state(state, preserve=False)
 74        return state
 75
 76    def migrate(
 77        self,
 78        targets: list[tuple[str, str]],
 79        plan: list[Migration] | None = None,
 80        state: ProjectState | None = None,
 81        fake: bool = False,
 82        atomic_batch: bool = False,
 83    ) -> ProjectState:
 84        """
 85        Migrate the database up to the given targets.
 86
 87        Plain first needs to create all project states before a migration is
 88        (un)applied and in a second step run all the database operations.
 89
 90        atomic_batch: Whether to run all migrations in a single transaction.
 91        """
 92        # The plain_migrations table must be present to record applied
 93        # migrations, but don't create it if there are no migrations to apply.
 94        if plan == []:
 95            if not self.recorder.has_table():
 96                return self._create_project_state(with_applied_migrations=False)
 97        else:
 98            self.recorder.ensure_schema()
 99
100        if plan is None:
101            plan = self.migration_plan(targets)
102        # Create the forwards plan Plain would follow on an empty database
103        full_plan = self.migration_plan(
104            self.loader.graph.leaf_nodes(), clean_start=True
105        )
106
107        if not plan:
108            if state is None:
109                # The resulting state should include applied migrations.
110                state = self._create_project_state(with_applied_migrations=True)
111        else:
112            if state is None:
113                # The resulting state should still include applied migrations.
114                state = self._create_project_state(with_applied_migrations=True)
115
116            migrations_to_run = set(plan)
117
118            # Choose context manager based on atomic_batch
119            batch_context = atomic if (atomic_batch and len(plan) > 1) else nullcontext
120
121            with batch_context():
122                for migration in full_plan:
123                    if not migrations_to_run:
124                        # We remove every migration that we applied from these sets so
125                        # that we can bail out once the last migration has been applied
126                        # and don't always run until the very end of the migration
127                        # process.
128                        break
129                    if migration in migrations_to_run:
130                        if "models_registry" not in state.__dict__:
131                            state.models_registry  # Render all -- performance critical
132                        state = self.apply_migration(state, migration, fake=fake)
133                        migrations_to_run.remove(migration)
134
135        self.check_replacements()
136
137        assert state is not None
138        return state
139
140    def apply_migration(
141        self, state: ProjectState, migration: Migration, fake: bool = False
142    ) -> ProjectState:
143        """Run a migration forwards."""
144        migration_recorded = False
145        if self.progress_callback:
146            self.progress_callback("apply_start", migration=migration, fake=fake)
147        if not fake:
148            # Alright, do it normally
149            with self.connection.schema_editor(
150                atomic=migration.atomic
151            ) as schema_editor:
152                state = migration.apply(
153                    state, schema_editor, operation_callback=self.progress_callback
154                )
155                if not schema_editor.deferred_sql:
156                    self.record_migration(migration)
157                    migration_recorded = True
158        if not migration_recorded:
159            self.record_migration(migration)
160        # Report progress
161        if self.progress_callback:
162            self.progress_callback("apply_success", migration=migration, fake=fake)
163        return state
164
165    def record_migration(self, migration: Migration) -> None:
166        # For replacement migrations, record individual statuses
167        if migration.replaces:
168            for package_label, name in migration.replaces:
169                self.recorder.record_applied(package_label, name)
170        else:
171            self.recorder.record_applied(migration.package_label, migration.name)
172
173    def check_replacements(self) -> None:
174        """
175        Mark replacement migrations applied if their replaced set all are.
176
177        Do this unconditionally on every migrate, rather than just when
178        migrations are applied or unapplied, to correctly handle the case
179        when a new squash migration is pushed to a deployment that already had
180        all its replaced migrations applied. In this case no new migration will
181        be applied, but the applied state of the squashed migration must be
182        maintained.
183        """
184        applied = self.recorder.applied_migrations()
185        for key, migration in self.loader.replacements.items():
186            all_applied = all(m in applied for m in migration.replaces)
187            if all_applied and key not in applied:
188                self.recorder.record_applied(*key)