1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any
  5
  6from psycopg import sql
  7
  8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
 10from plain.models.backends.utils import strip_quotes
 11from plain.models.fields import DbParameters
 12from plain.models.fields.related import ForeignKeyField, RelatedField
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
 16    from plain.models.base import Model
 17    from plain.models.fields import Field
 18    from plain.models.indexes import Index
 19
 20
 21class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
 22    # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
 23    connection: PostgreSQLDatabaseWrapper
 24
 25    # Setting all constraints to IMMEDIATE to allow changing data in the same
 26    # transaction.
 27    sql_update_with_default = (
 28        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 29        "; SET CONSTRAINTS ALL IMMEDIATE"
 30    )
 31    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 32    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 33
 34    sql_create_index = (
 35        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 36        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 37    )
 38    sql_create_index_concurrently = (
 39        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 40        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 41    )
 42    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 43    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 44
 45    # Setting the constraint to IMMEDIATE to allow changing data in the same
 46    # transaction.
 47    sql_create_column_inline_fk = (
 48        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 49        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 50    )
 51    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 52    # dropping it in the same transaction.
 53    sql_delete_fk = (
 54        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 55        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 56    )
 57
 58    def execute(
 59        self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
 60    ) -> None:
 61        # Merge the query client-side, as PostgreSQL won't do it server-side.
 62        if params is None:
 63            return super().execute(sql, params)
 64        sql = self.connection.ops.compose_sql(str(sql), params)
 65        # Don't let the superclass touch anything.
 66        return super().execute(sql, None)
 67
 68    sql_add_identity = (
 69        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 70        "GENERATED BY DEFAULT AS IDENTITY"
 71    )
 72    sql_drop_indentity = (
 73        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 74    )
 75
 76    def quote_value(self, value: Any) -> str:
 77        if isinstance(value, str):
 78            value = value.replace("%", "%%")
 79        return sql.quote(value, self.connection.connection)
 80
 81    def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
 82        output = super()._field_indexes_sql(model, field)
 83        like_index_statement = self._create_like_index_sql(model, field)
 84        if like_index_statement is not None:
 85            output.append(like_index_statement)
 86        return output
 87
 88    def _field_data_type(
 89        self, field: Field
 90    ) -> str | None | Callable[[dict[str, Any]], str]:
 91        if isinstance(field, RelatedField):
 92            return field.rel_db_type(self.connection)
 93        return self.connection.data_types.get(
 94            field.get_internal_type(),
 95            field.db_type(self.connection),
 96        )
 97
 98    def _create_like_index_sql(
 99        self, model: type[Model], field: Field
100    ) -> Statement | None:
101        """
102        Return the statement to create an index with varchar operator pattern
103        when the column type is 'varchar' or 'text', otherwise return None.
104        """
105        db_type = field.db_type(connection=self.connection)
106        if db_type is not None and field.primary_key:
107            # Fields with database column types of `varchar` and `text` need
108            # a second index that specifies their operator class, which is
109            # needed when performing correct LIKE queries outside the
110            # C locale. See #12234.
111            #
112            # The same doesn't apply to array fields such as varchar[size]
113            # and text[size], so skip them.
114            if "[" in db_type:
115                return None
116            if db_type.startswith("varchar"):
117                return self._create_index_sql(
118                    model,
119                    fields=[field],
120                    suffix="_like",
121                    opclasses=("varchar_pattern_ops",),
122                )
123            elif db_type.startswith("text"):
124                return self._create_index_sql(
125                    model,
126                    fields=[field],
127                    suffix="_like",
128                    opclasses=("text_pattern_ops",),
129                )
130        return None
131
132    def _using_sql(self, new_field: Field, old_field: Field) -> str:
133        using_sql = " USING %(column)s::%(type)s"
134        if self._field_data_type(old_field) != self._field_data_type(new_field):
135            return using_sql
136        return ""
137
138    def _get_sequence_name(self, table: str, column: str) -> str | None:
139        with self.connection.cursor() as cursor:
140            for sequence in self.connection.introspection.get_sequences(cursor, table):
141                if sequence["column"] == column:
142                    return sequence["name"]
143        return None
144
145    def _alter_column_type_sql(
146        self,
147        model: type[Model],
148        old_field: Field,
149        new_field: Field,
150        new_type: str,
151    ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
152        # Drop indexes on varchar/text/citext columns that are changing to a
153        # different type.
154        old_db_params = old_field.db_parameters(connection=self.connection)
155        old_type = old_db_params["type"]
156        assert old_type is not None, "old_type cannot be None for primary key field"
157        if old_field.primary_key and (
158            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
159            or (old_type.startswith("text") and not new_type.startswith("text"))
160            or (old_type.startswith("citext") and not new_type.startswith("citext"))
161        ):
162            index_name = self._create_index_name(
163                model.model_options.db_table, [old_field.column], suffix="_like"
164            )
165            self.execute(self._delete_index_sql(model, index_name))
166
167        self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
168        # Cast when data type changed.
169        if using_sql := self._using_sql(new_field, old_field):
170            self.sql_alter_column_type += using_sql
171        new_internal_type = new_field.get_internal_type()
172        old_internal_type = old_field.get_internal_type()
173        # Make ALTER TYPE with IDENTITY make sense.
174        table = strip_quotes(model.model_options.db_table)
175        auto_field_types = {"PrimaryKeyField"}
176        old_is_auto = old_internal_type in auto_field_types
177        new_is_auto = new_internal_type in auto_field_types
178        if new_is_auto and not old_is_auto:
179            column = strip_quotes(new_field.column)
180            return (
181                (
182                    self.sql_alter_column_type
183                    % {
184                        "column": self.quote_name(column),
185                        "type": new_type,
186                    },
187                    [],
188                ),
189                [
190                    (
191                        self.sql_add_identity
192                        % {
193                            "table": self.quote_name(table),
194                            "column": self.quote_name(column),
195                        },
196                        [],
197                    ),
198                ],
199            )
200        elif old_is_auto and not new_is_auto:
201            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
202            # it).
203            self.execute(
204                self.sql_drop_indentity
205                % {
206                    "table": self.quote_name(table),
207                    "column": self.quote_name(strip_quotes(new_field.column)),
208                }
209            )
210            column = strip_quotes(new_field.column)
211            fragment, _ = super()._alter_column_type_sql(
212                model, old_field, new_field, new_type
213            )
214            # Drop the sequence if exists (Plain 4.1+ identity columns don't
215            # have it).
216            other_actions = []
217            if sequence_name := self._get_sequence_name(table, column):
218                other_actions = [
219                    (
220                        self.sql_delete_sequence
221                        % {
222                            "sequence": self.quote_name(sequence_name),
223                        },
224                        [],
225                    )
226                ]
227            return fragment, other_actions
228        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
229            fragment, _ = super()._alter_column_type_sql(
230                model, old_field, new_field, new_type
231            )
232            column = strip_quotes(new_field.column)
233            db_types = {"PrimaryKeyField": "bigint"}
234            # Alter the sequence type if exists (Plain 4.1+ identity columns
235            # don't have it).
236            other_actions = []
237            if sequence_name := self._get_sequence_name(table, column):
238                other_actions = [
239                    (
240                        self.sql_alter_sequence_type
241                        % {
242                            "sequence": self.quote_name(sequence_name),
243                            "type": db_types[new_internal_type],
244                        },
245                        [],
246                    ),
247                ]
248            return fragment, other_actions
249        else:
250            return super()._alter_column_type_sql(model, old_field, new_field, new_type)
251
252    def _alter_field(
253        self,
254        model: type[Model],
255        old_field: Field,
256        new_field: Field,
257        old_type: str,
258        new_type: str,
259        old_db_params: DbParameters,
260        new_db_params: DbParameters,
261        strict: bool = False,
262    ) -> None:
263        super()._alter_field(
264            model,
265            old_field,
266            new_field,
267            old_type,
268            new_type,
269            old_db_params,
270            new_db_params,
271            strict,
272        )
273        # Added an index? Create any PostgreSQL-specific indexes.
274        if (
275            not (
276                (isinstance(old_field, ForeignKeyField) and old_field.db_index)
277                or old_field.primary_key
278            )
279            and isinstance(new_field, ForeignKeyField)
280            and new_field.db_index
281        ) or (not old_field.primary_key and new_field.primary_key):
282            like_index_statement = self._create_like_index_sql(model, new_field)
283            if like_index_statement is not None:
284                self.execute(like_index_statement)
285
286        # Removed an index? Drop any PostgreSQL-specific indexes.
287        if old_field.primary_key and not (
288            (isinstance(new_field, ForeignKeyField) and new_field.db_index)
289            or new_field.primary_key
290        ):
291            index_to_remove = self._create_index_name(
292                model.model_options.db_table, [old_field.column], suffix="_like"
293            )
294            self.execute(self._delete_index_sql(model, index_to_remove))
295
296    def _index_columns(
297        self,
298        table: str,
299        columns: list[str],
300        col_suffixes: tuple[str, ...],
301        opclasses: tuple[str, ...],
302    ) -> Columns | IndexColumns:
303        if opclasses:
304            return IndexColumns(
305                table,
306                columns,
307                self.quote_name,
308                col_suffixes=col_suffixes,
309                opclasses=opclasses,
310            )
311        return super()._index_columns(table, columns, col_suffixes, opclasses)
312
313    def add_index(
314        self, model: type[Model], index: Index, concurrently: bool = False
315    ) -> None:
316        self.execute(
317            index.create_sql(model, self, concurrently=concurrently), params=None
318        )
319
320    def remove_index(
321        self, model: type[Model], index: Index, concurrently: bool = False
322    ) -> None:
323        self.execute(index.remove_sql(model, self, concurrently=concurrently))
324
325    def _delete_index_sql(
326        self,
327        model: type[Model],
328        name: str,
329        sql: str | None = None,
330        concurrently: bool = False,
331    ) -> Statement:
332        sql = (
333            self.sql_delete_index_concurrently
334            if concurrently
335            else self.sql_delete_index
336        )
337        return super()._delete_index_sql(model, name, sql)
338
339    def _create_index_sql(
340        self,
341        model: type[Model],
342        *,
343        fields: list[Field] | None = None,
344        name: str | None = None,
345        suffix: str = "",
346        using: str = "",
347        col_suffixes: tuple[str, ...] = (),
348        sql: str | None = None,
349        opclasses: tuple[str, ...] = (),
350        condition: str | None = None,
351        concurrently: bool = False,
352        include: list[str] | None = None,
353        expressions: Any = None,
354    ) -> Statement:
355        sql = sql or (
356            self.sql_create_index
357            if not concurrently
358            else self.sql_create_index_concurrently
359        )
360        return super()._create_index_sql(
361            model,
362            fields=fields,
363            name=name,
364            suffix=suffix,
365            using=using,
366            col_suffixes=col_suffixes,
367            sql=sql,
368            opclasses=opclasses,
369            condition=condition,
370            include=include,
371            expressions=expressions,
372        )