1from plain.models import migrations
2
3from .loader import MigrationLoader
4from .recorder import MigrationRecorder
5from .state import ProjectState
6
7
8class MigrationExecutor:
9 """
10 End-to-end migration execution - load migrations and run them up or down
11 to a specified set of targets.
12 """
13
14 def __init__(self, connection, progress_callback=None):
15 self.connection = connection
16 self.loader = MigrationLoader(self.connection)
17 self.recorder = MigrationRecorder(self.connection)
18 self.progress_callback = progress_callback
19
20 def migration_plan(self, targets, clean_start=False):
21 """
22 Given a set of targets, return a list of Migration instances.
23 """
24 plan = []
25 if clean_start:
26 applied = {}
27 else:
28 applied = dict(self.loader.applied_migrations)
29 for target in targets:
30 for migration in self.loader.graph.forwards_plan(target):
31 if migration not in applied:
32 plan.append(self.loader.graph.nodes[migration])
33 applied[migration] = self.loader.graph.nodes[migration]
34 return plan
35
36 def _create_project_state(self, with_applied_migrations=False):
37 """
38 Create a project state including all the applications without
39 migrations and applied migrations if with_applied_migrations=True.
40 """
41 state = ProjectState(real_packages=self.loader.unmigrated_packages)
42 if with_applied_migrations:
43 # Create the forwards plan Plain would follow on an empty database
44 full_plan = self.migration_plan(
45 self.loader.graph.leaf_nodes(), clean_start=True
46 )
47 applied_migrations = {
48 self.loader.graph.nodes[key]
49 for key in self.loader.applied_migrations
50 if key in self.loader.graph.nodes
51 }
52 for migration in full_plan:
53 if migration in applied_migrations:
54 migration.mutate_state(state, preserve=False)
55 return state
56
57 def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
58 """
59 Migrate the database up to the given targets.
60
61 Plain first needs to create all project states before a migration is
62 (un)applied and in a second step run all the database operations.
63 """
64 # The plain_migrations table must be present to record applied
65 # migrations, but don't create it if there are no migrations to apply.
66 if plan == []:
67 if not self.recorder.has_table():
68 return self._create_project_state(with_applied_migrations=False)
69 else:
70 self.recorder.ensure_schema()
71
72 if plan is None:
73 plan = self.migration_plan(targets)
74 # Create the forwards plan Plain would follow on an empty database
75 full_plan = self.migration_plan(
76 self.loader.graph.leaf_nodes(), clean_start=True
77 )
78
79 if not plan:
80 if state is None:
81 # The resulting state should include applied migrations.
82 state = self._create_project_state(with_applied_migrations=True)
83 else:
84 if state is None:
85 # The resulting state should still include applied migrations.
86 state = self._create_project_state(with_applied_migrations=True)
87 state = self._migrate_all_forwards(
88 state, plan, full_plan, fake=fake, fake_initial=fake_initial
89 )
90
91 self.check_replacements()
92
93 return state
94
95 def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
96 """
97 Take a list of 2-tuples of the form (migration instance, False) and
98 apply them in the order they occur in the full_plan.
99 """
100 migrations_to_run = set(plan)
101 for migration in full_plan:
102 if not migrations_to_run:
103 # We remove every migration that we applied from these sets so
104 # that we can bail out once the last migration has been applied
105 # and don't always run until the very end of the migration
106 # process.
107 break
108 if migration in migrations_to_run:
109 if "models_registry" not in state.__dict__:
110 if self.progress_callback:
111 self.progress_callback("render_start")
112 state.models_registry # Render all -- performance critical
113 if self.progress_callback:
114 self.progress_callback("render_success")
115 state = self.apply_migration(
116 state, migration, fake=fake, fake_initial=fake_initial
117 )
118 migrations_to_run.remove(migration)
119
120 return state
121
122 def apply_migration(self, state, migration, fake=False, fake_initial=False):
123 """Run a migration forwards."""
124 migration_recorded = False
125 if self.progress_callback:
126 self.progress_callback("apply_start", migration, fake)
127 if not fake:
128 if fake_initial:
129 # Test to see if this is an already-applied initial migration
130 applied, state = self.detect_soft_applied(state, migration)
131 if applied:
132 fake = True
133 if not fake:
134 # Alright, do it normally
135 with self.connection.schema_editor(
136 atomic=migration.atomic
137 ) as schema_editor:
138 state = migration.apply(state, schema_editor)
139 if not schema_editor.deferred_sql:
140 self.record_migration(migration)
141 migration_recorded = True
142 if not migration_recorded:
143 self.record_migration(migration)
144 # Report progress
145 if self.progress_callback:
146 self.progress_callback("apply_success", migration, fake)
147 return state
148
149 def record_migration(self, migration):
150 # For replacement migrations, record individual statuses
151 if migration.replaces:
152 for package_label, name in migration.replaces:
153 self.recorder.record_applied(package_label, name)
154 else:
155 self.recorder.record_applied(migration.package_label, migration.name)
156
157 def check_replacements(self):
158 """
159 Mark replacement migrations applied if their replaced set all are.
160
161 Do this unconditionally on every migrate, rather than just when
162 migrations are applied or unapplied, to correctly handle the case
163 when a new squash migration is pushed to a deployment that already had
164 all its replaced migrations applied. In this case no new migration will
165 be applied, but the applied state of the squashed migration must be
166 maintained.
167 """
168 applied = self.recorder.applied_migrations()
169 for key, migration in self.loader.replacements.items():
170 all_applied = all(m in applied for m in migration.replaces)
171 if all_applied and key not in applied:
172 self.recorder.record_applied(*key)
173
174 def detect_soft_applied(self, project_state, migration):
175 """
176 Test whether a migration has been implicitly applied - that the
177 tables or columns it would create exist. This is intended only for use
178 on initial migrations (as it only looks for CreateModel and AddField).
179 """
180
181 if migration.initial is None:
182 # Bail if the migration isn't the first one in its app
183 if any(
184 app == migration.package_label for app, name in migration.dependencies
185 ):
186 return False, project_state
187 elif migration.initial is False:
188 # Bail if it's NOT an initial migration
189 return False, project_state
190
191 if project_state is None:
192 after_state = self.loader.project_state(
193 (migration.package_label, migration.name), at_end=True
194 )
195 else:
196 after_state = migration.mutate_state(project_state)
197 models_registry = after_state.models_registry
198 found_create_model_migration = False
199 found_add_field_migration = False
200 fold_identifier_case = self.connection.features.ignores_table_name_case
201 with self.connection.cursor() as cursor:
202 existing_table_names = set(
203 self.connection.introspection.table_names(cursor)
204 )
205 if fold_identifier_case:
206 existing_table_names = {
207 name.casefold() for name in existing_table_names
208 }
209 # Make sure all create model and add field operations are done
210 for operation in migration.operations:
211 if isinstance(operation, migrations.CreateModel):
212 model = models_registry.get_model(
213 migration.package_label, operation.name
214 )
215
216 db_table = model._meta.db_table
217 if fold_identifier_case:
218 db_table = db_table.casefold()
219 if db_table not in existing_table_names:
220 return False, project_state
221 found_create_model_migration = True
222 elif isinstance(operation, migrations.AddField):
223 model = models_registry.get_model(
224 migration.package_label, operation.model_name
225 )
226
227 table = model._meta.db_table
228 field = model._meta.get_field(operation.name)
229
230 # Handle implicit many-to-many tables created by AddField.
231 if field.many_to_many:
232 through_db_table = field.remote_field.through._meta.db_table
233 if fold_identifier_case:
234 through_db_table = through_db_table.casefold()
235 if through_db_table not in existing_table_names:
236 return False, project_state
237 else:
238 found_add_field_migration = True
239 continue
240 with self.connection.cursor() as cursor:
241 columns = self.connection.introspection.get_table_description(
242 cursor, table
243 )
244 for column in columns:
245 field_column = field.column
246 column_name = column.name
247 if fold_identifier_case:
248 column_name = column_name.casefold()
249 field_column = field_column.casefold()
250 if column_name == field_column:
251 found_add_field_migration = True
252 break
253 else:
254 return False, project_state
255 # If we get this far and we found at least one CreateModel or AddField
256 # migration, the migration is considered implicitly applied.
257 return (found_create_model_migration or found_add_field_migration), after_state