Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from contextlib import nullcontext
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.models.connections import DatabaseConnection
  8
  9from ..transaction import atomic
 10from .loader import MigrationLoader
 11from .migration import Migration
 12from .recorder import MigrationRecorder
 13from .state import ProjectState
 14
 15if TYPE_CHECKING:
 16    from plain.models.backends.base.base import BaseDatabaseWrapper
 17
 18
 19class MigrationExecutor:
 20    """
 21    End-to-end migration execution - load migrations and run them up or down
 22    to a specified set of targets.
 23    """
 24
 25    def __init__(
 26        self,
 27        connection: BaseDatabaseWrapper | DatabaseConnection,
 28        progress_callback: Callable[..., Any] | None = None,
 29    ) -> None:
 30        self.connection = connection
 31        self.loader = MigrationLoader(self.connection)
 32        self.recorder = MigrationRecorder(self.connection)
 33        self.progress_callback = progress_callback
 34
 35    def migration_plan(
 36        self, targets: list[tuple[str, str]], clean_start: bool = False
 37    ) -> list[Migration]:
 38        """
 39        Given a set of targets, return a list of Migration instances.
 40        """
 41        plan = []
 42        if clean_start:
 43            applied = {}
 44        else:
 45            applied = dict(self.loader.applied_migrations)
 46        for target in targets:
 47            for migration in self.loader.graph.forwards_plan(target):
 48                if migration not in applied:
 49                    plan.append(self.loader.graph.nodes[migration])
 50                    applied[migration] = self.loader.graph.nodes[migration]
 51        return plan
 52
 53    def _create_project_state(
 54        self, with_applied_migrations: bool = False
 55    ) -> ProjectState:
 56        """
 57        Create a project state including all the applications without
 58        migrations and applied migrations if with_applied_migrations=True.
 59        """
 60        state = ProjectState(real_packages=self.loader.unmigrated_packages)
 61        if with_applied_migrations:
 62            # Create the forwards plan Plain would follow on an empty database
 63            full_plan = self.migration_plan(
 64                self.loader.graph.leaf_nodes(), clean_start=True
 65            )
 66            applied_migrations = {
 67                self.loader.graph.nodes[key]
 68                for key in self.loader.applied_migrations
 69                if key in self.loader.graph.nodes
 70            }
 71            for migration in full_plan:
 72                if migration in applied_migrations:
 73                    migration.mutate_state(state, preserve=False)
 74        return state
 75
 76    def migrate(
 77        self,
 78        targets: list[tuple[str, str]],
 79        plan: list[Migration] | None = None,
 80        state: ProjectState | None = None,
 81        fake: bool = False,
 82        atomic_batch: bool = False,
 83    ) -> ProjectState:
 84        """
 85        Migrate the database up to the given targets.
 86
 87        Plain first needs to create all project states before a migration is
 88        (un)applied and in a second step run all the database operations.
 89
 90        atomic_batch: Whether to run all migrations in a single transaction.
 91        """
 92        # The plain_migrations table must be present to record applied
 93        # migrations, but don't create it if there are no migrations to apply.
 94        if plan == []:
 95            if not self.recorder.has_table():
 96                return self._create_project_state(with_applied_migrations=False)
 97        else:
 98            self.recorder.ensure_schema()
 99
100        if plan is None:
101            plan = self.migration_plan(targets)
102        # Create the forwards plan Plain would follow on an empty database
103        full_plan = self.migration_plan(
104            self.loader.graph.leaf_nodes(), clean_start=True
105        )
106
107        if not plan:
108            if state is None:
109                # The resulting state should include applied migrations.
110                state = self._create_project_state(with_applied_migrations=True)
111        else:
112            if state is None:
113                # The resulting state should still include applied migrations.
114                state = self._create_project_state(with_applied_migrations=True)
115
116            migrations_to_run = set(plan)
117
118            # Choose context manager based on atomic_batch
119            batch_context = atomic if (atomic_batch and len(plan) > 1) else nullcontext
120
121            with batch_context():
122                for migration in full_plan:
123                    if not migrations_to_run:
124                        # We remove every migration that we applied from these sets so
125                        # that we can bail out once the last migration has been applied
126                        # and don't always run until the very end of the migration
127                        # process.
128                        break
129                    if migration in migrations_to_run:
130                        if "models_registry" not in state.__dict__:
131                            state.models_registry  # Render all -- performance critical
132                        state = self.apply_migration(state, migration, fake=fake)
133                        migrations_to_run.remove(migration)
134
135        self.check_replacements()
136
137        assert state is not None
138        return state
139
140    def apply_migration(
141        self, state: ProjectState, migration: Migration, fake: bool = False
142    ) -> ProjectState:
143        """Run a migration forwards."""
144        migration_recorded = False
145        if self.progress_callback:
146            self.progress_callback("apply_start", migration=migration, fake=fake)
147        if not fake:
148            # Alright, do it normally
149            with self.connection.schema_editor(
150                atomic=migration.atomic
151            ) as schema_editor:
152                state = migration.apply(
153                    state, schema_editor, operation_callback=self.progress_callback
154                )
155                if not schema_editor.deferred_sql:
156                    self.record_migration(migration)
157                    migration_recorded = True
158        if not migration_recorded:
159            self.record_migration(migration)
160        # Report progress
161        if self.progress_callback:
162            self.progress_callback("apply_success", migration=migration, fake=fake)
163        return state
164
165    def record_migration(self, migration: Migration) -> None:
166        # For replacement migrations, record individual statuses
167        if migration.replaces:
168            for package_label, name in migration.replaces:
169                self.recorder.record_applied(package_label, name)
170        else:
171            self.recorder.record_applied(migration.package_label, migration.name)
172
173    def check_replacements(self) -> None:
174        """
175        Mark replacement migrations applied if their replaced set all are.
176
177        Do this unconditionally on every migrate, rather than just when
178        migrations are applied or unapplied, to correctly handle the case
179        when a new squash migration is pushed to a deployment that already had
180        all its replaced migrations applied. In this case no new migration will
181        be applied, but the applied state of the squashed migration must be
182        maintained.
183        """
184        applied = self.recorder.applied_migrations()
185        for key, migration in self.loader.replacements.items():
186            all_applied = all(m in applied for m in migration.replaces)
187            if all_applied and key not in applied:
188                self.recorder.record_applied(*key)