1from __future__ import annotations
2
3from collections.abc import Callable
4from contextlib import nullcontext
5from typing import TYPE_CHECKING, Any
6
7from plain.models.connections import DatabaseConnection
8
9from ..transaction import atomic
10from .loader import MigrationLoader
11from .migration import Migration
12from .recorder import MigrationRecorder
13from .state import ProjectState
14
15if TYPE_CHECKING:
16 from plain.models.backends.base.base import BaseDatabaseWrapper
17
18
19class MigrationExecutor:
20 """
21 End-to-end migration execution - load migrations and run them up or down
22 to a specified set of targets.
23 """
24
25 def __init__(
26 self,
27 connection: BaseDatabaseWrapper | DatabaseConnection,
28 progress_callback: Callable[..., Any] | None = None,
29 ) -> None:
30 self.connection = connection
31 self.loader = MigrationLoader(self.connection)
32 self.recorder = MigrationRecorder(self.connection)
33 self.progress_callback = progress_callback
34
35 def migration_plan(
36 self, targets: list[tuple[str, str]], clean_start: bool = False
37 ) -> list[Migration]:
38 """
39 Given a set of targets, return a list of Migration instances.
40 """
41 plan = []
42 if clean_start:
43 applied = {}
44 else:
45 applied = dict(self.loader.applied_migrations)
46 for target in targets:
47 for migration in self.loader.graph.forwards_plan(target):
48 if migration not in applied:
49 plan.append(self.loader.graph.nodes[migration])
50 applied[migration] = self.loader.graph.nodes[migration]
51 return plan
52
53 def _create_project_state(
54 self, with_applied_migrations: bool = False
55 ) -> ProjectState:
56 """
57 Create a project state including all the applications without
58 migrations and applied migrations if with_applied_migrations=True.
59 """
60 state = ProjectState(real_packages=self.loader.unmigrated_packages)
61 if with_applied_migrations:
62 # Create the forwards plan Plain would follow on an empty database
63 full_plan = self.migration_plan(
64 self.loader.graph.leaf_nodes(), clean_start=True
65 )
66 applied_migrations = {
67 self.loader.graph.nodes[key]
68 for key in self.loader.applied_migrations
69 if key in self.loader.graph.nodes
70 }
71 for migration in full_plan:
72 if migration in applied_migrations:
73 migration.mutate_state(state, preserve=False)
74 return state
75
76 def migrate(
77 self,
78 targets: list[tuple[str, str]],
79 plan: list[Migration] | None = None,
80 state: ProjectState | None = None,
81 fake: bool = False,
82 atomic_batch: bool = False,
83 ) -> ProjectState:
84 """
85 Migrate the database up to the given targets.
86
87 Plain first needs to create all project states before a migration is
88 (un)applied and in a second step run all the database operations.
89
90 atomic_batch: Whether to run all migrations in a single transaction.
91 """
92 # The plain_migrations table must be present to record applied
93 # migrations, but don't create it if there are no migrations to apply.
94 if plan == []:
95 if not self.recorder.has_table():
96 return self._create_project_state(with_applied_migrations=False)
97 else:
98 self.recorder.ensure_schema()
99
100 if plan is None:
101 plan = self.migration_plan(targets)
102 # Create the forwards plan Plain would follow on an empty database
103 full_plan = self.migration_plan(
104 self.loader.graph.leaf_nodes(), clean_start=True
105 )
106
107 if not plan:
108 if state is None:
109 # The resulting state should include applied migrations.
110 state = self._create_project_state(with_applied_migrations=True)
111 else:
112 if state is None:
113 # The resulting state should still include applied migrations.
114 state = self._create_project_state(with_applied_migrations=True)
115
116 migrations_to_run = set(plan)
117
118 # Choose context manager based on atomic_batch
119 batch_context = atomic if (atomic_batch and len(plan) > 1) else nullcontext
120
121 with batch_context():
122 for migration in full_plan:
123 if not migrations_to_run:
124 # We remove every migration that we applied from these sets so
125 # that we can bail out once the last migration has been applied
126 # and don't always run until the very end of the migration
127 # process.
128 break
129 if migration in migrations_to_run:
130 if "models_registry" not in state.__dict__:
131 state.models_registry # Render all -- performance critical
132 state = self.apply_migration(state, migration, fake=fake)
133 migrations_to_run.remove(migration)
134
135 self.check_replacements()
136
137 assert state is not None
138 return state
139
140 def apply_migration(
141 self, state: ProjectState, migration: Migration, fake: bool = False
142 ) -> ProjectState:
143 """Run a migration forwards."""
144 migration_recorded = False
145 if self.progress_callback:
146 self.progress_callback("apply_start", migration=migration, fake=fake)
147 if not fake:
148 # Alright, do it normally
149 with self.connection.schema_editor(
150 atomic=migration.atomic
151 ) as schema_editor:
152 state = migration.apply(
153 state, schema_editor, operation_callback=self.progress_callback
154 )
155 if not schema_editor.deferred_sql:
156 self.record_migration(migration)
157 migration_recorded = True
158 if not migration_recorded:
159 self.record_migration(migration)
160 # Report progress
161 if self.progress_callback:
162 self.progress_callback("apply_success", migration=migration, fake=fake)
163 return state
164
165 def record_migration(self, migration: Migration) -> None:
166 # For replacement migrations, record individual statuses
167 if migration.replaces:
168 for package_label, name in migration.replaces:
169 self.recorder.record_applied(package_label, name)
170 else:
171 self.recorder.record_applied(migration.package_label, migration.name)
172
173 def check_replacements(self) -> None:
174 """
175 Mark replacement migrations applied if their replaced set all are.
176
177 Do this unconditionally on every migrate, rather than just when
178 migrations are applied or unapplied, to correctly handle the case
179 when a new squash migration is pushed to a deployment that already had
180 all its replaced migrations applied. In this case no new migration will
181 be applied, but the applied state of the squashed migration must be
182 maintained.
183 """
184 applied = self.recorder.applied_migrations()
185 for key, migration in self.loader.replacements.items():
186 all_applied = all(m in applied for m in migration.replaces)
187 if all_applied and key not in applied:
188 self.recorder.record_applied(*key)