1from __future__ import annotations
2
3from collections.abc import Callable
4from typing import TYPE_CHECKING, Any
5
6from psycopg import sql # type: ignore[import-untyped]
7
8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
10from plain.models.backends.utils import strip_quotes
11
12if TYPE_CHECKING:
13 from collections.abc import Generator
14
15 from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
16 from plain.models.base import Model
17 from plain.models.fields import Field
18 from plain.models.indexes import Index
19
20
21class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
22 # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
23 connection: PostgreSQLDatabaseWrapper
24
25 # Setting all constraints to IMMEDIATE to allow changing data in the same
26 # transaction.
27 sql_update_with_default = (
28 "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
29 "; SET CONSTRAINTS ALL IMMEDIATE"
30 )
31 sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
32 sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
33
34 sql_create_index = (
35 "CREATE INDEX %(name)s ON %(table)s%(using)s "
36 "(%(columns)s)%(include)s%(extra)s%(condition)s"
37 )
38 sql_create_index_concurrently = (
39 "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
40 "(%(columns)s)%(include)s%(extra)s%(condition)s"
41 )
42 sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
43 sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
44
45 # Setting the constraint to IMMEDIATE to allow changing data in the same
46 # transaction.
47 sql_create_column_inline_fk = (
48 "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
49 "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
50 )
51 # Setting the constraint to IMMEDIATE runs any deferred checks to allow
52 # dropping it in the same transaction.
53 sql_delete_fk = (
54 "SET CONSTRAINTS %(name)s IMMEDIATE; "
55 "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
56 )
57
58 def execute(
59 self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
60 ) -> None:
61 # Merge the query client-side, as PostgreSQL won't do it server-side.
62 if params is None:
63 return super().execute(sql, params)
64 sql = self.connection.ops.compose_sql(str(sql), params)
65 # Don't let the superclass touch anything.
66 return super().execute(sql, None)
67
68 sql_add_identity = (
69 "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
70 "GENERATED BY DEFAULT AS IDENTITY"
71 )
72 sql_drop_indentity = (
73 "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
74 )
75
76 def quote_value(self, value: Any) -> str:
77 if isinstance(value, str):
78 value = value.replace("%", "%%")
79 return sql.quote(value, self.connection.connection)
80
81 def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
82 output = super()._field_indexes_sql(model, field)
83 like_index_statement = self._create_like_index_sql(model, field)
84 if like_index_statement is not None:
85 output.append(like_index_statement)
86 return output
87
88 def _field_data_type(
89 self, field: Field
90 ) -> str | None | Callable[[dict[str, Any]], str]:
91 if field.is_relation:
92 return field.rel_db_type(self.connection)
93 return self.connection.data_types.get(
94 field.get_internal_type(),
95 field.db_type(self.connection),
96 )
97
98 def _field_base_data_types(self, field: Field) -> Generator[str | None, None, None]:
99 # Yield base data types for array fields.
100 if field.base_field.get_internal_type() == "ArrayField": # type: ignore[attr-defined]
101 yield from self._field_base_data_types(field.base_field) # type: ignore[attr-defined]
102 else:
103 yield self._field_data_type(field.base_field) # type: ignore[attr-defined]
104
105 def _create_like_index_sql(
106 self, model: type[Model], field: Field
107 ) -> Statement | None:
108 """
109 Return the statement to create an index with varchar operator pattern
110 when the column type is 'varchar' or 'text', otherwise return None.
111 """
112 db_type = field.db_type(connection=self.connection)
113 if db_type is not None and field.primary_key:
114 # Fields with database column types of `varchar` and `text` need
115 # a second index that specifies their operator class, which is
116 # needed when performing correct LIKE queries outside the
117 # C locale. See #12234.
118 #
119 # The same doesn't apply to array fields such as varchar[size]
120 # and text[size], so skip them.
121 if "[" in db_type:
122 return None
123 # Non-deterministic collations on Postgresql don't support indexes
124 # for operator classes varchar_pattern_ops/text_pattern_ops.
125 if getattr(field, "db_collation", None):
126 return None
127 if db_type.startswith("varchar"):
128 return self._create_index_sql(
129 model,
130 fields=[field],
131 suffix="_like",
132 opclasses=("varchar_pattern_ops",),
133 )
134 elif db_type.startswith("text"):
135 return self._create_index_sql(
136 model,
137 fields=[field],
138 suffix="_like",
139 opclasses=("text_pattern_ops",),
140 )
141 return None
142
143 def _using_sql(self, new_field: Field, old_field: Field) -> str:
144 using_sql = " USING %(column)s::%(type)s"
145 new_internal_type = new_field.get_internal_type()
146 old_internal_type = old_field.get_internal_type()
147 if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
148 # Compare base data types for array fields.
149 if list(self._field_base_data_types(old_field)) != list(
150 self._field_base_data_types(new_field)
151 ):
152 return using_sql
153 elif self._field_data_type(old_field) != self._field_data_type(new_field):
154 return using_sql
155 return ""
156
157 def _get_sequence_name(self, table: str, column: str) -> str | None:
158 with self.connection.cursor() as cursor:
159 for sequence in self.connection.introspection.get_sequences(cursor, table):
160 if sequence["column"] == column:
161 return sequence["name"]
162 return None
163
164 def _alter_column_type_sql(
165 self,
166 model: type[Model],
167 old_field: Field,
168 new_field: Field,
169 new_type: str,
170 old_collation: str | None,
171 new_collation: str | None,
172 ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
173 # Drop indexes on varchar/text/citext columns that are changing to a
174 # different type.
175 old_db_params = old_field.db_parameters(connection=self.connection)
176 old_type = old_db_params["type"]
177 if old_field.primary_key and (
178 (old_type.startswith("varchar") and not new_type.startswith("varchar"))
179 or (old_type.startswith("text") and not new_type.startswith("text"))
180 or (old_type.startswith("citext") and not new_type.startswith("citext"))
181 ):
182 index_name = self._create_index_name(
183 model.model_options.db_table, [old_field.column], suffix="_like"
184 )
185 self.execute(self._delete_index_sql(model, index_name))
186
187 self.sql_alter_column_type = (
188 "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
189 )
190 # Cast when data type changed.
191 if using_sql := self._using_sql(new_field, old_field):
192 self.sql_alter_column_type += using_sql
193 new_internal_type = new_field.get_internal_type()
194 old_internal_type = old_field.get_internal_type()
195 # Make ALTER TYPE with IDENTITY make sense.
196 table = strip_quotes(model.model_options.db_table)
197 auto_field_types = {"PrimaryKeyField"}
198 old_is_auto = old_internal_type in auto_field_types
199 new_is_auto = new_internal_type in auto_field_types
200 if new_is_auto and not old_is_auto:
201 column = strip_quotes(new_field.column)
202 return (
203 (
204 self.sql_alter_column_type
205 % {
206 "column": self.quote_name(column),
207 "type": new_type,
208 "collation": "",
209 },
210 [],
211 ),
212 [
213 (
214 self.sql_add_identity
215 % {
216 "table": self.quote_name(table),
217 "column": self.quote_name(column),
218 },
219 [],
220 ),
221 ],
222 )
223 elif old_is_auto and not new_is_auto:
224 # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
225 # it).
226 self.execute(
227 self.sql_drop_indentity
228 % {
229 "table": self.quote_name(table),
230 "column": self.quote_name(strip_quotes(new_field.column)),
231 }
232 )
233 column = strip_quotes(new_field.column)
234 fragment, _ = super()._alter_column_type_sql(
235 model, old_field, new_field, new_type, old_collation, new_collation
236 )
237 # Drop the sequence if exists (Plain 4.1+ identity columns don't
238 # have it).
239 other_actions = []
240 if sequence_name := self._get_sequence_name(table, column):
241 other_actions = [
242 (
243 self.sql_delete_sequence
244 % {
245 "sequence": self.quote_name(sequence_name),
246 },
247 [],
248 )
249 ]
250 return fragment, other_actions
251 elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
252 fragment, _ = super()._alter_column_type_sql(
253 model, old_field, new_field, new_type, old_collation, new_collation
254 )
255 column = strip_quotes(new_field.column)
256 db_types = {"PrimaryKeyField": "bigint"}
257 # Alter the sequence type if exists (Plain 4.1+ identity columns
258 # don't have it).
259 other_actions = []
260 if sequence_name := self._get_sequence_name(table, column):
261 other_actions = [
262 (
263 self.sql_alter_sequence_type
264 % {
265 "sequence": self.quote_name(sequence_name),
266 "type": db_types[new_internal_type],
267 },
268 [],
269 ),
270 ]
271 return fragment, other_actions
272 else:
273 return super()._alter_column_type_sql(
274 model, old_field, new_field, new_type, old_collation, new_collation
275 )
276
277 def _alter_field(
278 self,
279 model: type[Model],
280 old_field: Field,
281 new_field: Field,
282 old_type: str,
283 new_type: str,
284 old_db_params: dict[str, Any],
285 new_db_params: dict[str, Any],
286 strict: bool = False,
287 ) -> None:
288 super()._alter_field(
289 model,
290 old_field,
291 new_field,
292 old_type,
293 new_type,
294 old_db_params,
295 new_db_params,
296 strict,
297 )
298 # Added an index? Create any PostgreSQL-specific indexes.
299 if (
300 not (
301 (old_field.remote_field and old_field.db_index) # type: ignore[attr-defined]
302 or old_field.primary_key
303 )
304 and (new_field.remote_field and new_field.db_index) # type: ignore[attr-defined]
305 ) or (not old_field.primary_key and new_field.primary_key):
306 like_index_statement = self._create_like_index_sql(model, new_field)
307 if like_index_statement is not None:
308 self.execute(like_index_statement)
309
310 # Removed an index? Drop any PostgreSQL-specific indexes.
311 if old_field.primary_key and not (
312 (new_field.remote_field and new_field.db_index) # type: ignore[attr-defined]
313 or new_field.primary_key
314 ):
315 index_to_remove = self._create_index_name(
316 model.model_options.db_table, [old_field.column], suffix="_like"
317 )
318 self.execute(self._delete_index_sql(model, index_to_remove))
319
320 def _index_columns(
321 self,
322 table: str,
323 columns: list[str],
324 col_suffixes: tuple[str, ...],
325 opclasses: tuple[str, ...],
326 ) -> Columns | IndexColumns:
327 if opclasses:
328 return IndexColumns(
329 table,
330 columns,
331 self.quote_name,
332 col_suffixes=col_suffixes,
333 opclasses=opclasses,
334 )
335 return super()._index_columns(table, columns, col_suffixes, opclasses)
336
337 def add_index(
338 self, model: type[Model], index: Index, concurrently: bool = False
339 ) -> None:
340 self.execute(
341 index.create_sql(model, self, concurrently=concurrently), params=None
342 )
343
344 def remove_index(
345 self, model: type[Model], index: Index, concurrently: bool = False
346 ) -> None:
347 self.execute(index.remove_sql(model, self, concurrently=concurrently))
348
349 def _delete_index_sql(
350 self,
351 model: type[Model],
352 name: str,
353 sql: str | None = None,
354 concurrently: bool = False,
355 ) -> Statement:
356 sql = (
357 self.sql_delete_index_concurrently
358 if concurrently
359 else self.sql_delete_index
360 )
361 return super()._delete_index_sql(model, name, sql)
362
363 def _create_index_sql(
364 self,
365 model: type[Model],
366 *,
367 fields: list[Field] | None = None,
368 name: str | None = None,
369 suffix: str = "",
370 using: str = "",
371 col_suffixes: tuple[str, ...] = (),
372 sql: str | None = None,
373 opclasses: tuple[str, ...] = (),
374 condition: str | None = None,
375 concurrently: bool = False,
376 include: list[str] | None = None,
377 expressions: Any = None,
378 ) -> Statement:
379 sql = sql or (
380 self.sql_create_index
381 if not concurrently
382 else self.sql_create_index_concurrently
383 )
384 return super()._create_index_sql(
385 model,
386 fields=fields,
387 name=name,
388 suffix=suffix,
389 using=using,
390 col_suffixes=col_suffixes,
391 sql=sql,
392 opclasses=opclasses,
393 condition=condition,
394 include=include,
395 expressions=expressions,
396 )