1from __future__ import annotations
2
3import copy
4from decimal import Decimal
5from typing import TYPE_CHECKING, Any, cast
6
7from plain.models import Options
8from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
9from plain.models.backends.ddl_references import Statement
10from plain.models.backends.utils import strip_quotes
11from plain.models.constraints import UniqueConstraint
12from plain.models.db import NotSupportedError
13from plain.models.fields import DbParameters
14from plain.models.fields.related import ForeignKeyField, RelatedField
15from plain.models.registry import ModelsRegistry
16from plain.models.transaction import atomic
17
18if TYPE_CHECKING:
19 from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
20 from plain.models.base import Model
21 from plain.models.constraints import BaseConstraint
22 from plain.models.fields import Field
23 from plain.models.fields.related import ManyToManyField
24 from plain.models.fields.reverse_related import ManyToManyRel
25
26
27class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
28 # Type checker hint: connection is always SQLiteDatabaseWrapper in this class
29 connection: SQLiteDatabaseWrapper
30
31 sql_delete_table = "DROP TABLE %(table)s"
32 sql_create_fk = None
33 sql_create_inline_fk = (
34 "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
35 )
36 sql_create_column_inline_fk = sql_create_inline_fk
37 sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
38 sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
39 sql_delete_unique = "DROP INDEX %(name)s"
40
41 def __enter__(self) -> DatabaseSchemaEditor:
42 # Some SQLite schema alterations need foreign key constraints to be
43 # disabled. Enforce it here for the duration of the schema edition.
44 if not self.connection.disable_constraint_checking():
45 raise NotSupportedError(
46 "SQLite schema editor cannot be used while foreign key "
47 "constraint checks are enabled. Make sure to disable them "
48 "before entering a transaction.atomic() context because "
49 "SQLite does not support disabling them in the middle of "
50 "a multi-statement transaction."
51 )
52 return super().__enter__()
53
54 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
55 self.connection.check_constraints()
56 super().__exit__(exc_type, exc_value, traceback)
57 self.connection.enable_constraint_checking()
58
59 def quote_value(self, value: Any) -> str:
60 # The backend "mostly works" without this function and there are use
61 # cases for compiling Python without the sqlite3 libraries (e.g.
62 # security hardening).
63 try:
64 import sqlite3
65
66 value = sqlite3.adapt(value) # type: ignore[call-overload]
67 except ImportError:
68 pass
69 except sqlite3.ProgrammingError:
70 pass
71 # Manual emulation of SQLite parameter quoting
72 if isinstance(value, bool):
73 return str(int(value))
74 elif isinstance(value, Decimal | float | int):
75 return str(value)
76 elif isinstance(value, str):
77 return "'{}'".format(value.replace("'", "''"))
78 elif value is None:
79 return "NULL"
80 elif isinstance(value, bytes | bytearray | memoryview):
81 # Bytes are only allowed for BLOB fields, encoded as string
82 # literals containing hexadecimal data and preceded by a single "X"
83 # character.
84 return f"X'{value.hex()}'"
85 else:
86 raise ValueError(
87 f"Cannot quote parameter value {value!r} of type {type(value)}"
88 )
89
90 def prepare_default(self, value: Any) -> str:
91 return self.quote_value(value)
92
93 def _is_referenced_by_fk_constraint(
94 self, table_name: str, column_name: str | None = None, ignore_self: bool = False
95 ) -> bool:
96 """
97 Return whether or not the provided table name is referenced by another
98 one. If `column_name` is specified, only references pointing to that
99 column are considered. If `ignore_self` is True, self-referential
100 constraints are ignored.
101 """
102 with self.connection.cursor() as cursor:
103 for other_table in self.connection.introspection.get_table_list(cursor):
104 if ignore_self and other_table.name == table_name:
105 continue
106 relations = self.connection.introspection.get_relations(
107 cursor, other_table.name
108 )
109 for constraint_column, constraint_table in relations.values():
110 if constraint_table == table_name and (
111 column_name is None or constraint_column == column_name
112 ):
113 return True
114 return False
115
116 def alter_db_table(
117 self,
118 model: type[Model],
119 old_db_table: str,
120 new_db_table: str,
121 disable_constraints: bool = True,
122 ) -> None:
123 if (
124 not self.connection.features.supports_atomic_references_rename
125 and disable_constraints
126 and self._is_referenced_by_fk_constraint(old_db_table)
127 ):
128 if self.connection.in_atomic_block:
129 raise NotSupportedError(
130 f"Renaming the {old_db_table!r} table while in a transaction is not "
131 "supported on SQLite < 3.26 because it would break referential "
132 "integrity. Try adding `atomic = False` to the Migration class."
133 )
134 self.connection.enable_constraint_checking()
135 super().alter_db_table(model, old_db_table, new_db_table)
136 self.connection.disable_constraint_checking()
137 else:
138 super().alter_db_table(model, old_db_table, new_db_table)
139
140 def alter_field(
141 self,
142 model: type[Model],
143 old_field: Field,
144 new_field: Field,
145 strict: bool = False,
146 ) -> None:
147 if not self._field_should_be_altered(old_field, new_field):
148 return
149 old_field_name = old_field.name
150 table_name = model.model_options.db_table
151 _, old_column_name = old_field.get_attname_column()
152 if (
153 new_field.name != old_field_name
154 and not self.connection.features.supports_atomic_references_rename
155 and self._is_referenced_by_fk_constraint(
156 table_name, old_column_name, ignore_self=True
157 )
158 ):
159 if self.connection.in_atomic_block:
160 raise NotSupportedError(
161 f"Renaming the {model.model_options.db_table!r}.{old_field_name!r} column while in a transaction is not "
162 "supported on SQLite < 3.26 because it would break referential "
163 "integrity. Try adding `atomic = False` to the Migration class."
164 )
165 with atomic():
166 super().alter_field(model, old_field, new_field, strict=strict)
167 # Follow SQLite's documented procedure for performing changes
168 # that don't affect the on-disk content.
169 # https://sqlite.org/lang_altertable.html#otheralter
170 with self.connection.cursor() as cursor:
171 row = cursor.execute("PRAGMA schema_version").fetchone()
172 assert row is not None
173 schema_version = row[0]
174 cursor.execute("PRAGMA writable_schema = 1")
175 references_template = f' REFERENCES "{table_name}" ("%s") '
176 new_column_name = new_field.get_attname_column()[1]
177 search = references_template % old_column_name
178 replacement = references_template % new_column_name
179 cursor.execute(
180 "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
181 (search, replacement),
182 )
183 cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1)) # noqa: UP031
184 cursor.execute("PRAGMA writable_schema = 0")
185 # The integrity check will raise an exception and rollback
186 # the transaction if the sqlite_master updates corrupt the
187 # database.
188 cursor.execute("PRAGMA integrity_check")
189 # Perform a VACUUM to refresh the database representation from
190 # the sqlite_master table.
191 with self.connection.cursor() as cursor:
192 cursor.execute("VACUUM")
193 else:
194 super().alter_field(model, old_field, new_field, strict=strict)
195
196 def _remake_table(
197 self,
198 model: type[Model],
199 create_field: Field | None = None,
200 delete_field: Field | None = None,
201 alter_fields: list[tuple[Field, Field]] | None = None,
202 ) -> None:
203 """
204 Shortcut to transform a model from old_model into new_model
205
206 This follows the correct procedure to perform non-rename or column
207 addition operations based on SQLite's documentation
208
209 https://www.sqlite.org/lang_altertable.html#caution
210
211 The essential steps are:
212 1. Create a table with the updated definition called "new__app_model"
213 2. Copy the data from the existing "app_model" table to the new table
214 3. Drop the "app_model" table
215 4. Rename the "new__app_model" table to "app_model"
216 5. Restore any index of the previous "app_model" table.
217 """
218
219 # Self-referential fields must be recreated rather than copied from
220 # the old model to ensure their remote_field.field_name doesn't refer
221 # to an altered field.
222 def is_self_referential(f: Field) -> bool:
223 return isinstance(f, RelatedField) and f.remote_field.model is model
224
225 # Work out the new fields dict / mapping
226 body = {
227 f.name: f.clone() if is_self_referential(f) else f
228 for f in model._model_meta.local_concrete_fields
229 }
230 # Since mapping might mix column names and default values,
231 # its values must be already quoted.
232 mapping = {
233 f.column: self.quote_name(f.column)
234 for f in model._model_meta.local_concrete_fields
235 }
236 # If any of the new or altered fields is introducing a new PK,
237 # remove the old one
238 restore_pk_field = None
239 alter_fields = alter_fields or []
240 if getattr(create_field, "primary_key", False) or any(
241 getattr(new_field, "primary_key", False) for _, new_field in alter_fields
242 ):
243 for name, field in list(body.items()):
244 if field.primary_key and not any(
245 # Do not remove the old primary key when an altered field
246 # that introduces a primary key is the same field.
247 name == new_field.name
248 for _, new_field in alter_fields
249 ):
250 field.primary_key = False
251 restore_pk_field = field
252 if field.auto_created:
253 del body[name]
254 del mapping[field.column]
255 # Add in any created fields
256 if create_field:
257 body[create_field.name] = create_field
258 # Choose a default and insert it into the copy map
259 from plain.models.fields.related import ManyToManyField
260
261 if not isinstance(create_field, ManyToManyField) and create_field.concrete:
262 mapping[create_field.column] = self.prepare_default(
263 self.effective_default(create_field),
264 )
265 # Add in any altered fields
266 for alter_field in alter_fields:
267 old_field, new_field = alter_field
268 body.pop(old_field.name, None)
269 mapping.pop(old_field.column, None)
270 body[new_field.name] = new_field
271 if old_field.allow_null and not new_field.allow_null:
272 case_sql = f"coalesce({self.quote_name(old_field.column)}, {self.prepare_default(self.effective_default(new_field))})"
273 mapping[new_field.column] = case_sql
274 else:
275 mapping[new_field.column] = self.quote_name(old_field.column)
276 # Remove any deleted fields
277 if delete_field:
278 del body[delete_field.name]
279 del mapping[delete_field.column]
280 # Work inside a new app registry
281 models_registry = ModelsRegistry()
282
283 indexes = model.model_options.indexes
284 if delete_field:
285 indexes = [
286 index for index in indexes if delete_field.name not in index.fields
287 ]
288
289 constraints = list(model.model_options.constraints)
290
291 # Provide isolated instances of the fields to the new model body so
292 # that the existing model's internals aren't interfered with when
293 # the dummy model is constructed.
294 body_copy = copy.deepcopy(body)
295
296 # Construct a new model with the new fields to allow self referential
297 # primary key to resolve to. This model won't ever be materialized as a
298 # table and solely exists for foreign key reference resolution purposes.
299 # This wouldn't be required if the schema editor was operating on model
300 # states instead of rendered models.
301 meta_options = Options(
302 package_label=model.model_options.package_label,
303 db_table=model.model_options.db_table,
304 indexes=indexes,
305 constraints=constraints,
306 )
307 body_copy["model_options"] = meta_options
308 body_copy["__module__"] = model.__module__
309 temp_model = cast(
310 "type[Model]",
311 type(model.model_options.object_name, model.__bases__, body_copy),
312 )
313 models_registry.register_model(model.model_options.package_label, temp_model)
314
315 # Construct a model with a renamed table name.
316 body_copy = copy.deepcopy(body)
317 meta_options = Options(
318 package_label=model.model_options.package_label,
319 db_table=f"new__{strip_quotes(model.model_options.db_table)}",
320 indexes=indexes,
321 constraints=constraints,
322 )
323 body_copy["model_options"] = meta_options
324 body_copy["__module__"] = model.__module__
325 new_model = cast(
326 "type[Model]",
327 type(f"New{model.model_options.object_name}", model.__bases__, body_copy),
328 )
329 models_registry.register_model(model.model_options.package_label, new_model)
330
331 # Create a new table with the updated schema.
332 self.create_model(new_model)
333
334 # Copy data from the old table into the new table
335 self.execute(
336 "INSERT INTO {} ({}) SELECT {} FROM {}".format(
337 self.quote_name(new_model.model_options.db_table),
338 ", ".join(self.quote_name(x) for x in mapping),
339 ", ".join(mapping.values()),
340 self.quote_name(model.model_options.db_table),
341 )
342 )
343
344 # Delete the old table to make way for the new
345 self.delete_model(model, handle_autom2m=False)
346
347 # Rename the new table to take way for the old
348 self.alter_db_table(
349 new_model,
350 new_model.model_options.db_table,
351 model.model_options.db_table,
352 disable_constraints=False,
353 )
354
355 # Run deferred SQL on correct table
356 for sql in self.deferred_sql:
357 self.execute(sql)
358 self.deferred_sql = []
359 # Fix any PK-removed field
360 if restore_pk_field:
361 restore_pk_field.primary_key = True
362
363 def delete_model(self, model: type[Model], handle_autom2m: bool = True) -> None:
364 if handle_autom2m:
365 super().delete_model(model)
366 else:
367 # Delete the table (and only that)
368 self.execute(
369 self.sql_delete_table
370 % {
371 "table": self.quote_name(model.model_options.db_table),
372 }
373 )
374 # Remove all deferred statements referencing the deleted table.
375 for sql in list(self.deferred_sql):
376 if isinstance(sql, Statement) and sql.references_table(
377 model.model_options.db_table
378 ):
379 self.deferred_sql.remove(sql)
380
381 def add_field(self, model: type[Model], field: Field) -> None:
382 """Create a field on a model."""
383 if (
384 # Primary keys are not supported in ALTER TABLE
385 # ADD COLUMN.
386 field.primary_key
387 or
388 # Fields with default values cannot by handled by ALTER TABLE ADD
389 # COLUMN statement because DROP DEFAULT is not supported in
390 # ALTER TABLE.
391 not field.allow_null
392 or self.effective_default(field) is not None
393 ):
394 self._remake_table(model, create_field=field)
395 else:
396 super().add_field(model, field)
397
398 def remove_field(self, model: type[Model], field: Field) -> None:
399 """
400 Remove a field from a model. Usually involves deleting a column,
401 but for M2Ms may involve deleting a table.
402 """
403 from plain.models.fields.related import ManyToManyField
404
405 # M2M fields are a special case
406 if isinstance(field, ManyToManyField):
407 # For explicit "through" M2M fields, do nothing
408 pass
409 elif (
410 self.connection.features.can_alter_table_drop_column
411 # Primary keys, unique fields, indexed fields, and foreign keys are
412 # not supported in ALTER TABLE DROP COLUMN.
413 and not field.primary_key
414 and not (isinstance(field, ForeignKeyField) and field.db_index)
415 and not (isinstance(field, ForeignKeyField) and field.db_constraint)
416 ):
417 super().remove_field(model, field)
418 # For everything else, remake.
419 else:
420 # It might not actually have a column behind it
421 if field.db_parameters(connection=self.connection)["type"] is None:
422 return
423 self._remake_table(model, delete_field=field)
424
425 def _alter_field(
426 self,
427 model: type[Model],
428 old_field: Field,
429 new_field: Field,
430 old_type: str,
431 new_type: str,
432 old_db_params: DbParameters,
433 new_db_params: DbParameters,
434 strict: bool = False,
435 ) -> None:
436 """Perform a "physical" (non-ManyToMany) field update."""
437 # Use "ALTER TABLE ... RENAME COLUMN" if only the column name
438 # changed and there aren't any constraints.
439 if (
440 self.connection.features.can_alter_table_rename_column
441 and old_field.column != new_field.column
442 and self.column_sql(model, old_field) == self.column_sql(model, new_field)
443 and not (
444 isinstance(old_field, ForeignKeyField)
445 and old_field.db_constraint
446 or isinstance(new_field, ForeignKeyField)
447 and new_field.db_constraint
448 )
449 ):
450 return self.execute(
451 self._rename_field_sql(
452 model.model_options.db_table, old_field, new_field, new_type
453 )
454 )
455 # Alter by remaking table
456 self._remake_table(model, alter_fields=[(old_field, new_field)])
457 # Rebuild tables with FKs pointing to this field.
458 old_collation = old_db_params.get("collation")
459 new_collation = new_db_params.get("collation")
460 if new_field.primary_key and (
461 old_type != new_type or old_collation != new_collation
462 ):
463 from plain.models.fields.reverse_related import ManyToManyRel
464
465 related_models = set()
466 meta = new_field.model._model_meta
467 for remote_field in meta.related_objects:
468 # Ignore self-relationship since the table was already rebuilt.
469 if remote_field.related_model == model:
470 continue
471 if not isinstance(remote_field, ManyToManyRel):
472 if remote_field.field_name == new_field.name:
473 related_models.add(remote_field.related_model)
474 if new_field.primary_key:
475 for many_to_many in meta.many_to_many:
476 # Ignore self-relationship since the table was already rebuilt.
477 if many_to_many.related_model == model:
478 continue
479 for related_model in related_models:
480 self._remake_table(related_model)
481
482 def _alter_many_to_many(
483 self,
484 model: type[Model],
485 old_field: ManyToManyField,
486 new_field: ManyToManyField,
487 strict: bool,
488 ) -> None:
489 """Alter M2Ms to repoint their to= endpoints."""
490 # Type narrow for ManyToManyField.remote_field
491 old_rel: ManyToManyRel = old_field.remote_field
492 new_rel: ManyToManyRel = new_field.remote_field
493
494 if (
495 old_rel.through.model_options.db_table
496 == new_rel.through.model_options.db_table
497 ):
498 # The field name didn't change, but some options did, so we have to
499 # propagate this altering.
500 self._remake_table(
501 old_rel.through,
502 alter_fields=[
503 (
504 # The field that points to the target model is needed,
505 # so that table can be remade with the new m2m field -
506 # this is m2m_reverse_field_name().
507 old_rel.through._model_meta.get_forward_field(
508 old_field.m2m_reverse_field_name()
509 ),
510 new_rel.through._model_meta.get_forward_field(
511 new_field.m2m_reverse_field_name()
512 ),
513 ),
514 (
515 # The field that points to the model itself is needed,
516 # so that table can be remade with the new self field -
517 # this is m2m_field_name().
518 old_rel.through._model_meta.get_forward_field(
519 old_field.m2m_field_name()
520 ),
521 new_rel.through._model_meta.get_forward_field(
522 new_field.m2m_field_name()
523 ),
524 ),
525 ],
526 )
527 return
528
529 # Make a new through table
530 self.create_model(new_rel.through)
531 # Copy the data across
532 self.execute(
533 "INSERT INTO {} ({}) SELECT {} FROM {}".format(
534 self.quote_name(new_rel.through.model_options.db_table),
535 ", ".join(
536 [
537 "id",
538 new_field.m2m_column_name(),
539 new_field.m2m_reverse_name(),
540 ]
541 ),
542 ", ".join(
543 [
544 "id",
545 old_field.m2m_column_name(),
546 old_field.m2m_reverse_name(),
547 ]
548 ),
549 self.quote_name(old_rel.through.model_options.db_table),
550 )
551 )
552 # Delete the old through table
553 self.delete_model(old_rel.through)
554
555 def add_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
556 if isinstance(constraint, UniqueConstraint) and (
557 constraint.condition
558 or constraint.contains_expressions
559 or constraint.include
560 or constraint.deferrable
561 ):
562 super().add_constraint(model, constraint)
563 else:
564 self._remake_table(model)
565
566 def remove_constraint(self, model: type[Model], constraint: BaseConstraint) -> None:
567 if isinstance(constraint, UniqueConstraint) and (
568 constraint.condition
569 or constraint.contains_expressions
570 or constraint.include
571 or constraint.deferrable
572 ):
573 super().remove_constraint(model, constraint)
574 else:
575 self._remake_table(model)
576
577 def _collate_sql(
578 self,
579 collation: str | None,
580 old_collation: str | None = None,
581 table_name: str | None = None,
582 ) -> str:
583 return "COLLATE " + collation if collation else ""