1from plain.models import migrations
2from plain.models.db import router
3from plain.packages.registry import packages as global_packages
4
5from .exceptions import InvalidMigrationPlan
6from .loader import MigrationLoader
7from .recorder import MigrationRecorder
8from .state import ProjectState
9
10
11class MigrationExecutor:
12 """
13 End-to-end migration execution - load migrations and run them up or down
14 to a specified set of targets.
15 """
16
17 def __init__(self, connection, progress_callback=None):
18 self.connection = connection
19 self.loader = MigrationLoader(self.connection)
20 self.recorder = MigrationRecorder(self.connection)
21 self.progress_callback = progress_callback
22
23 def migration_plan(self, targets, clean_start=False):
24 """
25 Given a set of targets, return a list of (Migration instance, backwards?).
26 """
27 plan = []
28 if clean_start:
29 applied = {}
30 else:
31 applied = dict(self.loader.applied_migrations)
32 for target in targets:
33 # If the target is (package_label, None), that means unmigrate everything
34 if target[1] is None:
35 for root in self.loader.graph.root_nodes():
36 if root[0] == target[0]:
37 for migration in self.loader.graph.backwards_plan(root):
38 if migration in applied:
39 plan.append((self.loader.graph.nodes[migration], True))
40 applied.pop(migration)
41 # If the migration is already applied, do backwards mode,
42 # otherwise do forwards mode.
43 elif target in applied:
44 # If the target is missing, it's likely a replaced migration.
45 # Reload the graph without replacements.
46 if (
47 self.loader.replace_migrations
48 and target not in self.loader.graph.node_map
49 ):
50 self.loader.replace_migrations = False
51 self.loader.build_graph()
52 return self.migration_plan(targets, clean_start=clean_start)
53 # Don't migrate backwards all the way to the target node (that
54 # may roll back dependencies in other packages that don't need to
55 # be rolled back); instead roll back through target's immediate
56 # child(ren) in the same app, and no further.
57 next_in_app = sorted(
58 n
59 for n in self.loader.graph.node_map[target].children
60 if n[0] == target[0]
61 )
62 for node in next_in_app:
63 for migration in self.loader.graph.backwards_plan(node):
64 if migration in applied:
65 plan.append((self.loader.graph.nodes[migration], True))
66 applied.pop(migration)
67 else:
68 for migration in self.loader.graph.forwards_plan(target):
69 if migration not in applied:
70 plan.append((self.loader.graph.nodes[migration], False))
71 applied[migration] = self.loader.graph.nodes[migration]
72 return plan
73
74 def _create_project_state(self, with_applied_migrations=False):
75 """
76 Create a project state including all the applications without
77 migrations and applied migrations if with_applied_migrations=True.
78 """
79 state = ProjectState(real_packages=self.loader.unmigrated_packages)
80 if with_applied_migrations:
81 # Create the forwards plan Plain would follow on an empty database
82 full_plan = self.migration_plan(
83 self.loader.graph.leaf_nodes(), clean_start=True
84 )
85 applied_migrations = {
86 self.loader.graph.nodes[key]
87 for key in self.loader.applied_migrations
88 if key in self.loader.graph.nodes
89 }
90 for migration, _ in full_plan:
91 if migration in applied_migrations:
92 migration.mutate_state(state, preserve=False)
93 return state
94
95 def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
96 """
97 Migrate the database up to the given targets.
98
99 Plain first needs to create all project states before a migration is
100 (un)applied and in a second step run all the database operations.
101 """
102 # The plain_migrations table must be present to record applied
103 # migrations, but don't create it if there are no migrations to apply.
104 if plan == []:
105 if not self.recorder.has_table():
106 return self._create_project_state(with_applied_migrations=False)
107 else:
108 self.recorder.ensure_schema()
109
110 if plan is None:
111 plan = self.migration_plan(targets)
112 # Create the forwards plan Plain would follow on an empty database
113 full_plan = self.migration_plan(
114 self.loader.graph.leaf_nodes(), clean_start=True
115 )
116
117 all_forwards = all(not backwards for mig, backwards in plan)
118 all_backwards = all(backwards for mig, backwards in plan)
119
120 if not plan:
121 if state is None:
122 # The resulting state should include applied migrations.
123 state = self._create_project_state(with_applied_migrations=True)
124 elif all_forwards == all_backwards:
125 # This should only happen if there's a mixed plan
126 raise InvalidMigrationPlan(
127 "Migration plans with both forwards and backwards migrations "
128 "are not supported. Please split your migration process into "
129 "separate plans of only forwards OR backwards migrations.",
130 plan,
131 )
132 elif all_forwards:
133 if state is None:
134 # The resulting state should still include applied migrations.
135 state = self._create_project_state(with_applied_migrations=True)
136 state = self._migrate_all_forwards(
137 state, plan, full_plan, fake=fake, fake_initial=fake_initial
138 )
139 else:
140 # No need to check for `elif all_backwards` here, as that condition
141 # would always evaluate to true.
142 state = self._migrate_all_backwards(plan, full_plan, fake=fake)
143
144 self.check_replacements()
145
146 return state
147
148 def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
149 """
150 Take a list of 2-tuples of the form (migration instance, False) and
151 apply them in the order they occur in the full_plan.
152 """
153 migrations_to_run = {m[0] for m in plan}
154 for migration, _ in full_plan:
155 if not migrations_to_run:
156 # We remove every migration that we applied from these sets so
157 # that we can bail out once the last migration has been applied
158 # and don't always run until the very end of the migration
159 # process.
160 break
161 if migration in migrations_to_run:
162 if "packages" not in state.__dict__:
163 if self.progress_callback:
164 self.progress_callback("render_start")
165 state.packages # Render all -- performance critical
166 if self.progress_callback:
167 self.progress_callback("render_success")
168 state = self.apply_migration(
169 state, migration, fake=fake, fake_initial=fake_initial
170 )
171 migrations_to_run.remove(migration)
172
173 return state
174
175 def _migrate_all_backwards(self, plan, full_plan, fake):
176 """
177 Take a list of 2-tuples of the form (migration instance, True) and
178 unapply them in reverse order they occur in the full_plan.
179
180 Since unapplying a migration requires the project state prior to that
181 migration, Plain will compute the migration states before each of them
182 in a first run over the plan and then unapply them in a second run over
183 the plan.
184 """
185 migrations_to_run = {m[0] for m in plan}
186 # Holds all migration states prior to the migrations being unapplied
187 states = {}
188 state = self._create_project_state()
189 applied_migrations = {
190 self.loader.graph.nodes[key]
191 for key in self.loader.applied_migrations
192 if key in self.loader.graph.nodes
193 }
194 if self.progress_callback:
195 self.progress_callback("render_start")
196 for migration, _ in full_plan:
197 if not migrations_to_run:
198 # We remove every migration that we applied from this set so
199 # that we can bail out once the last migration has been applied
200 # and don't always run until the very end of the migration
201 # process.
202 break
203 if migration in migrations_to_run:
204 if "packages" not in state.__dict__:
205 state.packages # Render all -- performance critical
206 # The state before this migration
207 states[migration] = state
208 # The old state keeps as-is, we continue with the new state
209 state = migration.mutate_state(state, preserve=True)
210 migrations_to_run.remove(migration)
211 elif migration in applied_migrations:
212 # Only mutate the state if the migration is actually applied
213 # to make sure the resulting state doesn't include changes
214 # from unrelated migrations.
215 migration.mutate_state(state, preserve=False)
216 if self.progress_callback:
217 self.progress_callback("render_success")
218
219 for migration, _ in plan:
220 self.unapply_migration(states[migration], migration, fake=fake)
221 applied_migrations.remove(migration)
222
223 # Generate the post migration state by starting from the state before
224 # the last migration is unapplied and mutating it to include all the
225 # remaining applied migrations.
226 last_unapplied_migration = plan[-1][0]
227 state = states[last_unapplied_migration]
228 for index, (migration, _) in enumerate(full_plan):
229 if migration == last_unapplied_migration:
230 for migration, _ in full_plan[index:]:
231 if migration in applied_migrations:
232 migration.mutate_state(state, preserve=False)
233 break
234
235 return state
236
237 def apply_migration(self, state, migration, fake=False, fake_initial=False):
238 """Run a migration forwards."""
239 migration_recorded = False
240 if self.progress_callback:
241 self.progress_callback("apply_start", migration, fake)
242 if not fake:
243 if fake_initial:
244 # Test to see if this is an already-applied initial migration
245 applied, state = self.detect_soft_applied(state, migration)
246 if applied:
247 fake = True
248 if not fake:
249 # Alright, do it normally
250 with self.connection.schema_editor(
251 atomic=migration.atomic
252 ) as schema_editor:
253 state = migration.apply(state, schema_editor)
254 if not schema_editor.deferred_sql:
255 self.record_migration(migration)
256 migration_recorded = True
257 if not migration_recorded:
258 self.record_migration(migration)
259 # Report progress
260 if self.progress_callback:
261 self.progress_callback("apply_success", migration, fake)
262 return state
263
264 def record_migration(self, migration):
265 # For replacement migrations, record individual statuses
266 if migration.replaces:
267 for package_label, name in migration.replaces:
268 self.recorder.record_applied(package_label, name)
269 else:
270 self.recorder.record_applied(migration.package_label, migration.name)
271
272 def unapply_migration(self, state, migration, fake=False):
273 """Run a migration backwards."""
274 if self.progress_callback:
275 self.progress_callback("unapply_start", migration, fake)
276 if not fake:
277 with self.connection.schema_editor(
278 atomic=migration.atomic
279 ) as schema_editor:
280 state = migration.unapply(state, schema_editor)
281 # For replacement migrations, also record individual statuses.
282 if migration.replaces:
283 for package_label, name in migration.replaces:
284 self.recorder.record_unapplied(package_label, name)
285 self.recorder.record_unapplied(migration.package_label, migration.name)
286 # Report progress
287 if self.progress_callback:
288 self.progress_callback("unapply_success", migration, fake)
289 return state
290
291 def check_replacements(self):
292 """
293 Mark replacement migrations applied if their replaced set all are.
294
295 Do this unconditionally on every migrate, rather than just when
296 migrations are applied or unapplied, to correctly handle the case
297 when a new squash migration is pushed to a deployment that already had
298 all its replaced migrations applied. In this case no new migration will
299 be applied, but the applied state of the squashed migration must be
300 maintained.
301 """
302 applied = self.recorder.applied_migrations()
303 for key, migration in self.loader.replacements.items():
304 all_applied = all(m in applied for m in migration.replaces)
305 if all_applied and key not in applied:
306 self.recorder.record_applied(*key)
307
308 def detect_soft_applied(self, project_state, migration):
309 """
310 Test whether a migration has been implicitly applied - that the
311 tables or columns it would create exist. This is intended only for use
312 on initial migrations (as it only looks for CreateModel and AddField).
313 """
314
315 def should_skip_detecting_model(migration, model):
316 """
317 No need to detect tables for unmanaged models, or
318 models that can't be migrated on the current database.
319 """
320 return not model._meta.managed or not router.allow_migrate(
321 self.connection.alias,
322 migration.package_label,
323 model_name=model._meta.model_name,
324 )
325
326 if migration.initial is None:
327 # Bail if the migration isn't the first one in its app
328 if any(
329 app == migration.package_label for app, name in migration.dependencies
330 ):
331 return False, project_state
332 elif migration.initial is False:
333 # Bail if it's NOT an initial migration
334 return False, project_state
335
336 if project_state is None:
337 after_state = self.loader.project_state(
338 (migration.package_label, migration.name), at_end=True
339 )
340 else:
341 after_state = migration.mutate_state(project_state)
342 packages = after_state.packages
343 found_create_model_migration = False
344 found_add_field_migration = False
345 fold_identifier_case = self.connection.features.ignores_table_name_case
346 with self.connection.cursor() as cursor:
347 existing_table_names = set(
348 self.connection.introspection.table_names(cursor)
349 )
350 if fold_identifier_case:
351 existing_table_names = {
352 name.casefold() for name in existing_table_names
353 }
354 # Make sure all create model and add field operations are done
355 for operation in migration.operations:
356 if isinstance(operation, migrations.CreateModel):
357 model = packages.get_model(migration.package_label, operation.name)
358 if model._meta.swapped:
359 # We have to fetch the model to test with from the
360 # main app cache, as it's not a direct dependency.
361 model = global_packages.get_model(model._meta.swapped)
362 if should_skip_detecting_model(migration, model):
363 continue
364 db_table = model._meta.db_table
365 if fold_identifier_case:
366 db_table = db_table.casefold()
367 if db_table not in existing_table_names:
368 return False, project_state
369 found_create_model_migration = True
370 elif isinstance(operation, migrations.AddField):
371 model = packages.get_model(
372 migration.package_label, operation.model_name
373 )
374 if model._meta.swapped:
375 # We have to fetch the model to test with from the
376 # main app cache, as it's not a direct dependency.
377 model = global_packages.get_model(model._meta.swapped)
378 if should_skip_detecting_model(migration, model):
379 continue
380
381 table = model._meta.db_table
382 field = model._meta.get_field(operation.name)
383
384 # Handle implicit many-to-many tables created by AddField.
385 if field.many_to_many:
386 through_db_table = field.remote_field.through._meta.db_table
387 if fold_identifier_case:
388 through_db_table = through_db_table.casefold()
389 if through_db_table not in existing_table_names:
390 return False, project_state
391 else:
392 found_add_field_migration = True
393 continue
394 with self.connection.cursor() as cursor:
395 columns = self.connection.introspection.get_table_description(
396 cursor, table
397 )
398 for column in columns:
399 field_column = field.column
400 column_name = column.name
401 if fold_identifier_case:
402 column_name = column_name.casefold()
403 field_column = field_column.casefold()
404 if column_name == field_column:
405 found_add_field_migration = True
406 break
407 else:
408 return False, project_state
409 # If we get this far and we found at least one CreateModel or AddField
410 # migration, the migration is considered implicitly applied.
411 return (found_create_model_migration or found_add_field_migration), after_state