Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from psycopg import sql  # type: ignore[import-untyped]
  6
  7from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  8from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
  9from plain.models.backends.utils import strip_quotes
 10
 11if TYPE_CHECKING:
 12    from collections.abc import Generator
 13
 14    from plain.models.base import Model
 15    from plain.models.fields import Field
 16    from plain.models.indexes import Index
 17
 18
 19class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 20    # Setting all constraints to IMMEDIATE to allow changing data in the same
 21    # transaction.
 22    sql_update_with_default = (
 23        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 24        "; SET CONSTRAINTS ALL IMMEDIATE"
 25    )
 26    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 27    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 28
 29    sql_create_index = (
 30        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 31        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 32    )
 33    sql_create_index_concurrently = (
 34        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 35        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 36    )
 37    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 38    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 39
 40    # Setting the constraint to IMMEDIATE to allow changing data in the same
 41    # transaction.
 42    sql_create_column_inline_fk = (
 43        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 44        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 45    )
 46    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 47    # dropping it in the same transaction.
 48    sql_delete_fk = (
 49        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 50        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 51    )
 52
 53    def execute(
 54        self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
 55    ) -> None:
 56        # Merge the query client-side, as PostgreSQL won't do it server-side.
 57        if params is None:
 58            return super().execute(sql, params)
 59        sql = self.connection.ops.compose_sql(str(sql), params)
 60        # Don't let the superclass touch anything.
 61        return super().execute(sql, None)
 62
 63    sql_add_identity = (
 64        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 65        "GENERATED BY DEFAULT AS IDENTITY"
 66    )
 67    sql_drop_indentity = (
 68        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 69    )
 70
 71    def quote_value(self, value: Any) -> str:
 72        if isinstance(value, str):
 73            value = value.replace("%", "%%")
 74        return sql.quote(value, self.connection.connection)
 75
 76    def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
 77        output = super()._field_indexes_sql(model, field)
 78        like_index_statement = self._create_like_index_sql(model, field)
 79        if like_index_statement is not None:
 80            output.append(like_index_statement)
 81        return output
 82
 83    def _field_data_type(self, field: Field) -> str | None:
 84        if field.is_relation:
 85            return field.rel_db_type(self.connection)
 86        return self.connection.data_types.get(
 87            field.get_internal_type(),
 88            field.db_type(self.connection),
 89        )
 90
 91    def _field_base_data_types(self, field: Field) -> Generator[str | None, None, None]:
 92        # Yield base data types for array fields.
 93        if field.base_field.get_internal_type() == "ArrayField":  # type: ignore[attr-defined]
 94            yield from self._field_base_data_types(field.base_field)  # type: ignore[attr-defined]
 95        else:
 96            yield self._field_data_type(field.base_field)  # type: ignore[attr-defined]
 97
 98    def _create_like_index_sql(
 99        self, model: type[Model], field: Field
100    ) -> Statement | None:
101        """
102        Return the statement to create an index with varchar operator pattern
103        when the column type is 'varchar' or 'text', otherwise return None.
104        """
105        db_type = field.db_type(connection=self.connection)
106        if db_type is not None and field.primary_key:
107            # Fields with database column types of `varchar` and `text` need
108            # a second index that specifies their operator class, which is
109            # needed when performing correct LIKE queries outside the
110            # C locale. See #12234.
111            #
112            # The same doesn't apply to array fields such as varchar[size]
113            # and text[size], so skip them.
114            if "[" in db_type:
115                return None
116            # Non-deterministic collations on Postgresql don't support indexes
117            # for operator classes varchar_pattern_ops/text_pattern_ops.
118            if getattr(field, "db_collation", None):
119                return None
120            if db_type.startswith("varchar"):
121                return self._create_index_sql(
122                    model,
123                    fields=[field],
124                    suffix="_like",
125                    opclasses=("varchar_pattern_ops",),
126                )
127            elif db_type.startswith("text"):
128                return self._create_index_sql(
129                    model,
130                    fields=[field],
131                    suffix="_like",
132                    opclasses=("text_pattern_ops",),
133                )
134        return None
135
136    def _using_sql(self, new_field: Field, old_field: Field) -> str:
137        using_sql = " USING %(column)s::%(type)s"
138        new_internal_type = new_field.get_internal_type()
139        old_internal_type = old_field.get_internal_type()
140        if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
141            # Compare base data types for array fields.
142            if list(self._field_base_data_types(old_field)) != list(
143                self._field_base_data_types(new_field)
144            ):
145                return using_sql
146        elif self._field_data_type(old_field) != self._field_data_type(new_field):
147            return using_sql
148        return ""
149
150    def _get_sequence_name(self, table: str, column: str) -> str | None:
151        with self.connection.cursor() as cursor:
152            for sequence in self.connection.introspection.get_sequences(cursor, table):
153                if sequence["column"] == column:
154                    return sequence["name"]
155        return None
156
157    def _alter_column_type_sql(
158        self,
159        model: type[Model],
160        old_field: Field,
161        new_field: Field,
162        new_type: str,
163        old_collation: str | None,
164        new_collation: str | None,
165    ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
166        # Drop indexes on varchar/text/citext columns that are changing to a
167        # different type.
168        old_db_params = old_field.db_parameters(connection=self.connection)
169        old_type = old_db_params["type"]
170        if old_field.primary_key and (
171            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
172            or (old_type.startswith("text") and not new_type.startswith("text"))
173            or (old_type.startswith("citext") and not new_type.startswith("citext"))
174        ):
175            index_name = self._create_index_name(
176                model.model_options.db_table, [old_field.column], suffix="_like"
177            )
178            self.execute(self._delete_index_sql(model, index_name))
179
180        self.sql_alter_column_type = (
181            "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
182        )
183        # Cast when data type changed.
184        if using_sql := self._using_sql(new_field, old_field):
185            self.sql_alter_column_type += using_sql
186        new_internal_type = new_field.get_internal_type()
187        old_internal_type = old_field.get_internal_type()
188        # Make ALTER TYPE with IDENTITY make sense.
189        table = strip_quotes(model.model_options.db_table)
190        auto_field_types = {"PrimaryKeyField"}
191        old_is_auto = old_internal_type in auto_field_types
192        new_is_auto = new_internal_type in auto_field_types
193        if new_is_auto and not old_is_auto:
194            column = strip_quotes(new_field.column)
195            return (
196                (
197                    self.sql_alter_column_type
198                    % {
199                        "column": self.quote_name(column),
200                        "type": new_type,
201                        "collation": "",
202                    },
203                    [],
204                ),
205                [
206                    (
207                        self.sql_add_identity
208                        % {
209                            "table": self.quote_name(table),
210                            "column": self.quote_name(column),
211                        },
212                        [],
213                    ),
214                ],
215            )
216        elif old_is_auto and not new_is_auto:
217            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
218            # it).
219            self.execute(
220                self.sql_drop_indentity
221                % {
222                    "table": self.quote_name(table),
223                    "column": self.quote_name(strip_quotes(new_field.column)),
224                }
225            )
226            column = strip_quotes(new_field.column)
227            fragment, _ = super()._alter_column_type_sql(
228                model, old_field, new_field, new_type, old_collation, new_collation
229            )
230            # Drop the sequence if exists (Plain 4.1+ identity columns don't
231            # have it).
232            other_actions = []
233            if sequence_name := self._get_sequence_name(table, column):
234                other_actions = [
235                    (
236                        self.sql_delete_sequence
237                        % {
238                            "sequence": self.quote_name(sequence_name),
239                        },
240                        [],
241                    )
242                ]
243            return fragment, other_actions
244        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
245            fragment, _ = super()._alter_column_type_sql(
246                model, old_field, new_field, new_type, old_collation, new_collation
247            )
248            column = strip_quotes(new_field.column)
249            db_types = {"PrimaryKeyField": "bigint"}
250            # Alter the sequence type if exists (Plain 4.1+ identity columns
251            # don't have it).
252            other_actions = []
253            if sequence_name := self._get_sequence_name(table, column):
254                other_actions = [
255                    (
256                        self.sql_alter_sequence_type
257                        % {
258                            "sequence": self.quote_name(sequence_name),
259                            "type": db_types[new_internal_type],
260                        },
261                        [],
262                    ),
263                ]
264            return fragment, other_actions
265        else:
266            return super()._alter_column_type_sql(
267                model, old_field, new_field, new_type, old_collation, new_collation
268            )
269
270    def _alter_field(
271        self,
272        model: type[Model],
273        old_field: Field,
274        new_field: Field,
275        old_type: str,
276        new_type: str,
277        old_db_params: dict[str, Any],
278        new_db_params: dict[str, Any],
279        strict: bool = False,
280    ) -> None:
281        super()._alter_field(
282            model,
283            old_field,
284            new_field,
285            old_type,
286            new_type,
287            old_db_params,
288            new_db_params,
289            strict,
290        )
291        # Added an index? Create any PostgreSQL-specific indexes.
292        if (
293            not (
294                (old_field.remote_field and old_field.db_index)  # type: ignore[attr-defined]
295                or old_field.primary_key
296            )
297            and (new_field.remote_field and new_field.db_index)  # type: ignore[attr-defined]
298        ) or (not old_field.primary_key and new_field.primary_key):
299            like_index_statement = self._create_like_index_sql(model, new_field)
300            if like_index_statement is not None:
301                self.execute(like_index_statement)
302
303        # Removed an index? Drop any PostgreSQL-specific indexes.
304        if old_field.primary_key and not (
305            (new_field.remote_field and new_field.db_index)  # type: ignore[attr-defined]
306            or new_field.primary_key
307        ):
308            index_to_remove = self._create_index_name(
309                model.model_options.db_table, [old_field.column], suffix="_like"
310            )
311            self.execute(self._delete_index_sql(model, index_to_remove))
312
313    def _index_columns(
314        self,
315        table: str,
316        columns: list[str],
317        col_suffixes: tuple[str, ...],
318        opclasses: tuple[str, ...],
319    ) -> Columns | IndexColumns:
320        if opclasses:
321            return IndexColumns(
322                table,
323                columns,
324                self.quote_name,
325                col_suffixes=col_suffixes,
326                opclasses=opclasses,
327            )
328        return super()._index_columns(table, columns, col_suffixes, opclasses)
329
330    def add_index(
331        self, model: type[Model], index: Index, concurrently: bool = False
332    ) -> None:
333        self.execute(
334            index.create_sql(model, self, concurrently=concurrently), params=None
335        )
336
337    def remove_index(
338        self, model: type[Model], index: Index, concurrently: bool = False
339    ) -> None:
340        self.execute(index.remove_sql(model, self, concurrently=concurrently))
341
342    def _delete_index_sql(
343        self,
344        model: type[Model],
345        name: str,
346        sql: str | None = None,
347        concurrently: bool = False,
348    ) -> Statement:
349        sql = (
350            self.sql_delete_index_concurrently
351            if concurrently
352            else self.sql_delete_index
353        )
354        return super()._delete_index_sql(model, name, sql)
355
356    def _create_index_sql(
357        self,
358        model: type[Model],
359        *,
360        fields: list[Field] | None = None,
361        name: str | None = None,
362        suffix: str = "",
363        using: str = "",
364        col_suffixes: tuple[str, ...] = (),
365        sql: str | None = None,
366        opclasses: tuple[str, ...] = (),
367        condition: str | None = None,
368        concurrently: bool = False,
369        include: list[str] | None = None,
370        expressions: Any = None,
371    ) -> Statement:
372        sql = sql or (
373            self.sql_create_index
374            if not concurrently
375            else self.sql_create_index_concurrently
376        )
377        return super()._create_index_sql(
378            model,
379            fields=fields,
380            name=name,
381            suffix=suffix,
382            using=using,
383            col_suffixes=col_suffixes,
384            sql=sql,
385            opclasses=opclasses,
386            condition=condition,
387            include=include,
388            expressions=expressions,
389        )