1from __future__ import annotations
2
3import copy
4from decimal import Decimal
5from typing import TYPE_CHECKING, Any
6
7from plain.models import Options
8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
9from plain.models.backends.ddl_references import Statement
10from plain.models.backends.utils import strip_quotes
11from plain.models.constraints import UniqueConstraint
12from plain.models.db import NotSupportedError
13from plain.models.registry import ModelsRegistry
14from plain.models.transaction import atomic
15
16if TYPE_CHECKING:
17 from plain.models.base import Model
18 from plain.models.constraints import BaseConstraint
19 from plain.models.fields import Field
20 from plain.models.fields.related import ManyToManyField
21 from plain.models.fields.reverse_related import ManyToManyRel
22
23
24class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
25 sql_delete_table = "DROP TABLE %(table)s"
26 sql_create_fk = None
27 sql_create_inline_fk = (
28 "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
29 )
30 sql_create_column_inline_fk = sql_create_inline_fk
31 sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
32 sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
33 sql_delete_unique = "DROP INDEX %(name)s"
34
35 def __enter__(self) -> DatabaseSchemaEditor:
36 # Some SQLite schema alterations need foreign key constraints to be
37 # disabled. Enforce it here for the duration of the schema edition.
38 if not self.connection.disable_constraint_checking():
39 raise NotSupportedError(
40 "SQLite schema editor cannot be used while foreign key "
41 "constraint checks are enabled. Make sure to disable them "
42 "before entering a transaction.atomic() context because "
43 "SQLite does not support disabling them in the middle of "
44 "a multi-statement transaction."
45 )
46 return super().__enter__()
47
48 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
49 self.connection.check_constraints()
50 super().__exit__(exc_type, exc_value, traceback)
51 self.connection.enable_constraint_checking()
52
53 def quote_value(self, value: Any) -> str:
54 # The backend "mostly works" without this function and there are use
55 # cases for compiling Python without the sqlite3 libraries (e.g.
56 # security hardening).
57 try:
58 import sqlite3
59
60 value = sqlite3.adapt(value) # type: ignore[call-overload]
61 except ImportError:
62 pass
63 except sqlite3.ProgrammingError:
64 pass
65 # Manual emulation of SQLite parameter quoting
66 if isinstance(value, bool):
67 return str(int(value))
68 elif isinstance(value, Decimal | float | int):
69 return str(value)
70 elif isinstance(value, str):
71 return "'{}'".format(value.replace("'", "''"))
72 elif value is None:
73 return "NULL"
74 elif isinstance(value, bytes | bytearray | memoryview):
75 # Bytes are only allowed for BLOB fields, encoded as string
76 # literals containing hexadecimal data and preceded by a single "X"
77 # character.
78 return f"X'{value.hex()}'"
79 else:
80 raise ValueError(
81 f"Cannot quote parameter value {value!r} of type {type(value)}"
82 )
83
84 def prepare_default(self, value: Any) -> str:
85 return self.quote_value(value)
86
87 def _is_referenced_by_fk_constraint(
88 self, table_name: str, column_name: str | None = None, ignore_self: bool = False
89 ) -> bool:
90 """
91 Return whether or not the provided table name is referenced by another
92 one. If `column_name` is specified, only references pointing to that
93 column are considered. If `ignore_self` is True, self-referential
94 constraints are ignored.
95 """
96 with self.connection.cursor() as cursor:
97 for other_table in self.connection.introspection.get_table_list(cursor):
98 if ignore_self and other_table.name == table_name:
99 continue
100 relations = self.connection.introspection.get_relations(
101 cursor, other_table.name
102 )
103 for constraint_column, constraint_table in relations.values():
104 if constraint_table == table_name and (
105 column_name is None or constraint_column == column_name
106 ):
107 return True
108 return False
109
110 def alter_db_table(
111 self,
112 model: type[Model],
113 old_db_table: str,
114 new_db_table: str,
115 disable_constraints: bool = True,
116 ) -> None:
117 if (
118 not self.connection.features.supports_atomic_references_rename
119 and disable_constraints
120 and self._is_referenced_by_fk_constraint(old_db_table)
121 ):
122 if self.connection.in_atomic_block:
123 raise NotSupportedError(
124 f"Renaming the {old_db_table!r} table while in a transaction is not "
125 "supported on SQLite < 3.26 because it would break referential "
126 "integrity. Try adding `atomic = False` to the Migration class."
127 )
128 self.connection.enable_constraint_checking()
129 super().alter_db_table(model, old_db_table, new_db_table)
130 self.connection.disable_constraint_checking()
131 else:
132 super().alter_db_table(model, old_db_table, new_db_table)
133
134 def alter_field(
135 self,
136 model: type[Model],
137 old_field: Field,
138 new_field: Field,
139 strict: bool = False,
140 ) -> None:
141 if not self._field_should_be_altered(old_field, new_field):
142 return
143 old_field_name = old_field.name
144 table_name = model.model_options.db_table
145 _, old_column_name = old_field.get_attname_column()
146 if (
147 new_field.name != old_field_name
148 and not self.connection.features.supports_atomic_references_rename
149 and self._is_referenced_by_fk_constraint(
150 table_name, old_column_name, ignore_self=True
151 )
152 ):
153 if self.connection.in_atomic_block:
154 raise NotSupportedError(
155 f"Renaming the {model.model_options.db_table!r}.{old_field_name!r} column while in a transaction is not "
156 "supported on SQLite < 3.26 because it would break referential "
157 "integrity. Try adding `atomic = False` to the Migration class."
158 )
159 with atomic():
160 super().alter_field(model, old_field, new_field, strict=strict)
161 # Follow SQLite's documented procedure for performing changes
162 # that don't affect the on-disk content.
163 # https://sqlite.org/lang_altertable.html#otheralter
164 with self.connection.cursor() as cursor:
165 schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
166 0
167 ]
168 cursor.execute("PRAGMA writable_schema = 1")
169 references_template = f' REFERENCES "{table_name}" ("%s") '
170 new_column_name = new_field.get_attname_column()[1]
171 search = references_template % old_column_name
172 replacement = references_template % new_column_name
173 cursor.execute(
174 "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
175 (search, replacement),
176 )
177 cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1)) # noqa: UP031
178 cursor.execute("PRAGMA writable_schema = 0")
179 # The integrity check will raise an exception and rollback
180 # the transaction if the sqlite_master updates corrupt the
181 # database.
182 cursor.execute("PRAGMA integrity_check")
183 # Perform a VACUUM to refresh the database representation from
184 # the sqlite_master table.
185 with self.connection.cursor() as cursor:
186 cursor.execute("VACUUM")
187 else:
188 super().alter_field(model, old_field, new_field, strict=strict)
189
190 def _remake_table(
191 self,
192 model: type[Model],
193 create_field: Field | None = None,
194 delete_field: Field | None = None,
195 alter_fields: list[tuple[Field, Field]] | None = None,
196 ) -> None:
197 """
198 Shortcut to transform a model from old_model into new_model
199
200 This follows the correct procedure to perform non-rename or column
201 addition operations based on SQLite's documentation
202
203 https://www.sqlite.org/lang_altertable.html#caution
204
205 The essential steps are:
206 1. Create a table with the updated definition called "new__app_model"
207 2. Copy the data from the existing "app_model" table to the new table
208 3. Drop the "app_model" table
209 4. Rename the "new__app_model" table to "app_model"
210 5. Restore any index of the previous "app_model" table.
211 """
212
213 # Self-referential fields must be recreated rather than copied from
214 # the old model to ensure their remote_field.field_name doesn't refer
215 # to an altered field.
216 def is_self_referential(f: Field) -> bool:
217 return f.is_relation and f.remote_field.model is model # type: ignore[attr-defined]
218
219 # Work out the new fields dict / mapping
220 body = {
221 f.name: f.clone() if is_self_referential(f) else f
222 for f in model._model_meta.local_concrete_fields
223 }
224 # Since mapping might mix column names and default values,
225 # its values must be already quoted.
226 mapping = {
227 f.column: self.quote_name(f.column)
228 for f in model._model_meta.local_concrete_fields
229 }
230 # If any of the new or altered fields is introducing a new PK,
231 # remove the old one
232 restore_pk_field = None
233 alter_fields = alter_fields or []
234 if getattr(create_field, "primary_key", False) or any(
235 getattr(new_field, "primary_key", False) for _, new_field in alter_fields
236 ):
237 for name, field in list(body.items()):
238 if field.primary_key and not any(
239 # Do not remove the old primary key when an altered field
240 # that introduces a primary key is the same field.
241 name == new_field.name
242 for _, new_field in alter_fields
243 ):
244 field.primary_key = False
245 restore_pk_field = field
246 if field.auto_created:
247 del body[name]
248 del mapping[field.column]
249 # Add in any created fields
250 if create_field:
251 body[create_field.name] = create_field
252 # Choose a default and insert it into the copy map
253 if not create_field.many_to_many and create_field.concrete:
254 mapping[create_field.column] = self.prepare_default(
255 self.effective_default(create_field),
256 )
257 # Add in any altered fields
258 for alter_field in alter_fields:
259 old_field, new_field = alter_field
260 body.pop(old_field.name, None)
261 mapping.pop(old_field.column, None)
262 body[new_field.name] = new_field
263 if old_field.allow_null and not new_field.allow_null:
264 case_sql = f"coalesce({self.quote_name(old_field.column)}, {self.prepare_default(self.effective_default(new_field))})"
265 mapping[new_field.column] = case_sql
266 else:
267 mapping[new_field.column] = self.quote_name(old_field.column)
268 # Remove any deleted fields
269 if delete_field:
270 del body[delete_field.name]
271 del mapping[delete_field.column]
272 # Work inside a new app registry
273 models_registry = ModelsRegistry()
274
275 indexes = model.model_options.indexes
276 if delete_field:
277 indexes = [
278 index for index in indexes if delete_field.name not in index.fields
279 ]
280
281 constraints = list(model.model_options.constraints)
282
283 # Provide isolated instances of the fields to the new model body so
284 # that the existing model's internals aren't interfered with when
285 # the dummy model is constructed.
286 body_copy = copy.deepcopy(body)
287
288 # Construct a new model with the new fields to allow self referential
289 # primary key to resolve to. This model won't ever be materialized as a
290 # table and solely exists for foreign key reference resolution purposes.
291 # This wouldn't be required if the schema editor was operating on model
292 # states instead of rendered models.
293 meta_options = Options(
294 package_label=model.model_options.package_label,
295 db_table=model.model_options.db_table,
296 indexes=indexes,
297 constraints=constraints,
298 )
299 body_copy["model_options"] = meta_options
300 body_copy["__module__"] = model.__module__
301 temp_model: type[Model] = type( # type: ignore[assignment]
302 model.model_options.object_name, model.__bases__, body_copy
303 )
304 models_registry.register_model(model.model_options.package_label, temp_model)
305
306 # Construct a model with a renamed table name.
307 body_copy = copy.deepcopy(body)
308 meta_options = Options(
309 package_label=model.model_options.package_label,
310 db_table=f"new__{strip_quotes(model.model_options.db_table)}",
311 indexes=indexes,
312 constraints=constraints,
313 )
314 body_copy["model_options"] = meta_options
315 body_copy["__module__"] = model.__module__
316 new_model: type[Model] = type( # type: ignore[assignment]
317 f"New{model.model_options.object_name}", model.__bases__, body_copy
318 )
319 models_registry.register_model(model.model_options.package_label, new_model)
320
321 # Create a new table with the updated schema.
322 self.create_model(new_model)
323
324 # Copy data from the old table into the new table
325 self.execute(
326 "INSERT INTO {} ({}) SELECT {} FROM {}".format(
327 self.quote_name(new_model.model_options.db_table),
328 ", ".join(self.quote_name(x) for x in mapping),
329 ", ".join(mapping.values()),
330 self.quote_name(model.model_options.db_table),
331 )
332 )
333
334 # Delete the old table to make way for the new
335 self.delete_model(model, handle_autom2m=False)
336
337 # Rename the new table to take way for the old
338 self.alter_db_table(
339 new_model, # type: ignore[arg-type]
340 new_model.model_options.db_table,
341 model.model_options.db_table,
342 disable_constraints=False,
343 )
344
345 # Run deferred SQL on correct table
346 for sql in self.deferred_sql:
347 self.execute(sql)
348 self.deferred_sql = []
349 # Fix any PK-removed field
350 if restore_pk_field:
351 restore_pk_field.primary_key = True
352
353 def delete_model(self, model: type[Model], handle_autom2m: bool = True) -> None:
354 if handle_autom2m:
355 super().delete_model(model)
356 else:
357 # Delete the table (and only that)
358 self.execute(
359 self.sql_delete_table
360 % {
361 "table": self.quote_name(model.model_options.db_table),
362 }
363 )
364 # Remove all deferred statements referencing the deleted table.
365 for sql in list(self.deferred_sql):
366 if isinstance(sql, Statement) and sql.references_table(
367 model.model_options.db_table
368 ):
369 self.deferred_sql.remove(sql)
370
371 def add_field(self, model: type[Model], field: Field) -> None:
372 """Create a field on a model."""
373 if (
374 # Primary keys are not supported in ALTER TABLE
375 # ADD COLUMN.
376 field.primary_key
377 or
378 # Fields with default values cannot by handled by ALTER TABLE ADD
379 # COLUMN statement because DROP DEFAULT is not supported in
380 # ALTER TABLE.
381 not field.allow_null
382 or self.effective_default(field) is not None
383 ):
384 self._remake_table(model, create_field=field)
385 else:
386 super().add_field(model, field)
387
388 def remove_field(self, model: type[Model], field: Field) -> None:
389 """
390 Remove a field from a model. Usually involves deleting a column,
391 but for M2Ms may involve deleting a table.
392 """
393 # M2M fields are a special case
394 if field.many_to_many:
395 # For explicit "through" M2M fields, do nothing
396 pass
397 elif (
398 self.connection.features.can_alter_table_drop_column
399 # Primary keys, unique fields, indexed fields, and foreign keys are
400 # not supported in ALTER TABLE DROP COLUMN.
401 and not field.primary_key
402 and not (field.remote_field and field.db_index) # type: ignore[attr-defined]
403 and not (field.remote_field and field.db_constraint) # type: ignore[attr-defined]
404 ):
405 super().remove_field(model, field)
406 # For everything else, remake.
407 else:
408 # It might not actually have a column behind it
409 if field.db_parameters(connection=self.connection)["type"] is None:
410 return
411 self._remake_table(model, delete_field=field)
412
413 def _alter_field(
414 self,
415 model: type[Model],
416 old_field: Field,
417 new_field: Field,
418 old_type: str,
419 new_type: str,
420 old_db_params: dict[str, Any],
421 new_db_params: dict[str, Any],
422 strict: bool = False,
423 ) -> None:
424 """Perform a "physical" (non-ManyToMany) field update."""
425 # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
426 # changed and there aren't any constraints.
427 if (
428 self.connection.features.can_alter_table_rename_column
429 and old_field.column != new_field.column
430 and self.column_sql(model, old_field) == self.column_sql(model, new_field)
431 and not (
432 old_field.remote_field
433 and old_field.db_constraint # type: ignore[attr-defined]
434 or new_field.remote_field
435 and new_field.db_constraint # type: ignore[attr-defined]
436 )
437 ):
438 return self.execute(
439 self._rename_field_sql(
440 model.model_options.db_table, old_field, new_field, new_type
441 )
442 )
443 # Alter by remaking table
444 self._remake_table(model, alter_fields=[(old_field, new_field)])
445 # Rebuild tables with FKs pointing to this field.
446 old_collation = old_db_params.get("collation")
447 new_collation = new_db_params.get("collation")
448 if new_field.primary_key and (
449 old_type != new_type or old_collation != new_collation
450 ):
451 related_models = set()
452 meta = new_field.model._model_meta
453 for remote_field in meta.related_objects:
454 # Ignore self-relationship since the table was already rebuilt.
455 if remote_field.related_model == model:
456 continue
457 if not remote_field.many_to_many:
458 if remote_field.field_name == new_field.name:
459 related_models.add(remote_field.related_model)
460 if new_field.primary_key:
461 for many_to_many in meta.many_to_many:
462 # Ignore self-relationship since the table was already rebuilt.
463 if many_to_many.related_model == model:
464 continue
465 for related_model in related_models:
466 self._remake_table(related_model)
467
468 def _alter_many_to_many(
469 self,
470 model: type[Model],
471 old_field: ManyToManyField,
472 new_field: ManyToManyField,
473 strict: bool,
474 ) -> None:
475 """Alter M2Ms to repoint their to= endpoints."""
476 # Type narrow for ManyToManyField.remote_field
477 old_rel: ManyToManyRel = old_field.remote_field # type: ignore[assignment]
478 new_rel: ManyToManyRel = new_field.remote_field # type: ignore[assignment]
479
480 if (
481 old_rel.through.model_options.db_table
482 == new_rel.through.model_options.db_table
483 ):
484 # The field name didn't change, but some options did, so we have to
485 # propagate this altering.
486 self._remake_table(
487 old_rel.through,
488 alter_fields=[
489 (
490 # The field that points to the target model is needed,
491 # so that table can be remade with the new m2m field -
492 # this is m2m_reverse_field_name().
493 old_rel.through._model_meta.get_field(
494 old_field.m2m_reverse_field_name() # type: ignore[attr-defined]
495 ),
496 new_rel.through._model_meta.get_field(
497 new_field.m2m_reverse_field_name() # type: ignore[attr-defined]
498 ),
499 ),
500 (
501 # The field that points to the model itself is needed,
502 # so that table can be remade with the new self field -
503 # this is m2m_field_name().
504 old_rel.through._model_meta.get_field(
505 old_field.m2m_field_name() # type: ignore[attr-defined]
506 ),
507 new_rel.through._model_meta.get_field(
508 new_field.m2m_field_name() # type: ignore[attr-defined]
509 ),
510 ),
511 ],
512 )
513 return
514
515 # Make a new through table
516 self.create_model(new_rel.through)
517 # Copy the data across
518 self.execute(
519 "INSERT INTO {} ({}) SELECT {} FROM {}".format(
520 self.quote_name(new_rel.through.model_options.db_table),
521 ", ".join(
522 [
523 "id",
524 new_field.m2m_column_name(), # type: ignore[attr-defined]
525 new_field.m2m_reverse_name(), # type: ignore[attr-defined]
526 ]
527 ),
528 ", ".join(
529 [
530 "id",
531 old_field.m2m_column_name(), # type: ignore[attr-defined]
532 old_field.m2m_reverse_name(), # type: ignore[attr-defined]
533 ]
534 ),
535 self.quote_name(old_rel.through.model_options.db_table),
536 )
537 )
538 # Delete the old through table
539 self.delete_model(old_rel.through)
540
541 def add_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
542 if isinstance(constraint, UniqueConstraint) and (
543 constraint.condition
544 or constraint.contains_expressions
545 or constraint.include
546 or constraint.deferrable
547 ):
548 super().add_constraint(model, constraint)
549 else:
550 self._remake_table(model)
551
552 def remove_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
553 if isinstance(constraint, UniqueConstraint) and (
554 constraint.condition
555 or constraint.contains_expressions
556 or constraint.include
557 or constraint.deferrable
558 ):
559 super().remove_constraint(model, constraint)
560 else:
561 self._remake_table(model)
562
563 def _collate_sql(self, collation: str) -> str:
564 return "COLLATE " + collation