1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from psycopg import sql # type: ignore[import-untyped]
6
7from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
8from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
9from plain.models.backends.utils import strip_quotes
10
11if TYPE_CHECKING:
12 from collections.abc import Generator
13
14 from plain.models.base import Model
15 from plain.models.fields import Field
16 from plain.models.indexes import Index
17
18
19class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
20 # Setting all constraints to IMMEDIATE to allow changing data in the same
21 # transaction.
22 sql_update_with_default = (
23 "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
24 "; SET CONSTRAINTS ALL IMMEDIATE"
25 )
26 sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
27 sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
28
29 sql_create_index = (
30 "CREATE INDEX %(name)s ON %(table)s%(using)s "
31 "(%(columns)s)%(include)s%(extra)s%(condition)s"
32 )
33 sql_create_index_concurrently = (
34 "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
35 "(%(columns)s)%(include)s%(extra)s%(condition)s"
36 )
37 sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
38 sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
39
40 # Setting the constraint to IMMEDIATE to allow changing data in the same
41 # transaction.
42 sql_create_column_inline_fk = (
43 "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
44 "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
45 )
46 # Setting the constraint to IMMEDIATE runs any deferred checks to allow
47 # dropping it in the same transaction.
48 sql_delete_fk = (
49 "SET CONSTRAINTS %(name)s IMMEDIATE; "
50 "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
51 )
52
53 def execute(
54 self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
55 ) -> None:
56 # Merge the query client-side, as PostgreSQL won't do it server-side.
57 if params is None:
58 return super().execute(sql, params)
59 sql = self.connection.ops.compose_sql(str(sql), params)
60 # Don't let the superclass touch anything.
61 return super().execute(sql, None)
62
63 sql_add_identity = (
64 "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
65 "GENERATED BY DEFAULT AS IDENTITY"
66 )
67 sql_drop_indentity = (
68 "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
69 )
70
71 def quote_value(self, value: Any) -> str:
72 if isinstance(value, str):
73 value = value.replace("%", "%%")
74 return sql.quote(value, self.connection.connection)
75
76 def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
77 output = super()._field_indexes_sql(model, field)
78 like_index_statement = self._create_like_index_sql(model, field)
79 if like_index_statement is not None:
80 output.append(like_index_statement)
81 return output
82
83 def _field_data_type(self, field: Field) -> str | None:
84 if field.is_relation:
85 return field.rel_db_type(self.connection)
86 return self.connection.data_types.get(
87 field.get_internal_type(),
88 field.db_type(self.connection),
89 )
90
91 def _field_base_data_types(self, field: Field) -> Generator[str | None, None, None]:
92 # Yield base data types for array fields.
93 if field.base_field.get_internal_type() == "ArrayField": # type: ignore[attr-defined]
94 yield from self._field_base_data_types(field.base_field) # type: ignore[attr-defined]
95 else:
96 yield self._field_data_type(field.base_field) # type: ignore[attr-defined]
97
98 def _create_like_index_sql(
99 self, model: type[Model], field: Field
100 ) -> Statement | None:
101 """
102 Return the statement to create an index with varchar operator pattern
103 when the column type is 'varchar' or 'text', otherwise return None.
104 """
105 db_type = field.db_type(connection=self.connection)
106 if db_type is not None and field.primary_key:
107 # Fields with database column types of `varchar` and `text` need
108 # a second index that specifies their operator class, which is
109 # needed when performing correct LIKE queries outside the
110 # C locale. See #12234.
111 #
112 # The same doesn't apply to array fields such as varchar[size]
113 # and text[size], so skip them.
114 if "[" in db_type:
115 return None
116 # Non-deterministic collations on Postgresql don't support indexes
117 # for operator classes varchar_pattern_ops/text_pattern_ops.
118 if getattr(field, "db_collation", None):
119 return None
120 if db_type.startswith("varchar"):
121 return self._create_index_sql(
122 model,
123 fields=[field],
124 suffix="_like",
125 opclasses=("varchar_pattern_ops",),
126 )
127 elif db_type.startswith("text"):
128 return self._create_index_sql(
129 model,
130 fields=[field],
131 suffix="_like",
132 opclasses=("text_pattern_ops",),
133 )
134 return None
135
136 def _using_sql(self, new_field: Field, old_field: Field) -> str:
137 using_sql = " USING %(column)s::%(type)s"
138 new_internal_type = new_field.get_internal_type()
139 old_internal_type = old_field.get_internal_type()
140 if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
141 # Compare base data types for array fields.
142 if list(self._field_base_data_types(old_field)) != list(
143 self._field_base_data_types(new_field)
144 ):
145 return using_sql
146 elif self._field_data_type(old_field) != self._field_data_type(new_field):
147 return using_sql
148 return ""
149
150 def _get_sequence_name(self, table: str, column: str) -> str | None:
151 with self.connection.cursor() as cursor:
152 for sequence in self.connection.introspection.get_sequences(cursor, table):
153 if sequence["column"] == column:
154 return sequence["name"]
155 return None
156
157 def _alter_column_type_sql(
158 self,
159 model: type[Model],
160 old_field: Field,
161 new_field: Field,
162 new_type: str,
163 old_collation: str | None,
164 new_collation: str | None,
165 ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
166 # Drop indexes on varchar/text/citext columns that are changing to a
167 # different type.
168 old_db_params = old_field.db_parameters(connection=self.connection)
169 old_type = old_db_params["type"]
170 if old_field.primary_key and (
171 (old_type.startswith("varchar") and not new_type.startswith("varchar"))
172 or (old_type.startswith("text") and not new_type.startswith("text"))
173 or (old_type.startswith("citext") and not new_type.startswith("citext"))
174 ):
175 index_name = self._create_index_name(
176 model.model_options.db_table, [old_field.column], suffix="_like"
177 )
178 self.execute(self._delete_index_sql(model, index_name))
179
180 self.sql_alter_column_type = (
181 "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
182 )
183 # Cast when data type changed.
184 if using_sql := self._using_sql(new_field, old_field):
185 self.sql_alter_column_type += using_sql
186 new_internal_type = new_field.get_internal_type()
187 old_internal_type = old_field.get_internal_type()
188 # Make ALTER TYPE with IDENTITY make sense.
189 table = strip_quotes(model.model_options.db_table)
190 auto_field_types = {"PrimaryKeyField"}
191 old_is_auto = old_internal_type in auto_field_types
192 new_is_auto = new_internal_type in auto_field_types
193 if new_is_auto and not old_is_auto:
194 column = strip_quotes(new_field.column)
195 return (
196 (
197 self.sql_alter_column_type
198 % {
199 "column": self.quote_name(column),
200 "type": new_type,
201 "collation": "",
202 },
203 [],
204 ),
205 [
206 (
207 self.sql_add_identity
208 % {
209 "table": self.quote_name(table),
210 "column": self.quote_name(column),
211 },
212 [],
213 ),
214 ],
215 )
216 elif old_is_auto and not new_is_auto:
217 # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
218 # it).
219 self.execute(
220 self.sql_drop_indentity
221 % {
222 "table": self.quote_name(table),
223 "column": self.quote_name(strip_quotes(new_field.column)),
224 }
225 )
226 column = strip_quotes(new_field.column)
227 fragment, _ = super()._alter_column_type_sql(
228 model, old_field, new_field, new_type, old_collation, new_collation
229 )
230 # Drop the sequence if exists (Plain 4.1+ identity columns don't
231 # have it).
232 other_actions = []
233 if sequence_name := self._get_sequence_name(table, column):
234 other_actions = [
235 (
236 self.sql_delete_sequence
237 % {
238 "sequence": self.quote_name(sequence_name),
239 },
240 [],
241 )
242 ]
243 return fragment, other_actions
244 elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
245 fragment, _ = super()._alter_column_type_sql(
246 model, old_field, new_field, new_type, old_collation, new_collation
247 )
248 column = strip_quotes(new_field.column)
249 db_types = {"PrimaryKeyField": "bigint"}
250 # Alter the sequence type if exists (Plain 4.1+ identity columns
251 # don't have it).
252 other_actions = []
253 if sequence_name := self._get_sequence_name(table, column):
254 other_actions = [
255 (
256 self.sql_alter_sequence_type
257 % {
258 "sequence": self.quote_name(sequence_name),
259 "type": db_types[new_internal_type],
260 },
261 [],
262 ),
263 ]
264 return fragment, other_actions
265 else:
266 return super()._alter_column_type_sql(
267 model, old_field, new_field, new_type, old_collation, new_collation
268 )
269
270 def _alter_field(
271 self,
272 model: type[Model],
273 old_field: Field,
274 new_field: Field,
275 old_type: str,
276 new_type: str,
277 old_db_params: dict[str, Any],
278 new_db_params: dict[str, Any],
279 strict: bool = False,
280 ) -> None:
281 super()._alter_field(
282 model,
283 old_field,
284 new_field,
285 old_type,
286 new_type,
287 old_db_params,
288 new_db_params,
289 strict,
290 )
291 # Added an index? Create any PostgreSQL-specific indexes.
292 if (
293 not (
294 (old_field.remote_field and old_field.db_index) # type: ignore[attr-defined]
295 or old_field.primary_key
296 )
297 and (new_field.remote_field and new_field.db_index) # type: ignore[attr-defined]
298 ) or (not old_field.primary_key and new_field.primary_key):
299 like_index_statement = self._create_like_index_sql(model, new_field)
300 if like_index_statement is not None:
301 self.execute(like_index_statement)
302
303 # Removed an index? Drop any PostgreSQL-specific indexes.
304 if old_field.primary_key and not (
305 (new_field.remote_field and new_field.db_index) # type: ignore[attr-defined]
306 or new_field.primary_key
307 ):
308 index_to_remove = self._create_index_name(
309 model.model_options.db_table, [old_field.column], suffix="_like"
310 )
311 self.execute(self._delete_index_sql(model, index_to_remove))
312
313 def _index_columns(
314 self,
315 table: str,
316 columns: list[str],
317 col_suffixes: tuple[str, ...],
318 opclasses: tuple[str, ...],
319 ) -> Columns | IndexColumns:
320 if opclasses:
321 return IndexColumns(
322 table,
323 columns,
324 self.quote_name,
325 col_suffixes=col_suffixes,
326 opclasses=opclasses,
327 )
328 return super()._index_columns(table, columns, col_suffixes, opclasses)
329
330 def add_index(
331 self, model: type[Model], index: Index, concurrently: bool = False
332 ) -> None:
333 self.execute(
334 index.create_sql(model, self, concurrently=concurrently), params=None
335 )
336
337 def remove_index(
338 self, model: type[Model], index: Index, concurrently: bool = False
339 ) -> None:
340 self.execute(index.remove_sql(model, self, concurrently=concurrently))
341
342 def _delete_index_sql(
343 self,
344 model: type[Model],
345 name: str,
346 sql: str | None = None,
347 concurrently: bool = False,
348 ) -> Statement:
349 sql = (
350 self.sql_delete_index_concurrently
351 if concurrently
352 else self.sql_delete_index
353 )
354 return super()._delete_index_sql(model, name, sql)
355
356 def _create_index_sql(
357 self,
358 model: type[Model],
359 *,
360 fields: list[Field] | None = None,
361 name: str | None = None,
362 suffix: str = "",
363 using: str = "",
364 col_suffixes: tuple[str, ...] = (),
365 sql: str | None = None,
366 opclasses: tuple[str, ...] = (),
367 condition: str | None = None,
368 concurrently: bool = False,
369 include: list[str] | None = None,
370 expressions: Any = None,
371 ) -> Statement:
372 sql = sql or (
373 self.sql_create_index
374 if not concurrently
375 else self.sql_create_index_concurrently
376 )
377 return super()._create_index_sql(
378 model,
379 fields=fields,
380 name=name,
381 suffix=suffix,
382 using=using,
383 col_suffixes=col_suffixes,
384 sql=sql,
385 opclasses=opclasses,
386 condition=condition,
387 include=include,
388 expressions=expressions,
389 )