1from __future__ import annotations
2
3from collections.abc import Callable
4from contextlib import nullcontext
5from typing import TYPE_CHECKING, Any
6
7from plain.models.connections import DatabaseConnection
8
9from ..transaction import atomic
10from .loader import MigrationLoader
11from .migration import Migration
12from .recorder import MigrationRecorder
13from .state import ProjectState
14
15if TYPE_CHECKING:
16 from plain.models.backends.base.base import BaseDatabaseWrapper
17
18
19class MigrationExecutor:
20 """
21 End-to-end migration execution - load migrations and run them up or down
22 to a specified set of targets.
23 """
24
25 def __init__(
26 self,
27 connection: BaseDatabaseWrapper | DatabaseConnection,
28 progress_callback: Callable[..., Any] | None = None,
29 ) -> None:
30 self.connection = connection
31 self.loader = MigrationLoader(self.connection)
32 self.recorder = MigrationRecorder(self.connection)
33 self.progress_callback = progress_callback
34
35 def migration_plan(
36 self, targets: list[tuple[str, str]], clean_start: bool = False
37 ) -> list[Migration]:
38 """
39 Given a set of targets, return a list of Migration instances.
40 """
41 plan = []
42 if clean_start:
43 applied = {}
44 else:
45 applied_source = self.loader.applied_migrations or {}
46 applied = dict(applied_source)
47 for target in targets:
48 for migration in self.loader.graph.forwards_plan(target):
49 if migration not in applied:
50 plan.append(self.loader.graph.nodes[migration])
51 applied[migration] = self.loader.graph.nodes[migration]
52 return plan
53
54 def _create_project_state(
55 self, with_applied_migrations: bool = False
56 ) -> ProjectState:
57 """
58 Create a project state including all the applications without
59 migrations and applied migrations if with_applied_migrations=True.
60 """
61 state = ProjectState(real_packages=self.loader.unmigrated_packages)
62 if with_applied_migrations:
63 # Create the forwards plan Plain would follow on an empty database
64 full_plan = self.migration_plan(
65 self.loader.graph.leaf_nodes(), clean_start=True
66 )
67 applied_source = self.loader.applied_migrations or {}
68 applied_migrations = {
69 self.loader.graph.nodes[key]
70 for key in applied_source
71 if key in self.loader.graph.nodes
72 }
73 for migration in full_plan:
74 if migration in applied_migrations:
75 migration.mutate_state(state, preserve=False)
76 return state
77
78 def migrate(
79 self,
80 targets: list[tuple[str, str]],
81 plan: list[Migration] | None = None,
82 state: ProjectState | None = None,
83 fake: bool = False,
84 atomic_batch: bool = False,
85 ) -> ProjectState:
86 """
87 Migrate the database up to the given targets.
88
89 Plain first needs to create all project states before a migration is
90 (un)applied and in a second step run all the database operations.
91
92 atomic_batch: Whether to run all migrations in a single transaction.
93 """
94 # The plain_migrations table must be present to record applied
95 # migrations, but don't create it if there are no migrations to apply.
96 if plan == []:
97 if not self.recorder.has_table():
98 return self._create_project_state(with_applied_migrations=False)
99 else:
100 self.recorder.ensure_schema()
101
102 if plan is None:
103 plan = self.migration_plan(targets)
104 # Create the forwards plan Plain would follow on an empty database
105 full_plan = self.migration_plan(
106 self.loader.graph.leaf_nodes(), clean_start=True
107 )
108
109 if not plan:
110 if state is None:
111 # The resulting state should include applied migrations.
112 state = self._create_project_state(with_applied_migrations=True)
113 else:
114 if state is None:
115 # The resulting state should still include applied migrations.
116 state = self._create_project_state(with_applied_migrations=True)
117
118 migrations_to_run = set(plan)
119
120 # Choose context manager based on atomic_batch
121 batch_context = atomic if (atomic_batch and len(plan) > 1) else nullcontext
122
123 with batch_context():
124 for migration in full_plan:
125 if not migrations_to_run:
126 # We remove every migration that we applied from these sets so
127 # that we can bail out once the last migration has been applied
128 # and don't always run until the very end of the migration
129 # process.
130 break
131 if migration in migrations_to_run:
132 if "models_registry" not in state.__dict__:
133 state.models_registry # Render all -- performance critical
134 state = self.apply_migration(state, migration, fake=fake)
135 migrations_to_run.remove(migration)
136
137 self.check_replacements()
138
139 assert state is not None
140 return state
141
142 def apply_migration(
143 self, state: ProjectState, migration: Migration, fake: bool = False
144 ) -> ProjectState:
145 """Run a migration forwards."""
146 migration_recorded = False
147 if self.progress_callback:
148 self.progress_callback("apply_start", migration=migration, fake=fake)
149 if not fake:
150 # Alright, do it normally
151 with self.connection.schema_editor(
152 atomic=migration.atomic
153 ) as schema_editor:
154 state = migration.apply(
155 state, schema_editor, operation_callback=self.progress_callback
156 )
157 if not schema_editor.deferred_sql:
158 self.record_migration(migration)
159 migration_recorded = True
160 if not migration_recorded:
161 self.record_migration(migration)
162 # Report progress
163 if self.progress_callback:
164 self.progress_callback("apply_success", migration=migration, fake=fake)
165 return state
166
167 def record_migration(self, migration: Migration) -> None:
168 # For replacement migrations, record individual statuses
169 if migration.replaces:
170 for package_label, name in migration.replaces:
171 self.recorder.record_applied(package_label, name)
172 else:
173 self.recorder.record_applied(migration.package_label, migration.name)
174
175 def check_replacements(self) -> None:
176 """
177 Mark replacement migrations applied if their replaced set all are.
178
179 Do this unconditionally on every migrate, rather than just when
180 migrations are applied or unapplied, to correctly handle the case
181 when a new squash migration is pushed to a deployment that already had
182 all its replaced migrations applied. In this case no new migration will
183 be applied, but the applied state of the squashed migration must be
184 maintained.
185 """
186 applied = self.recorder.applied_migrations()
187 for key, migration in self.loader.replacements.items():
188 all_applied = all(m in applied for m in migration.replaces)
189 if all_applied and key not in applied:
190 self.recorder.record_applied(*key)