Plain is headed towards 1.0! Subscribe for development updates →

  1from plain.models import migrations
  2
  3from .loader import MigrationLoader
  4from .recorder import MigrationRecorder
  5from .state import ProjectState
  6
  7
  8class MigrationExecutor:
  9    """
 10    End-to-end migration execution - load migrations and run them up or down
 11    to a specified set of targets.
 12    """
 13
 14    def __init__(self, connection, progress_callback=None):
 15        self.connection = connection
 16        self.loader = MigrationLoader(self.connection)
 17        self.recorder = MigrationRecorder(self.connection)
 18        self.progress_callback = progress_callback
 19
 20    def migration_plan(self, targets, clean_start=False):
 21        """
 22        Given a set of targets, return a list of Migration instances.
 23        """
 24        plan = []
 25        if clean_start:
 26            applied = {}
 27        else:
 28            applied = dict(self.loader.applied_migrations)
 29        for target in targets:
 30            for migration in self.loader.graph.forwards_plan(target):
 31                if migration not in applied:
 32                    plan.append(self.loader.graph.nodes[migration])
 33                    applied[migration] = self.loader.graph.nodes[migration]
 34        return plan
 35
 36    def _create_project_state(self, with_applied_migrations=False):
 37        """
 38        Create a project state including all the applications without
 39        migrations and applied migrations if with_applied_migrations=True.
 40        """
 41        state = ProjectState(real_packages=self.loader.unmigrated_packages)
 42        if with_applied_migrations:
 43            # Create the forwards plan Plain would follow on an empty database
 44            full_plan = self.migration_plan(
 45                self.loader.graph.leaf_nodes(), clean_start=True
 46            )
 47            applied_migrations = {
 48                self.loader.graph.nodes[key]
 49                for key in self.loader.applied_migrations
 50                if key in self.loader.graph.nodes
 51            }
 52            for migration in full_plan:
 53                if migration in applied_migrations:
 54                    migration.mutate_state(state, preserve=False)
 55        return state
 56
 57    def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
 58        """
 59        Migrate the database up to the given targets.
 60
 61        Plain first needs to create all project states before a migration is
 62        (un)applied and in a second step run all the database operations.
 63        """
 64        # The plain_migrations table must be present to record applied
 65        # migrations, but don't create it if there are no migrations to apply.
 66        if plan == []:
 67            if not self.recorder.has_table():
 68                return self._create_project_state(with_applied_migrations=False)
 69        else:
 70            self.recorder.ensure_schema()
 71
 72        if plan is None:
 73            plan = self.migration_plan(targets)
 74        # Create the forwards plan Plain would follow on an empty database
 75        full_plan = self.migration_plan(
 76            self.loader.graph.leaf_nodes(), clean_start=True
 77        )
 78
 79        if not plan:
 80            if state is None:
 81                # The resulting state should include applied migrations.
 82                state = self._create_project_state(with_applied_migrations=True)
 83        else:
 84            if state is None:
 85                # The resulting state should still include applied migrations.
 86                state = self._create_project_state(with_applied_migrations=True)
 87            state = self._migrate_all_forwards(
 88                state, plan, full_plan, fake=fake, fake_initial=fake_initial
 89            )
 90
 91        self.check_replacements()
 92
 93        return state
 94
 95    def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
 96        """
 97        Take a list of 2-tuples of the form (migration instance, False) and
 98        apply them in the order they occur in the full_plan.
 99        """
100        migrations_to_run = set(plan)
101        for migration in full_plan:
102            if not migrations_to_run:
103                # We remove every migration that we applied from these sets so
104                # that we can bail out once the last migration has been applied
105                # and don't always run until the very end of the migration
106                # process.
107                break
108            if migration in migrations_to_run:
109                if "models_registry" not in state.__dict__:
110                    if self.progress_callback:
111                        self.progress_callback("render_start")
112                    state.models_registry  # Render all -- performance critical
113                    if self.progress_callback:
114                        self.progress_callback("render_success")
115                state = self.apply_migration(
116                    state, migration, fake=fake, fake_initial=fake_initial
117                )
118                migrations_to_run.remove(migration)
119
120        return state
121
122    def apply_migration(self, state, migration, fake=False, fake_initial=False):
123        """Run a migration forwards."""
124        migration_recorded = False
125        if self.progress_callback:
126            self.progress_callback("apply_start", migration, fake)
127        if not fake:
128            if fake_initial:
129                # Test to see if this is an already-applied initial migration
130                applied, state = self.detect_soft_applied(state, migration)
131                if applied:
132                    fake = True
133            if not fake:
134                # Alright, do it normally
135                with self.connection.schema_editor(
136                    atomic=migration.atomic
137                ) as schema_editor:
138                    state = migration.apply(state, schema_editor)
139                    if not schema_editor.deferred_sql:
140                        self.record_migration(migration)
141                        migration_recorded = True
142        if not migration_recorded:
143            self.record_migration(migration)
144        # Report progress
145        if self.progress_callback:
146            self.progress_callback("apply_success", migration, fake)
147        return state
148
149    def record_migration(self, migration):
150        # For replacement migrations, record individual statuses
151        if migration.replaces:
152            for package_label, name in migration.replaces:
153                self.recorder.record_applied(package_label, name)
154        else:
155            self.recorder.record_applied(migration.package_label, migration.name)
156
157    def check_replacements(self):
158        """
159        Mark replacement migrations applied if their replaced set all are.
160
161        Do this unconditionally on every migrate, rather than just when
162        migrations are applied or unapplied, to correctly handle the case
163        when a new squash migration is pushed to a deployment that already had
164        all its replaced migrations applied. In this case no new migration will
165        be applied, but the applied state of the squashed migration must be
166        maintained.
167        """
168        applied = self.recorder.applied_migrations()
169        for key, migration in self.loader.replacements.items():
170            all_applied = all(m in applied for m in migration.replaces)
171            if all_applied and key not in applied:
172                self.recorder.record_applied(*key)
173
174    def detect_soft_applied(self, project_state, migration):
175        """
176        Test whether a migration has been implicitly applied - that the
177        tables or columns it would create exist. This is intended only for use
178        on initial migrations (as it only looks for CreateModel and AddField).
179        """
180
181        if migration.initial is None:
182            # Bail if the migration isn't the first one in its app
183            if any(
184                app == migration.package_label for app, name in migration.dependencies
185            ):
186                return False, project_state
187        elif migration.initial is False:
188            # Bail if it's NOT an initial migration
189            return False, project_state
190
191        if project_state is None:
192            after_state = self.loader.project_state(
193                (migration.package_label, migration.name), at_end=True
194            )
195        else:
196            after_state = migration.mutate_state(project_state)
197        models_registry = after_state.models_registry
198        found_create_model_migration = False
199        found_add_field_migration = False
200        fold_identifier_case = self.connection.features.ignores_table_name_case
201        with self.connection.cursor() as cursor:
202            existing_table_names = set(
203                self.connection.introspection.table_names(cursor)
204            )
205            if fold_identifier_case:
206                existing_table_names = {
207                    name.casefold() for name in existing_table_names
208                }
209        # Make sure all create model and add field operations are done
210        for operation in migration.operations:
211            if isinstance(operation, migrations.CreateModel):
212                model = models_registry.get_model(
213                    migration.package_label, operation.name
214                )
215
216                db_table = model._meta.db_table
217                if fold_identifier_case:
218                    db_table = db_table.casefold()
219                if db_table not in existing_table_names:
220                    return False, project_state
221                found_create_model_migration = True
222            elif isinstance(operation, migrations.AddField):
223                model = models_registry.get_model(
224                    migration.package_label, operation.model_name
225                )
226
227                table = model._meta.db_table
228                field = model._meta.get_field(operation.name)
229
230                # Handle implicit many-to-many tables created by AddField.
231                if field.many_to_many:
232                    through_db_table = field.remote_field.through._meta.db_table
233                    if fold_identifier_case:
234                        through_db_table = through_db_table.casefold()
235                    if through_db_table not in existing_table_names:
236                        return False, project_state
237                    else:
238                        found_add_field_migration = True
239                        continue
240                with self.connection.cursor() as cursor:
241                    columns = self.connection.introspection.get_table_description(
242                        cursor, table
243                    )
244                for column in columns:
245                    field_column = field.column
246                    column_name = column.name
247                    if fold_identifier_case:
248                        column_name = column_name.casefold()
249                        field_column = field_column.casefold()
250                    if column_name == field_column:
251                        found_add_field_migration = True
252                        break
253                else:
254                    return False, project_state
255        # If we get this far and we found at least one CreateModel or AddField
256        # migration, the migration is considered implicitly applied.
257        return (found_create_model_migration or found_add_field_migration), after_state