Plain is headed towards 1.0! Subscribe for development updates →

  1from plain.models import migrations
  2from plain.models.db import router
  3from plain.packages.registry import packages as global_packages
  4
  5from .exceptions import InvalidMigrationPlan
  6from .loader import MigrationLoader
  7from .recorder import MigrationRecorder
  8from .state import ProjectState
  9
 10
 11class MigrationExecutor:
 12    """
 13    End-to-end migration execution - load migrations and run them up or down
 14    to a specified set of targets.
 15    """
 16
 17    def __init__(self, connection, progress_callback=None):
 18        self.connection = connection
 19        self.loader = MigrationLoader(self.connection)
 20        self.recorder = MigrationRecorder(self.connection)
 21        self.progress_callback = progress_callback
 22
 23    def migration_plan(self, targets, clean_start=False):
 24        """
 25        Given a set of targets, return a list of (Migration instance, backwards?).
 26        """
 27        plan = []
 28        if clean_start:
 29            applied = {}
 30        else:
 31            applied = dict(self.loader.applied_migrations)
 32        for target in targets:
 33            # If the target is (package_label, None), that means unmigrate everything
 34            if target[1] is None:
 35                for root in self.loader.graph.root_nodes():
 36                    if root[0] == target[0]:
 37                        for migration in self.loader.graph.backwards_plan(root):
 38                            if migration in applied:
 39                                plan.append((self.loader.graph.nodes[migration], True))
 40                                applied.pop(migration)
 41            # If the migration is already applied, do backwards mode,
 42            # otherwise do forwards mode.
 43            elif target in applied:
 44                # If the target is missing, it's likely a replaced migration.
 45                # Reload the graph without replacements.
 46                if (
 47                    self.loader.replace_migrations
 48                    and target not in self.loader.graph.node_map
 49                ):
 50                    self.loader.replace_migrations = False
 51                    self.loader.build_graph()
 52                    return self.migration_plan(targets, clean_start=clean_start)
 53                # Don't migrate backwards all the way to the target node (that
 54                # may roll back dependencies in other packages that don't need to
 55                # be rolled back); instead roll back through target's immediate
 56                # child(ren) in the same app, and no further.
 57                next_in_app = sorted(
 58                    n
 59                    for n in self.loader.graph.node_map[target].children
 60                    if n[0] == target[0]
 61                )
 62                for node in next_in_app:
 63                    for migration in self.loader.graph.backwards_plan(node):
 64                        if migration in applied:
 65                            plan.append((self.loader.graph.nodes[migration], True))
 66                            applied.pop(migration)
 67            else:
 68                for migration in self.loader.graph.forwards_plan(target):
 69                    if migration not in applied:
 70                        plan.append((self.loader.graph.nodes[migration], False))
 71                        applied[migration] = self.loader.graph.nodes[migration]
 72        return plan
 73
 74    def _create_project_state(self, with_applied_migrations=False):
 75        """
 76        Create a project state including all the applications without
 77        migrations and applied migrations if with_applied_migrations=True.
 78        """
 79        state = ProjectState(real_packages=self.loader.unmigrated_packages)
 80        if with_applied_migrations:
 81            # Create the forwards plan Plain would follow on an empty database
 82            full_plan = self.migration_plan(
 83                self.loader.graph.leaf_nodes(), clean_start=True
 84            )
 85            applied_migrations = {
 86                self.loader.graph.nodes[key]
 87                for key in self.loader.applied_migrations
 88                if key in self.loader.graph.nodes
 89            }
 90            for migration, _ in full_plan:
 91                if migration in applied_migrations:
 92                    migration.mutate_state(state, preserve=False)
 93        return state
 94
 95    def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
 96        """
 97        Migrate the database up to the given targets.
 98
 99        Plain first needs to create all project states before a migration is
100        (un)applied and in a second step run all the database operations.
101        """
102        # The plain_migrations table must be present to record applied
103        # migrations, but don't create it if there are no migrations to apply.
104        if plan == []:
105            if not self.recorder.has_table():
106                return self._create_project_state(with_applied_migrations=False)
107        else:
108            self.recorder.ensure_schema()
109
110        if plan is None:
111            plan = self.migration_plan(targets)
112        # Create the forwards plan Plain would follow on an empty database
113        full_plan = self.migration_plan(
114            self.loader.graph.leaf_nodes(), clean_start=True
115        )
116
117        all_forwards = all(not backwards for mig, backwards in plan)
118        all_backwards = all(backwards for mig, backwards in plan)
119
120        if not plan:
121            if state is None:
122                # The resulting state should include applied migrations.
123                state = self._create_project_state(with_applied_migrations=True)
124        elif all_forwards == all_backwards:
125            # This should only happen if there's a mixed plan
126            raise InvalidMigrationPlan(
127                "Migration plans with both forwards and backwards migrations "
128                "are not supported. Please split your migration process into "
129                "separate plans of only forwards OR backwards migrations.",
130                plan,
131            )
132        elif all_forwards:
133            if state is None:
134                # The resulting state should still include applied migrations.
135                state = self._create_project_state(with_applied_migrations=True)
136            state = self._migrate_all_forwards(
137                state, plan, full_plan, fake=fake, fake_initial=fake_initial
138            )
139        else:
140            # No need to check for `elif all_backwards` here, as that condition
141            # would always evaluate to true.
142            state = self._migrate_all_backwards(plan, full_plan, fake=fake)
143
144        self.check_replacements()
145
146        return state
147
148    def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
149        """
150        Take a list of 2-tuples of the form (migration instance, False) and
151        apply them in the order they occur in the full_plan.
152        """
153        migrations_to_run = {m[0] for m in plan}
154        for migration, _ in full_plan:
155            if not migrations_to_run:
156                # We remove every migration that we applied from these sets so
157                # that we can bail out once the last migration has been applied
158                # and don't always run until the very end of the migration
159                # process.
160                break
161            if migration in migrations_to_run:
162                if "packages" not in state.__dict__:
163                    if self.progress_callback:
164                        self.progress_callback("render_start")
165                    state.packages  # Render all -- performance critical
166                    if self.progress_callback:
167                        self.progress_callback("render_success")
168                state = self.apply_migration(
169                    state, migration, fake=fake, fake_initial=fake_initial
170                )
171                migrations_to_run.remove(migration)
172
173        return state
174
175    def _migrate_all_backwards(self, plan, full_plan, fake):
176        """
177        Take a list of 2-tuples of the form (migration instance, True) and
178        unapply them in reverse order they occur in the full_plan.
179
180        Since unapplying a migration requires the project state prior to that
181        migration, Plain will compute the migration states before each of them
182        in a first run over the plan and then unapply them in a second run over
183        the plan.
184        """
185        migrations_to_run = {m[0] for m in plan}
186        # Holds all migration states prior to the migrations being unapplied
187        states = {}
188        state = self._create_project_state()
189        applied_migrations = {
190            self.loader.graph.nodes[key]
191            for key in self.loader.applied_migrations
192            if key in self.loader.graph.nodes
193        }
194        if self.progress_callback:
195            self.progress_callback("render_start")
196        for migration, _ in full_plan:
197            if not migrations_to_run:
198                # We remove every migration that we applied from this set so
199                # that we can bail out once the last migration has been applied
200                # and don't always run until the very end of the migration
201                # process.
202                break
203            if migration in migrations_to_run:
204                if "packages" not in state.__dict__:
205                    state.packages  # Render all -- performance critical
206                # The state before this migration
207                states[migration] = state
208                # The old state keeps as-is, we continue with the new state
209                state = migration.mutate_state(state, preserve=True)
210                migrations_to_run.remove(migration)
211            elif migration in applied_migrations:
212                # Only mutate the state if the migration is actually applied
213                # to make sure the resulting state doesn't include changes
214                # from unrelated migrations.
215                migration.mutate_state(state, preserve=False)
216        if self.progress_callback:
217            self.progress_callback("render_success")
218
219        for migration, _ in plan:
220            self.unapply_migration(states[migration], migration, fake=fake)
221            applied_migrations.remove(migration)
222
223        # Generate the post migration state by starting from the state before
224        # the last migration is unapplied and mutating it to include all the
225        # remaining applied migrations.
226        last_unapplied_migration = plan[-1][0]
227        state = states[last_unapplied_migration]
228        for index, (migration, _) in enumerate(full_plan):
229            if migration == last_unapplied_migration:
230                for migration, _ in full_plan[index:]:
231                    if migration in applied_migrations:
232                        migration.mutate_state(state, preserve=False)
233                break
234
235        return state
236
237    def apply_migration(self, state, migration, fake=False, fake_initial=False):
238        """Run a migration forwards."""
239        migration_recorded = False
240        if self.progress_callback:
241            self.progress_callback("apply_start", migration, fake)
242        if not fake:
243            if fake_initial:
244                # Test to see if this is an already-applied initial migration
245                applied, state = self.detect_soft_applied(state, migration)
246                if applied:
247                    fake = True
248            if not fake:
249                # Alright, do it normally
250                with self.connection.schema_editor(
251                    atomic=migration.atomic
252                ) as schema_editor:
253                    state = migration.apply(state, schema_editor)
254                    if not schema_editor.deferred_sql:
255                        self.record_migration(migration)
256                        migration_recorded = True
257        if not migration_recorded:
258            self.record_migration(migration)
259        # Report progress
260        if self.progress_callback:
261            self.progress_callback("apply_success", migration, fake)
262        return state
263
264    def record_migration(self, migration):
265        # For replacement migrations, record individual statuses
266        if migration.replaces:
267            for package_label, name in migration.replaces:
268                self.recorder.record_applied(package_label, name)
269        else:
270            self.recorder.record_applied(migration.package_label, migration.name)
271
272    def unapply_migration(self, state, migration, fake=False):
273        """Run a migration backwards."""
274        if self.progress_callback:
275            self.progress_callback("unapply_start", migration, fake)
276        if not fake:
277            with self.connection.schema_editor(
278                atomic=migration.atomic
279            ) as schema_editor:
280                state = migration.unapply(state, schema_editor)
281        # For replacement migrations, also record individual statuses.
282        if migration.replaces:
283            for package_label, name in migration.replaces:
284                self.recorder.record_unapplied(package_label, name)
285        self.recorder.record_unapplied(migration.package_label, migration.name)
286        # Report progress
287        if self.progress_callback:
288            self.progress_callback("unapply_success", migration, fake)
289        return state
290
291    def check_replacements(self):
292        """
293        Mark replacement migrations applied if their replaced set all are.
294
295        Do this unconditionally on every migrate, rather than just when
296        migrations are applied or unapplied, to correctly handle the case
297        when a new squash migration is pushed to a deployment that already had
298        all its replaced migrations applied. In this case no new migration will
299        be applied, but the applied state of the squashed migration must be
300        maintained.
301        """
302        applied = self.recorder.applied_migrations()
303        for key, migration in self.loader.replacements.items():
304            all_applied = all(m in applied for m in migration.replaces)
305            if all_applied and key not in applied:
306                self.recorder.record_applied(*key)
307
308    def detect_soft_applied(self, project_state, migration):
309        """
310        Test whether a migration has been implicitly applied - that the
311        tables or columns it would create exist. This is intended only for use
312        on initial migrations (as it only looks for CreateModel and AddField).
313        """
314
315        def should_skip_detecting_model(migration, model):
316            """
317            No need to detect tables for unmanaged models, or
318            models that can't be migrated on the current database.
319            """
320            return not model._meta.managed or not router.allow_migrate(
321                self.connection.alias,
322                migration.package_label,
323                model_name=model._meta.model_name,
324            )
325
326        if migration.initial is None:
327            # Bail if the migration isn't the first one in its app
328            if any(
329                app == migration.package_label for app, name in migration.dependencies
330            ):
331                return False, project_state
332        elif migration.initial is False:
333            # Bail if it's NOT an initial migration
334            return False, project_state
335
336        if project_state is None:
337            after_state = self.loader.project_state(
338                (migration.package_label, migration.name), at_end=True
339            )
340        else:
341            after_state = migration.mutate_state(project_state)
342        packages = after_state.packages
343        found_create_model_migration = False
344        found_add_field_migration = False
345        fold_identifier_case = self.connection.features.ignores_table_name_case
346        with self.connection.cursor() as cursor:
347            existing_table_names = set(
348                self.connection.introspection.table_names(cursor)
349            )
350            if fold_identifier_case:
351                existing_table_names = {
352                    name.casefold() for name in existing_table_names
353                }
354        # Make sure all create model and add field operations are done
355        for operation in migration.operations:
356            if isinstance(operation, migrations.CreateModel):
357                model = packages.get_model(migration.package_label, operation.name)
358                if model._meta.swapped:
359                    # We have to fetch the model to test with from the
360                    # main app cache, as it's not a direct dependency.
361                    model = global_packages.get_model(model._meta.swapped)
362                if should_skip_detecting_model(migration, model):
363                    continue
364                db_table = model._meta.db_table
365                if fold_identifier_case:
366                    db_table = db_table.casefold()
367                if db_table not in existing_table_names:
368                    return False, project_state
369                found_create_model_migration = True
370            elif isinstance(operation, migrations.AddField):
371                model = packages.get_model(
372                    migration.package_label, operation.model_name
373                )
374                if model._meta.swapped:
375                    # We have to fetch the model to test with from the
376                    # main app cache, as it's not a direct dependency.
377                    model = global_packages.get_model(model._meta.swapped)
378                if should_skip_detecting_model(migration, model):
379                    continue
380
381                table = model._meta.db_table
382                field = model._meta.get_field(operation.name)
383
384                # Handle implicit many-to-many tables created by AddField.
385                if field.many_to_many:
386                    through_db_table = field.remote_field.through._meta.db_table
387                    if fold_identifier_case:
388                        through_db_table = through_db_table.casefold()
389                    if through_db_table not in existing_table_names:
390                        return False, project_state
391                    else:
392                        found_add_field_migration = True
393                        continue
394                with self.connection.cursor() as cursor:
395                    columns = self.connection.introspection.get_table_description(
396                        cursor, table
397                    )
398                for column in columns:
399                    field_column = field.column
400                    column_name = column.name
401                    if fold_identifier_case:
402                        column_name = column_name.casefold()
403                        field_column = field_column.casefold()
404                    if column_name == field_column:
405                        found_add_field_migration = True
406                        break
407                else:
408                    return False, project_state
409        # If we get this far and we found at least one CreateModel or AddField
410        # migration, the migration is considered implicitly applied.
411        return (found_create_model_migration or found_add_field_migration), after_state