Plain is headed towards 1.0! Subscribe for development updates →

  1from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  2from plain.models.backends.ddl_references import IndexColumns
  3from plain.models.backends.postgresql.psycopg_any import sql
  4from plain.models.backends.utils import strip_quotes
  5
  6
  7class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
  8    # Setting all constraints to IMMEDIATE to allow changing data in the same
  9    # transaction.
 10    sql_update_with_default = (
 11        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 12        "; SET CONSTRAINTS ALL IMMEDIATE"
 13    )
 14    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 15    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 16
 17    sql_create_index = (
 18        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 19        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 20    )
 21    sql_create_index_concurrently = (
 22        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 23        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 24    )
 25    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 26    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 27
 28    # Setting the constraint to IMMEDIATE to allow changing data in the same
 29    # transaction.
 30    sql_create_column_inline_fk = (
 31        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 32        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 33    )
 34    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 35    # dropping it in the same transaction.
 36    sql_delete_fk = (
 37        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 38        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 39    )
 40    sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
 41
 42    def execute(self, sql, params=()):
 43        # Merge the query client-side, as PostgreSQL won't do it server-side.
 44        if params is None:
 45            return super().execute(sql, params)
 46        sql = self.connection.ops.compose_sql(str(sql), params)
 47        # Don't let the superclass touch anything.
 48        return super().execute(sql, None)
 49
 50    sql_add_identity = (
 51        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 52        "GENERATED BY DEFAULT AS IDENTITY"
 53    )
 54    sql_drop_indentity = (
 55        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 56    )
 57
 58    def quote_value(self, value):
 59        if isinstance(value, str):
 60            value = value.replace("%", "%%")
 61        return sql.quote(value, self.connection.connection)
 62
 63    def _field_indexes_sql(self, model, field):
 64        output = super()._field_indexes_sql(model, field)
 65        like_index_statement = self._create_like_index_sql(model, field)
 66        if like_index_statement is not None:
 67            output.append(like_index_statement)
 68        return output
 69
 70    def _field_data_type(self, field):
 71        if field.is_relation:
 72            return field.rel_db_type(self.connection)
 73        return self.connection.data_types.get(
 74            field.get_internal_type(),
 75            field.db_type(self.connection),
 76        )
 77
 78    def _field_base_data_types(self, field):
 79        # Yield base data types for array fields.
 80        if field.base_field.get_internal_type() == "ArrayField":
 81            yield from self._field_base_data_types(field.base_field)
 82        else:
 83            yield self._field_data_type(field.base_field)
 84
 85    def _create_like_index_sql(self, model, field):
 86        """
 87        Return the statement to create an index with varchar operator pattern
 88        when the column type is 'varchar' or 'text', otherwise return None.
 89        """
 90        db_type = field.db_type(connection=self.connection)
 91        if db_type is not None and (field.db_index or field.unique):
 92            # Fields with database column types of `varchar` and `text` need
 93            # a second index that specifies their operator class, which is
 94            # needed when performing correct LIKE queries outside the
 95            # C locale. See #12234.
 96            #
 97            # The same doesn't apply to array fields such as varchar[size]
 98            # and text[size], so skip them.
 99            if "[" in db_type:
100                return None
101            # Non-deterministic collations on Postgresql don't support indexes
102            # for operator classes varchar_pattern_ops/text_pattern_ops.
103            if getattr(field, "db_collation", None):
104                return None
105            if db_type.startswith("varchar"):
106                return self._create_index_sql(
107                    model,
108                    fields=[field],
109                    suffix="_like",
110                    opclasses=["varchar_pattern_ops"],
111                )
112            elif db_type.startswith("text"):
113                return self._create_index_sql(
114                    model,
115                    fields=[field],
116                    suffix="_like",
117                    opclasses=["text_pattern_ops"],
118                )
119        return None
120
121    def _using_sql(self, new_field, old_field):
122        using_sql = " USING %(column)s::%(type)s"
123        new_internal_type = new_field.get_internal_type()
124        old_internal_type = old_field.get_internal_type()
125        if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
126            # Compare base data types for array fields.
127            if list(self._field_base_data_types(old_field)) != list(
128                self._field_base_data_types(new_field)
129            ):
130                return using_sql
131        elif self._field_data_type(old_field) != self._field_data_type(new_field):
132            return using_sql
133        return ""
134
135    def _get_sequence_name(self, table, column):
136        with self.connection.cursor() as cursor:
137            for sequence in self.connection.introspection.get_sequences(cursor, table):
138                if sequence["column"] == column:
139                    return sequence["name"]
140        return None
141
142    def _alter_column_type_sql(
143        self, model, old_field, new_field, new_type, old_collation, new_collation
144    ):
145        # Drop indexes on varchar/text/citext columns that are changing to a
146        # different type.
147        old_db_params = old_field.db_parameters(connection=self.connection)
148        old_type = old_db_params["type"]
149        if (old_field.db_index or old_field.unique) and (
150            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
151            or (old_type.startswith("text") and not new_type.startswith("text"))
152            or (old_type.startswith("citext") and not new_type.startswith("citext"))
153        ):
154            index_name = self._create_index_name(
155                model._meta.db_table, [old_field.column], suffix="_like"
156            )
157            self.execute(self._delete_index_sql(model, index_name))
158
159        self.sql_alter_column_type = (
160            "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
161        )
162        # Cast when data type changed.
163        if using_sql := self._using_sql(new_field, old_field):
164            self.sql_alter_column_type += using_sql
165        new_internal_type = new_field.get_internal_type()
166        old_internal_type = old_field.get_internal_type()
167        # Make ALTER TYPE with IDENTITY make sense.
168        table = strip_quotes(model._meta.db_table)
169        auto_field_types = {
170            "AutoField",
171            "BigAutoField",
172            "SmallAutoField",
173        }
174        old_is_auto = old_internal_type in auto_field_types
175        new_is_auto = new_internal_type in auto_field_types
176        if new_is_auto and not old_is_auto:
177            column = strip_quotes(new_field.column)
178            return (
179                (
180                    self.sql_alter_column_type
181                    % {
182                        "column": self.quote_name(column),
183                        "type": new_type,
184                        "collation": "",
185                    },
186                    [],
187                ),
188                [
189                    (
190                        self.sql_add_identity
191                        % {
192                            "table": self.quote_name(table),
193                            "column": self.quote_name(column),
194                        },
195                        [],
196                    ),
197                ],
198            )
199        elif old_is_auto and not new_is_auto:
200            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
201            # it).
202            self.execute(
203                self.sql_drop_indentity
204                % {
205                    "table": self.quote_name(table),
206                    "column": self.quote_name(strip_quotes(new_field.column)),
207                }
208            )
209            column = strip_quotes(new_field.column)
210            fragment, _ = super()._alter_column_type_sql(
211                model, old_field, new_field, new_type, old_collation, new_collation
212            )
213            # Drop the sequence if exists (Plain 4.1+ identity columns don't
214            # have it).
215            other_actions = []
216            if sequence_name := self._get_sequence_name(table, column):
217                other_actions = [
218                    (
219                        self.sql_delete_sequence
220                        % {
221                            "sequence": self.quote_name(sequence_name),
222                        },
223                        [],
224                    )
225                ]
226            return fragment, other_actions
227        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
228            fragment, _ = super()._alter_column_type_sql(
229                model, old_field, new_field, new_type, old_collation, new_collation
230            )
231            column = strip_quotes(new_field.column)
232            db_types = {
233                "AutoField": "integer",
234                "BigAutoField": "bigint",
235                "SmallAutoField": "smallint",
236            }
237            # Alter the sequence type if exists (Plain 4.1+ identity columns
238            # don't have it).
239            other_actions = []
240            if sequence_name := self._get_sequence_name(table, column):
241                other_actions = [
242                    (
243                        self.sql_alter_sequence_type
244                        % {
245                            "sequence": self.quote_name(sequence_name),
246                            "type": db_types[new_internal_type],
247                        },
248                        [],
249                    ),
250                ]
251            return fragment, other_actions
252        else:
253            return super()._alter_column_type_sql(
254                model, old_field, new_field, new_type, old_collation, new_collation
255            )
256
257    def _alter_column_collation_sql(
258        self, model, new_field, new_type, new_collation, old_field
259    ):
260        sql = self.sql_alter_column_collate
261        # Cast when data type changed.
262        if using_sql := self._using_sql(new_field, old_field):
263            sql += using_sql
264        return (
265            sql
266            % {
267                "column": self.quote_name(new_field.column),
268                "type": new_type,
269                "collation": " " + self._collate_sql(new_collation)
270                if new_collation
271                else "",
272            },
273            [],
274        )
275
276    def _alter_field(
277        self,
278        model,
279        old_field,
280        new_field,
281        old_type,
282        new_type,
283        old_db_params,
284        new_db_params,
285        strict=False,
286    ):
287        super()._alter_field(
288            model,
289            old_field,
290            new_field,
291            old_type,
292            new_type,
293            old_db_params,
294            new_db_params,
295            strict,
296        )
297        # Added an index? Create any PostgreSQL-specific indexes.
298        if (not (old_field.db_index or old_field.unique) and new_field.db_index) or (
299            not old_field.unique and new_field.unique
300        ):
301            like_index_statement = self._create_like_index_sql(model, new_field)
302            if like_index_statement is not None:
303                self.execute(like_index_statement)
304
305        # Removed an index? Drop any PostgreSQL-specific indexes.
306        if old_field.unique and not (new_field.db_index or new_field.unique):
307            index_to_remove = self._create_index_name(
308                model._meta.db_table, [old_field.column], suffix="_like"
309            )
310            self.execute(self._delete_index_sql(model, index_to_remove))
311
312    def _index_columns(self, table, columns, col_suffixes, opclasses):
313        if opclasses:
314            return IndexColumns(
315                table,
316                columns,
317                self.quote_name,
318                col_suffixes=col_suffixes,
319                opclasses=opclasses,
320            )
321        return super()._index_columns(table, columns, col_suffixes, opclasses)
322
323    def add_index(self, model, index, concurrently=False):
324        self.execute(
325            index.create_sql(model, self, concurrently=concurrently), params=None
326        )
327
328    def remove_index(self, model, index, concurrently=False):
329        self.execute(index.remove_sql(model, self, concurrently=concurrently))
330
331    def _delete_index_sql(self, model, name, sql=None, concurrently=False):
332        sql = (
333            self.sql_delete_index_concurrently
334            if concurrently
335            else self.sql_delete_index
336        )
337        return super()._delete_index_sql(model, name, sql)
338
339    def _create_index_sql(
340        self,
341        model,
342        *,
343        fields=None,
344        name=None,
345        suffix="",
346        using="",
347        db_tablespace=None,
348        col_suffixes=(),
349        sql=None,
350        opclasses=(),
351        condition=None,
352        concurrently=False,
353        include=None,
354        expressions=None,
355    ):
356        sql = sql or (
357            self.sql_create_index
358            if not concurrently
359            else self.sql_create_index_concurrently
360        )
361        return super()._create_index_sql(
362            model,
363            fields=fields,
364            name=name,
365            suffix=suffix,
366            using=using,
367            db_tablespace=db_tablespace,
368            col_suffixes=col_suffixes,
369            sql=sql,
370            opclasses=opclasses,
371            condition=condition,
372            include=include,
373            expressions=expressions,
374        )