Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any
  5
  6from psycopg import sql
  7
  8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
 10from plain.models.backends.utils import strip_quotes
 11from plain.models.fields.related import ForeignKeyField, RelatedField
 12
 13if TYPE_CHECKING:
 14    from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
 15    from plain.models.base import Model
 16    from plain.models.fields import Field
 17    from plain.models.indexes import Index
 18
 19
 20class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 21    # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
 22    connection: PostgreSQLDatabaseWrapper
 23
 24    # Setting all constraints to IMMEDIATE to allow changing data in the same
 25    # transaction.
 26    sql_update_with_default = (
 27        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 28        "; SET CONSTRAINTS ALL IMMEDIATE"
 29    )
 30    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 31    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 32
 33    sql_create_index = (
 34        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 35        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 36    )
 37    sql_create_index_concurrently = (
 38        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 39        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 40    )
 41    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 42    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 43
 44    # Setting the constraint to IMMEDIATE to allow changing data in the same
 45    # transaction.
 46    sql_create_column_inline_fk = (
 47        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 48        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 49    )
 50    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 51    # dropping it in the same transaction.
 52    sql_delete_fk = (
 53        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 54        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 55    )
 56
 57    def execute(
 58        self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
 59    ) -> None:
 60        # Merge the query client-side, as PostgreSQL won't do it server-side.
 61        if params is None:
 62            return super().execute(sql, params)
 63        sql = self.connection.ops.compose_sql(str(sql), params)
 64        # Don't let the superclass touch anything.
 65        return super().execute(sql, None)
 66
 67    sql_add_identity = (
 68        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 69        "GENERATED BY DEFAULT AS IDENTITY"
 70    )
 71    sql_drop_indentity = (
 72        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 73    )
 74
 75    def quote_value(self, value: Any) -> str:
 76        if isinstance(value, str):
 77            value = value.replace("%", "%%")
 78        return sql.quote(value, self.connection.connection)
 79
 80    def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
 81        output = super()._field_indexes_sql(model, field)
 82        like_index_statement = self._create_like_index_sql(model, field)
 83        if like_index_statement is not None:
 84            output.append(like_index_statement)
 85        return output
 86
 87    def _field_data_type(
 88        self, field: Field
 89    ) -> str | None | Callable[[dict[str, Any]], str]:
 90        if isinstance(field, RelatedField):
 91            return field.rel_db_type(self.connection)
 92        return self.connection.data_types.get(
 93            field.get_internal_type(),
 94            field.db_type(self.connection),
 95        )
 96
 97    def _create_like_index_sql(
 98        self, model: type[Model], field: Field
 99    ) -> Statement | None:
100        """
101        Return the statement to create an index with varchar operator pattern
102        when the column type is 'varchar' or 'text', otherwise return None.
103        """
104        db_type = field.db_type(connection=self.connection)
105        if db_type is not None and field.primary_key:
106            # Fields with database column types of `varchar` and `text` need
107            # a second index that specifies their operator class, which is
108            # needed when performing correct LIKE queries outside the
109            # C locale. See #12234.
110            #
111            # The same doesn't apply to array fields such as varchar[size]
112            # and text[size], so skip them.
113            if "[" in db_type:
114                return None
115            # Non-deterministic collations on Postgresql don't support indexes
116            # for operator classes varchar_pattern_ops/text_pattern_ops.
117            if getattr(field, "db_collation", None):
118                return None
119            if db_type.startswith("varchar"):
120                return self._create_index_sql(
121                    model,
122                    fields=[field],
123                    suffix="_like",
124                    opclasses=("varchar_pattern_ops",),
125                )
126            elif db_type.startswith("text"):
127                return self._create_index_sql(
128                    model,
129                    fields=[field],
130                    suffix="_like",
131                    opclasses=("text_pattern_ops",),
132                )
133        return None
134
135    def _using_sql(self, new_field: Field, old_field: Field) -> str:
136        using_sql = " USING %(column)s::%(type)s"
137        if self._field_data_type(old_field) != self._field_data_type(new_field):
138            return using_sql
139        return ""
140
141    def _get_sequence_name(self, table: str, column: str) -> str | None:
142        with self.connection.cursor() as cursor:
143            for sequence in self.connection.introspection.get_sequences(cursor, table):
144                if sequence["column"] == column:
145                    return sequence["name"]
146        return None
147
148    def _alter_column_type_sql(
149        self,
150        model: type[Model],
151        old_field: Field,
152        new_field: Field,
153        new_type: str,
154        old_collation: str | None,
155        new_collation: str | None,
156    ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
157        # Drop indexes on varchar/text/citext columns that are changing to a
158        # different type.
159        old_db_params = old_field.db_parameters(connection=self.connection)
160        old_type = old_db_params["type"]
161        if old_field.primary_key and (
162            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
163            or (old_type.startswith("text") and not new_type.startswith("text"))
164            or (old_type.startswith("citext") and not new_type.startswith("citext"))
165        ):
166            index_name = self._create_index_name(
167                model.model_options.db_table, [old_field.column], suffix="_like"
168            )
169            self.execute(self._delete_index_sql(model, index_name))
170
171        self.sql_alter_column_type = (
172            "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
173        )
174        # Cast when data type changed.
175        if using_sql := self._using_sql(new_field, old_field):
176            self.sql_alter_column_type += using_sql
177        new_internal_type = new_field.get_internal_type()
178        old_internal_type = old_field.get_internal_type()
179        # Make ALTER TYPE with IDENTITY make sense.
180        table = strip_quotes(model.model_options.db_table)
181        auto_field_types = {"PrimaryKeyField"}
182        old_is_auto = old_internal_type in auto_field_types
183        new_is_auto = new_internal_type in auto_field_types
184        if new_is_auto and not old_is_auto:
185            column = strip_quotes(new_field.column)
186            return (
187                (
188                    self.sql_alter_column_type
189                    % {
190                        "column": self.quote_name(column),
191                        "type": new_type,
192                        "collation": "",
193                    },
194                    [],
195                ),
196                [
197                    (
198                        self.sql_add_identity
199                        % {
200                            "table": self.quote_name(table),
201                            "column": self.quote_name(column),
202                        },
203                        [],
204                    ),
205                ],
206            )
207        elif old_is_auto and not new_is_auto:
208            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
209            # it).
210            self.execute(
211                self.sql_drop_indentity
212                % {
213                    "table": self.quote_name(table),
214                    "column": self.quote_name(strip_quotes(new_field.column)),
215                }
216            )
217            column = strip_quotes(new_field.column)
218            fragment, _ = super()._alter_column_type_sql(
219                model, old_field, new_field, new_type, old_collation, new_collation
220            )
221            # Drop the sequence if exists (Plain 4.1+ identity columns don't
222            # have it).
223            other_actions = []
224            if sequence_name := self._get_sequence_name(table, column):
225                other_actions = [
226                    (
227                        self.sql_delete_sequence
228                        % {
229                            "sequence": self.quote_name(sequence_name),
230                        },
231                        [],
232                    )
233                ]
234            return fragment, other_actions
235        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
236            fragment, _ = super()._alter_column_type_sql(
237                model, old_field, new_field, new_type, old_collation, new_collation
238            )
239            column = strip_quotes(new_field.column)
240            db_types = {"PrimaryKeyField": "bigint"}
241            # Alter the sequence type if exists (Plain 4.1+ identity columns
242            # don't have it).
243            other_actions = []
244            if sequence_name := self._get_sequence_name(table, column):
245                other_actions = [
246                    (
247                        self.sql_alter_sequence_type
248                        % {
249                            "sequence": self.quote_name(sequence_name),
250                            "type": db_types[new_internal_type],
251                        },
252                        [],
253                    ),
254                ]
255            return fragment, other_actions
256        else:
257            return super()._alter_column_type_sql(
258                model, old_field, new_field, new_type, old_collation, new_collation
259            )
260
261    def _alter_field(
262        self,
263        model: type[Model],
264        old_field: Field,
265        new_field: Field,
266        old_type: str,
267        new_type: str,
268        old_db_params: dict[str, Any],
269        new_db_params: dict[str, Any],
270        strict: bool = False,
271    ) -> None:
272        super()._alter_field(
273            model,
274            old_field,
275            new_field,
276            old_type,
277            new_type,
278            old_db_params,
279            new_db_params,
280            strict,
281        )
282        # Added an index? Create any PostgreSQL-specific indexes.
283        if (
284            not (
285                (isinstance(old_field, ForeignKeyField) and old_field.db_index)
286                or old_field.primary_key
287            )
288            and isinstance(new_field, ForeignKeyField)
289            and new_field.db_index
290        ) or (not old_field.primary_key and new_field.primary_key):
291            like_index_statement = self._create_like_index_sql(model, new_field)
292            if like_index_statement is not None:
293                self.execute(like_index_statement)
294
295        # Removed an index? Drop any PostgreSQL-specific indexes.
296        if old_field.primary_key and not (
297            (isinstance(new_field, ForeignKeyField) and new_field.db_index)
298            or new_field.primary_key
299        ):
300            index_to_remove = self._create_index_name(
301                model.model_options.db_table, [old_field.column], suffix="_like"
302            )
303            self.execute(self._delete_index_sql(model, index_to_remove))
304
305    def _index_columns(
306        self,
307        table: str,
308        columns: list[str],
309        col_suffixes: tuple[str, ...],
310        opclasses: tuple[str, ...],
311    ) -> Columns | IndexColumns:
312        if opclasses:
313            return IndexColumns(
314                table,
315                columns,
316                self.quote_name,
317                col_suffixes=col_suffixes,
318                opclasses=opclasses,
319            )
320        return super()._index_columns(table, columns, col_suffixes, opclasses)
321
322    def add_index(
323        self, model: type[Model], index: Index, concurrently: bool = False
324    ) -> None:
325        self.execute(
326            index.create_sql(model, self, concurrently=concurrently), params=None
327        )
328
329    def remove_index(
330        self, model: type[Model], index: Index, concurrently: bool = False
331    ) -> None:
332        self.execute(index.remove_sql(model, self, concurrently=concurrently))
333
334    def _delete_index_sql(
335        self,
336        model: type[Model],
337        name: str,
338        sql: str | None = None,
339        concurrently: bool = False,
340    ) -> Statement:
341        sql = (
342            self.sql_delete_index_concurrently
343            if concurrently
344            else self.sql_delete_index
345        )
346        return super()._delete_index_sql(model, name, sql)
347
348    def _create_index_sql(
349        self,
350        model: type[Model],
351        *,
352        fields: list[Field] | None = None,
353        name: str | None = None,
354        suffix: str = "",
355        using: str = "",
356        col_suffixes: tuple[str, ...] = (),
357        sql: str | None = None,
358        opclasses: tuple[str, ...] = (),
359        condition: str | None = None,
360        concurrently: bool = False,
361        include: list[str] | None = None,
362        expressions: Any = None,
363    ) -> Statement:
364        sql = sql or (
365            self.sql_create_index
366            if not concurrently
367            else self.sql_create_index_concurrently
368        )
369        return super()._create_index_sql(
370            model,
371            fields=fields,
372            name=name,
373            suffix=suffix,
374            using=using,
375            col_suffixes=col_suffixes,
376            sql=sql,
377            opclasses=opclasses,
378            condition=condition,
379            include=include,
380            expressions=expressions,
381        )