Plain is headed towards 1.0! Subscribe for development updates →

  1import copy
  2from decimal import Decimal
  3
  4from plain.models import register_model
  5from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  6from plain.models.backends.ddl_references import Statement
  7from plain.models.backends.utils import strip_quotes
  8from plain.models.constraints import UniqueConstraint
  9from plain.models.db import NotSupportedError
 10from plain.models.registry import ModelsRegistry
 11from plain.models.transaction import atomic
 12
 13
 14class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 15    sql_delete_table = "DROP TABLE %(table)s"
 16    sql_create_fk = None
 17    sql_create_inline_fk = (
 18        "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
 19    )
 20    sql_create_column_inline_fk = sql_create_inline_fk
 21    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
 22    sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
 23    sql_delete_unique = "DROP INDEX %(name)s"
 24
 25    def __enter__(self):
 26        # Some SQLite schema alterations need foreign key constraints to be
 27        # disabled. Enforce it here for the duration of the schema edition.
 28        if not self.connection.disable_constraint_checking():
 29            raise NotSupportedError(
 30                "SQLite schema editor cannot be used while foreign key "
 31                "constraint checks are enabled. Make sure to disable them "
 32                "before entering a transaction.atomic() context because "
 33                "SQLite does not support disabling them in the middle of "
 34                "a multi-statement transaction."
 35            )
 36        return super().__enter__()
 37
 38    def __exit__(self, exc_type, exc_value, traceback):
 39        self.connection.check_constraints()
 40        super().__exit__(exc_type, exc_value, traceback)
 41        self.connection.enable_constraint_checking()
 42
 43    def quote_value(self, value):
 44        # The backend "mostly works" without this function and there are use
 45        # cases for compiling Python without the sqlite3 libraries (e.g.
 46        # security hardening).
 47        try:
 48            import sqlite3
 49
 50            value = sqlite3.adapt(value)
 51        except ImportError:
 52            pass
 53        except sqlite3.ProgrammingError:
 54            pass
 55        # Manual emulation of SQLite parameter quoting
 56        if isinstance(value, bool):
 57            return str(int(value))
 58        elif isinstance(value, Decimal | float | int):
 59            return str(value)
 60        elif isinstance(value, str):
 61            return "'{}'".format(value.replace("'", "''"))
 62        elif value is None:
 63            return "NULL"
 64        elif isinstance(value, bytes | bytearray | memoryview):
 65            # Bytes are only allowed for BLOB fields, encoded as string
 66            # literals containing hexadecimal data and preceded by a single "X"
 67            # character.
 68            return f"X'{value.hex()}'"
 69        else:
 70            raise ValueError(
 71                f"Cannot quote parameter value {value!r} of type {type(value)}"
 72            )
 73
 74    def prepare_default(self, value):
 75        return self.quote_value(value)
 76
 77    def _is_referenced_by_fk_constraint(
 78        self, table_name, column_name=None, ignore_self=False
 79    ):
 80        """
 81        Return whether or not the provided table name is referenced by another
 82        one. If `column_name` is specified, only references pointing to that
 83        column are considered. If `ignore_self` is True, self-referential
 84        constraints are ignored.
 85        """
 86        with self.connection.cursor() as cursor:
 87            for other_table in self.connection.introspection.get_table_list(cursor):
 88                if ignore_self and other_table.name == table_name:
 89                    continue
 90                relations = self.connection.introspection.get_relations(
 91                    cursor, other_table.name
 92                )
 93                for constraint_column, constraint_table in relations.values():
 94                    if constraint_table == table_name and (
 95                        column_name is None or constraint_column == column_name
 96                    ):
 97                        return True
 98        return False
 99
100    def alter_db_table(
101        self, model, old_db_table, new_db_table, disable_constraints=True
102    ):
103        if (
104            not self.connection.features.supports_atomic_references_rename
105            and disable_constraints
106            and self._is_referenced_by_fk_constraint(old_db_table)
107        ):
108            if self.connection.in_atomic_block:
109                raise NotSupportedError(
110                    f"Renaming the {old_db_table!r} table while in a transaction is not "
111                    "supported on SQLite < 3.26 because it would break referential "
112                    "integrity. Try adding `atomic = False` to the Migration class."
113                )
114            self.connection.enable_constraint_checking()
115            super().alter_db_table(model, old_db_table, new_db_table)
116            self.connection.disable_constraint_checking()
117        else:
118            super().alter_db_table(model, old_db_table, new_db_table)
119
120    def alter_field(self, model, old_field, new_field, strict=False):
121        if not self._field_should_be_altered(old_field, new_field):
122            return
123        old_field_name = old_field.name
124        table_name = model._meta.db_table
125        _, old_column_name = old_field.get_attname_column()
126        if (
127            new_field.name != old_field_name
128            and not self.connection.features.supports_atomic_references_rename
129            and self._is_referenced_by_fk_constraint(
130                table_name, old_column_name, ignore_self=True
131            )
132        ):
133            if self.connection.in_atomic_block:
134                raise NotSupportedError(
135                    f"Renaming the {model._meta.db_table!r}.{old_field_name!r} column while in a transaction is not "
136                    "supported on SQLite < 3.26 because it would break referential "
137                    "integrity. Try adding `atomic = False` to the Migration class."
138                )
139            with atomic():
140                super().alter_field(model, old_field, new_field, strict=strict)
141                # Follow SQLite's documented procedure for performing changes
142                # that don't affect the on-disk content.
143                # https://sqlite.org/lang_altertable.html#otheralter
144                with self.connection.cursor() as cursor:
145                    schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
146                        0
147                    ]
148                    cursor.execute("PRAGMA writable_schema = 1")
149                    references_template = f' REFERENCES "{table_name}" ("%s") '
150                    new_column_name = new_field.get_attname_column()[1]
151                    search = references_template % old_column_name
152                    replacement = references_template % new_column_name
153                    cursor.execute(
154                        "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
155                        (search, replacement),
156                    )
157                    cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))  # noqa: UP031
158                    cursor.execute("PRAGMA writable_schema = 0")
159                    # The integrity check will raise an exception and rollback
160                    # the transaction if the sqlite_master updates corrupt the
161                    # database.
162                    cursor.execute("PRAGMA integrity_check")
163            # Perform a VACUUM to refresh the database representation from
164            # the sqlite_master table.
165            with self.connection.cursor() as cursor:
166                cursor.execute("VACUUM")
167        else:
168            super().alter_field(model, old_field, new_field, strict=strict)
169
170    def _remake_table(
171        self, model, create_field=None, delete_field=None, alter_fields=None
172    ):
173        """
174        Shortcut to transform a model from old_model into new_model
175
176        This follows the correct procedure to perform non-rename or column
177        addition operations based on SQLite's documentation
178
179        https://www.sqlite.org/lang_altertable.html#caution
180
181        The essential steps are:
182          1. Create a table with the updated definition called "new__app_model"
183          2. Copy the data from the existing "app_model" table to the new table
184          3. Drop the "app_model" table
185          4. Rename the "new__app_model" table to "app_model"
186          5. Restore any index of the previous "app_model" table.
187        """
188
189        # Self-referential fields must be recreated rather than copied from
190        # the old model to ensure their remote_field.field_name doesn't refer
191        # to an altered field.
192        def is_self_referential(f):
193            return f.is_relation and f.remote_field.model is model
194
195        # Work out the new fields dict / mapping
196        body = {
197            f.name: f.clone() if is_self_referential(f) else f
198            for f in model._meta.local_concrete_fields
199        }
200        # Since mapping might mix column names and default values,
201        # its values must be already quoted.
202        mapping = {
203            f.column: self.quote_name(f.column)
204            for f in model._meta.local_concrete_fields
205        }
206        # If any of the new or altered fields is introducing a new PK,
207        # remove the old one
208        restore_pk_field = None
209        alter_fields = alter_fields or []
210        if getattr(create_field, "primary_key", False) or any(
211            getattr(new_field, "primary_key", False) for _, new_field in alter_fields
212        ):
213            for name, field in list(body.items()):
214                if field.primary_key and not any(
215                    # Do not remove the old primary key when an altered field
216                    # that introduces a primary key is the same field.
217                    name == new_field.name
218                    for _, new_field in alter_fields
219                ):
220                    field.primary_key = False
221                    restore_pk_field = field
222                    if field.auto_created:
223                        del body[name]
224                        del mapping[field.column]
225        # Add in any created fields
226        if create_field:
227            body[create_field.name] = create_field
228            # Choose a default and insert it into the copy map
229            if not create_field.many_to_many and create_field.concrete:
230                mapping[create_field.column] = self.prepare_default(
231                    self.effective_default(create_field),
232                )
233        # Add in any altered fields
234        for alter_field in alter_fields:
235            old_field, new_field = alter_field
236            body.pop(old_field.name, None)
237            mapping.pop(old_field.column, None)
238            body[new_field.name] = new_field
239            if old_field.allow_null and not new_field.allow_null:
240                case_sql = f"coalesce({self.quote_name(old_field.column)}, {self.prepare_default(self.effective_default(new_field))})"
241                mapping[new_field.column] = case_sql
242            else:
243                mapping[new_field.column] = self.quote_name(old_field.column)
244        # Remove any deleted fields
245        if delete_field:
246            del body[delete_field.name]
247            del mapping[delete_field.column]
248        # Work inside a new app registry
249        models_registry = ModelsRegistry()
250
251        indexes = model._meta.indexes
252        if delete_field:
253            indexes = [
254                index for index in indexes if delete_field.name not in index.fields
255            ]
256
257        constraints = list(model._meta.constraints)
258
259        # Provide isolated instances of the fields to the new model body so
260        # that the existing model's internals aren't interfered with when
261        # the dummy model is constructed.
262        body_copy = copy.deepcopy(body)
263
264        # Construct a new model with the new fields to allow self referential
265        # primary key to resolve to. This model won't ever be materialized as a
266        # table and solely exists for foreign key reference resolution purposes.
267        # This wouldn't be required if the schema editor was operating on model
268        # states instead of rendered models.
269        meta_contents = {
270            "package_label": model._meta.package_label,
271            "db_table": model._meta.db_table,
272            "indexes": indexes,
273            "constraints": constraints,
274            "models_registry": models_registry,
275        }
276        meta = type("Meta", (), meta_contents)
277        body_copy["Meta"] = meta
278        body_copy["__module__"] = model.__module__
279        register_model(type(model._meta.object_name, model.__bases__, body_copy))
280
281        # Construct a model with a renamed table name.
282        body_copy = copy.deepcopy(body)
283        meta_contents = {
284            "package_label": model._meta.package_label,
285            "db_table": f"new__{strip_quotes(model._meta.db_table)}",
286            "indexes": indexes,
287            "constraints": constraints,
288            "models_registry": models_registry,
289        }
290        meta = type("Meta", (), meta_contents)
291        body_copy["Meta"] = meta
292        body_copy["__module__"] = model.__module__
293        new_model = type(f"New{model._meta.object_name}", model.__bases__, body_copy)
294        register_model(new_model)
295
296        # Create a new table with the updated schema.
297        self.create_model(new_model)
298
299        # Copy data from the old table into the new table
300        self.execute(
301            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
302                self.quote_name(new_model._meta.db_table),
303                ", ".join(self.quote_name(x) for x in mapping),
304                ", ".join(mapping.values()),
305                self.quote_name(model._meta.db_table),
306            )
307        )
308
309        # Delete the old table to make way for the new
310        self.delete_model(model, handle_autom2m=False)
311
312        # Rename the new table to take way for the old
313        self.alter_db_table(
314            new_model,
315            new_model._meta.db_table,
316            model._meta.db_table,
317            disable_constraints=False,
318        )
319
320        # Run deferred SQL on correct table
321        for sql in self.deferred_sql:
322            self.execute(sql)
323        self.deferred_sql = []
324        # Fix any PK-removed field
325        if restore_pk_field:
326            restore_pk_field.primary_key = True
327
328    def delete_model(self, model, handle_autom2m=True):
329        if handle_autom2m:
330            super().delete_model(model)
331        else:
332            # Delete the table (and only that)
333            self.execute(
334                self.sql_delete_table
335                % {
336                    "table": self.quote_name(model._meta.db_table),
337                }
338            )
339            # Remove all deferred statements referencing the deleted table.
340            for sql in list(self.deferred_sql):
341                if isinstance(sql, Statement) and sql.references_table(
342                    model._meta.db_table
343                ):
344                    self.deferred_sql.remove(sql)
345
346    def add_field(self, model, field):
347        """Create a field on a model."""
348        if (
349            # Primary keys are not supported in ALTER TABLE
350            # ADD COLUMN.
351            field.primary_key
352            or
353            # Fields with default values cannot by handled by ALTER TABLE ADD
354            # COLUMN statement because DROP DEFAULT is not supported in
355            # ALTER TABLE.
356            not field.allow_null
357            or self.effective_default(field) is not None
358        ):
359            self._remake_table(model, create_field=field)
360        else:
361            super().add_field(model, field)
362
363    def remove_field(self, model, field):
364        """
365        Remove a field from a model. Usually involves deleting a column,
366        but for M2Ms may involve deleting a table.
367        """
368        # M2M fields are a special case
369        if field.many_to_many:
370            # For explicit "through" M2M fields, do nothing
371            pass
372        elif (
373            self.connection.features.can_alter_table_drop_column
374            # Primary keys, unique fields, indexed fields, and foreign keys are
375            # not supported in ALTER TABLE DROP COLUMN.
376            and not field.primary_key
377            and not (field.remote_field and field.db_index)
378            and not (field.remote_field and field.db_constraint)
379        ):
380            super().remove_field(model, field)
381        # For everything else, remake.
382        else:
383            # It might not actually have a column behind it
384            if field.db_parameters(connection=self.connection)["type"] is None:
385                return
386            self._remake_table(model, delete_field=field)
387
388    def _alter_field(
389        self,
390        model,
391        old_field,
392        new_field,
393        old_type,
394        new_type,
395        old_db_params,
396        new_db_params,
397        strict=False,
398    ):
399        """Perform a "physical" (non-ManyToMany) field update."""
400        # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
401        # changed and there aren't any constraints.
402        if (
403            self.connection.features.can_alter_table_rename_column
404            and old_field.column != new_field.column
405            and self.column_sql(model, old_field) == self.column_sql(model, new_field)
406            and not (
407                old_field.remote_field
408                and old_field.db_constraint
409                or new_field.remote_field
410                and new_field.db_constraint
411            )
412        ):
413            return self.execute(
414                self._rename_field_sql(
415                    model._meta.db_table, old_field, new_field, new_type
416                )
417            )
418        # Alter by remaking table
419        self._remake_table(model, alter_fields=[(old_field, new_field)])
420        # Rebuild tables with FKs pointing to this field.
421        old_collation = old_db_params.get("collation")
422        new_collation = new_db_params.get("collation")
423        if new_field.primary_key and (
424            old_type != new_type or old_collation != new_collation
425        ):
426            related_models = set()
427            opts = new_field.model._meta
428            for remote_field in opts.related_objects:
429                # Ignore self-relationship since the table was already rebuilt.
430                if remote_field.related_model == model:
431                    continue
432                if not remote_field.many_to_many:
433                    if remote_field.field_name == new_field.name:
434                        related_models.add(remote_field.related_model)
435            if new_field.primary_key:
436                for many_to_many in opts.many_to_many:
437                    # Ignore self-relationship since the table was already rebuilt.
438                    if many_to_many.related_model == model:
439                        continue
440            for related_model in related_models:
441                self._remake_table(related_model)
442
443    def _alter_many_to_many(self, model, old_field, new_field, strict):
444        """Alter M2Ms to repoint their to= endpoints."""
445        if (
446            old_field.remote_field.through._meta.db_table
447            == new_field.remote_field.through._meta.db_table
448        ):
449            # The field name didn't change, but some options did, so we have to
450            # propagate this altering.
451            self._remake_table(
452                old_field.remote_field.through,
453                alter_fields=[
454                    (
455                        # The field that points to the target model is needed,
456                        # so that table can be remade with the new m2m field -
457                        # this is m2m_reverse_field_name().
458                        old_field.remote_field.through._meta.get_field(
459                            old_field.m2m_reverse_field_name()
460                        ),
461                        new_field.remote_field.through._meta.get_field(
462                            new_field.m2m_reverse_field_name()
463                        ),
464                    ),
465                    (
466                        # The field that points to the model itself is needed,
467                        # so that table can be remade with the new self field -
468                        # this is m2m_field_name().
469                        old_field.remote_field.through._meta.get_field(
470                            old_field.m2m_field_name()
471                        ),
472                        new_field.remote_field.through._meta.get_field(
473                            new_field.m2m_field_name()
474                        ),
475                    ),
476                ],
477            )
478            return
479
480        # Make a new through table
481        self.create_model(new_field.remote_field.through)
482        # Copy the data across
483        self.execute(
484            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
485                self.quote_name(new_field.remote_field.through._meta.db_table),
486                ", ".join(
487                    [
488                        "id",
489                        new_field.m2m_column_name(),
490                        new_field.m2m_reverse_name(),
491                    ]
492                ),
493                ", ".join(
494                    [
495                        "id",
496                        old_field.m2m_column_name(),
497                        old_field.m2m_reverse_name(),
498                    ]
499                ),
500                self.quote_name(old_field.remote_field.through._meta.db_table),
501            )
502        )
503        # Delete the old through table
504        self.delete_model(old_field.remote_field.through)
505
506    def add_constraint(self, model, constraint):
507        if isinstance(constraint, UniqueConstraint) and (
508            constraint.condition
509            or constraint.contains_expressions
510            or constraint.include
511            or constraint.deferrable
512        ):
513            super().add_constraint(model, constraint)
514        else:
515            self._remake_table(model)
516
517    def remove_constraint(self, model, constraint):
518        if isinstance(constraint, UniqueConstraint) and (
519            constraint.condition
520            or constraint.contains_expressions
521            or constraint.include
522            or constraint.deferrable
523        ):
524            super().remove_constraint(model, constraint)
525        else:
526            self._remake_table(model)
527
528    def _collate_sql(self, collation):
529        return "COLLATE " + collation