Plain is headed towards 1.0! Subscribe for development updates →

  1from .loader import MigrationLoader
  2from .recorder import MigrationRecorder
  3from .state import ProjectState
  4
  5
  6class MigrationExecutor:
  7    """
  8    End-to-end migration execution - load migrations and run them up or down
  9    to a specified set of targets.
 10    """
 11
 12    def __init__(self, connection, progress_callback=None):
 13        self.connection = connection
 14        self.loader = MigrationLoader(self.connection)
 15        self.recorder = MigrationRecorder(self.connection)
 16        self.progress_callback = progress_callback
 17
 18    def migration_plan(self, targets, clean_start=False):
 19        """
 20        Given a set of targets, return a list of Migration instances.
 21        """
 22        plan = []
 23        if clean_start:
 24            applied = {}
 25        else:
 26            applied = dict(self.loader.applied_migrations)
 27        for target in targets:
 28            for migration in self.loader.graph.forwards_plan(target):
 29                if migration not in applied:
 30                    plan.append(self.loader.graph.nodes[migration])
 31                    applied[migration] = self.loader.graph.nodes[migration]
 32        return plan
 33
 34    def _create_project_state(self, with_applied_migrations=False):
 35        """
 36        Create a project state including all the applications without
 37        migrations and applied migrations if with_applied_migrations=True.
 38        """
 39        state = ProjectState(real_packages=self.loader.unmigrated_packages)
 40        if with_applied_migrations:
 41            # Create the forwards plan Plain would follow on an empty database
 42            full_plan = self.migration_plan(
 43                self.loader.graph.leaf_nodes(), clean_start=True
 44            )
 45            applied_migrations = {
 46                self.loader.graph.nodes[key]
 47                for key in self.loader.applied_migrations
 48                if key in self.loader.graph.nodes
 49            }
 50            for migration in full_plan:
 51                if migration in applied_migrations:
 52                    migration.mutate_state(state, preserve=False)
 53        return state
 54
 55    def migrate(self, targets, plan=None, state=None, fake=False):
 56        """
 57        Migrate the database up to the given targets.
 58
 59        Plain first needs to create all project states before a migration is
 60        (un)applied and in a second step run all the database operations.
 61        """
 62        # The plain_migrations table must be present to record applied
 63        # migrations, but don't create it if there are no migrations to apply.
 64        if plan == []:
 65            if not self.recorder.has_table():
 66                return self._create_project_state(with_applied_migrations=False)
 67        else:
 68            self.recorder.ensure_schema()
 69
 70        if plan is None:
 71            plan = self.migration_plan(targets)
 72        # Create the forwards plan Plain would follow on an empty database
 73        full_plan = self.migration_plan(
 74            self.loader.graph.leaf_nodes(), clean_start=True
 75        )
 76
 77        if not plan:
 78            if state is None:
 79                # The resulting state should include applied migrations.
 80                state = self._create_project_state(with_applied_migrations=True)
 81        else:
 82            if state is None:
 83                # The resulting state should still include applied migrations.
 84                state = self._create_project_state(with_applied_migrations=True)
 85            state = self._migrate_all_forwards(state, plan, full_plan, fake=fake)
 86
 87        self.check_replacements()
 88
 89        return state
 90
 91    def _migrate_all_forwards(self, state, plan, full_plan, fake):
 92        """
 93        Take a list of 2-tuples of the form (migration instance, False) and
 94        apply them in the order they occur in the full_plan.
 95        """
 96        migrations_to_run = set(plan)
 97        for migration in full_plan:
 98            if not migrations_to_run:
 99                # We remove every migration that we applied from these sets so
100                # that we can bail out once the last migration has been applied
101                # and don't always run until the very end of the migration
102                # process.
103                break
104            if migration in migrations_to_run:
105                if "models_registry" not in state.__dict__:
106                    if self.progress_callback:
107                        self.progress_callback("render_start")
108                    state.models_registry  # Render all -- performance critical
109                    if self.progress_callback:
110                        self.progress_callback("render_success")
111                state = self.apply_migration(state, migration, fake=fake)
112                migrations_to_run.remove(migration)
113
114        return state
115
116    def apply_migration(self, state, migration, fake=False):
117        """Run a migration forwards."""
118        migration_recorded = False
119        if self.progress_callback:
120            self.progress_callback("apply_start", migration, fake)
121        if not fake:
122            # Alright, do it normally
123            with self.connection.schema_editor(
124                atomic=migration.atomic
125            ) as schema_editor:
126                state = migration.apply(state, schema_editor)
127                if not schema_editor.deferred_sql:
128                    self.record_migration(migration)
129                    migration_recorded = True
130        if not migration_recorded:
131            self.record_migration(migration)
132        # Report progress
133        if self.progress_callback:
134            self.progress_callback("apply_success", migration, fake)
135        return state
136
137    def record_migration(self, migration):
138        # For replacement migrations, record individual statuses
139        if migration.replaces:
140            for package_label, name in migration.replaces:
141                self.recorder.record_applied(package_label, name)
142        else:
143            self.recorder.record_applied(migration.package_label, migration.name)
144
145    def check_replacements(self):
146        """
147        Mark replacement migrations applied if their replaced set all are.
148
149        Do this unconditionally on every migrate, rather than just when
150        migrations are applied or unapplied, to correctly handle the case
151        when a new squash migration is pushed to a deployment that already had
152        all its replaced migrations applied. In this case no new migration will
153        be applied, but the applied state of the squashed migration must be
154        maintained.
155        """
156        applied = self.recorder.applied_migrations()
157        for key, migration in self.loader.replacements.items():
158            all_applied = all(m in applied for m in migration.replaces)
159            if all_applied and key not in applied:
160                self.recorder.record_applied(*key)