1from __future__ import annotations
2
3from functools import cached_property
4from typing import TYPE_CHECKING, Any
5
6from plain import models
7from plain.models.base import ModelBase
8from plain.models.migrations.operations.base import Operation
9from plain.models.migrations.state import ModelState
10from plain.models.migrations.utils import field_references, resolve_relation
11
12from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
13
14if TYPE_CHECKING:
15 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
16 from plain.models.fields import Field
17 from plain.models.migrations.state import ProjectState
18
19
20def _check_for_duplicates(arg_name: str, objs: Any) -> None:
21 used_vals = set()
22 for val in objs:
23 if val in used_vals:
24 raise ValueError(
25 f"Found duplicate value {val} in CreateModel {arg_name} argument."
26 )
27 used_vals.add(val)
28
29
30class ModelOperation(Operation):
31 def __init__(self, name: str) -> None:
32 self.name = name
33
34 @cached_property
35 def name_lower(self) -> str:
36 return self.name.lower()
37
38 def references_model(self, name: str, package_label: str) -> bool:
39 return name.lower() == self.name_lower
40
41 def reduce(
42 self, operation: Operation, package_label: str
43 ) -> bool | list[Operation]:
44 return super().reduce(operation, package_label) or self.can_reduce_through(
45 operation, package_label
46 )
47
48 def can_reduce_through(self, operation: Operation, package_label: str) -> bool:
49 return not operation.references_model(self.name, package_label)
50
51
52class CreateModel(ModelOperation):
53 """Create a model's table."""
54
55 serialization_expand_args = ["fields", "options"]
56
57 def __init__(
58 self,
59 name: str,
60 fields: list[tuple[str, Field]],
61 options: dict[str, Any] | None = None,
62 bases: tuple[Any, ...] | None = None,
63 ) -> None:
64 self.fields = fields
65 self.options = options or {}
66 self.bases = bases or (models.Model,)
67 super().__init__(name)
68 # Sanity-check that there are no duplicated field names or bases
69 _check_for_duplicates("fields", (name for name, _ in self.fields))
70 _check_for_duplicates(
71 "bases",
72 (
73 base.model_options.label_lower
74 if not isinstance(base, str)
75 and base is not models.Model
76 and hasattr(base, "_model_meta")
77 else base.lower()
78 if isinstance(base, str)
79 else base
80 for base in self.bases
81 ),
82 )
83
84 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
85 kwargs: dict[str, Any] = {
86 "name": self.name,
87 "fields": self.fields,
88 }
89 if self.options:
90 kwargs["options"] = self.options
91 if self.bases and self.bases != (models.Model,):
92 kwargs["bases"] = self.bases
93 return (self.__class__.__qualname__, (), kwargs)
94
95 def state_forwards(self, package_label: str, state: ProjectState) -> None:
96 state.add_model(
97 ModelState(
98 package_label,
99 self.name,
100 list(self.fields),
101 dict(self.options),
102 tuple(self.bases),
103 )
104 )
105
106 def database_forwards(
107 self,
108 package_label: str,
109 schema_editor: BaseDatabaseSchemaEditor,
110 from_state: ProjectState,
111 to_state: ProjectState,
112 ) -> None:
113 model = to_state.models_registry.get_model(package_label, self.name)
114 if self.allow_migrate_model(schema_editor.connection, model):
115 schema_editor.create_model(model)
116
117 def describe(self) -> str:
118 return f"Create model {self.name}"
119
120 @property
121 def migration_name_fragment(self) -> str:
122 return self.name_lower
123
124 def references_model(self, name: str, package_label: str) -> bool:
125 name_lower = name.lower()
126 if name_lower == self.name_lower:
127 return True
128
129 # Check we didn't inherit from the model
130 reference_model_tuple = (package_label, name_lower)
131 for base in self.bases:
132 if (
133 base is not models.Model
134 and isinstance(base, ModelBase | str)
135 and resolve_relation(base, package_label) == reference_model_tuple
136 ):
137 return True
138
139 # Check we have no FKs/M2Ms with it
140 for _name, field in self.fields:
141 if field_references(
142 (package_label, self.name_lower), field, reference_model_tuple
143 ):
144 return True
145 return False
146
147 def reduce(
148 self, operation: Operation, package_label: str
149 ) -> bool | list[Operation]:
150 if (
151 isinstance(operation, DeleteModel)
152 and self.name_lower == operation.name_lower
153 ):
154 return []
155 elif (
156 isinstance(operation, RenameModel)
157 and self.name_lower == operation.old_name_lower
158 ):
159 return [
160 CreateModel(
161 operation.new_name,
162 fields=self.fields,
163 options=self.options,
164 bases=self.bases,
165 ),
166 ]
167 elif (
168 isinstance(operation, AlterModelOptions)
169 and self.name_lower == operation.name_lower
170 ):
171 options = {**self.options, **operation.options}
172 for key in operation.ALTER_OPTION_KEYS:
173 if key not in operation.options:
174 options.pop(key, None)
175 return [
176 CreateModel(
177 self.name,
178 fields=self.fields,
179 options=options,
180 bases=self.bases,
181 ),
182 ]
183 elif (
184 isinstance(operation, FieldOperation)
185 and self.name_lower == operation.model_name_lower
186 ):
187 if isinstance(operation, AddField):
188 assert operation.field is not None
189 return [
190 CreateModel(
191 self.name,
192 fields=self.fields + [(operation.name, operation.field)],
193 options=self.options,
194 bases=self.bases,
195 ),
196 ]
197 elif isinstance(operation, AlterField):
198 assert operation.field is not None
199 return [
200 CreateModel(
201 self.name,
202 fields=[
203 (n, operation.field if n == operation.name else v)
204 for n, v in self.fields
205 ],
206 options=self.options,
207 bases=self.bases,
208 ),
209 ]
210 elif isinstance(operation, RemoveField):
211 options = self.options.copy()
212
213 return [
214 CreateModel(
215 self.name,
216 fields=[
217 (n, v)
218 for n, v in self.fields
219 if n.lower() != operation.name_lower
220 ],
221 options=options,
222 bases=self.bases,
223 ),
224 ]
225 elif isinstance(operation, RenameField):
226 options = self.options.copy()
227
228 return [
229 CreateModel(
230 self.name,
231 fields=[
232 (operation.new_name if n == operation.old_name else n, v)
233 for n, v in self.fields
234 ],
235 options=options,
236 bases=self.bases,
237 ),
238 ]
239 return super().reduce(operation, package_label)
240
241
242class DeleteModel(ModelOperation):
243 """Drop a model's table."""
244
245 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
246 kwargs: dict[str, Any] = {
247 "name": self.name,
248 }
249 return (self.__class__.__qualname__, (), kwargs)
250
251 def state_forwards(self, package_label: str, state: ProjectState) -> None:
252 state.remove_model(package_label, self.name_lower)
253
254 def database_forwards(
255 self,
256 package_label: str,
257 schema_editor: BaseDatabaseSchemaEditor,
258 from_state: ProjectState,
259 to_state: ProjectState,
260 ) -> None:
261 model = from_state.models_registry.get_model(package_label, self.name)
262 if self.allow_migrate_model(schema_editor.connection, model):
263 schema_editor.delete_model(model)
264
265 def references_model(self, name: str, package_label: str) -> bool:
266 # The deleted model could be referencing the specified model through
267 # related fields.
268 return True
269
270 def describe(self) -> str:
271 return f"Delete model {self.name}"
272
273 @property
274 def migration_name_fragment(self) -> str:
275 return f"delete_{self.name_lower}"
276
277
278class RenameModel(ModelOperation):
279 """Rename a model."""
280
281 def __init__(self, old_name: str, new_name: str) -> None:
282 self.old_name = old_name
283 self.new_name = new_name
284 super().__init__(old_name)
285
286 @cached_property
287 def old_name_lower(self) -> str:
288 return self.old_name.lower()
289
290 @cached_property
291 def new_name_lower(self) -> str:
292 return self.new_name.lower()
293
294 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
295 kwargs: dict[str, Any] = {
296 "old_name": self.old_name,
297 "new_name": self.new_name,
298 }
299 return (self.__class__.__qualname__, (), kwargs)
300
301 def state_forwards(self, package_label: str, state: ProjectState) -> None:
302 state.rename_model(package_label, self.old_name, self.new_name)
303
304 def database_forwards(
305 self,
306 package_label: str,
307 schema_editor: BaseDatabaseSchemaEditor,
308 from_state: ProjectState,
309 to_state: ProjectState,
310 ) -> None:
311 new_model = to_state.models_registry.get_model(package_label, self.new_name)
312 if self.allow_migrate_model(schema_editor.connection, new_model):
313 old_model = from_state.models_registry.get_model(
314 package_label, self.old_name
315 )
316 # Move the main table
317 schema_editor.alter_db_table(
318 new_model,
319 old_model.model_options.db_table,
320 new_model.model_options.db_table,
321 )
322 # Alter the fields pointing to us
323 for related_object in old_model._model_meta.related_objects:
324 if related_object.related_model == old_model:
325 model = new_model
326 related_key = (package_label, self.new_name_lower)
327 else:
328 model = related_object.related_model
329 related_key = (
330 related_object.related_model.model_options.package_label,
331 related_object.related_model.model_options.model_name,
332 )
333 to_field = to_state.models_registry.get_model(
334 *related_key
335 )._model_meta.get_field(related_object.field.name)
336 schema_editor.alter_field(
337 model,
338 related_object.field,
339 to_field,
340 )
341
342 def references_model(self, name: str, package_label: str) -> bool:
343 return (
344 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
345 )
346
347 def describe(self) -> str:
348 return f"Rename model {self.old_name} to {self.new_name}"
349
350 @property
351 def migration_name_fragment(self) -> str:
352 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
353
354 def reduce(
355 self, operation: Operation, package_label: str
356 ) -> bool | list[Operation]:
357 if (
358 isinstance(operation, RenameModel)
359 and self.new_name_lower == operation.old_name_lower
360 ):
361 return [
362 RenameModel(
363 self.old_name,
364 operation.new_name,
365 ),
366 ]
367 # Skip `ModelOperation.reduce` as we want to run `references_model`
368 # against self.new_name.
369 return super(ModelOperation, self).reduce(
370 operation, package_label
371 ) or not operation.references_model(self.new_name, package_label)
372
373
374class ModelOptionOperation(ModelOperation):
375 def reduce(
376 self, operation: Operation, package_label: str
377 ) -> bool | list[Operation]:
378 # Use tuple syntax because self.__class__ is not compatible with union syntax in isinstance
379 if isinstance(operation, (self.__class__, DeleteModel)) and ( # noqa: UP038
380 self.name_lower == operation.name_lower
381 ):
382 return [operation]
383 return super().reduce(operation, package_label)
384
385
386class AlterModelTable(ModelOptionOperation):
387 """Rename a model's table."""
388
389 def __init__(self, name: str, table: str | None) -> None:
390 self.table = table
391 super().__init__(name)
392
393 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
394 kwargs: dict[str, Any] = {
395 "name": self.name,
396 "table": self.table,
397 }
398 return (self.__class__.__qualname__, (), kwargs)
399
400 def state_forwards(self, package_label: str, state: ProjectState) -> None:
401 state.alter_model_options(
402 package_label, self.name_lower, {"db_table": self.table}
403 )
404
405 def database_forwards(
406 self,
407 package_label: str,
408 schema_editor: BaseDatabaseSchemaEditor,
409 from_state: ProjectState,
410 to_state: ProjectState,
411 ) -> None:
412 new_model = to_state.models_registry.get_model(package_label, self.name)
413 if self.allow_migrate_model(schema_editor.connection, new_model):
414 old_model = from_state.models_registry.get_model(package_label, self.name)
415 schema_editor.alter_db_table(
416 new_model,
417 old_model.model_options.db_table,
418 new_model.model_options.db_table,
419 )
420
421 def describe(self) -> str:
422 return "Rename table for {} to {}".format(
423 self.name,
424 self.table if self.table is not None else "(default)",
425 )
426
427 @property
428 def migration_name_fragment(self) -> str:
429 return f"alter_{self.name_lower}_table"
430
431
432class AlterModelOptions(ModelOptionOperation):
433 """
434 Set new model options that don't directly affect the database schema
435 (like ordering). Python code in migrations
436 may still need them.
437 """
438
439 # Model options we want to compare and preserve in an AlterModelOptions op
440 ALTER_OPTION_KEYS = [
441 "ordering",
442 ]
443
444 def __init__(self, name: str, options: dict[str, Any]) -> None:
445 self.options = options
446 super().__init__(name)
447
448 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
449 kwargs: dict[str, Any] = {
450 "name": self.name,
451 "options": self.options,
452 }
453 return (self.__class__.__qualname__, (), kwargs)
454
455 def state_forwards(self, package_label: str, state: ProjectState) -> None:
456 state.alter_model_options(
457 package_label,
458 self.name_lower,
459 self.options,
460 self.ALTER_OPTION_KEYS,
461 )
462
463 def database_forwards(
464 self,
465 package_label: str,
466 schema_editor: BaseDatabaseSchemaEditor,
467 from_state: ProjectState,
468 to_state: ProjectState,
469 ) -> None:
470 pass
471
472 def describe(self) -> str:
473 return f"Change Meta options on {self.name}"
474
475 @property
476 def migration_name_fragment(self) -> str:
477 return f"alter_{self.name_lower}_options"
478
479
480class IndexOperation(Operation):
481 option_name = "indexes"
482 model_name: str # Set by subclasses
483
484 @cached_property
485 def model_name_lower(self) -> str:
486 return self.model_name.lower()
487
488
489class AddIndex(IndexOperation):
490 """Add an index on a model."""
491
492 def __init__(self, model_name: str, index: Any) -> None:
493 self.model_name = model_name
494 if not index.name:
495 raise ValueError(
496 "Indexes passed to AddIndex operations require a name "
497 f"argument. {index!r} doesn't have one."
498 )
499 self.index = index
500
501 def state_forwards(self, package_label: str, state: ProjectState) -> None:
502 state.add_index(package_label, self.model_name_lower, self.index)
503
504 def database_forwards(
505 self,
506 package_label: str,
507 schema_editor: BaseDatabaseSchemaEditor,
508 from_state: ProjectState,
509 to_state: ProjectState,
510 ) -> None:
511 model = to_state.models_registry.get_model(package_label, self.model_name)
512 if self.allow_migrate_model(schema_editor.connection, model):
513 schema_editor.add_index(model, self.index)
514
515 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
516 kwargs: dict[str, Any] = {
517 "model_name": self.model_name,
518 "index": self.index,
519 }
520 return (
521 self.__class__.__qualname__,
522 (),
523 kwargs,
524 )
525
526 def describe(self) -> str:
527 if self.index.expressions:
528 return "Create index {} on {} on model {}".format(
529 self.index.name,
530 ", ".join([str(expression) for expression in self.index.expressions]),
531 self.model_name,
532 )
533 return "Create index {} on field(s) {} of model {}".format(
534 self.index.name,
535 ", ".join(self.index.fields),
536 self.model_name,
537 )
538
539 @property
540 def migration_name_fragment(self) -> str:
541 return f"{self.model_name_lower}_{self.index.name.lower()}"
542
543
544class RemoveIndex(IndexOperation):
545 """Remove an index from a model."""
546
547 def __init__(self, model_name: str, name: str) -> None:
548 self.model_name = model_name
549 self.name = name
550
551 def state_forwards(self, package_label: str, state: ProjectState) -> None:
552 state.remove_index(package_label, self.model_name_lower, self.name)
553
554 def database_forwards(
555 self,
556 package_label: str,
557 schema_editor: BaseDatabaseSchemaEditor,
558 from_state: ProjectState,
559 to_state: ProjectState,
560 ) -> None:
561 model = from_state.models_registry.get_model(package_label, self.model_name)
562 if self.allow_migrate_model(schema_editor.connection, model):
563 from_model_state = from_state.models[package_label, self.model_name_lower]
564 index = from_model_state.get_index_by_name(self.name)
565 schema_editor.remove_index(model, index)
566
567 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
568 kwargs: dict[str, Any] = {
569 "model_name": self.model_name,
570 "name": self.name,
571 }
572 return (
573 self.__class__.__qualname__,
574 (),
575 kwargs,
576 )
577
578 def describe(self) -> str:
579 return f"Remove index {self.name} from {self.model_name}"
580
581 @property
582 def migration_name_fragment(self) -> str:
583 return f"remove_{self.model_name_lower}_{self.name.lower()}"
584
585
586class RenameIndex(IndexOperation):
587 """Rename an index."""
588
589 def __init__(
590 self,
591 model_name: str,
592 new_name: str,
593 old_name: str | None = None,
594 old_fields: list[str] | tuple[str, ...] | None = None,
595 ) -> None:
596 if not old_name and not old_fields:
597 raise ValueError(
598 "RenameIndex requires one of old_name and old_fields arguments to be "
599 "set."
600 )
601 if old_name and old_fields:
602 raise ValueError(
603 "RenameIndex.old_name and old_fields are mutually exclusive."
604 )
605 self.model_name = model_name
606 self.new_name = new_name
607 self.old_name = old_name
608 self.old_fields = old_fields
609
610 @cached_property
611 def old_name_lower(self) -> str:
612 assert self.old_name is not None, "old_name is set during initialization"
613 return self.old_name.lower()
614
615 @cached_property
616 def new_name_lower(self) -> str:
617 return self.new_name.lower()
618
619 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
620 kwargs: dict[str, Any] = {
621 "model_name": self.model_name,
622 "new_name": self.new_name,
623 }
624 if self.old_name:
625 kwargs["old_name"] = self.old_name
626 if self.old_fields:
627 kwargs["old_fields"] = self.old_fields
628 return (self.__class__.__qualname__, (), kwargs)
629
630 def state_forwards(self, package_label: str, state: ProjectState) -> None:
631 if self.old_fields:
632 state.add_index(
633 package_label,
634 self.model_name_lower,
635 models.Index(fields=self.old_fields, name=self.new_name),
636 )
637 else:
638 assert self.old_name is not None
639 state.rename_index(
640 package_label,
641 self.model_name_lower,
642 self.old_name,
643 self.new_name,
644 )
645
646 def database_forwards(
647 self,
648 package_label: str,
649 schema_editor: BaseDatabaseSchemaEditor,
650 from_state: ProjectState,
651 to_state: ProjectState,
652 ) -> None:
653 model = to_state.models_registry.get_model(package_label, self.model_name)
654 if not self.allow_migrate_model(schema_editor.connection, model):
655 return None
656
657 if self.old_fields:
658 from_model = from_state.models_registry.get_model(
659 package_label, self.model_name
660 )
661 columns = [
662 from_model._model_meta.get_forward_field(field).column
663 for field in self.old_fields
664 ]
665 matching_index_name = schema_editor._constraint_names(
666 from_model, column_names=columns, index=True
667 )
668 if len(matching_index_name) != 1:
669 raise ValueError(
670 "Found wrong number ({}) of indexes for {}({}).".format(
671 len(matching_index_name),
672 from_model.model_options.db_table,
673 ", ".join(columns),
674 )
675 )
676 old_index = models.Index(
677 fields=self.old_fields,
678 name=matching_index_name[0],
679 )
680 else:
681 from_model_state = from_state.models[package_label, self.model_name_lower]
682 assert self.old_name is not None
683 old_index = from_model_state.get_index_by_name(self.old_name)
684 # Don't alter when the index name is not changed.
685 if old_index.name == self.new_name:
686 return None
687
688 to_model_state = to_state.models[package_label, self.model_name_lower]
689 new_index = to_model_state.get_index_by_name(self.new_name)
690 schema_editor.rename_index(model, old_index, new_index)
691 return None
692
693 def describe(self) -> str:
694 if self.old_name:
695 return (
696 f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
697 )
698 return (
699 f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
700 f"{self.new_name}"
701 )
702
703 @property
704 def migration_name_fragment(self) -> str:
705 if self.old_name:
706 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
707 assert self.old_fields is not None, "old_fields is set when old_name is None"
708 return "rename_{}_{}_{}".format(
709 self.model_name_lower,
710 "_".join(self.old_fields),
711 self.new_name_lower,
712 )
713
714 def reduce(
715 self, operation: Operation, package_label: str
716 ) -> bool | list[Operation]:
717 if (
718 isinstance(operation, RenameIndex)
719 and self.model_name_lower == operation.model_name_lower
720 and operation.old_name
721 and self.new_name_lower == operation.old_name_lower
722 ):
723 return [
724 RenameIndex(
725 self.model_name,
726 new_name=operation.new_name,
727 old_name=self.old_name,
728 old_fields=self.old_fields,
729 )
730 ]
731 return super().reduce(operation, package_label)
732
733
734class AddConstraint(IndexOperation):
735 option_name = "constraints"
736
737 def __init__(self, model_name: str, constraint: Any) -> None:
738 self.model_name = model_name
739 self.constraint = constraint
740
741 def state_forwards(self, package_label: str, state: ProjectState) -> None:
742 state.add_constraint(package_label, self.model_name_lower, self.constraint)
743
744 def database_forwards(
745 self,
746 package_label: str,
747 schema_editor: BaseDatabaseSchemaEditor,
748 from_state: ProjectState,
749 to_state: ProjectState,
750 ) -> None:
751 model = to_state.models_registry.get_model(package_label, self.model_name)
752 if self.allow_migrate_model(schema_editor.connection, model):
753 schema_editor.add_constraint(model, self.constraint)
754
755 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
756 return (
757 self.__class__.__name__,
758 (),
759 {
760 "model_name": self.model_name,
761 "constraint": self.constraint,
762 },
763 )
764
765 def describe(self) -> str:
766 return f"Create constraint {self.constraint.name} on model {self.model_name}"
767
768 @property
769 def migration_name_fragment(self) -> str:
770 return f"{self.model_name_lower}_{self.constraint.name.lower()}"
771
772
773class RemoveConstraint(IndexOperation):
774 option_name = "constraints"
775
776 def __init__(self, model_name: str, name: str) -> None:
777 self.model_name = model_name
778 self.name = name
779
780 def state_forwards(self, package_label: str, state: ProjectState) -> None:
781 state.remove_constraint(package_label, self.model_name_lower, self.name)
782
783 def database_forwards(
784 self,
785 package_label: str,
786 schema_editor: BaseDatabaseSchemaEditor,
787 from_state: ProjectState,
788 to_state: ProjectState,
789 ) -> None:
790 model = to_state.models_registry.get_model(package_label, self.model_name)
791 if self.allow_migrate_model(schema_editor.connection, model):
792 from_model_state = from_state.models[package_label, self.model_name_lower]
793 constraint = from_model_state.get_constraint_by_name(self.name)
794 schema_editor.remove_constraint(model, constraint)
795
796 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
797 return (
798 self.__class__.__name__,
799 (),
800 {
801 "model_name": self.model_name,
802 "name": self.name,
803 },
804 )
805
806 def describe(self) -> str:
807 return f"Remove constraint {self.name} from model {self.model_name}"
808
809 @property
810 def migration_name_fragment(self) -> str:
811 return f"remove_{self.model_name_lower}_{self.name.lower()}"