Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any
  5
  6from psycopg import sql  # type: ignore[import-untyped]
  7
  8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
 10from plain.models.backends.utils import strip_quotes
 11
 12if TYPE_CHECKING:
 13    from collections.abc import Generator
 14
 15    from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
 16    from plain.models.base import Model
 17    from plain.models.fields import Field
 18    from plain.models.indexes import Index
 19
 20
 21class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 22    # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
 23    connection: PostgreSQLDatabaseWrapper
 24
 25    # Setting all constraints to IMMEDIATE to allow changing data in the same
 26    # transaction.
 27    sql_update_with_default = (
 28        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 29        "; SET CONSTRAINTS ALL IMMEDIATE"
 30    )
 31    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 32    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 33
 34    sql_create_index = (
 35        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 36        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 37    )
 38    sql_create_index_concurrently = (
 39        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 40        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 41    )
 42    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 43    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 44
 45    # Setting the constraint to IMMEDIATE to allow changing data in the same
 46    # transaction.
 47    sql_create_column_inline_fk = (
 48        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 49        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 50    )
 51    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 52    # dropping it in the same transaction.
 53    sql_delete_fk = (
 54        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 55        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 56    )
 57
 58    def execute(
 59        self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
 60    ) -> None:
 61        # Merge the query client-side, as PostgreSQL won't do it server-side.
 62        if params is None:
 63            return super().execute(sql, params)
 64        sql = self.connection.ops.compose_sql(str(sql), params)
 65        # Don't let the superclass touch anything.
 66        return super().execute(sql, None)
 67
 68    sql_add_identity = (
 69        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 70        "GENERATED BY DEFAULT AS IDENTITY"
 71    )
 72    sql_drop_indentity = (
 73        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 74    )
 75
 76    def quote_value(self, value: Any) -> str:
 77        if isinstance(value, str):
 78            value = value.replace("%", "%%")
 79        return sql.quote(value, self.connection.connection)
 80
 81    def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
 82        output = super()._field_indexes_sql(model, field)
 83        like_index_statement = self._create_like_index_sql(model, field)
 84        if like_index_statement is not None:
 85            output.append(like_index_statement)
 86        return output
 87
 88    def _field_data_type(
 89        self, field: Field
 90    ) -> str | None | Callable[[dict[str, Any]], str]:
 91        if field.is_relation:
 92            return field.rel_db_type(self.connection)
 93        return self.connection.data_types.get(
 94            field.get_internal_type(),
 95            field.db_type(self.connection),
 96        )
 97
 98    def _field_base_data_types(self, field: Field) -> Generator[str | None, None, None]:
 99        # Yield base data types for array fields.
100        if field.base_field.get_internal_type() == "ArrayField":  # type: ignore[attr-defined]
101            yield from self._field_base_data_types(field.base_field)  # type: ignore[attr-defined]
102        else:
103            yield self._field_data_type(field.base_field)  # type: ignore[attr-defined]
104
105    def _create_like_index_sql(
106        self, model: type[Model], field: Field
107    ) -> Statement | None:
108        """
109        Return the statement to create an index with varchar operator pattern
110        when the column type is 'varchar' or 'text', otherwise return None.
111        """
112        db_type = field.db_type(connection=self.connection)
113        if db_type is not None and field.primary_key:
114            # Fields with database column types of `varchar` and `text` need
115            # a second index that specifies their operator class, which is
116            # needed when performing correct LIKE queries outside the
117            # C locale. See #12234.
118            #
119            # The same doesn't apply to array fields such as varchar[size]
120            # and text[size], so skip them.
121            if "[" in db_type:
122                return None
123            # Non-deterministic collations on Postgresql don't support indexes
124            # for operator classes varchar_pattern_ops/text_pattern_ops.
125            if getattr(field, "db_collation", None):
126                return None
127            if db_type.startswith("varchar"):
128                return self._create_index_sql(
129                    model,
130                    fields=[field],
131                    suffix="_like",
132                    opclasses=("varchar_pattern_ops",),
133                )
134            elif db_type.startswith("text"):
135                return self._create_index_sql(
136                    model,
137                    fields=[field],
138                    suffix="_like",
139                    opclasses=("text_pattern_ops",),
140                )
141        return None
142
143    def _using_sql(self, new_field: Field, old_field: Field) -> str:
144        using_sql = " USING %(column)s::%(type)s"
145        new_internal_type = new_field.get_internal_type()
146        old_internal_type = old_field.get_internal_type()
147        if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
148            # Compare base data types for array fields.
149            if list(self._field_base_data_types(old_field)) != list(
150                self._field_base_data_types(new_field)
151            ):
152                return using_sql
153        elif self._field_data_type(old_field) != self._field_data_type(new_field):
154            return using_sql
155        return ""
156
157    def _get_sequence_name(self, table: str, column: str) -> str | None:
158        with self.connection.cursor() as cursor:
159            for sequence in self.connection.introspection.get_sequences(cursor, table):
160                if sequence["column"] == column:
161                    return sequence["name"]
162        return None
163
164    def _alter_column_type_sql(
165        self,
166        model: type[Model],
167        old_field: Field,
168        new_field: Field,
169        new_type: str,
170        old_collation: str | None,
171        new_collation: str | None,
172    ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
173        # Drop indexes on varchar/text/citext columns that are changing to a
174        # different type.
175        old_db_params = old_field.db_parameters(connection=self.connection)
176        old_type = old_db_params["type"]
177        if old_field.primary_key and (
178            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
179            or (old_type.startswith("text") and not new_type.startswith("text"))
180            or (old_type.startswith("citext") and not new_type.startswith("citext"))
181        ):
182            index_name = self._create_index_name(
183                model.model_options.db_table, [old_field.column], suffix="_like"
184            )
185            self.execute(self._delete_index_sql(model, index_name))
186
187        self.sql_alter_column_type = (
188            "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
189        )
190        # Cast when data type changed.
191        if using_sql := self._using_sql(new_field, old_field):
192            self.sql_alter_column_type += using_sql
193        new_internal_type = new_field.get_internal_type()
194        old_internal_type = old_field.get_internal_type()
195        # Make ALTER TYPE with IDENTITY make sense.
196        table = strip_quotes(model.model_options.db_table)
197        auto_field_types = {"PrimaryKeyField"}
198        old_is_auto = old_internal_type in auto_field_types
199        new_is_auto = new_internal_type in auto_field_types
200        if new_is_auto and not old_is_auto:
201            column = strip_quotes(new_field.column)
202            return (
203                (
204                    self.sql_alter_column_type
205                    % {
206                        "column": self.quote_name(column),
207                        "type": new_type,
208                        "collation": "",
209                    },
210                    [],
211                ),
212                [
213                    (
214                        self.sql_add_identity
215                        % {
216                            "table": self.quote_name(table),
217                            "column": self.quote_name(column),
218                        },
219                        [],
220                    ),
221                ],
222            )
223        elif old_is_auto and not new_is_auto:
224            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
225            # it).
226            self.execute(
227                self.sql_drop_indentity
228                % {
229                    "table": self.quote_name(table),
230                    "column": self.quote_name(strip_quotes(new_field.column)),
231                }
232            )
233            column = strip_quotes(new_field.column)
234            fragment, _ = super()._alter_column_type_sql(
235                model, old_field, new_field, new_type, old_collation, new_collation
236            )
237            # Drop the sequence if exists (Plain 4.1+ identity columns don't
238            # have it).
239            other_actions = []
240            if sequence_name := self._get_sequence_name(table, column):
241                other_actions = [
242                    (
243                        self.sql_delete_sequence
244                        % {
245                            "sequence": self.quote_name(sequence_name),
246                        },
247                        [],
248                    )
249                ]
250            return fragment, other_actions
251        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
252            fragment, _ = super()._alter_column_type_sql(
253                model, old_field, new_field, new_type, old_collation, new_collation
254            )
255            column = strip_quotes(new_field.column)
256            db_types = {"PrimaryKeyField": "bigint"}
257            # Alter the sequence type if exists (Plain 4.1+ identity columns
258            # don't have it).
259            other_actions = []
260            if sequence_name := self._get_sequence_name(table, column):
261                other_actions = [
262                    (
263                        self.sql_alter_sequence_type
264                        % {
265                            "sequence": self.quote_name(sequence_name),
266                            "type": db_types[new_internal_type],
267                        },
268                        [],
269                    ),
270                ]
271            return fragment, other_actions
272        else:
273            return super()._alter_column_type_sql(
274                model, old_field, new_field, new_type, old_collation, new_collation
275            )
276
277    def _alter_field(
278        self,
279        model: type[Model],
280        old_field: Field,
281        new_field: Field,
282        old_type: str,
283        new_type: str,
284        old_db_params: dict[str, Any],
285        new_db_params: dict[str, Any],
286        strict: bool = False,
287    ) -> None:
288        super()._alter_field(
289            model,
290            old_field,
291            new_field,
292            old_type,
293            new_type,
294            old_db_params,
295            new_db_params,
296            strict,
297        )
298        # Added an index? Create any PostgreSQL-specific indexes.
299        if (
300            not (
301                (old_field.remote_field and old_field.db_index)  # type: ignore[attr-defined]
302                or old_field.primary_key
303            )
304            and (new_field.remote_field and new_field.db_index)  # type: ignore[attr-defined]
305        ) or (not old_field.primary_key and new_field.primary_key):
306            like_index_statement = self._create_like_index_sql(model, new_field)
307            if like_index_statement is not None:
308                self.execute(like_index_statement)
309
310        # Removed an index? Drop any PostgreSQL-specific indexes.
311        if old_field.primary_key and not (
312            (new_field.remote_field and new_field.db_index)  # type: ignore[attr-defined]
313            or new_field.primary_key
314        ):
315            index_to_remove = self._create_index_name(
316                model.model_options.db_table, [old_field.column], suffix="_like"
317            )
318            self.execute(self._delete_index_sql(model, index_to_remove))
319
320    def _index_columns(
321        self,
322        table: str,
323        columns: list[str],
324        col_suffixes: tuple[str, ...],
325        opclasses: tuple[str, ...],
326    ) -> Columns | IndexColumns:
327        if opclasses:
328            return IndexColumns(
329                table,
330                columns,
331                self.quote_name,
332                col_suffixes=col_suffixes,
333                opclasses=opclasses,
334            )
335        return super()._index_columns(table, columns, col_suffixes, opclasses)
336
337    def add_index(
338        self, model: type[Model], index: Index, concurrently: bool = False
339    ) -> None:
340        self.execute(
341            index.create_sql(model, self, concurrently=concurrently), params=None
342        )
343
344    def remove_index(
345        self, model: type[Model], index: Index, concurrently: bool = False
346    ) -> None:
347        self.execute(index.remove_sql(model, self, concurrently=concurrently))
348
349    def _delete_index_sql(
350        self,
351        model: type[Model],
352        name: str,
353        sql: str | None = None,
354        concurrently: bool = False,
355    ) -> Statement:
356        sql = (
357            self.sql_delete_index_concurrently
358            if concurrently
359            else self.sql_delete_index
360        )
361        return super()._delete_index_sql(model, name, sql)
362
363    def _create_index_sql(
364        self,
365        model: type[Model],
366        *,
367        fields: list[Field] | None = None,
368        name: str | None = None,
369        suffix: str = "",
370        using: str = "",
371        col_suffixes: tuple[str, ...] = (),
372        sql: str | None = None,
373        opclasses: tuple[str, ...] = (),
374        condition: str | None = None,
375        concurrently: bool = False,
376        include: list[str] | None = None,
377        expressions: Any = None,
378    ) -> Statement:
379        sql = sql or (
380            self.sql_create_index
381            if not concurrently
382            else self.sql_create_index_concurrently
383        )
384        return super()._create_index_sql(
385            model,
386            fields=fields,
387            name=name,
388            suffix=suffix,
389            using=using,
390            col_suffixes=col_suffixes,
391            sql=sql,
392            opclasses=opclasses,
393            condition=condition,
394            include=include,
395            expressions=expressions,
396        )