1from psycopg import sql
2
3from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
4from plain.models.backends.ddl_references import IndexColumns
5from plain.models.backends.utils import strip_quotes
6
7
8class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
9 # Setting all constraints to IMMEDIATE to allow changing data in the same
10 # transaction.
11 sql_update_with_default = (
12 "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
13 "; SET CONSTRAINTS ALL IMMEDIATE"
14 )
15 sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
16 sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
17
18 sql_create_index = (
19 "CREATE INDEX %(name)s ON %(table)s%(using)s "
20 "(%(columns)s)%(include)s%(extra)s%(condition)s"
21 )
22 sql_create_index_concurrently = (
23 "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
24 "(%(columns)s)%(include)s%(extra)s%(condition)s"
25 )
26 sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
27 sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
28
29 # Setting the constraint to IMMEDIATE to allow changing data in the same
30 # transaction.
31 sql_create_column_inline_fk = (
32 "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
33 "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
34 )
35 # Setting the constraint to IMMEDIATE runs any deferred checks to allow
36 # dropping it in the same transaction.
37 sql_delete_fk = (
38 "SET CONSTRAINTS %(name)s IMMEDIATE; "
39 "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
40 )
41
42 def execute(self, sql, params=()):
43 # Merge the query client-side, as PostgreSQL won't do it server-side.
44 if params is None:
45 return super().execute(sql, params)
46 sql = self.connection.ops.compose_sql(str(sql), params)
47 # Don't let the superclass touch anything.
48 return super().execute(sql, None)
49
50 sql_add_identity = (
51 "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
52 "GENERATED BY DEFAULT AS IDENTITY"
53 )
54 sql_drop_indentity = (
55 "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
56 )
57
58 def quote_value(self, value):
59 if isinstance(value, str):
60 value = value.replace("%", "%%")
61 return sql.quote(value, self.connection.connection)
62
63 def _field_indexes_sql(self, model, field):
64 output = super()._field_indexes_sql(model, field)
65 like_index_statement = self._create_like_index_sql(model, field)
66 if like_index_statement is not None:
67 output.append(like_index_statement)
68 return output
69
70 def _field_data_type(self, field):
71 if field.is_relation:
72 return field.rel_db_type(self.connection)
73 return self.connection.data_types.get(
74 field.get_internal_type(),
75 field.db_type(self.connection),
76 )
77
78 def _field_base_data_types(self, field):
79 # Yield base data types for array fields.
80 if field.base_field.get_internal_type() == "ArrayField":
81 yield from self._field_base_data_types(field.base_field)
82 else:
83 yield self._field_data_type(field.base_field)
84
85 def _create_like_index_sql(self, model, field):
86 """
87 Return the statement to create an index with varchar operator pattern
88 when the column type is 'varchar' or 'text', otherwise return None.
89 """
90 db_type = field.db_type(connection=self.connection)
91 if db_type is not None and field.primary_key:
92 # Fields with database column types of `varchar` and `text` need
93 # a second index that specifies their operator class, which is
94 # needed when performing correct LIKE queries outside the
95 # C locale. See #12234.
96 #
97 # The same doesn't apply to array fields such as varchar[size]
98 # and text[size], so skip them.
99 if "[" in db_type:
100 return None
101 # Non-deterministic collations on Postgresql don't support indexes
102 # for operator classes varchar_pattern_ops/text_pattern_ops.
103 if getattr(field, "db_collation", None):
104 return None
105 if db_type.startswith("varchar"):
106 return self._create_index_sql(
107 model,
108 fields=[field],
109 suffix="_like",
110 opclasses=["varchar_pattern_ops"],
111 )
112 elif db_type.startswith("text"):
113 return self._create_index_sql(
114 model,
115 fields=[field],
116 suffix="_like",
117 opclasses=["text_pattern_ops"],
118 )
119 return None
120
121 def _using_sql(self, new_field, old_field):
122 using_sql = " USING %(column)s::%(type)s"
123 new_internal_type = new_field.get_internal_type()
124 old_internal_type = old_field.get_internal_type()
125 if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
126 # Compare base data types for array fields.
127 if list(self._field_base_data_types(old_field)) != list(
128 self._field_base_data_types(new_field)
129 ):
130 return using_sql
131 elif self._field_data_type(old_field) != self._field_data_type(new_field):
132 return using_sql
133 return ""
134
135 def _get_sequence_name(self, table, column):
136 with self.connection.cursor() as cursor:
137 for sequence in self.connection.introspection.get_sequences(cursor, table):
138 if sequence["column"] == column:
139 return sequence["name"]
140 return None
141
142 def _alter_column_type_sql(
143 self, model, old_field, new_field, new_type, old_collation, new_collation
144 ):
145 # Drop indexes on varchar/text/citext columns that are changing to a
146 # different type.
147 old_db_params = old_field.db_parameters(connection=self.connection)
148 old_type = old_db_params["type"]
149 if old_field.primary_key and (
150 (old_type.startswith("varchar") and not new_type.startswith("varchar"))
151 or (old_type.startswith("text") and not new_type.startswith("text"))
152 or (old_type.startswith("citext") and not new_type.startswith("citext"))
153 ):
154 index_name = self._create_index_name(
155 model._meta.db_table, [old_field.column], suffix="_like"
156 )
157 self.execute(self._delete_index_sql(model, index_name))
158
159 self.sql_alter_column_type = (
160 "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
161 )
162 # Cast when data type changed.
163 if using_sql := self._using_sql(new_field, old_field):
164 self.sql_alter_column_type += using_sql
165 new_internal_type = new_field.get_internal_type()
166 old_internal_type = old_field.get_internal_type()
167 # Make ALTER TYPE with IDENTITY make sense.
168 table = strip_quotes(model._meta.db_table)
169 auto_field_types = {
170 "AutoField",
171 "BigAutoField",
172 "SmallAutoField",
173 }
174 old_is_auto = old_internal_type in auto_field_types
175 new_is_auto = new_internal_type in auto_field_types
176 if new_is_auto and not old_is_auto:
177 column = strip_quotes(new_field.column)
178 return (
179 (
180 self.sql_alter_column_type
181 % {
182 "column": self.quote_name(column),
183 "type": new_type,
184 "collation": "",
185 },
186 [],
187 ),
188 [
189 (
190 self.sql_add_identity
191 % {
192 "table": self.quote_name(table),
193 "column": self.quote_name(column),
194 },
195 [],
196 ),
197 ],
198 )
199 elif old_is_auto and not new_is_auto:
200 # Drop IDENTITY if exists (pre-Plain 4.1 serial columns don't have
201 # it).
202 self.execute(
203 self.sql_drop_indentity
204 % {
205 "table": self.quote_name(table),
206 "column": self.quote_name(strip_quotes(new_field.column)),
207 }
208 )
209 column = strip_quotes(new_field.column)
210 fragment, _ = super()._alter_column_type_sql(
211 model, old_field, new_field, new_type, old_collation, new_collation
212 )
213 # Drop the sequence if exists (Plain 4.1+ identity columns don't
214 # have it).
215 other_actions = []
216 if sequence_name := self._get_sequence_name(table, column):
217 other_actions = [
218 (
219 self.sql_delete_sequence
220 % {
221 "sequence": self.quote_name(sequence_name),
222 },
223 [],
224 )
225 ]
226 return fragment, other_actions
227 elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
228 fragment, _ = super()._alter_column_type_sql(
229 model, old_field, new_field, new_type, old_collation, new_collation
230 )
231 column = strip_quotes(new_field.column)
232 db_types = {
233 "AutoField": "integer",
234 "BigAutoField": "bigint",
235 "SmallAutoField": "smallint",
236 }
237 # Alter the sequence type if exists (Plain 4.1+ identity columns
238 # don't have it).
239 other_actions = []
240 if sequence_name := self._get_sequence_name(table, column):
241 other_actions = [
242 (
243 self.sql_alter_sequence_type
244 % {
245 "sequence": self.quote_name(sequence_name),
246 "type": db_types[new_internal_type],
247 },
248 [],
249 ),
250 ]
251 return fragment, other_actions
252 else:
253 return super()._alter_column_type_sql(
254 model, old_field, new_field, new_type, old_collation, new_collation
255 )
256
257 def _alter_field(
258 self,
259 model,
260 old_field,
261 new_field,
262 old_type,
263 new_type,
264 old_db_params,
265 new_db_params,
266 strict=False,
267 ):
268 super()._alter_field(
269 model,
270 old_field,
271 new_field,
272 old_type,
273 new_type,
274 old_db_params,
275 new_db_params,
276 strict,
277 )
278 # Added an index? Create any PostgreSQL-specific indexes.
279 if (
280 not (
281 (old_field.remote_field and old_field.db_index) or old_field.primary_key
282 )
283 and (new_field.remote_field and new_field.db_index)
284 ) or (not old_field.primary_key and new_field.primary_key):
285 like_index_statement = self._create_like_index_sql(model, new_field)
286 if like_index_statement is not None:
287 self.execute(like_index_statement)
288
289 # Removed an index? Drop any PostgreSQL-specific indexes.
290 if old_field.primary_key and not (
291 (new_field.remote_field and new_field.db_index) or new_field.primary_key
292 ):
293 index_to_remove = self._create_index_name(
294 model._meta.db_table, [old_field.column], suffix="_like"
295 )
296 self.execute(self._delete_index_sql(model, index_to_remove))
297
298 def _index_columns(self, table, columns, col_suffixes, opclasses):
299 if opclasses:
300 return IndexColumns(
301 table,
302 columns,
303 self.quote_name,
304 col_suffixes=col_suffixes,
305 opclasses=opclasses,
306 )
307 return super()._index_columns(table, columns, col_suffixes, opclasses)
308
309 def add_index(self, model, index, concurrently=False):
310 self.execute(
311 index.create_sql(model, self, concurrently=concurrently), params=None
312 )
313
314 def remove_index(self, model, index, concurrently=False):
315 self.execute(index.remove_sql(model, self, concurrently=concurrently))
316
317 def _delete_index_sql(self, model, name, sql=None, concurrently=False):
318 sql = (
319 self.sql_delete_index_concurrently
320 if concurrently
321 else self.sql_delete_index
322 )
323 return super()._delete_index_sql(model, name, sql)
324
325 def _create_index_sql(
326 self,
327 model,
328 *,
329 fields=None,
330 name=None,
331 suffix="",
332 using="",
333 col_suffixes=(),
334 sql=None,
335 opclasses=(),
336 condition=None,
337 concurrently=False,
338 include=None,
339 expressions=None,
340 ):
341 sql = sql or (
342 self.sql_create_index
343 if not concurrently
344 else self.sql_create_index_concurrently
345 )
346 return super()._create_index_sql(
347 model,
348 fields=fields,
349 name=name,
350 suffix=suffix,
351 using=using,
352 col_suffixes=col_suffixes,
353 sql=sql,
354 opclasses=opclasses,
355 condition=condition,
356 include=include,
357 expressions=expressions,
358 )