Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import copy
  4from decimal import Decimal
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.models import Options
  8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  9from plain.models.backends.ddl_references import Statement
 10from plain.models.backends.utils import strip_quotes
 11from plain.models.constraints import UniqueConstraint
 12from plain.models.db import NotSupportedError
 13from plain.models.registry import ModelsRegistry
 14from plain.models.transaction import atomic
 15
 16if TYPE_CHECKING:
 17    from plain.models.base import Model
 18    from plain.models.constraints import BaseConstraint
 19    from plain.models.fields import Field
 20    from plain.models.fields.related import ManyToManyField
 21    from plain.models.fields.reverse_related import ManyToManyRel
 22
 23
 24class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 25    sql_delete_table = "DROP TABLE %(table)s"
 26    sql_create_fk = None
 27    sql_create_inline_fk = (
 28        "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
 29    )
 30    sql_create_column_inline_fk = sql_create_inline_fk
 31    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
 32    sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
 33    sql_delete_unique = "DROP INDEX %(name)s"
 34
 35    def __enter__(self) -> DatabaseSchemaEditor:
 36        # Some SQLite schema alterations need foreign key constraints to be
 37        # disabled. Enforce it here for the duration of the schema edition.
 38        if not self.connection.disable_constraint_checking():
 39            raise NotSupportedError(
 40                "SQLite schema editor cannot be used while foreign key "
 41                "constraint checks are enabled. Make sure to disable them "
 42                "before entering a transaction.atomic() context because "
 43                "SQLite does not support disabling them in the middle of "
 44                "a multi-statement transaction."
 45            )
 46        return super().__enter__()
 47
 48    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
 49        self.connection.check_constraints()
 50        super().__exit__(exc_type, exc_value, traceback)
 51        self.connection.enable_constraint_checking()
 52
 53    def quote_value(self, value: Any) -> str:
 54        # The backend "mostly works" without this function and there are use
 55        # cases for compiling Python without the sqlite3 libraries (e.g.
 56        # security hardening).
 57        try:
 58            import sqlite3
 59
 60            value = sqlite3.adapt(value)  # type: ignore[call-overload]
 61        except ImportError:
 62            pass
 63        except sqlite3.ProgrammingError:
 64            pass
 65        # Manual emulation of SQLite parameter quoting
 66        if isinstance(value, bool):
 67            return str(int(value))
 68        elif isinstance(value, Decimal | float | int):
 69            return str(value)
 70        elif isinstance(value, str):
 71            return "'{}'".format(value.replace("'", "''"))
 72        elif value is None:
 73            return "NULL"
 74        elif isinstance(value, bytes | bytearray | memoryview):
 75            # Bytes are only allowed for BLOB fields, encoded as string
 76            # literals containing hexadecimal data and preceded by a single "X"
 77            # character.
 78            return f"X'{value.hex()}'"
 79        else:
 80            raise ValueError(
 81                f"Cannot quote parameter value {value!r} of type {type(value)}"
 82            )
 83
 84    def prepare_default(self, value: Any) -> str:
 85        return self.quote_value(value)
 86
 87    def _is_referenced_by_fk_constraint(
 88        self, table_name: str, column_name: str | None = None, ignore_self: bool = False
 89    ) -> bool:
 90        """
 91        Return whether or not the provided table name is referenced by another
 92        one. If `column_name` is specified, only references pointing to that
 93        column are considered. If `ignore_self` is True, self-referential
 94        constraints are ignored.
 95        """
 96        with self.connection.cursor() as cursor:
 97            for other_table in self.connection.introspection.get_table_list(cursor):
 98                if ignore_self and other_table.name == table_name:
 99                    continue
100                relations = self.connection.introspection.get_relations(
101                    cursor, other_table.name
102                )
103                for constraint_column, constraint_table in relations.values():
104                    if constraint_table == table_name and (
105                        column_name is None or constraint_column == column_name
106                    ):
107                        return True
108        return False
109
110    def alter_db_table(
111        self,
112        model: type[Model],
113        old_db_table: str,
114        new_db_table: str,
115        disable_constraints: bool = True,
116    ) -> None:
117        if (
118            not self.connection.features.supports_atomic_references_rename
119            and disable_constraints
120            and self._is_referenced_by_fk_constraint(old_db_table)
121        ):
122            if self.connection.in_atomic_block:
123                raise NotSupportedError(
124                    f"Renaming the {old_db_table!r} table while in a transaction is not "
125                    "supported on SQLite < 3.26 because it would break referential "
126                    "integrity. Try adding `atomic = False` to the Migration class."
127                )
128            self.connection.enable_constraint_checking()
129            super().alter_db_table(model, old_db_table, new_db_table)
130            self.connection.disable_constraint_checking()
131        else:
132            super().alter_db_table(model, old_db_table, new_db_table)
133
134    def alter_field(
135        self,
136        model: type[Model],
137        old_field: Field,
138        new_field: Field,
139        strict: bool = False,
140    ) -> None:
141        if not self._field_should_be_altered(old_field, new_field):
142            return
143        old_field_name = old_field.name
144        table_name = model.model_options.db_table
145        _, old_column_name = old_field.get_attname_column()
146        if (
147            new_field.name != old_field_name
148            and not self.connection.features.supports_atomic_references_rename
149            and self._is_referenced_by_fk_constraint(
150                table_name, old_column_name, ignore_self=True
151            )
152        ):
153            if self.connection.in_atomic_block:
154                raise NotSupportedError(
155                    f"Renaming the {model.model_options.db_table!r}.{old_field_name!r} column while in a transaction is not "
156                    "supported on SQLite < 3.26 because it would break referential "
157                    "integrity. Try adding `atomic = False` to the Migration class."
158                )
159            with atomic():
160                super().alter_field(model, old_field, new_field, strict=strict)
161                # Follow SQLite's documented procedure for performing changes
162                # that don't affect the on-disk content.
163                # https://sqlite.org/lang_altertable.html#otheralter
164                with self.connection.cursor() as cursor:
165                    schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
166                        0
167                    ]
168                    cursor.execute("PRAGMA writable_schema = 1")
169                    references_template = f' REFERENCES "{table_name}" ("%s") '
170                    new_column_name = new_field.get_attname_column()[1]
171                    search = references_template % old_column_name
172                    replacement = references_template % new_column_name
173                    cursor.execute(
174                        "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
175                        (search, replacement),
176                    )
177                    cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))  # noqa: UP031
178                    cursor.execute("PRAGMA writable_schema = 0")
179                    # The integrity check will raise an exception and rollback
180                    # the transaction if the sqlite_master updates corrupt the
181                    # database.
182                    cursor.execute("PRAGMA integrity_check")
183            # Perform a VACUUM to refresh the database representation from
184            # the sqlite_master table.
185            with self.connection.cursor() as cursor:
186                cursor.execute("VACUUM")
187        else:
188            super().alter_field(model, old_field, new_field, strict=strict)
189
190    def _remake_table(
191        self,
192        model: type[Model],
193        create_field: Field | None = None,
194        delete_field: Field | None = None,
195        alter_fields: list[tuple[Field, Field]] | None = None,
196    ) -> None:
197        """
198        Shortcut to transform a model from old_model into new_model
199
200        This follows the correct procedure to perform non-rename or column
201        addition operations based on SQLite's documentation
202
203        https://www.sqlite.org/lang_altertable.html#caution
204
205        The essential steps are:
206          1. Create a table with the updated definition called "new__app_model"
207          2. Copy the data from the existing "app_model" table to the new table
208          3. Drop the "app_model" table
209          4. Rename the "new__app_model" table to "app_model"
210          5. Restore any index of the previous "app_model" table.
211        """
212
213        # Self-referential fields must be recreated rather than copied from
214        # the old model to ensure their remote_field.field_name doesn't refer
215        # to an altered field.
216        def is_self_referential(f: Field) -> bool:
217            return f.is_relation and f.remote_field.model is model  # type: ignore[attr-defined]
218
219        # Work out the new fields dict / mapping
220        body = {
221            f.name: f.clone() if is_self_referential(f) else f
222            for f in model._model_meta.local_concrete_fields
223        }
224        # Since mapping might mix column names and default values,
225        # its values must be already quoted.
226        mapping = {
227            f.column: self.quote_name(f.column)
228            for f in model._model_meta.local_concrete_fields
229        }
230        # If any of the new or altered fields is introducing a new PK,
231        # remove the old one
232        restore_pk_field = None
233        alter_fields = alter_fields or []
234        if getattr(create_field, "primary_key", False) or any(
235            getattr(new_field, "primary_key", False) for _, new_field in alter_fields
236        ):
237            for name, field in list(body.items()):
238                if field.primary_key and not any(
239                    # Do not remove the old primary key when an altered field
240                    # that introduces a primary key is the same field.
241                    name == new_field.name
242                    for _, new_field in alter_fields
243                ):
244                    field.primary_key = False
245                    restore_pk_field = field
246                    if field.auto_created:
247                        del body[name]
248                        del mapping[field.column]
249        # Add in any created fields
250        if create_field:
251            body[create_field.name] = create_field
252            # Choose a default and insert it into the copy map
253            if not create_field.many_to_many and create_field.concrete:
254                mapping[create_field.column] = self.prepare_default(
255                    self.effective_default(create_field),
256                )
257        # Add in any altered fields
258        for alter_field in alter_fields:
259            old_field, new_field = alter_field
260            body.pop(old_field.name, None)
261            mapping.pop(old_field.column, None)
262            body[new_field.name] = new_field
263            if old_field.allow_null and not new_field.allow_null:
264                case_sql = f"coalesce({self.quote_name(old_field.column)}, {self.prepare_default(self.effective_default(new_field))})"
265                mapping[new_field.column] = case_sql
266            else:
267                mapping[new_field.column] = self.quote_name(old_field.column)
268        # Remove any deleted fields
269        if delete_field:
270            del body[delete_field.name]
271            del mapping[delete_field.column]
272        # Work inside a new app registry
273        models_registry = ModelsRegistry()
274
275        indexes = model.model_options.indexes
276        if delete_field:
277            indexes = [
278                index for index in indexes if delete_field.name not in index.fields
279            ]
280
281        constraints = list(model.model_options.constraints)
282
283        # Provide isolated instances of the fields to the new model body so
284        # that the existing model's internals aren't interfered with when
285        # the dummy model is constructed.
286        body_copy = copy.deepcopy(body)
287
288        # Construct a new model with the new fields to allow self referential
289        # primary key to resolve to. This model won't ever be materialized as a
290        # table and solely exists for foreign key reference resolution purposes.
291        # This wouldn't be required if the schema editor was operating on model
292        # states instead of rendered models.
293        meta_options = Options(
294            package_label=model.model_options.package_label,
295            db_table=model.model_options.db_table,
296            indexes=indexes,
297            constraints=constraints,
298        )
299        body_copy["model_options"] = meta_options
300        body_copy["__module__"] = model.__module__
301        temp_model: type[Model] = type(  # type: ignore[assignment]
302            model.model_options.object_name, model.__bases__, body_copy
303        )
304        models_registry.register_model(model.model_options.package_label, temp_model)
305
306        # Construct a model with a renamed table name.
307        body_copy = copy.deepcopy(body)
308        meta_options = Options(
309            package_label=model.model_options.package_label,
310            db_table=f"new__{strip_quotes(model.model_options.db_table)}",
311            indexes=indexes,
312            constraints=constraints,
313        )
314        body_copy["model_options"] = meta_options
315        body_copy["__module__"] = model.__module__
316        new_model: type[Model] = type(  # type: ignore[assignment]
317            f"New{model.model_options.object_name}", model.__bases__, body_copy
318        )
319        models_registry.register_model(model.model_options.package_label, new_model)
320
321        # Create a new table with the updated schema.
322        self.create_model(new_model)
323
324        # Copy data from the old table into the new table
325        self.execute(
326            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
327                self.quote_name(new_model.model_options.db_table),
328                ", ".join(self.quote_name(x) for x in mapping),
329                ", ".join(mapping.values()),
330                self.quote_name(model.model_options.db_table),
331            )
332        )
333
334        # Delete the old table to make way for the new
335        self.delete_model(model, handle_autom2m=False)
336
337        # Rename the new table to take way for the old
338        self.alter_db_table(
339            new_model,  # type: ignore[arg-type]
340            new_model.model_options.db_table,
341            model.model_options.db_table,
342            disable_constraints=False,
343        )
344
345        # Run deferred SQL on correct table
346        for sql in self.deferred_sql:
347            self.execute(sql)
348        self.deferred_sql = []
349        # Fix any PK-removed field
350        if restore_pk_field:
351            restore_pk_field.primary_key = True
352
353    def delete_model(self, model: type[Model], handle_autom2m: bool = True) -> None:
354        if handle_autom2m:
355            super().delete_model(model)
356        else:
357            # Delete the table (and only that)
358            self.execute(
359                self.sql_delete_table
360                % {
361                    "table": self.quote_name(model.model_options.db_table),
362                }
363            )
364            # Remove all deferred statements referencing the deleted table.
365            for sql in list(self.deferred_sql):
366                if isinstance(sql, Statement) and sql.references_table(
367                    model.model_options.db_table
368                ):
369                    self.deferred_sql.remove(sql)
370
371    def add_field(self, model: type[Model], field: Field) -> None:
372        """Create a field on a model."""
373        if (
374            # Primary keys are not supported in ALTER TABLE
375            # ADD COLUMN.
376            field.primary_key
377            or
378            # Fields with default values cannot by handled by ALTER TABLE ADD
379            # COLUMN statement because DROP DEFAULT is not supported in
380            # ALTER TABLE.
381            not field.allow_null
382            or self.effective_default(field) is not None
383        ):
384            self._remake_table(model, create_field=field)
385        else:
386            super().add_field(model, field)
387
388    def remove_field(self, model: type[Model], field: Field) -> None:
389        """
390        Remove a field from a model. Usually involves deleting a column,
391        but for M2Ms may involve deleting a table.
392        """
393        # M2M fields are a special case
394        if field.many_to_many:
395            # For explicit "through" M2M fields, do nothing
396            pass
397        elif (
398            self.connection.features.can_alter_table_drop_column
399            # Primary keys, unique fields, indexed fields, and foreign keys are
400            # not supported in ALTER TABLE DROP COLUMN.
401            and not field.primary_key
402            and not (field.remote_field and field.db_index)  # type: ignore[attr-defined]
403            and not (field.remote_field and field.db_constraint)  # type: ignore[attr-defined]
404        ):
405            super().remove_field(model, field)
406        # For everything else, remake.
407        else:
408            # It might not actually have a column behind it
409            if field.db_parameters(connection=self.connection)["type"] is None:
410                return
411            self._remake_table(model, delete_field=field)
412
413    def _alter_field(
414        self,
415        model: type[Model],
416        old_field: Field,
417        new_field: Field,
418        old_type: str,
419        new_type: str,
420        old_db_params: dict[str, Any],
421        new_db_params: dict[str, Any],
422        strict: bool = False,
423    ) -> None:
424        """Perform a "physical" (non-ManyToMany) field update."""
425        # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
426        # changed and there aren't any constraints.
427        if (
428            self.connection.features.can_alter_table_rename_column
429            and old_field.column != new_field.column
430            and self.column_sql(model, old_field) == self.column_sql(model, new_field)
431            and not (
432                old_field.remote_field
433                and old_field.db_constraint  # type: ignore[attr-defined]
434                or new_field.remote_field
435                and new_field.db_constraint  # type: ignore[attr-defined]
436            )
437        ):
438            return self.execute(
439                self._rename_field_sql(
440                    model.model_options.db_table, old_field, new_field, new_type
441                )
442            )
443        # Alter by remaking table
444        self._remake_table(model, alter_fields=[(old_field, new_field)])
445        # Rebuild tables with FKs pointing to this field.
446        old_collation = old_db_params.get("collation")
447        new_collation = new_db_params.get("collation")
448        if new_field.primary_key and (
449            old_type != new_type or old_collation != new_collation
450        ):
451            related_models = set()
452            meta = new_field.model._model_meta
453            for remote_field in meta.related_objects:
454                # Ignore self-relationship since the table was already rebuilt.
455                if remote_field.related_model == model:
456                    continue
457                if not remote_field.many_to_many:
458                    if remote_field.field_name == new_field.name:
459                        related_models.add(remote_field.related_model)
460            if new_field.primary_key:
461                for many_to_many in meta.many_to_many:
462                    # Ignore self-relationship since the table was already rebuilt.
463                    if many_to_many.related_model == model:
464                        continue
465            for related_model in related_models:
466                self._remake_table(related_model)
467
468    def _alter_many_to_many(
469        self,
470        model: type[Model],
471        old_field: ManyToManyField,
472        new_field: ManyToManyField,
473        strict: bool,
474    ) -> None:
475        """Alter M2Ms to repoint their to= endpoints."""
476        # Type narrow for ManyToManyField.remote_field
477        old_rel: ManyToManyRel = old_field.remote_field  # type: ignore[assignment]
478        new_rel: ManyToManyRel = new_field.remote_field  # type: ignore[assignment]
479
480        if (
481            old_rel.through.model_options.db_table
482            == new_rel.through.model_options.db_table
483        ):
484            # The field name didn't change, but some options did, so we have to
485            # propagate this altering.
486            self._remake_table(
487                old_rel.through,
488                alter_fields=[
489                    (
490                        # The field that points to the target model is needed,
491                        # so that table can be remade with the new m2m field -
492                        # this is m2m_reverse_field_name().
493                        old_rel.through._model_meta.get_field(
494                            old_field.m2m_reverse_field_name()  # type: ignore[attr-defined]
495                        ),
496                        new_rel.through._model_meta.get_field(
497                            new_field.m2m_reverse_field_name()  # type: ignore[attr-defined]
498                        ),
499                    ),
500                    (
501                        # The field that points to the model itself is needed,
502                        # so that table can be remade with the new self field -
503                        # this is m2m_field_name().
504                        old_rel.through._model_meta.get_field(
505                            old_field.m2m_field_name()  # type: ignore[attr-defined]
506                        ),
507                        new_rel.through._model_meta.get_field(
508                            new_field.m2m_field_name()  # type: ignore[attr-defined]
509                        ),
510                    ),
511                ],
512            )
513            return
514
515        # Make a new through table
516        self.create_model(new_rel.through)
517        # Copy the data across
518        self.execute(
519            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
520                self.quote_name(new_rel.through.model_options.db_table),
521                ", ".join(
522                    [
523                        "id",
524                        new_field.m2m_column_name(),  # type: ignore[attr-defined]
525                        new_field.m2m_reverse_name(),  # type: ignore[attr-defined]
526                    ]
527                ),
528                ", ".join(
529                    [
530                        "id",
531                        old_field.m2m_column_name(),  # type: ignore[attr-defined]
532                        old_field.m2m_reverse_name(),  # type: ignore[attr-defined]
533                    ]
534                ),
535                self.quote_name(old_rel.through.model_options.db_table),
536            )
537        )
538        # Delete the old through table
539        self.delete_model(old_rel.through)
540
541    def add_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
542        if isinstance(constraint, UniqueConstraint) and (
543            constraint.condition
544            or constraint.contains_expressions
545            or constraint.include
546            or constraint.deferrable
547        ):
548            super().add_constraint(model, constraint)
549        else:
550            self._remake_table(model)
551
552    def remove_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
553        if isinstance(constraint, UniqueConstraint) and (
554            constraint.condition
555            or constraint.contains_expressions
556            or constraint.include
557            or constraint.deferrable
558        ):
559            super().remove_constraint(model, constraint)
560        else:
561            self._remake_table(model)
562
563    def _collate_sql(self, collation: str) -> str:
564        return "COLLATE " + collation