Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any
  5
  6from psycopg import sql
  7
  8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
 10from plain.models.backends.utils import strip_quotes
 11from plain.models.fields import DbParameters
 12from plain.models.fields.related import ForeignKeyField, RelatedField
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
 16    from plain.models.base import Model
 17    from plain.models.fields import Field
 18    from plain.models.indexes import Index
 19
 20
 21class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 22    # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
 23    connection: PostgreSQLDatabaseWrapper
 24
 25    # Setting all constraints to IMMEDIATE to allow changing data in the same
 26    # transaction.
 27    sql_update_with_default = (
 28        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 29        "; SET CONSTRAINTS ALL IMMEDIATE"
 30    )
 31    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 32    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 33
 34    sql_create_index = (
 35        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 36        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 37    )
 38    sql_create_index_concurrently = (
 39        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 40        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 41    )
 42    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 43    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 44
 45    # Setting the constraint to IMMEDIATE to allow changing data in the same
 46    # transaction.
 47    sql_create_column_inline_fk = (
 48        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 49        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 50    )
 51    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 52    # dropping it in the same transaction.
 53    sql_delete_fk = (
 54        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 55        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 56    )
 57
 58    def execute(
 59        self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
 60    ) -> None:
 61        # Merge the query client-side, as PostgreSQL won't do it server-side.
 62        if params is None:
 63            return super().execute(sql, params)
 64        sql = self.connection.ops.compose_sql(str(sql), params)
 65        # Don't let the superclass touch anything.
 66        return super().execute(sql, None)
 67
 68    sql_add_identity = (
 69        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 70        "GENERATED BY DEFAULT AS IDENTITY"
 71    )
 72    sql_drop_indentity = (
 73        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 74    )
 75
 76    def quote_value(self, value: Any) -> str:
 77        if isinstance(value, str):
 78            value = value.replace("%", "%%")
 79        return sql.quote(value, self.connection.connection)
 80
 81    def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
 82        output = super()._field_indexes_sql(model, field)
 83        like_index_statement = self._create_like_index_sql(model, field)
 84        if like_index_statement is not None:
 85            output.append(like_index_statement)
 86        return output
 87
 88    def _field_data_type(
 89        self, field: Field
 90    ) -> str | None | Callable[[dict[str, Any]], str]:
 91        if isinstance(field, RelatedField):
 92            return field.rel_db_type(self.connection)
 93        return self.connection.data_types.get(
 94            field.get_internal_type(),
 95            field.db_type(self.connection),
 96        )
 97
 98    def _create_like_index_sql(
 99        self, model: type[Model], field: Field
100    ) -> Statement | None:
101        """
102        Return the statement to create an index with varchar operator pattern
103        when the column type is 'varchar' or 'text', otherwise return None.
104        """
105        db_type = field.db_type(connection=self.connection)
106        if db_type is not None and field.primary_key:
107            # Fields with database column types of `varchar` and `text` need
108            # a second index that specifies their operator class, which is
109            # needed when performing correct LIKE queries outside the
110            # C locale. See #12234.
111            #
112            # The same doesn't apply to array fields such as varchar[size]
113            # and text[size], so skip them.
114            if "[" in db_type:
115                return None
116            # Non-deterministic collations on Postgresql don't support indexes
117            # for operator classes varchar_pattern_ops/text_pattern_ops.
118            if getattr(field, "db_collation", None):
119                return None
120            if db_type.startswith("varchar"):
121                return self._create_index_sql(
122                    model,
123                    fields=[field],
124                    suffix="_like",
125                    opclasses=("varchar_pattern_ops",),
126                )
127            elif db_type.startswith("text"):
128                return self._create_index_sql(
129                    model,
130                    fields=[field],
131                    suffix="_like",
132                    opclasses=("text_pattern_ops",),
133                )
134        return None
135
136    def _using_sql(self, new_field: Field, old_field: Field) -> str:
137        using_sql = " USING %(column)s::%(type)s"
138        if self._field_data_type(old_field) != self._field_data_type(new_field):
139            return using_sql
140        return ""
141
142    def _get_sequence_name(self, table: str, column: str) -> str | None:
143        with self.connection.cursor() as cursor:
144            for sequence in self.connection.introspection.get_sequences(cursor, table):
145                if sequence["column"] == column:
146                    return sequence["name"]
147        return None
148
149    def _alter_column_type_sql(
150        self,
151        model: type[Model],
152        old_field: Field,
153        new_field: Field,
154        new_type: str,
155        old_collation: str | None,
156        new_collation: str | None,
157    ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
158        # Drop indexes on varchar/text/citext columns that are changing to a
159        # different type.
160        old_db_params = old_field.db_parameters(connection=self.connection)
161        old_type = old_db_params["type"]
162        assert old_type is not None, "old_type cannot be None for primary key field"
163        if old_field.primary_key and (
164            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
165            or (old_type.startswith("text") and not new_type.startswith("text"))
166            or (old_type.startswith("citext") and not new_type.startswith("citext"))
167        ):
168            index_name = self._create_index_name(
169                model.model_options.db_table, [old_field.column], suffix="_like"
170            )
171            self.execute(self._delete_index_sql(model, index_name))
172
173        self.sql_alter_column_type = (
174            "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
175        )
176        # Cast when data type changed.
177        if using_sql := self._using_sql(new_field, old_field):
178            self.sql_alter_column_type += using_sql
179        new_internal_type = new_field.get_internal_type()
180        old_internal_type = old_field.get_internal_type()
181        # Make ALTER TYPE with IDENTITY make sense.
182        table = strip_quotes(model.model_options.db_table)
183        auto_field_types = {"PrimaryKeyField"}
184        old_is_auto = old_internal_type in auto_field_types
185        new_is_auto = new_internal_type in auto_field_types
186        if new_is_auto and not old_is_auto:
187            column = strip_quotes(new_field.column)
188            return (
189                (
190                    self.sql_alter_column_type
191                    % {
192                        "column": self.quote_name(column),
193                        "type": new_type,
194                        "collation": "",
195                    },
196                    [],
197                ),
198                [
199                    (
200                        self.sql_add_identity
201                        % {
202                            "table": self.quote_name(table),
203                            "column": self.quote_name(column),
204                        },
205                        [],
206                    ),
207                ],
208            )
209        elif old_is_auto and not new_is_auto:
210            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
211            # it).
212            self.execute(
213                self.sql_drop_indentity
214                % {
215                    "table": self.quote_name(table),
216                    "column": self.quote_name(strip_quotes(new_field.column)),
217                }
218            )
219            column = strip_quotes(new_field.column)
220            fragment, _ = super()._alter_column_type_sql(
221                model, old_field, new_field, new_type, old_collation, new_collation
222            )
223            # Drop the sequence if exists (Plain 4.1+ identity columns don't
224            # have it).
225            other_actions = []
226            if sequence_name := self._get_sequence_name(table, column):
227                other_actions = [
228                    (
229                        self.sql_delete_sequence
230                        % {
231                            "sequence": self.quote_name(sequence_name),
232                        },
233                        [],
234                    )
235                ]
236            return fragment, other_actions
237        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
238            fragment, _ = super()._alter_column_type_sql(
239                model, old_field, new_field, new_type, old_collation, new_collation
240            )
241            column = strip_quotes(new_field.column)
242            db_types = {"PrimaryKeyField": "bigint"}
243            # Alter the sequence type if exists (Plain 4.1+ identity columns
244            # don't have it).
245            other_actions = []
246            if sequence_name := self._get_sequence_name(table, column):
247                other_actions = [
248                    (
249                        self.sql_alter_sequence_type
250                        % {
251                            "sequence": self.quote_name(sequence_name),
252                            "type": db_types[new_internal_type],
253                        },
254                        [],
255                    ),
256                ]
257            return fragment, other_actions
258        else:
259            return super()._alter_column_type_sql(
260                model, old_field, new_field, new_type, old_collation, new_collation
261            )
262
263    def _alter_field(
264        self,
265        model: type[Model],
266        old_field: Field,
267        new_field: Field,
268        old_type: str,
269        new_type: str,
270        old_db_params: DbParameters,
271        new_db_params: DbParameters,
272        strict: bool = False,
273    ) -> None:
274        super()._alter_field(
275            model,
276            old_field,
277            new_field,
278            old_type,
279            new_type,
280            old_db_params,
281            new_db_params,
282            strict,
283        )
284        # Added an index? Create any PostgreSQL-specific indexes.
285        if (
286            not (
287                (isinstance(old_field, ForeignKeyField) and old_field.db_index)
288                or old_field.primary_key
289            )
290            and isinstance(new_field, ForeignKeyField)
291            and new_field.db_index
292        ) or (not old_field.primary_key and new_field.primary_key):
293            like_index_statement = self._create_like_index_sql(model, new_field)
294            if like_index_statement is not None:
295                self.execute(like_index_statement)
296
297        # Removed an index? Drop any PostgreSQL-specific indexes.
298        if old_field.primary_key and not (
299            (isinstance(new_field, ForeignKeyField) and new_field.db_index)
300            or new_field.primary_key
301        ):
302            index_to_remove = self._create_index_name(
303                model.model_options.db_table, [old_field.column], suffix="_like"
304            )
305            self.execute(self._delete_index_sql(model, index_to_remove))
306
307    def _index_columns(
308        self,
309        table: str,
310        columns: list[str],
311        col_suffixes: tuple[str, ...],
312        opclasses: tuple[str, ...],
313    ) -> Columns | IndexColumns:
314        if opclasses:
315            return IndexColumns(
316                table,
317                columns,
318                self.quote_name,
319                col_suffixes=col_suffixes,
320                opclasses=opclasses,
321            )
322        return super()._index_columns(table, columns, col_suffixes, opclasses)
323
324    def add_index(
325        self, model: type[Model], index: Index, concurrently: bool = False
326    ) -> None:
327        self.execute(
328            index.create_sql(model, self, concurrently=concurrently), params=None
329        )
330
331    def remove_index(
332        self, model: type[Model], index: Index, concurrently: bool = False
333    ) -> None:
334        self.execute(index.remove_sql(model, self, concurrently=concurrently))
335
336    def _delete_index_sql(
337        self,
338        model: type[Model],
339        name: str,
340        sql: str | None = None,
341        concurrently: bool = False,
342    ) -> Statement:
343        sql = (
344            self.sql_delete_index_concurrently
345            if concurrently
346            else self.sql_delete_index
347        )
348        return super()._delete_index_sql(model, name, sql)
349
350    def _create_index_sql(
351        self,
352        model: type[Model],
353        *,
354        fields: list[Field] | None = None,
355        name: str | None = None,
356        suffix: str = "",
357        using: str = "",
358        col_suffixes: tuple[str, ...] = (),
359        sql: str | None = None,
360        opclasses: tuple[str, ...] = (),
361        condition: str | None = None,
362        concurrently: bool = False,
363        include: list[str] | None = None,
364        expressions: Any = None,
365    ) -> Statement:
366        sql = sql or (
367            self.sql_create_index
368            if not concurrently
369            else self.sql_create_index_concurrently
370        )
371        return super()._create_index_sql(
372            model,
373            fields=fields,
374            name=name,
375            suffix=suffix,
376            using=using,
377            col_suffixes=col_suffixes,
378            sql=sql,
379            opclasses=opclasses,
380            condition=condition,
381            include=include,
382            expressions=expressions,
383        )