Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from contextlib import nullcontext
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.models.connections import DatabaseConnection
  8
  9from ..transaction import atomic
 10from .loader import MigrationLoader
 11from .migration import Migration
 12from .recorder import MigrationRecorder
 13from .state import ProjectState
 14
 15if TYPE_CHECKING:
 16    from plain.models.backends.base.base import BaseDatabaseWrapper
 17
 18
 19class MigrationExecutor:
 20    """
 21    End-to-end migration execution - load migrations and run them up or down
 22    to a specified set of targets.
 23    """
 24
 25    def __init__(
 26        self,
 27        connection: BaseDatabaseWrapper | DatabaseConnection,
 28        progress_callback: Callable[..., Any] | None = None,
 29    ) -> None:
 30        self.connection = connection
 31        self.loader = MigrationLoader(self.connection)
 32        self.recorder = MigrationRecorder(self.connection)
 33        self.progress_callback = progress_callback
 34
 35    def migration_plan(
 36        self, targets: list[tuple[str, str]], clean_start: bool = False
 37    ) -> list[Migration]:
 38        """
 39        Given a set of targets, return a list of Migration instances.
 40        """
 41        plan = []
 42        if clean_start:
 43            applied = {}
 44        else:
 45            applied_source = self.loader.applied_migrations or {}
 46            applied = dict(applied_source)
 47        for target in targets:
 48            for migration in self.loader.graph.forwards_plan(target):
 49                if migration not in applied:
 50                    plan.append(self.loader.graph.nodes[migration])
 51                    applied[migration] = self.loader.graph.nodes[migration]
 52        return plan
 53
 54    def _create_project_state(
 55        self, with_applied_migrations: bool = False
 56    ) -> ProjectState:
 57        """
 58        Create a project state including all the applications without
 59        migrations and applied migrations if with_applied_migrations=True.
 60        """
 61        state = ProjectState(real_packages=self.loader.unmigrated_packages)
 62        if with_applied_migrations:
 63            # Create the forwards plan Plain would follow on an empty database
 64            full_plan = self.migration_plan(
 65                self.loader.graph.leaf_nodes(), clean_start=True
 66            )
 67            applied_source = self.loader.applied_migrations or {}
 68            applied_migrations = {
 69                self.loader.graph.nodes[key]
 70                for key in applied_source
 71                if key in self.loader.graph.nodes
 72            }
 73            for migration in full_plan:
 74                if migration in applied_migrations:
 75                    migration.mutate_state(state, preserve=False)
 76        return state
 77
 78    def migrate(
 79        self,
 80        targets: list[tuple[str, str]],
 81        plan: list[Migration] | None = None,
 82        state: ProjectState | None = None,
 83        fake: bool = False,
 84        atomic_batch: bool = False,
 85    ) -> ProjectState:
 86        """
 87        Migrate the database up to the given targets.
 88
 89        Plain first needs to create all project states before a migration is
 90        (un)applied and in a second step run all the database operations.
 91
 92        atomic_batch: Whether to run all migrations in a single transaction.
 93        """
 94        # The plain_migrations table must be present to record applied
 95        # migrations, but don't create it if there are no migrations to apply.
 96        if plan == []:
 97            if not self.recorder.has_table():
 98                return self._create_project_state(with_applied_migrations=False)
 99        else:
100            self.recorder.ensure_schema()
101
102        if plan is None:
103            plan = self.migration_plan(targets)
104        # Create the forwards plan Plain would follow on an empty database
105        full_plan = self.migration_plan(
106            self.loader.graph.leaf_nodes(), clean_start=True
107        )
108
109        if not plan:
110            if state is None:
111                # The resulting state should include applied migrations.
112                state = self._create_project_state(with_applied_migrations=True)
113        else:
114            if state is None:
115                # The resulting state should still include applied migrations.
116                state = self._create_project_state(with_applied_migrations=True)
117
118            migrations_to_run = set(plan)
119
120            # Choose context manager based on atomic_batch
121            batch_context = atomic if (atomic_batch and len(plan) > 1) else nullcontext
122
123            with batch_context():
124                for migration in full_plan:
125                    if not migrations_to_run:
126                        # We remove every migration that we applied from these sets so
127                        # that we can bail out once the last migration has been applied
128                        # and don't always run until the very end of the migration
129                        # process.
130                        break
131                    if migration in migrations_to_run:
132                        if "models_registry" not in state.__dict__:
133                            state.models_registry  # Render all -- performance critical
134                        state = self.apply_migration(state, migration, fake=fake)
135                        migrations_to_run.remove(migration)
136
137        self.check_replacements()
138
139        assert state is not None
140        return state
141
142    def apply_migration(
143        self, state: ProjectState, migration: Migration, fake: bool = False
144    ) -> ProjectState:
145        """Run a migration forwards."""
146        migration_recorded = False
147        if self.progress_callback:
148            self.progress_callback("apply_start", migration=migration, fake=fake)
149        if not fake:
150            # Alright, do it normally
151            with self.connection.schema_editor(
152                atomic=migration.atomic
153            ) as schema_editor:
154                state = migration.apply(
155                    state, schema_editor, operation_callback=self.progress_callback
156                )
157                if not schema_editor.deferred_sql:
158                    self.record_migration(migration)
159                    migration_recorded = True
160        if not migration_recorded:
161            self.record_migration(migration)
162        # Report progress
163        if self.progress_callback:
164            self.progress_callback("apply_success", migration=migration, fake=fake)
165        return state
166
167    def record_migration(self, migration: Migration) -> None:
168        # For replacement migrations, record individual statuses
169        if migration.replaces:
170            for package_label, name in migration.replaces:
171                self.recorder.record_applied(package_label, name)
172        else:
173            self.recorder.record_applied(migration.package_label, migration.name)
174
175    def check_replacements(self) -> None:
176        """
177        Mark replacement migrations applied if their replaced set all are.
178
179        Do this unconditionally on every migrate, rather than just when
180        migrations are applied or unapplied, to correctly handle the case
181        when a new squash migration is pushed to a deployment that already had
182        all its replaced migrations applied. In this case no new migration will
183        be applied, but the applied state of the squashed migration must be
184        maintained.
185        """
186        applied = self.recorder.applied_migrations()
187        for key, migration in self.loader.replacements.items():
188            all_applied = all(m in applied for m in migration.replaces)
189            if all_applied and key not in applied:
190                self.recorder.record_applied(*key)