1from __future__ import annotations
  2
  3import copy
  4from decimal import Decimal
  5from typing import TYPE_CHECKING, Any, cast
  6
  7from plain.models import Options
  8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  9from plain.models.backends.ddl_references import Statement
 10from plain.models.backends.utils import strip_quotes
 11from plain.models.constraints import UniqueConstraint
 12from plain.models.db import NotSupportedError
 13from plain.models.fields import DbParameters
 14from plain.models.fields.related import ForeignKeyField, RelatedField
 15from plain.models.registry import ModelsRegistry
 16from plain.models.transaction import atomic
 17
 18if TYPE_CHECKING:
 19    from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
 20    from plain.models.base import Model
 21    from plain.models.constraints import BaseConstraint
 22    from plain.models.fields import Field
 23    from plain.models.fields.related import ManyToManyField
 24    from plain.models.fields.reverse_related import ManyToManyRel
 25
 26
 27class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 28    # Type checker hint: connection is always SQLiteDatabaseWrapper in this class
 29    connection: SQLiteDatabaseWrapper
 30
 31    sql_delete_table = "DROP TABLE %(table)s"
 32    sql_create_fk = None
 33    sql_create_inline_fk = (
 34        "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
 35    )
 36    sql_create_column_inline_fk = sql_create_inline_fk
 37    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
 38    sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
 39    sql_delete_unique = "DROP INDEX %(name)s"
 40
 41    def __enter__(self) -> DatabaseSchemaEditor:
 42        # Some SQLite schema alterations need foreign key constraints to be
 43        # disabled. Enforce it here for the duration of the schema edition.
 44        if not self.connection.disable_constraint_checking():
 45            raise NotSupportedError(
 46                "SQLite schema editor cannot be used while foreign key "
 47                "constraint checks are enabled. Make sure to disable them "
 48                "before entering a transaction.atomic() context because "
 49                "SQLite does not support disabling them in the middle of "
 50                "a multi-statement transaction."
 51            )
 52        return super().__enter__()
 53
 54    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
 55        self.connection.check_constraints()
 56        super().__exit__(exc_type, exc_value, traceback)
 57        self.connection.enable_constraint_checking()
 58
 59    def quote_value(self, value: Any) -> str:
 60        # The backend "mostly works" without this function and there are use
 61        # cases for compiling Python without the sqlite3 libraries (e.g.
 62        # security hardening).
 63        try:
 64            import sqlite3
 65
 66            value = sqlite3.adapt(value)  # type: ignore[call-overload]
 67        except ImportError:
 68            pass
 69        except sqlite3.ProgrammingError:
 70            pass
 71        # Manual emulation of SQLite parameter quoting
 72        if isinstance(value, bool):
 73            return str(int(value))
 74        elif isinstance(value, Decimal | float | int):
 75            return str(value)
 76        elif isinstance(value, str):
 77            return "'{}'".format(value.replace("'", "''"))
 78        elif value is None:
 79            return "NULL"
 80        elif isinstance(value, bytes | bytearray | memoryview):
 81            # Bytes are only allowed for BLOB fields, encoded as string
 82            # literals containing hexadecimal data and preceded by a single "X"
 83            # character.
 84            return f"X'{value.hex()}'"
 85        else:
 86            raise ValueError(
 87                f"Cannot quote parameter value {value!r} of type {type(value)}"
 88            )
 89
 90    def prepare_default(self, value: Any) -> str:
 91        return self.quote_value(value)
 92
 93    def _is_referenced_by_fk_constraint(
 94        self, table_name: str, column_name: str | None = None, ignore_self: bool = False
 95    ) -> bool:
 96        """
 97        Return whether or not the provided table name is referenced by another
 98        one. If `column_name` is specified, only references pointing to that
 99        column are considered. If `ignore_self` is True, self-referential
100        constraints are ignored.
101        """
102        with self.connection.cursor() as cursor:
103            for other_table in self.connection.introspection.get_table_list(cursor):
104                if ignore_self and other_table.name == table_name:
105                    continue
106                relations = self.connection.introspection.get_relations(
107                    cursor, other_table.name
108                )
109                for constraint_column, constraint_table in relations.values():
110                    if constraint_table == table_name and (
111                        column_name is None or constraint_column == column_name
112                    ):
113                        return True
114        return False
115
116    def alter_db_table(
117        self,
118        model: type[Model],
119        old_db_table: str,
120        new_db_table: str,
121        disable_constraints: bool = True,
122    ) -> None:
123        if (
124            not self.connection.features.supports_atomic_references_rename
125            and disable_constraints
126            and self._is_referenced_by_fk_constraint(old_db_table)
127        ):
128            if self.connection.in_atomic_block:
129                raise NotSupportedError(
130                    f"Renaming the {old_db_table!r} table while in a transaction is not "
131                    "supported on SQLite < 3.26 because it would break referential "
132                    "integrity. Try adding `atomic = False` to the Migration class."
133                )
134            self.connection.enable_constraint_checking()
135            super().alter_db_table(model, old_db_table, new_db_table)
136            self.connection.disable_constraint_checking()
137        else:
138            super().alter_db_table(model, old_db_table, new_db_table)
139
140    def alter_field(
141        self,
142        model: type[Model],
143        old_field: Field,
144        new_field: Field,
145        strict: bool = False,
146    ) -> None:
147        if not self._field_should_be_altered(old_field, new_field):
148            return
149        old_field_name = old_field.name
150        table_name = model.model_options.db_table
151        _, old_column_name = old_field.get_attname_column()
152        if (
153            new_field.name != old_field_name
154            and not self.connection.features.supports_atomic_references_rename
155            and self._is_referenced_by_fk_constraint(
156                table_name, old_column_name, ignore_self=True
157            )
158        ):
159            if self.connection.in_atomic_block:
160                raise NotSupportedError(
161                    f"Renaming the {model.model_options.db_table!r}.{old_field_name!r} column while in a transaction is not "
162                    "supported on SQLite < 3.26 because it would break referential "
163                    "integrity. Try adding `atomic = False` to the Migration class."
164                )
165            with atomic():
166                super().alter_field(model, old_field, new_field, strict=strict)
167                # Follow SQLite's documented procedure for performing changes
168                # that don't affect the on-disk content.
169                # https://sqlite.org/lang_altertable.html#otheralter
170                with self.connection.cursor() as cursor:
171                    row = cursor.execute("PRAGMA schema_version").fetchone()
172                    assert row is not None
173                    schema_version = row[0]
174                    cursor.execute("PRAGMA writable_schema = 1")
175                    references_template = f' REFERENCES "{table_name}" ("%s") '
176                    new_column_name = new_field.get_attname_column()[1]
177                    search = references_template % old_column_name
178                    replacement = references_template % new_column_name
179                    cursor.execute(
180                        "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
181                        (search, replacement),
182                    )
183                    cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))  # noqa: UP031
184                    cursor.execute("PRAGMA writable_schema = 0")
185                    # The integrity check will raise an exception and rollback
186                    # the transaction if the sqlite_master updates corrupt the
187                    # database.
188                    cursor.execute("PRAGMA integrity_check")
189            # Perform a VACUUM to refresh the database representation from
190            # the sqlite_master table.
191            with self.connection.cursor() as cursor:
192                cursor.execute("VACUUM")
193        else:
194            super().alter_field(model, old_field, new_field, strict=strict)
195
196    def _remake_table(
197        self,
198        model: type[Model],
199        create_field: Field | None = None,
200        delete_field: Field | None = None,
201        alter_fields: list[tuple[Field, Field]] | None = None,
202    ) -> None:
203        """
204        Shortcut to transform a model from old_model into new_model
205
206        This follows the correct procedure to perform non-rename or column
207        addition operations based on SQLite's documentation
208
209        https://www.sqlite.org/lang_altertable.html#caution
210
211        The essential steps are:
212          1. Create a table with the updated definition called "new__app_model"
213          2. Copy the data from the existing "app_model" table to the new table
214          3. Drop the "app_model" table
215          4. Rename the "new__app_model" table to "app_model"
216          5. Restore any index of the previous "app_model" table.
217        """
218
219        # Self-referential fields must be recreated rather than copied from
220        # the old model to ensure their remote_field.field_name doesn't refer
221        # to an altered field.
222        def is_self_referential(f: Field) -> bool:
223            return isinstance(f, RelatedField) and f.remote_field.model is model
224
225        # Work out the new fields dict / mapping
226        body = {
227            f.name: f.clone() if is_self_referential(f) else f
228            for f in model._model_meta.local_concrete_fields
229        }
230        # Since mapping might mix column names and default values,
231        # its values must be already quoted.
232        mapping = {
233            f.column: self.quote_name(f.column)
234            for f in model._model_meta.local_concrete_fields
235        }
236        # If any of the new or altered fields is introducing a new PK,
237        # remove the old one
238        restore_pk_field = None
239        alter_fields = alter_fields or []
240        if getattr(create_field, "primary_key", False) or any(
241            getattr(new_field, "primary_key", False) for _, new_field in alter_fields
242        ):
243            for name, field in list(body.items()):
244                if field.primary_key and not any(
245                    # Do not remove the old primary key when an altered field
246                    # that introduces a primary key is the same field.
247                    name == new_field.name
248                    for _, new_field in alter_fields
249                ):
250                    field.primary_key = False
251                    restore_pk_field = field
252                    if field.auto_created:
253                        del body[name]
254                        del mapping[field.column]
255        # Add in any created fields
256        if create_field:
257            body[create_field.name] = create_field
258            # Choose a default and insert it into the copy map
259            from plain.models.fields.related import ManyToManyField
260
261            if not isinstance(create_field, ManyToManyField) and create_field.concrete:
262                mapping[create_field.column] = self.prepare_default(
263                    self.effective_default(create_field),
264                )
265        # Add in any altered fields
266        for alter_field in alter_fields:
267            old_field, new_field = alter_field
268            body.pop(old_field.name, None)
269            mapping.pop(old_field.column, None)
270            body[new_field.name] = new_field
271            if old_field.allow_null and not new_field.allow_null:
272                case_sql = f"coalesce({self.quote_name(old_field.column)}, {self.prepare_default(self.effective_default(new_field))})"
273                mapping[new_field.column] = case_sql
274            else:
275                mapping[new_field.column] = self.quote_name(old_field.column)
276        # Remove any deleted fields
277        if delete_field:
278            del body[delete_field.name]
279            del mapping[delete_field.column]
280        # Work inside a new app registry
281        models_registry = ModelsRegistry()
282
283        indexes = model.model_options.indexes
284        if delete_field:
285            indexes = [
286                index for index in indexes if delete_field.name not in index.fields
287            ]
288
289        constraints = list(model.model_options.constraints)
290
291        # Provide isolated instances of the fields to the new model body so
292        # that the existing model's internals aren't interfered with when
293        # the dummy model is constructed.
294        body_copy = copy.deepcopy(body)
295
296        # Construct a new model with the new fields to allow self referential
297        # primary key to resolve to. This model won't ever be materialized as a
298        # table and solely exists for foreign key reference resolution purposes.
299        # This wouldn't be required if the schema editor was operating on model
300        # states instead of rendered models.
301        meta_options = Options(
302            package_label=model.model_options.package_label,
303            db_table=model.model_options.db_table,
304            indexes=indexes,
305            constraints=constraints,
306        )
307        body_copy["model_options"] = meta_options
308        body_copy["__module__"] = model.__module__
309        temp_model = cast(
310            "type[Model]",
311            type(model.model_options.object_name, model.__bases__, body_copy),
312        )
313        models_registry.register_model(model.model_options.package_label, temp_model)
314
315        # Construct a model with a renamed table name.
316        body_copy = copy.deepcopy(body)
317        meta_options = Options(
318            package_label=model.model_options.package_label,
319            db_table=f"new__{strip_quotes(model.model_options.db_table)}",
320            indexes=indexes,
321            constraints=constraints,
322        )
323        body_copy["model_options"] = meta_options
324        body_copy["__module__"] = model.__module__
325        new_model = cast(
326            "type[Model]",
327            type(f"New{model.model_options.object_name}", model.__bases__, body_copy),
328        )
329        models_registry.register_model(model.model_options.package_label, new_model)
330
331        # Create a new table with the updated schema.
332        self.create_model(new_model)
333
334        # Copy data from the old table into the new table
335        self.execute(
336            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
337                self.quote_name(new_model.model_options.db_table),
338                ", ".join(self.quote_name(x) for x in mapping),
339                ", ".join(mapping.values()),
340                self.quote_name(model.model_options.db_table),
341            )
342        )
343
344        # Delete the old table to make way for the new
345        self.delete_model(model, handle_autom2m=False)
346
347        # Rename the new table to take way for the old
348        self.alter_db_table(
349            new_model,
350            new_model.model_options.db_table,
351            model.model_options.db_table,
352            disable_constraints=False,
353        )
354
355        # Run deferred SQL on correct table
356        for sql in self.deferred_sql:
357            self.execute(sql)
358        self.deferred_sql = []
359        # Fix any PK-removed field
360        if restore_pk_field:
361            restore_pk_field.primary_key = True
362
363    def delete_model(self, model: type[Model], handle_autom2m: bool = True) -> None:
364        if handle_autom2m:
365            super().delete_model(model)
366        else:
367            # Delete the table (and only that)
368            self.execute(
369                self.sql_delete_table
370                % {
371                    "table": self.quote_name(model.model_options.db_table),
372                }
373            )
374            # Remove all deferred statements referencing the deleted table.
375            for sql in list(self.deferred_sql):
376                if isinstance(sql, Statement) and sql.references_table(
377                    model.model_options.db_table
378                ):
379                    self.deferred_sql.remove(sql)
380
381    def add_field(self, model: type[Model], field: Field) -> None:
382        """Create a field on a model."""
383        if (
384            # Primary keys are not supported in ALTER TABLE
385            # ADD COLUMN.
386            field.primary_key
387            or
388            # Fields with default values cannot by handled by ALTER TABLE ADD
389            # COLUMN statement because DROP DEFAULT is not supported in
390            # ALTER TABLE.
391            not field.allow_null
392            or self.effective_default(field) is not None
393        ):
394            self._remake_table(model, create_field=field)
395        else:
396            super().add_field(model, field)
397
398    def remove_field(self, model: type[Model], field: Field) -> None:
399        """
400        Remove a field from a model. Usually involves deleting a column,
401        but for M2Ms may involve deleting a table.
402        """
403        from plain.models.fields.related import ManyToManyField
404
405        # M2M fields are a special case
406        if isinstance(field, ManyToManyField):
407            # For explicit "through" M2M fields, do nothing
408            pass
409        elif (
410            self.connection.features.can_alter_table_drop_column
411            # Primary keys, unique fields, indexed fields, and foreign keys are
412            # not supported in ALTER TABLE DROP COLUMN.
413            and not field.primary_key
414            and not (isinstance(field, ForeignKeyField) and field.db_index)
415            and not (isinstance(field, ForeignKeyField) and field.db_constraint)
416        ):
417            super().remove_field(model, field)
418        # For everything else, remake.
419        else:
420            # It might not actually have a column behind it
421            if field.db_parameters(connection=self.connection)["type"] is None:
422                return
423            self._remake_table(model, delete_field=field)
424
425    def _alter_field(
426        self,
427        model: type[Model],
428        old_field: Field,
429        new_field: Field,
430        old_type: str,
431        new_type: str,
432        old_db_params: DbParameters,
433        new_db_params: DbParameters,
434        strict: bool = False,
435    ) -> None:
436        """Perform a "physical" (non-ManyToMany) field update."""
437        # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
438        # changed and there aren't any constraints.
439        if (
440            self.connection.features.can_alter_table_rename_column
441            and old_field.column != new_field.column
442            and self.column_sql(model, old_field) == self.column_sql(model, new_field)
443            and not (
444                isinstance(old_field, ForeignKeyField)
445                and old_field.db_constraint
446                or isinstance(new_field, ForeignKeyField)
447                and new_field.db_constraint
448            )
449        ):
450            return self.execute(
451                self._rename_field_sql(
452                    model.model_options.db_table, old_field, new_field, new_type
453                )
454            )
455        # Alter by remaking table
456        self._remake_table(model, alter_fields=[(old_field, new_field)])
457        # Rebuild tables with FKs pointing to this field.
458        old_collation = old_db_params.get("collation")
459        new_collation = new_db_params.get("collation")
460        if new_field.primary_key and (
461            old_type != new_type or old_collation != new_collation
462        ):
463            from plain.models.fields.reverse_related import ManyToManyRel
464
465            related_models = set()
466            meta = new_field.model._model_meta
467            for remote_field in meta.related_objects:
468                # Ignore self-relationship since the table was already rebuilt.
469                if remote_field.related_model == model:
470                    continue
471                if not isinstance(remote_field, ManyToManyRel):
472                    if remote_field.field_name == new_field.name:
473                        related_models.add(remote_field.related_model)
474            if new_field.primary_key:
475                for many_to_many in meta.many_to_many:
476                    # Ignore self-relationship since the table was already rebuilt.
477                    if many_to_many.related_model == model:
478                        continue
479            for related_model in related_models:
480                self._remake_table(related_model)
481
482    def _alter_many_to_many(
483        self,
484        model: type[Model],
485        old_field: ManyToManyField,
486        new_field: ManyToManyField,
487        strict: bool,
488    ) -> None:
489        """Alter M2Ms to repoint their to= endpoints."""
490        # Type narrow for ManyToManyField.remote_field
491        old_rel: ManyToManyRel = old_field.remote_field
492        new_rel: ManyToManyRel = new_field.remote_field
493
494        if (
495            old_rel.through.model_options.db_table
496            == new_rel.through.model_options.db_table
497        ):
498            # The field name didn't change, but some options did, so we have to
499            # propagate this altering.
500            self._remake_table(
501                old_rel.through,
502                alter_fields=[
503                    (
504                        # The field that points to the target model is needed,
505                        # so that table can be remade with the new m2m field -
506                        # this is m2m_reverse_field_name().
507                        old_rel.through._model_meta.get_forward_field(
508                            old_field.m2m_reverse_field_name()
509                        ),
510                        new_rel.through._model_meta.get_forward_field(
511                            new_field.m2m_reverse_field_name()
512                        ),
513                    ),
514                    (
515                        # The field that points to the model itself is needed,
516                        # so that table can be remade with the new self field -
517                        # this is m2m_field_name().
518                        old_rel.through._model_meta.get_forward_field(
519                            old_field.m2m_field_name()
520                        ),
521                        new_rel.through._model_meta.get_forward_field(
522                            new_field.m2m_field_name()
523                        ),
524                    ),
525                ],
526            )
527            return
528
529        # Make a new through table
530        self.create_model(new_rel.through)
531        # Copy the data across
532        self.execute(
533            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
534                self.quote_name(new_rel.through.model_options.db_table),
535                ", ".join(
536                    [
537                        "id",
538                        new_field.m2m_column_name(),
539                        new_field.m2m_reverse_name(),
540                    ]
541                ),
542                ", ".join(
543                    [
544                        "id",
545                        old_field.m2m_column_name(),
546                        old_field.m2m_reverse_name(),
547                    ]
548                ),
549                self.quote_name(old_rel.through.model_options.db_table),
550            )
551        )
552        # Delete the old through table
553        self.delete_model(old_rel.through)
554
555    def add_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
556        if isinstance(constraint, UniqueConstraint) and (
557            constraint.condition
558            or constraint.contains_expressions
559            or constraint.include
560            or constraint.deferrable
561        ):
562            super().add_constraint(model, constraint)
563        else:
564            self._remake_table(model)
565
566    def remove_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
567        if isinstance(constraint, UniqueConstraint) and (
568            constraint.condition
569            or constraint.contains_expressions
570            or constraint.include
571            or constraint.deferrable
572        ):
573            super().remove_constraint(model, constraint)
574        else:
575            self._remake_table(model)
576
577    def _collate_sql(
578        self,
579        collation: str | None,
580        old_collation: str | None = None,
581        table_name: str | None = None,
582    ) -> str:
583        return "COLLATE " + collation if collation else ""