1from __future__ import annotations
2
3from collections.abc import Callable
4from typing import TYPE_CHECKING, Any
5
6from psycopg import sql
7
8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
9from plain.models.backends.ddl_references import Columns, IndexColumns, Statement
10from plain.models.backends.utils import strip_quotes
11from plain.models.fields import DbParameters
12from plain.models.fields.related import ForeignKeyField, RelatedField
13
14if TYPE_CHECKING:
15 from plain.models.backends.postgresql.base import PostgreSQLDatabaseWrapper
16 from plain.models.base import Model
17 from plain.models.fields import Field
18 from plain.models.indexes import Index
19
20
21class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
22 # Type checker hint: connection is always PostgreSQLDatabaseWrapper in this class
23 connection: PostgreSQLDatabaseWrapper
24
25 # Setting all constraints to IMMEDIATE to allow changing data in the same
26 # transaction.
27 sql_update_with_default = (
28 "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
29 "; SET CONSTRAINTS ALL IMMEDIATE"
30 )
31 sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
32 sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
33
34 sql_create_index = (
35 "CREATE INDEX %(name)s ON %(table)s%(using)s "
36 "(%(columns)s)%(include)s%(extra)s%(condition)s"
37 )
38 sql_create_index_concurrently = (
39 "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
40 "(%(columns)s)%(include)s%(extra)s%(condition)s"
41 )
42 sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
43 sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
44
45 # Setting the constraint to IMMEDIATE to allow changing data in the same
46 # transaction.
47 sql_create_column_inline_fk = (
48 "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
49 "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
50 )
51 # Setting the constraint to IMMEDIATE runs any deferred checks to allow
52 # dropping it in the same transaction.
53 sql_delete_fk = (
54 "SET CONSTRAINTS %(name)s IMMEDIATE; "
55 "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
56 )
57
58 def execute(
59 self, sql: str | Statement, params: tuple[Any, ...] | list[Any] | None = ()
60 ) -> None:
61 # Merge the query client-side, as PostgreSQL won't do it server-side.
62 if params is None:
63 return super().execute(sql, params)
64 sql = self.connection.ops.compose_sql(str(sql), params)
65 # Don't let the superclass touch anything.
66 return super().execute(sql, None)
67
68 sql_add_identity = (
69 "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
70 "GENERATED BY DEFAULT AS IDENTITY"
71 )
72 sql_drop_indentity = (
73 "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
74 )
75
76 def quote_value(self, value: Any) -> str:
77 if isinstance(value, str):
78 value = value.replace("%", "%%")
79 return sql.quote(value, self.connection.connection)
80
81 def _field_indexes_sql(self, model: type[Model], field: Field) -> list[Statement]:
82 output = super()._field_indexes_sql(model, field)
83 like_index_statement = self._create_like_index_sql(model, field)
84 if like_index_statement is not None:
85 output.append(like_index_statement)
86 return output
87
88 def _field_data_type(
89 self, field: Field
90 ) -> str | None | Callable[[dict[str, Any]], str]:
91 if isinstance(field, RelatedField):
92 return field.rel_db_type(self.connection)
93 return self.connection.data_types.get(
94 field.get_internal_type(),
95 field.db_type(self.connection),
96 )
97
98 def _create_like_index_sql(
99 self, model: type[Model], field: Field
100 ) -> Statement | None:
101 """
102 Return the statement to create an index with varchar operator pattern
103 when the column type is 'varchar' or 'text', otherwise return None.
104 """
105 db_type = field.db_type(connection=self.connection)
106 if db_type is not None and field.primary_key:
107 # Fields with database column types of `varchar` and `text` need
108 # a second index that specifies their operator class, which is
109 # needed when performing correct LIKE queries outside the
110 # C locale. See #12234.
111 #
112 # The same doesn't apply to array fields such as varchar[size]
113 # and text[size], so skip them.
114 if "[" in db_type:
115 return None
116 if db_type.startswith("varchar"):
117 return self._create_index_sql(
118 model,
119 fields=[field],
120 suffix="_like",
121 opclasses=("varchar_pattern_ops",),
122 )
123 elif db_type.startswith("text"):
124 return self._create_index_sql(
125 model,
126 fields=[field],
127 suffix="_like",
128 opclasses=("text_pattern_ops",),
129 )
130 return None
131
132 def _using_sql(self, new_field: Field, old_field: Field) -> str:
133 using_sql = " USING %(column)s::%(type)s"
134 if self._field_data_type(old_field) != self._field_data_type(new_field):
135 return using_sql
136 return ""
137
138 def _get_sequence_name(self, table: str, column: str) -> str | None:
139 with self.connection.cursor() as cursor:
140 for sequence in self.connection.introspection.get_sequences(cursor, table):
141 if sequence["column"] == column:
142 return sequence["name"]
143 return None
144
145 def _alter_column_type_sql(
146 self,
147 model: type[Model],
148 old_field: Field,
149 new_field: Field,
150 new_type: str,
151 ) -> tuple[tuple[str, list[Any]], list[tuple[str, list[Any]]]]:
152 # Drop indexes on varchar/text/citext columns that are changing to a
153 # different type.
154 old_db_params = old_field.db_parameters(connection=self.connection)
155 old_type = old_db_params["type"]
156 assert old_type is not None, "old_type cannot be None for primary key field"
157 if old_field.primary_key and (
158 (old_type.startswith("varchar") and not new_type.startswith("varchar"))
159 or (old_type.startswith("text") and not new_type.startswith("text"))
160 or (old_type.startswith("citext") and not new_type.startswith("citext"))
161 ):
162 index_name = self._create_index_name(
163 model.model_options.db_table, [old_field.column], suffix="_like"
164 )
165 self.execute(self._delete_index_sql(model, index_name))
166
167 self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
168 # Cast when data type changed.
169 if using_sql := self._using_sql(new_field, old_field):
170 self.sql_alter_column_type += using_sql
171 new_internal_type = new_field.get_internal_type()
172 old_internal_type = old_field.get_internal_type()
173 # Make ALTER TYPE with IDENTITY make sense.
174 table = strip_quotes(model.model_options.db_table)
175 auto_field_types = {"PrimaryKeyField"}
176 old_is_auto = old_internal_type in auto_field_types
177 new_is_auto = new_internal_type in auto_field_types
178 if new_is_auto and not old_is_auto:
179 column = strip_quotes(new_field.column)
180 return (
181 (
182 self.sql_alter_column_type
183 % {
184 "column": self.quote_name(column),
185 "type": new_type,
186 },
187 [],
188 ),
189 [
190 (
191 self.sql_add_identity
192 % {
193 "table": self.quote_name(table),
194 "column": self.quote_name(column),
195 },
196 [],
197 ),
198 ],
199 )
200 elif old_is_auto and not new_is_auto:
201 # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
202 # it).
203 self.execute(
204 self.sql_drop_indentity
205 % {
206 "table": self.quote_name(table),
207 "column": self.quote_name(strip_quotes(new_field.column)),
208 }
209 )
210 column = strip_quotes(new_field.column)
211 fragment, _ = super()._alter_column_type_sql(
212 model, old_field, new_field, new_type
213 )
214 # Drop the sequence if exists (Plain 4.1+ identity columns don't
215 # have it).
216 other_actions = []
217 if sequence_name := self._get_sequence_name(table, column):
218 other_actions = [
219 (
220 self.sql_delete_sequence
221 % {
222 "sequence": self.quote_name(sequence_name),
223 },
224 [],
225 )
226 ]
227 return fragment, other_actions
228 elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
229 fragment, _ = super()._alter_column_type_sql(
230 model, old_field, new_field, new_type
231 )
232 column = strip_quotes(new_field.column)
233 db_types = {"PrimaryKeyField": "bigint"}
234 # Alter the sequence type if exists (Plain 4.1+ identity columns
235 # don't have it).
236 other_actions = []
237 if sequence_name := self._get_sequence_name(table, column):
238 other_actions = [
239 (
240 self.sql_alter_sequence_type
241 % {
242 "sequence": self.quote_name(sequence_name),
243 "type": db_types[new_internal_type],
244 },
245 [],
246 ),
247 ]
248 return fragment, other_actions
249 else:
250 return super()._alter_column_type_sql(model, old_field, new_field, new_type)
251
252 def _alter_field(
253 self,
254 model: type[Model],
255 old_field: Field,
256 new_field: Field,
257 old_type: str,
258 new_type: str,
259 old_db_params: DbParameters,
260 new_db_params: DbParameters,
261 strict: bool = False,
262 ) -> None:
263 super()._alter_field(
264 model,
265 old_field,
266 new_field,
267 old_type,
268 new_type,
269 old_db_params,
270 new_db_params,
271 strict,
272 )
273 # Added an index? Create any PostgreSQL-specific indexes.
274 if (
275 not (
276 (isinstance(old_field, ForeignKeyField) and old_field.db_index)
277 or old_field.primary_key
278 )
279 and isinstance(new_field, ForeignKeyField)
280 and new_field.db_index
281 ) or (not old_field.primary_key and new_field.primary_key):
282 like_index_statement = self._create_like_index_sql(model, new_field)
283 if like_index_statement is not None:
284 self.execute(like_index_statement)
285
286 # Removed an index? Drop any PostgreSQL-specific indexes.
287 if old_field.primary_key and not (
288 (isinstance(new_field, ForeignKeyField) and new_field.db_index)
289 or new_field.primary_key
290 ):
291 index_to_remove = self._create_index_name(
292 model.model_options.db_table, [old_field.column], suffix="_like"
293 )
294 self.execute(self._delete_index_sql(model, index_to_remove))
295
296 def _index_columns(
297 self,
298 table: str,
299 columns: list[str],
300 col_suffixes: tuple[str, ...],
301 opclasses: tuple[str, ...],
302 ) -> Columns | IndexColumns:
303 if opclasses:
304 return IndexColumns(
305 table,
306 columns,
307 self.quote_name,
308 col_suffixes=col_suffixes,
309 opclasses=opclasses,
310 )
311 return super()._index_columns(table, columns, col_suffixes, opclasses)
312
313 def add_index(
314 self, model: type[Model], index: Index, concurrently: bool = False
315 ) -> None:
316 self.execute(
317 index.create_sql(model, self, concurrently=concurrently), params=None
318 )
319
320 def remove_index(
321 self, model: type[Model], index: Index, concurrently: bool = False
322 ) -> None:
323 self.execute(index.remove_sql(model, self, concurrently=concurrently))
324
325 def _delete_index_sql(
326 self,
327 model: type[Model],
328 name: str,
329 sql: str | None = None,
330 concurrently: bool = False,
331 ) -> Statement:
332 sql = (
333 self.sql_delete_index_concurrently
334 if concurrently
335 else self.sql_delete_index
336 )
337 return super()._delete_index_sql(model, name, sql)
338
339 def _create_index_sql(
340 self,
341 model: type[Model],
342 *,
343 fields: list[Field] | None = None,
344 name: str | None = None,
345 suffix: str = "",
346 using: str = "",
347 col_suffixes: tuple[str, ...] = (),
348 sql: str | None = None,
349 opclasses: tuple[str, ...] = (),
350 condition: str | None = None,
351 concurrently: bool = False,
352 include: list[str] | None = None,
353 expressions: Any = None,
354 ) -> Statement:
355 sql = sql or (
356 self.sql_create_index
357 if not concurrently
358 else self.sql_create_index_concurrently
359 )
360 return super()._create_index_sql(
361 model,
362 fields=fields,
363 name=name,
364 suffix=suffix,
365 using=using,
366 col_suffixes=col_suffixes,
367 sql=sql,
368 opclasses=opclasses,
369 condition=condition,
370 include=include,
371 expressions=expressions,
372 )