Plain is headed towards 1.0! Subscribe for development updates →

  1from psycopg import sql
  2
  3from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
  4from plain.models.backends.ddl_references import IndexColumns
  5from plain.models.backends.utils import strip_quotes
  6
  7
  8class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
  9    # Setting all constraints to IMMEDIATE to allow changing data in the same
 10    # transaction.
 11    sql_update_with_default = (
 12        "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
 13        "; SET CONSTRAINTS ALL IMMEDIATE"
 14    )
 15    sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
 16    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
 17
 18    sql_create_index = (
 19        "CREATE INDEX %(name)s ON %(table)s%(using)s "
 20        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 21    )
 22    sql_create_index_concurrently = (
 23        "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
 24        "(%(columns)s)%(include)s%(extra)s%(condition)s"
 25    )
 26    sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
 27    sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
 28
 29    # Setting the constraint to IMMEDIATE to allow changing data in the same
 30    # transaction.
 31    sql_create_column_inline_fk = (
 32        "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
 33        "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
 34    )
 35    # Setting the constraint to IMMEDIATE runs any deferred checks to allow
 36    # dropping it in the same transaction.
 37    sql_delete_fk = (
 38        "SET CONSTRAINTS %(name)s IMMEDIATE; "
 39        "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
 40    )
 41
 42    def execute(self, sql, params=()):
 43        # Merge the query client-side, as PostgreSQL won't do it server-side.
 44        if params is None:
 45            return super().execute(sql, params)
 46        sql = self.connection.ops.compose_sql(str(sql), params)
 47        # Don't let the superclass touch anything.
 48        return super().execute(sql, None)
 49
 50    sql_add_identity = (
 51        "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
 52        "GENERATED BY DEFAULT AS IDENTITY"
 53    )
 54    sql_drop_indentity = (
 55        "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
 56    )
 57
 58    def quote_value(self, value):
 59        if isinstance(value, str):
 60            value = value.replace("%", "%%")
 61        return sql.quote(value, self.connection.connection)
 62
 63    def _field_indexes_sql(self, model, field):
 64        output = super()._field_indexes_sql(model, field)
 65        like_index_statement = self._create_like_index_sql(model, field)
 66        if like_index_statement is not None:
 67            output.append(like_index_statement)
 68        return output
 69
 70    def _field_data_type(self, field):
 71        if field.is_relation:
 72            return field.rel_db_type(self.connection)
 73        return self.connection.data_types.get(
 74            field.get_internal_type(),
 75            field.db_type(self.connection),
 76        )
 77
 78    def _field_base_data_types(self, field):
 79        # Yield base data types for array fields.
 80        if field.base_field.get_internal_type() == "ArrayField":
 81            yield from self._field_base_data_types(field.base_field)
 82        else:
 83            yield self._field_data_type(field.base_field)
 84
 85    def _create_like_index_sql(self, model, field):
 86        """
 87        Return the statement to create an index with varchar operator pattern
 88        when the column type is 'varchar' or 'text', otherwise return None.
 89        """
 90        db_type = field.db_type(connection=self.connection)
 91        if db_type is not None and field.primary_key:
 92            # Fields with database column types of `varchar` and `text` need
 93            # a second index that specifies their operator class, which is
 94            # needed when performing correct LIKE queries outside the
 95            # C locale. See #12234.
 96            #
 97            # The same doesn't apply to array fields such as varchar[size]
 98            # and text[size], so skip them.
 99            if "[" in db_type:
100                return None
101            # Non-deterministic collations on Postgresql don't support indexes
102            # for operator classes varchar_pattern_ops/text_pattern_ops.
103            if getattr(field, "db_collation", None):
104                return None
105            if db_type.startswith("varchar"):
106                return self._create_index_sql(
107                    model,
108                    fields=[field],
109                    suffix="_like",
110                    opclasses=["varchar_pattern_ops"],
111                )
112            elif db_type.startswith("text"):
113                return self._create_index_sql(
114                    model,
115                    fields=[field],
116                    suffix="_like",
117                    opclasses=["text_pattern_ops"],
118                )
119        return None
120
121    def _using_sql(self, new_field, old_field):
122        using_sql = " USING %(column)s::%(type)s"
123        new_internal_type = new_field.get_internal_type()
124        old_internal_type = old_field.get_internal_type()
125        if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
126            # Compare base data types for array fields.
127            if list(self._field_base_data_types(old_field)) != list(
128                self._field_base_data_types(new_field)
129            ):
130                return using_sql
131        elif self._field_data_type(old_field) != self._field_data_type(new_field):
132            return using_sql
133        return ""
134
135    def _get_sequence_name(self, table, column):
136        with self.connection.cursor() as cursor:
137            for sequence in self.connection.introspection.get_sequences(cursor, table):
138                if sequence["column"] == column:
139                    return sequence["name"]
140        return None
141
142    def _alter_column_type_sql(
143        self, model, old_field, new_field, new_type, old_collation, new_collation
144    ):
145        # Drop indexes on varchar/text/citext columns that are changing to a
146        # different type.
147        old_db_params = old_field.db_parameters(connection=self.connection)
148        old_type = old_db_params["type"]
149        if old_field.primary_key and (
150            (old_type.startswith("varchar") and not new_type.startswith("varchar"))
151            or (old_type.startswith("text") and not new_type.startswith("text"))
152            or (old_type.startswith("citext") and not new_type.startswith("citext"))
153        ):
154            index_name = self._create_index_name(
155                model._meta.db_table, [old_field.column], suffix="_like"
156            )
157            self.execute(self._delete_index_sql(model, index_name))
158
159        self.sql_alter_column_type = (
160            "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
161        )
162        # Cast when data type changed.
163        if using_sql := self._using_sql(new_field, old_field):
164            self.sql_alter_column_type += using_sql
165        new_internal_type = new_field.get_internal_type()
166        old_internal_type = old_field.get_internal_type()
167        # Make ALTER TYPE with IDENTITY make sense.
168        table = strip_quotes(model._meta.db_table)
169        auto_field_types = {"PrimaryKeyField"}
170        old_is_auto = old_internal_type in auto_field_types
171        new_is_auto = new_internal_type in auto_field_types
172        if new_is_auto and not old_is_auto:
173            column = strip_quotes(new_field.column)
174            return (
175                (
176                    self.sql_alter_column_type
177                    % {
178                        "column": self.quote_name(column),
179                        "type": new_type,
180                        "collation": "",
181                    },
182                    [],
183                ),
184                [
185                    (
186                        self.sql_add_identity
187                        % {
188                            "table": self.quote_name(table),
189                            "column": self.quote_name(column),
190                        },
191                        [],
192                    ),
193                ],
194            )
195        elif old_is_auto and not new_is_auto:
196            # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
197            # it).
198            self.execute(
199                self.sql_drop_indentity
200                % {
201                    "table": self.quote_name(table),
202                    "column": self.quote_name(strip_quotes(new_field.column)),
203                }
204            )
205            column = strip_quotes(new_field.column)
206            fragment, _ = super()._alter_column_type_sql(
207                model, old_field, new_field, new_type, old_collation, new_collation
208            )
209            # Drop the sequence if exists (Plain 4.1+ identity columns don't
210            # have it).
211            other_actions = []
212            if sequence_name := self._get_sequence_name(table, column):
213                other_actions = [
214                    (
215                        self.sql_delete_sequence
216                        % {
217                            "sequence": self.quote_name(sequence_name),
218                        },
219                        [],
220                    )
221                ]
222            return fragment, other_actions
223        elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
224            fragment, _ = super()._alter_column_type_sql(
225                model, old_field, new_field, new_type, old_collation, new_collation
226            )
227            column = strip_quotes(new_field.column)
228            db_types = {"PrimaryKeyField": "bigint"}
229            # Alter the sequence type if exists (Plain 4.1+ identity columns
230            # don't have it).
231            other_actions = []
232            if sequence_name := self._get_sequence_name(table, column):
233                other_actions = [
234                    (
235                        self.sql_alter_sequence_type
236                        % {
237                            "sequence": self.quote_name(sequence_name),
238                            "type": db_types[new_internal_type],
239                        },
240                        [],
241                    ),
242                ]
243            return fragment, other_actions
244        else:
245            return super()._alter_column_type_sql(
246                model, old_field, new_field, new_type, old_collation, new_collation
247            )
248
249    def _alter_field(
250        self,
251        model,
252        old_field,
253        new_field,
254        old_type,
255        new_type,
256        old_db_params,
257        new_db_params,
258        strict=False,
259    ):
260        super()._alter_field(
261            model,
262            old_field,
263            new_field,
264            old_type,
265            new_type,
266            old_db_params,
267            new_db_params,
268            strict,
269        )
270        # Added an index? Create any PostgreSQL-specific indexes.
271        if (
272            not (
273                (old_field.remote_field and old_field.db_index) or old_field.primary_key
274            )
275            and (new_field.remote_field and new_field.db_index)
276        ) or (not old_field.primary_key and new_field.primary_key):
277            like_index_statement = self._create_like_index_sql(model, new_field)
278            if like_index_statement is not None:
279                self.execute(like_index_statement)
280
281        # Removed an index? Drop any PostgreSQL-specific indexes.
282        if old_field.primary_key and not (
283            (new_field.remote_field and new_field.db_index) or new_field.primary_key
284        ):
285            index_to_remove = self._create_index_name(
286                model._meta.db_table, [old_field.column], suffix="_like"
287            )
288            self.execute(self._delete_index_sql(model, index_to_remove))
289
290    def _index_columns(self, table, columns, col_suffixes, opclasses):
291        if opclasses:
292            return IndexColumns(
293                table,
294                columns,
295                self.quote_name,
296                col_suffixes=col_suffixes,
297                opclasses=opclasses,
298            )
299        return super()._index_columns(table, columns, col_suffixes, opclasses)
300
301    def add_index(self, model, index, concurrently=False):
302        self.execute(
303            index.create_sql(model, self, concurrently=concurrently), params=None
304        )
305
306    def remove_index(self, model, index, concurrently=False):
307        self.execute(index.remove_sql(model, self, concurrently=concurrently))
308
309    def _delete_index_sql(self, model, name, sql=None, concurrently=False):
310        sql = (
311            self.sql_delete_index_concurrently
312            if concurrently
313            else self.sql_delete_index
314        )
315        return super()._delete_index_sql(model, name, sql)
316
317    def _create_index_sql(
318        self,
319        model,
320        *,
321        fields=None,
322        name=None,
323        suffix="",
324        using="",
325        col_suffixes=(),
326        sql=None,
327        opclasses=(),
328        condition=None,
329        concurrently=False,
330        include=None,
331        expressions=None,
332    ):
333        sql = sql or (
334            self.sql_create_index
335            if not concurrently
336            else self.sql_create_index_concurrently
337        )
338        return super()._create_index_sql(
339            model,
340            fields=fields,
341            name=name,
342            suffix=suffix,
343            using=using,
344            col_suffixes=col_suffixes,
345            sql=sql,
346            opclasses=opclasses,
347            condition=condition,
348            include=include,
349            expressions=expressions,
350        )