Plain is headed towards 1.0! Subscribe for development updates →

  1import copy
  2from decimal import Decimal
  3
  4from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  5from plain.models.backends.ddl_references import Statement
  6from plain.models.backends.utils import strip_quotes
  7from plain.models.constraints import UniqueConstraint
  8from plain.models.db import NotSupportedError
  9from plain.models.transaction import atomic
 10from plain.packages.registry import Packages
 11
 12
 13class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 14    sql_delete_table = "DROP TABLE %(table)s"
 15    sql_create_fk = None
 16    sql_create_inline_fk = (
 17        "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
 18    )
 19    sql_create_column_inline_fk = sql_create_inline_fk
 20    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
 21    sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
 22    sql_delete_unique = "DROP INDEX %(name)s"
 23
 24    def __enter__(self):
 25        # Some SQLite schema alterations need foreign key constraints to be
 26        # disabled. Enforce it here for the duration of the schema edition.
 27        if not self.connection.disable_constraint_checking():
 28            raise NotSupportedError(
 29                "SQLite schema editor cannot be used while foreign key "
 30                "constraint checks are enabled. Make sure to disable them "
 31                "before entering a transaction.atomic() context because "
 32                "SQLite does not support disabling them in the middle of "
 33                "a multi-statement transaction."
 34            )
 35        return super().__enter__()
 36
 37    def __exit__(self, exc_type, exc_value, traceback):
 38        self.connection.check_constraints()
 39        super().__exit__(exc_type, exc_value, traceback)
 40        self.connection.enable_constraint_checking()
 41
 42    def quote_value(self, value):
 43        # The backend "mostly works" without this function and there are use
 44        # cases for compiling Python without the sqlite3 libraries (e.g.
 45        # security hardening).
 46        try:
 47            import sqlite3
 48
 49            value = sqlite3.adapt(value)
 50        except ImportError:
 51            pass
 52        except sqlite3.ProgrammingError:
 53            pass
 54        # Manual emulation of SQLite parameter quoting
 55        if isinstance(value, bool):
 56            return str(int(value))
 57        elif isinstance(value, Decimal | float | int):
 58            return str(value)
 59        elif isinstance(value, str):
 60            return "'%s'" % value.replace("'", "''")
 61        elif value is None:
 62            return "NULL"
 63        elif isinstance(value, bytes | bytearray | memoryview):
 64            # Bytes are only allowed for BLOB fields, encoded as string
 65            # literals containing hexadecimal data and preceded by a single "X"
 66            # character.
 67            return "X'%s'" % value.hex()
 68        else:
 69            raise ValueError(
 70                f"Cannot quote parameter value {value!r} of type {type(value)}"
 71            )
 72
 73    def prepare_default(self, value):
 74        return self.quote_value(value)
 75
 76    def _is_referenced_by_fk_constraint(
 77        self, table_name, column_name=None, ignore_self=False
 78    ):
 79        """
 80        Return whether or not the provided table name is referenced by another
 81        one. If `column_name` is specified, only references pointing to that
 82        column are considered. If `ignore_self` is True, self-referential
 83        constraints are ignored.
 84        """
 85        with self.connection.cursor() as cursor:
 86            for other_table in self.connection.introspection.get_table_list(cursor):
 87                if ignore_self and other_table.name == table_name:
 88                    continue
 89                relations = self.connection.introspection.get_relations(
 90                    cursor, other_table.name
 91                )
 92                for constraint_column, constraint_table in relations.values():
 93                    if constraint_table == table_name and (
 94                        column_name is None or constraint_column == column_name
 95                    ):
 96                        return True
 97        return False
 98
 99    def alter_db_table(
100        self, model, old_db_table, new_db_table, disable_constraints=True
101    ):
102        if (
103            not self.connection.features.supports_atomic_references_rename
104            and disable_constraints
105            and self._is_referenced_by_fk_constraint(old_db_table)
106        ):
107            if self.connection.in_atomic_block:
108                raise NotSupportedError(
109                    (
110                        "Renaming the %r table while in a transaction is not "
111                        "supported on SQLite < 3.26 because it would break referential "
112                        "integrity. Try adding `atomic = False` to the Migration class."
113                    )
114                    % old_db_table
115                )
116            self.connection.enable_constraint_checking()
117            super().alter_db_table(model, old_db_table, new_db_table)
118            self.connection.disable_constraint_checking()
119        else:
120            super().alter_db_table(model, old_db_table, new_db_table)
121
122    def alter_field(self, model, old_field, new_field, strict=False):
123        if not self._field_should_be_altered(old_field, new_field):
124            return
125        old_field_name = old_field.name
126        table_name = model._meta.db_table
127        _, old_column_name = old_field.get_attname_column()
128        if (
129            new_field.name != old_field_name
130            and not self.connection.features.supports_atomic_references_rename
131            and self._is_referenced_by_fk_constraint(
132                table_name, old_column_name, ignore_self=True
133            )
134        ):
135            if self.connection.in_atomic_block:
136                raise NotSupportedError(
137                    (
138                        "Renaming the {!r}.{!r} column while in a transaction is not "
139                        "supported on SQLite < 3.26 because it would break referential "
140                        "integrity. Try adding `atomic = False` to the Migration class."
141                    ).format(model._meta.db_table, old_field_name)
142                )
143            with atomic(self.connection.alias):
144                super().alter_field(model, old_field, new_field, strict=strict)
145                # Follow SQLite's documented procedure for performing changes
146                # that don't affect the on-disk content.
147                # https://sqlite.org/lang_altertable.html#otheralter
148                with self.connection.cursor() as cursor:
149                    schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
150                        0
151                    ]
152                    cursor.execute("PRAGMA writable_schema = 1")
153                    references_template = ' REFERENCES "%s" ("%%s") ' % table_name
154                    new_column_name = new_field.get_attname_column()[1]
155                    search = references_template % old_column_name
156                    replacement = references_template % new_column_name
157                    cursor.execute(
158                        "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
159                        (search, replacement),
160                    )
161                    cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
162                    cursor.execute("PRAGMA writable_schema = 0")
163                    # The integrity check will raise an exception and rollback
164                    # the transaction if the sqlite_master updates corrupt the
165                    # database.
166                    cursor.execute("PRAGMA integrity_check")
167            # Perform a VACUUM to refresh the database representation from
168            # the sqlite_master table.
169            with self.connection.cursor() as cursor:
170                cursor.execute("VACUUM")
171        else:
172            super().alter_field(model, old_field, new_field, strict=strict)
173
174    def _remake_table(
175        self, model, create_field=None, delete_field=None, alter_fields=None
176    ):
177        """
178        Shortcut to transform a model from old_model into new_model
179
180        This follows the correct procedure to perform non-rename or column
181        addition operations based on SQLite's documentation
182
183        https://www.sqlite.org/lang_altertable.html#caution
184
185        The essential steps are:
186          1. Create a table with the updated definition called "new__app_model"
187          2. Copy the data from the existing "app_model" table to the new table
188          3. Drop the "app_model" table
189          4. Rename the "new__app_model" table to "app_model"
190          5. Restore any index of the previous "app_model" table.
191        """
192
193        # Self-referential fields must be recreated rather than copied from
194        # the old model to ensure their remote_field.field_name doesn't refer
195        # to an altered field.
196        def is_self_referential(f):
197            return f.is_relation and f.remote_field.model is model
198
199        # Work out the new fields dict / mapping
200        body = {
201            f.name: f.clone() if is_self_referential(f) else f
202            for f in model._meta.local_concrete_fields
203        }
204        # Since mapping might mix column names and default values,
205        # its values must be already quoted.
206        mapping = {
207            f.column: self.quote_name(f.column)
208            for f in model._meta.local_concrete_fields
209        }
210        # If any of the new or altered fields is introducing a new PK,
211        # remove the old one
212        restore_pk_field = None
213        alter_fields = alter_fields or []
214        if getattr(create_field, "primary_key", False) or any(
215            getattr(new_field, "primary_key", False) for _, new_field in alter_fields
216        ):
217            for name, field in list(body.items()):
218                if field.primary_key and not any(
219                    # Do not remove the old primary key when an altered field
220                    # that introduces a primary key is the same field.
221                    name == new_field.name
222                    for _, new_field in alter_fields
223                ):
224                    field.primary_key = False
225                    restore_pk_field = field
226                    if field.auto_created:
227                        del body[name]
228                        del mapping[field.column]
229        # Add in any created fields
230        if create_field:
231            body[create_field.name] = create_field
232            # Choose a default and insert it into the copy map
233            if not create_field.many_to_many and create_field.concrete:
234                mapping[create_field.column] = self.prepare_default(
235                    self.effective_default(create_field),
236                )
237        # Add in any altered fields
238        for alter_field in alter_fields:
239            old_field, new_field = alter_field
240            body.pop(old_field.name, None)
241            mapping.pop(old_field.column, None)
242            body[new_field.name] = new_field
243            if old_field.null and not new_field.null:
244                case_sql = "coalesce({col}, {default})".format(
245                    col=self.quote_name(old_field.column),
246                    default=self.prepare_default(self.effective_default(new_field)),
247                )
248                mapping[new_field.column] = case_sql
249            else:
250                mapping[new_field.column] = self.quote_name(old_field.column)
251        # Remove any deleted fields
252        if delete_field:
253            del body[delete_field.name]
254            del mapping[delete_field.column]
255            # Remove any implicit M2M tables
256            if (
257                delete_field.many_to_many
258                and delete_field.remote_field.through._meta.auto_created
259            ):
260                return self.delete_model(delete_field.remote_field.through)
261        # Work inside a new app registry
262        packages = Packages()
263
264        indexes = model._meta.indexes
265        if delete_field:
266            indexes = [
267                index for index in indexes if delete_field.name not in index.fields
268            ]
269
270        constraints = list(model._meta.constraints)
271
272        # Provide isolated instances of the fields to the new model body so
273        # that the existing model's internals aren't interfered with when
274        # the dummy model is constructed.
275        body_copy = copy.deepcopy(body)
276
277        # Construct a new model with the new fields to allow self referential
278        # primary key to resolve to. This model won't ever be materialized as a
279        # table and solely exists for foreign key reference resolution purposes.
280        # This wouldn't be required if the schema editor was operating on model
281        # states instead of rendered models.
282        meta_contents = {
283            "package_label": model._meta.package_label,
284            "db_table": model._meta.db_table,
285            "indexes": indexes,
286            "constraints": constraints,
287            "packages": packages,
288        }
289        meta = type("Meta", (), meta_contents)
290        body_copy["Meta"] = meta
291        body_copy["__module__"] = model.__module__
292        type(model._meta.object_name, model.__bases__, body_copy)
293
294        # Construct a model with a renamed table name.
295        body_copy = copy.deepcopy(body)
296        meta_contents = {
297            "package_label": model._meta.package_label,
298            "db_table": "new__%s" % strip_quotes(model._meta.db_table),
299            "indexes": indexes,
300            "constraints": constraints,
301            "packages": packages,
302        }
303        meta = type("Meta", (), meta_contents)
304        body_copy["Meta"] = meta
305        body_copy["__module__"] = model.__module__
306        new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
307
308        # Create a new table with the updated schema.
309        self.create_model(new_model)
310
311        # Copy data from the old table into the new table
312        self.execute(
313            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
314                self.quote_name(new_model._meta.db_table),
315                ", ".join(self.quote_name(x) for x in mapping),
316                ", ".join(mapping.values()),
317                self.quote_name(model._meta.db_table),
318            )
319        )
320
321        # Delete the old table to make way for the new
322        self.delete_model(model, handle_autom2m=False)
323
324        # Rename the new table to take way for the old
325        self.alter_db_table(
326            new_model,
327            new_model._meta.db_table,
328            model._meta.db_table,
329            disable_constraints=False,
330        )
331
332        # Run deferred SQL on correct table
333        for sql in self.deferred_sql:
334            self.execute(sql)
335        self.deferred_sql = []
336        # Fix any PK-removed field
337        if restore_pk_field:
338            restore_pk_field.primary_key = True
339
340    def delete_model(self, model, handle_autom2m=True):
341        if handle_autom2m:
342            super().delete_model(model)
343        else:
344            # Delete the table (and only that)
345            self.execute(
346                self.sql_delete_table
347                % {
348                    "table": self.quote_name(model._meta.db_table),
349                }
350            )
351            # Remove all deferred statements referencing the deleted table.
352            for sql in list(self.deferred_sql):
353                if isinstance(sql, Statement) and sql.references_table(
354                    model._meta.db_table
355                ):
356                    self.deferred_sql.remove(sql)
357
358    def add_field(self, model, field):
359        """Create a field on a model."""
360        # Special-case implicit M2M tables.
361        if field.many_to_many and field.remote_field.through._meta.auto_created:
362            self.create_model(field.remote_field.through)
363        elif (
364            # Primary keys and unique fields are not supported in ALTER TABLE
365            # ADD COLUMN.
366            field.primary_key
367            or field.unique
368            or
369            # Fields with default values cannot by handled by ALTER TABLE ADD
370            # COLUMN statement because DROP DEFAULT is not supported in
371            # ALTER TABLE.
372            not field.null
373            or self.effective_default(field) is not None
374        ):
375            self._remake_table(model, create_field=field)
376        else:
377            super().add_field(model, field)
378
379    def remove_field(self, model, field):
380        """
381        Remove a field from a model. Usually involves deleting a column,
382        but for M2Ms may involve deleting a table.
383        """
384        # M2M fields are a special case
385        if field.many_to_many:
386            # For implicit M2M tables, delete the auto-created table
387            if field.remote_field.through._meta.auto_created:
388                self.delete_model(field.remote_field.through)
389            # For explicit "through" M2M fields, do nothing
390        elif (
391            self.connection.features.can_alter_table_drop_column
392            # Primary keys, unique fields, indexed fields, and foreign keys are
393            # not supported in ALTER TABLE DROP COLUMN.
394            and not field.primary_key
395            and not field.unique
396            and not field.db_index
397            and not (field.remote_field and field.db_constraint)
398        ):
399            super().remove_field(model, field)
400        # For everything else, remake.
401        else:
402            # It might not actually have a column behind it
403            if field.db_parameters(connection=self.connection)["type"] is None:
404                return
405            self._remake_table(model, delete_field=field)
406
407    def _alter_field(
408        self,
409        model,
410        old_field,
411        new_field,
412        old_type,
413        new_type,
414        old_db_params,
415        new_db_params,
416        strict=False,
417    ):
418        """Perform a "physical" (non-ManyToMany) field update."""
419        # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
420        # changed and there aren't any constraints.
421        if (
422            self.connection.features.can_alter_table_rename_column
423            and old_field.column != new_field.column
424            and self.column_sql(model, old_field) == self.column_sql(model, new_field)
425            and not (
426                old_field.remote_field
427                and old_field.db_constraint
428                or new_field.remote_field
429                and new_field.db_constraint
430            )
431        ):
432            return self.execute(
433                self._rename_field_sql(
434                    model._meta.db_table, old_field, new_field, new_type
435                )
436            )
437        # Alter by remaking table
438        self._remake_table(model, alter_fields=[(old_field, new_field)])
439        # Rebuild tables with FKs pointing to this field.
440        old_collation = old_db_params.get("collation")
441        new_collation = new_db_params.get("collation")
442        if new_field.unique and (
443            old_type != new_type or old_collation != new_collation
444        ):
445            related_models = set()
446            opts = new_field.model._meta
447            for remote_field in opts.related_objects:
448                # Ignore self-relationship since the table was already rebuilt.
449                if remote_field.related_model == model:
450                    continue
451                if not remote_field.many_to_many:
452                    if remote_field.field_name == new_field.name:
453                        related_models.add(remote_field.related_model)
454                elif new_field.primary_key and remote_field.through._meta.auto_created:
455                    related_models.add(remote_field.through)
456            if new_field.primary_key:
457                for many_to_many in opts.many_to_many:
458                    # Ignore self-relationship since the table was already rebuilt.
459                    if many_to_many.related_model == model:
460                        continue
461                    if many_to_many.remote_field.through._meta.auto_created:
462                        related_models.add(many_to_many.remote_field.through)
463            for related_model in related_models:
464                self._remake_table(related_model)
465
466    def _alter_many_to_many(self, model, old_field, new_field, strict):
467        """Alter M2Ms to repoint their to= endpoints."""
468        if (
469            old_field.remote_field.through._meta.db_table
470            == new_field.remote_field.through._meta.db_table
471        ):
472            # The field name didn't change, but some options did, so we have to
473            # propagate this altering.
474            self._remake_table(
475                old_field.remote_field.through,
476                alter_fields=[
477                    (
478                        # The field that points to the target model is needed,
479                        # so that table can be remade with the new m2m field -
480                        # this is m2m_reverse_field_name().
481                        old_field.remote_field.through._meta.get_field(
482                            old_field.m2m_reverse_field_name()
483                        ),
484                        new_field.remote_field.through._meta.get_field(
485                            new_field.m2m_reverse_field_name()
486                        ),
487                    ),
488                    (
489                        # The field that points to the model itself is needed,
490                        # so that table can be remade with the new self field -
491                        # this is m2m_field_name().
492                        old_field.remote_field.through._meta.get_field(
493                            old_field.m2m_field_name()
494                        ),
495                        new_field.remote_field.through._meta.get_field(
496                            new_field.m2m_field_name()
497                        ),
498                    ),
499                ],
500            )
501            return
502
503        # Make a new through table
504        self.create_model(new_field.remote_field.through)
505        # Copy the data across
506        self.execute(
507            "INSERT INTO {} ({}) SELECT {} FROM {}".format(
508                self.quote_name(new_field.remote_field.through._meta.db_table),
509                ", ".join(
510                    [
511                        "id",
512                        new_field.m2m_column_name(),
513                        new_field.m2m_reverse_name(),
514                    ]
515                ),
516                ", ".join(
517                    [
518                        "id",
519                        old_field.m2m_column_name(),
520                        old_field.m2m_reverse_name(),
521                    ]
522                ),
523                self.quote_name(old_field.remote_field.through._meta.db_table),
524            )
525        )
526        # Delete the old through table
527        self.delete_model(old_field.remote_field.through)
528
529    def add_constraint(self, model, constraint):
530        if isinstance(constraint, UniqueConstraint) and (
531            constraint.condition
532            or constraint.contains_expressions
533            or constraint.include
534            or constraint.deferrable
535        ):
536            super().add_constraint(model, constraint)
537        else:
538            self._remake_table(model)
539
540    def remove_constraint(self, model, constraint):
541        if isinstance(constraint, UniqueConstraint) and (
542            constraint.condition
543            or constraint.contains_expressions
544            or constraint.include
545            or constraint.deferrable
546        ):
547            super().remove_constraint(model, constraint)
548        else:
549            self._remake_table(model)
550
551    def _collate_sql(self, collation):
552        return "COLLATE " + collation