1from __future__ import annotations
2
3from collections.abc import Callable
4from typing import TYPE_CHECKING, Any
5
6from psycopg import sql
7
8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
10from plain.models.backends.utils import strip_quotes
11from plain.models.fields.related import ForeignKeyField, RelatedField
12
13if TYPE_CHECKING:
14 from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
15 from plain.models.base import Model
16 from plain.models.fields import Field
17 from plain.models.indexes import Index
18
19
20class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
21 # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
22 connection: PostgreSQLDatabaseWrapper
23
24 # Setting all constraints to IMMEDIATE to allow changing data in the same
25 # transaction.
26 sql_update_with_default = (
27 "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
28 "; SET CONSTRAINTS ALL IMMEDIATE"
29 )
30 sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
31 sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
32
33 sql_create_index = (
34 "CREATE INDEX %(name)s ON %(table)s%(using)s "
35 "(%(columns)s)%(include)s%(extra)s%(condition)s"
36 )
37 sql_create_index_concurrently = (
38 "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
39 "(%(columns)s)%(include)s%(extra)s%(condition)s"
40 )
41 sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
42 sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
43
44 # Setting the constraint to IMMEDIATE to allow changing data in the same
45 # transaction.
46 sql_create_column_inline_fk = (
47 "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
48 "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
49 )
50 # Setting the constraint to IMMEDIATE runs any deferred checks to allow
51 # dropping it in the same transaction.
52 sql_delete_fk = (
53 "SET CONSTRAINTS %(name)s IMMEDIATE; "
54 "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
55 )
56
57 def execute(
58 self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
59 ) -> None:
60 # Merge the query client-side, as PostgreSQL won't do it server-side.
61 if params is None:
62 return super().execute(sql, params)
63 sql = self.connection.ops.compose_sql(str(sql), params)
64 # Don't let the superclass touch anything.
65 return super().execute(sql, None)
66
67 sql_add_identity = (
68 "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
69 "GENERATED BY DEFAULT AS IDENTITY"
70 )
71 sql_drop_indentity = (
72 "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
73 )
74
75 def quote_value(self, value: Any) -> str:
76 if isinstance(value, str):
77 value = value.replace("%", "%%")
78 return sql.quote(value, self.connection.connection)
79
80 def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
81 output = super()._field_indexes_sql(model, field)
82 like_index_statement = self._create_like_index_sql(model, field)
83 if like_index_statement is not None:
84 output.append(like_index_statement)
85 return output
86
87 def _field_data_type(
88 self, field: Field
89 ) -> str | None | Callable[[dict[str, Any]], str]:
90 if isinstance(field, RelatedField):
91 return field.rel_db_type(self.connection)
92 return self.connection.data_types.get(
93 field.get_internal_type(),
94 field.db_type(self.connection),
95 )
96
97 def _create_like_index_sql(
98 self, model: type[Model], field: Field
99 ) -> Statement | None:
100 """
101 Return the statement to create an index with varchar operator pattern
102 when the column type is 'varchar' or 'text', otherwise return None.
103 """
104 db_type = field.db_type(connection=self.connection)
105 if db_type is not None and field.primary_key:
106 # Fields with database column types of `varchar` and `text` need
107 # a second index that specifies their operator class, which is
108 # needed when performing correct LIKE queries outside the
109 # C locale. See #12234.
110 #
111 # The same doesn't apply to array fields such as varchar[size]
112 # and text[size], so skip them.
113 if "[" in db_type:
114 return None
115 # Non-deterministic collations on Postgresql don't support indexes
116 # for operator classes varchar_pattern_ops/text_pattern_ops.
117 if getattr(field, "db_collation", None):
118 return None
119 if db_type.startswith("varchar"):
120 return self._create_index_sql(
121 model,
122 fields=[field],
123 suffix="_like",
124 opclasses=("varchar_pattern_ops",),
125 )
126 elif db_type.startswith("text"):
127 return self._create_index_sql(
128 model,
129 fields=[field],
130 suffix="_like",
131 opclasses=("text_pattern_ops",),
132 )
133 return None
134
135 def _using_sql(self, new_field: Field, old_field: Field) -> str:
136 using_sql = " USING %(column)s::%(type)s"
137 if self._field_data_type(old_field) != self._field_data_type(new_field):
138 return using_sql
139 return ""
140
141 def _get_sequence_name(self, table: str, column: str) -> str | None:
142 with self.connection.cursor() as cursor:
143 for sequence in self.connection.introspection.get_sequences(cursor, table):
144 if sequence["column"] == column:
145 return sequence["name"]
146 return None
147
148 def _alter_column_type_sql(
149 self,
150 model: type[Model],
151 old_field: Field,
152 new_field: Field,
153 new_type: str,
154 old_collation: str | None,
155 new_collation: str | None,
156 ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
157 # Drop indexes on varchar/text/citext columns that are changing to a
158 # different type.
159 old_db_params = old_field.db_parameters(connection=self.connection)
160 old_type = old_db_params["type"]
161 if old_field.primary_key and (
162 (old_type.startswith("varchar") and not new_type.startswith("varchar"))
163 or (old_type.startswith("text") and not new_type.startswith("text"))
164 or (old_type.startswith("citext") and not new_type.startswith("citext"))
165 ):
166 index_name = self._create_index_name(
167 model.model_options.db_table, [old_field.column], suffix="_like"
168 )
169 self.execute(self._delete_index_sql(model, index_name))
170
171 self.sql_alter_column_type = (
172 "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
173 )
174 # Cast when data type changed.
175 if using_sql := self._using_sql(new_field, old_field):
176 self.sql_alter_column_type += using_sql
177 new_internal_type = new_field.get_internal_type()
178 old_internal_type = old_field.get_internal_type()
179 # Make ALTER TYPE with IDENTITY make sense.
180 table = strip_quotes(model.model_options.db_table)
181 auto_field_types = {"PrimaryKeyField"}
182 old_is_auto = old_internal_type in auto_field_types
183 new_is_auto = new_internal_type in auto_field_types
184 if new_is_auto and not old_is_auto:
185 column = strip_quotes(new_field.column)
186 return (
187 (
188 self.sql_alter_column_type
189 % {
190 "column": self.quote_name(column),
191 "type": new_type,
192 "collation": "",
193 },
194 [],
195 ),
196 [
197 (
198 self.sql_add_identity
199 % {
200 "table": self.quote_name(table),
201 "column": self.quote_name(column),
202 },
203 [],
204 ),
205 ],
206 )
207 elif old_is_auto and not new_is_auto:
208 # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
209 # it).
210 self.execute(
211 self.sql_drop_indentity
212 % {
213 "table": self.quote_name(table),
214 "column": self.quote_name(strip_quotes(new_field.column)),
215 }
216 )
217 column = strip_quotes(new_field.column)
218 fragment, _ = super()._alter_column_type_sql(
219 model, old_field, new_field, new_type, old_collation, new_collation
220 )
221 # Drop the sequence if exists (Plain 4.1+ identity columns don't
222 # have it).
223 other_actions = []
224 if sequence_name := self._get_sequence_name(table, column):
225 other_actions = [
226 (
227 self.sql_delete_sequence
228 % {
229 "sequence": self.quote_name(sequence_name),
230 },
231 [],
232 )
233 ]
234 return fragment, other_actions
235 elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
236 fragment, _ = super()._alter_column_type_sql(
237 model, old_field, new_field, new_type, old_collation, new_collation
238 )
239 column = strip_quotes(new_field.column)
240 db_types = {"PrimaryKeyField": "bigint"}
241 # Alter the sequence type if exists (Plain 4.1+ identity columns
242 # don't have it).
243 other_actions = []
244 if sequence_name := self._get_sequence_name(table, column):
245 other_actions = [
246 (
247 self.sql_alter_sequence_type
248 % {
249 "sequence": self.quote_name(sequence_name),
250 "type": db_types[new_internal_type],
251 },
252 [],
253 ),
254 ]
255 return fragment, other_actions
256 else:
257 return super()._alter_column_type_sql(
258 model, old_field, new_field, new_type, old_collation, new_collation
259 )
260
261 def _alter_field(
262 self,
263 model: type[Model],
264 old_field: Field,
265 new_field: Field,
266 old_type: str,
267 new_type: str,
268 old_db_params: dict[str, Any],
269 new_db_params: dict[str, Any],
270 strict: bool = False,
271 ) -> None:
272 super()._alter_field(
273 model,
274 old_field,
275 new_field,
276 old_type,
277 new_type,
278 old_db_params,
279 new_db_params,
280 strict,
281 )
282 # Added an index? Create any PostgreSQL-specific indexes.
283 if (
284 not (
285 (isinstance(old_field, ForeignKeyField) and old_field.db_index)
286 or old_field.primary_key
287 )
288 and isinstance(new_field, ForeignKeyField)
289 and new_field.db_index
290 ) or (not old_field.primary_key and new_field.primary_key):
291 like_index_statement = self._create_like_index_sql(model, new_field)
292 if like_index_statement is not None:
293 self.execute(like_index_statement)
294
295 # Removed an index? Drop any PostgreSQL-specific indexes.
296 if old_field.primary_key and not (
297 (isinstance(new_field, ForeignKeyField) and new_field.db_index)
298 or new_field.primary_key
299 ):
300 index_to_remove = self._create_index_name(
301 model.model_options.db_table, [old_field.column], suffix="_like"
302 )
303 self.execute(self._delete_index_sql(model, index_to_remove))
304
305 def _index_columns(
306 self,
307 table: str,
308 columns: list[str],
309 col_suffixes: tuple[str, ...],
310 opclasses: tuple[str, ...],
311 ) -> Columns | IndexColumns:
312 if opclasses:
313 return IndexColumns(
314 table,
315 columns,
316 self.quote_name,
317 col_suffixes=col_suffixes,
318 opclasses=opclasses,
319 )
320 return super()._index_columns(table, columns, col_suffixes, opclasses)
321
322 def add_index(
323 self, model: type[Model], index: Index, concurrently: bool = False
324 ) -> None:
325 self.execute(
326 index.create_sql(model, self, concurrently=concurrently), params=None
327 )
328
329 def remove_index(
330 self, model: type[Model], index: Index, concurrently: bool = False
331 ) -> None:
332 self.execute(index.remove_sql(model, self, concurrently=concurrently))
333
334 def _delete_index_sql(
335 self,
336 model: type[Model],
337 name: str,
338 sql: str | None = None,
339 concurrently: bool = False,
340 ) -> Statement:
341 sql = (
342 self.sql_delete_index_concurrently
343 if concurrently
344 else self.sql_delete_index
345 )
346 return super()._delete_index_sql(model, name, sql)
347
348 def _create_index_sql(
349 self,
350 model: type[Model],
351 *,
352 fields: list[Field] | None = None,
353 name: str | None = None,
354 suffix: str = "",
355 using: str = "",
356 col_suffixes: tuple[str, ...] = (),
357 sql: str | None = None,
358 opclasses: tuple[str, ...] = (),
359 condition: str | None = None,
360 concurrently: bool = False,
361 include: list[str] | None = None,
362 expressions: Any = None,
363 ) -> Statement:
364 sql = sql or (
365 self.sql_create_index
366 if not concurrently
367 else self.sql_create_index_concurrently
368 )
369 return super()._create_index_sql(
370 model,
371 fields=fields,
372 name=name,
373 suffix=suffix,
374 using=using,
375 col_suffixes=col_suffixes,
376 sql=sql,
377 opclasses=opclasses,
378 condition=condition,
379 include=include,
380 expressions=expressions,
381 )