1from __future__ import annotations
2
3from functools import cached_property
4from typing import TYPE_CHECKING, Any
5
6from plain import models
7from plain.models.base import ModelBase
8from plain.models.migrations.operations.base import Operation
9from plain.models.migrations.state import ModelState
10from plain.models.migrations.utils import field_references, resolve_relation
11
12from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
13
14if TYPE_CHECKING:
15 from plain.models.fields import Field
16 from plain.models.migrations.state import ProjectState
17 from plain.models.postgres.schema import DatabaseSchemaEditor
18
19
20def _check_for_duplicates(arg_name: str, objs: Any) -> None:
21 used_vals = set()
22 for val in objs:
23 if val in used_vals:
24 raise ValueError(
25 f"Found duplicate value {val} in CreateModel {arg_name} argument."
26 )
27 used_vals.add(val)
28
29
30class ModelOperation(Operation):
31 def __init__(self, name: str) -> None:
32 self.name = name
33
34 @cached_property
35 def name_lower(self) -> str:
36 return self.name.lower()
37
38 def references_model(self, name: str, package_label: str) -> bool:
39 return name.lower() == self.name_lower
40
41 def reduce(
42 self, operation: Operation, package_label: str
43 ) -> bool | list[Operation]:
44 return super().reduce(operation, package_label) or self.can_reduce_through(
45 operation, package_label
46 )
47
48 def can_reduce_through(self, operation: Operation, package_label: str) -> bool:
49 return not operation.references_model(self.name, package_label)
50
51
52class CreateModel(ModelOperation):
53 """Create a model's table."""
54
55 serialization_expand_args = ["fields", "options"]
56
57 def __init__(
58 self,
59 name: str,
60 fields: list[tuple[str, Field]],
61 options: dict[str, Any] | None = None,
62 bases: tuple[Any, ...] | None = None,
63 ) -> None:
64 self.fields = fields
65 self.options = options or {}
66 self.bases = bases or (models.Model,)
67 super().__init__(name)
68 # Sanity-check that there are no duplicated field names or bases
69 _check_for_duplicates("fields", (name for name, _ in self.fields))
70 _check_for_duplicates(
71 "bases",
72 (
73 base.model_options.label_lower
74 if not isinstance(base, str)
75 and base is not models.Model
76 and hasattr(base, "_model_meta")
77 else base.lower()
78 if isinstance(base, str)
79 else base
80 for base in self.bases
81 ),
82 )
83
84 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
85 kwargs: dict[str, Any] = {
86 "name": self.name,
87 "fields": self.fields,
88 }
89 if self.options:
90 kwargs["options"] = self.options
91 if self.bases and self.bases != (models.Model,):
92 kwargs["bases"] = self.bases
93 return (self.__class__.__qualname__, (), kwargs)
94
95 def state_forwards(self, package_label: str, state: ProjectState) -> None:
96 state.add_model(
97 ModelState(
98 package_label,
99 self.name,
100 list(self.fields),
101 dict(self.options),
102 tuple(self.bases),
103 )
104 )
105
106 def database_forwards(
107 self,
108 package_label: str,
109 schema_editor: DatabaseSchemaEditor,
110 from_state: ProjectState,
111 to_state: ProjectState,
112 ) -> None:
113 model = to_state.models_registry.get_model(package_label, self.name)
114 schema_editor.create_model(model)
115
116 def describe(self) -> str:
117 return f"Create model {self.name}"
118
119 @property
120 def migration_name_fragment(self) -> str:
121 return self.name_lower
122
123 def references_model(self, name: str, package_label: str) -> bool:
124 name_lower = name.lower()
125 if name_lower == self.name_lower:
126 return True
127
128 # Check we didn't inherit from the model
129 reference_model_tuple = (package_label, name_lower)
130 for base in self.bases:
131 if (
132 base is not models.Model
133 and isinstance(base, ModelBase | str)
134 and resolve_relation(base, package_label) == reference_model_tuple
135 ):
136 return True
137
138 # Check we have no FKs/M2Ms with it
139 for _name, field in self.fields:
140 if field_references(
141 (package_label, self.name_lower), field, reference_model_tuple
142 ):
143 return True
144 return False
145
146 def reduce(
147 self, operation: Operation, package_label: str
148 ) -> bool | list[Operation]:
149 if (
150 isinstance(operation, DeleteModel)
151 and self.name_lower == operation.name_lower
152 ):
153 return []
154 elif (
155 isinstance(operation, RenameModel)
156 and self.name_lower == operation.old_name_lower
157 ):
158 return [
159 CreateModel(
160 operation.new_name,
161 fields=self.fields,
162 options=self.options,
163 bases=self.bases,
164 ),
165 ]
166 elif (
167 isinstance(operation, AlterModelOptions)
168 and self.name_lower == operation.name_lower
169 ):
170 options = {**self.options, **operation.options}
171 for key in operation.ALTER_OPTION_KEYS:
172 if key not in operation.options:
173 options.pop(key, None)
174 return [
175 CreateModel(
176 self.name,
177 fields=self.fields,
178 options=options,
179 bases=self.bases,
180 ),
181 ]
182 elif (
183 isinstance(operation, FieldOperation)
184 and self.name_lower == operation.model_name_lower
185 ):
186 if isinstance(operation, AddField):
187 assert operation.field is not None
188 return [
189 CreateModel(
190 self.name,
191 fields=self.fields + [(operation.name, operation.field)],
192 options=self.options,
193 bases=self.bases,
194 ),
195 ]
196 elif isinstance(operation, AlterField):
197 assert operation.field is not None
198 return [
199 CreateModel(
200 self.name,
201 fields=[
202 (n, operation.field if n == operation.name else v)
203 for n, v in self.fields
204 ],
205 options=self.options,
206 bases=self.bases,
207 ),
208 ]
209 elif isinstance(operation, RemoveField):
210 options = self.options.copy()
211
212 return [
213 CreateModel(
214 self.name,
215 fields=[
216 (n, v)
217 for n, v in self.fields
218 if n.lower() != operation.name_lower
219 ],
220 options=options,
221 bases=self.bases,
222 ),
223 ]
224 elif isinstance(operation, RenameField):
225 options = self.options.copy()
226
227 return [
228 CreateModel(
229 self.name,
230 fields=[
231 (operation.new_name if n == operation.old_name else n, v)
232 for n, v in self.fields
233 ],
234 options=options,
235 bases=self.bases,
236 ),
237 ]
238 return super().reduce(operation, package_label)
239
240
241class DeleteModel(ModelOperation):
242 """Drop a model's table."""
243
244 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
245 kwargs: dict[str, Any] = {
246 "name": self.name,
247 }
248 return (self.__class__.__qualname__, (), kwargs)
249
250 def state_forwards(self, package_label: str, state: ProjectState) -> None:
251 state.remove_model(package_label, self.name_lower)
252
253 def database_forwards(
254 self,
255 package_label: str,
256 schema_editor: DatabaseSchemaEditor,
257 from_state: ProjectState,
258 to_state: ProjectState,
259 ) -> None:
260 model = from_state.models_registry.get_model(package_label, self.name)
261 schema_editor.delete_model(model)
262
263 def references_model(self, name: str, package_label: str) -> bool:
264 # The deleted model could be referencing the specified model through
265 # related fields.
266 return True
267
268 def describe(self) -> str:
269 return f"Delete model {self.name}"
270
271 @property
272 def migration_name_fragment(self) -> str:
273 return f"delete_{self.name_lower}"
274
275
276class RenameModel(ModelOperation):
277 """Rename a model."""
278
279 def __init__(self, old_name: str, new_name: str) -> None:
280 self.old_name = old_name
281 self.new_name = new_name
282 super().__init__(old_name)
283
284 @cached_property
285 def old_name_lower(self) -> str:
286 return self.old_name.lower()
287
288 @cached_property
289 def new_name_lower(self) -> str:
290 return self.new_name.lower()
291
292 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
293 kwargs: dict[str, Any] = {
294 "old_name": self.old_name,
295 "new_name": self.new_name,
296 }
297 return (self.__class__.__qualname__, (), kwargs)
298
299 def state_forwards(self, package_label: str, state: ProjectState) -> None:
300 state.rename_model(package_label, self.old_name, self.new_name)
301
302 def database_forwards(
303 self,
304 package_label: str,
305 schema_editor: DatabaseSchemaEditor,
306 from_state: ProjectState,
307 to_state: ProjectState,
308 ) -> None:
309 new_model = to_state.models_registry.get_model(package_label, self.new_name)
310 old_model = from_state.models_registry.get_model(package_label, self.old_name)
311 # Move the main table
312 schema_editor.alter_db_table(
313 new_model,
314 old_model.model_options.db_table,
315 new_model.model_options.db_table,
316 )
317 # Alter the fields pointing to us
318 for related_object in old_model._model_meta.related_objects:
319 if related_object.related_model == old_model:
320 model = new_model
321 related_key = (package_label, self.new_name_lower)
322 else:
323 model = related_object.related_model
324 related_key = (
325 related_object.related_model.model_options.package_label,
326 related_object.related_model.model_options.model_name,
327 )
328 to_field = to_state.models_registry.get_model(
329 *related_key
330 )._model_meta.get_field(related_object.field.name)
331 schema_editor.alter_field(
332 model,
333 related_object.field,
334 to_field,
335 )
336
337 def references_model(self, name: str, package_label: str) -> bool:
338 return (
339 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
340 )
341
342 def describe(self) -> str:
343 return f"Rename model {self.old_name} to {self.new_name}"
344
345 @property
346 def migration_name_fragment(self) -> str:
347 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
348
349 def reduce(
350 self, operation: Operation, package_label: str
351 ) -> bool | list[Operation]:
352 if (
353 isinstance(operation, RenameModel)
354 and self.new_name_lower == operation.old_name_lower
355 ):
356 return [
357 RenameModel(
358 self.old_name,
359 operation.new_name,
360 ),
361 ]
362 # Skip `ModelOperation.reduce` as we want to run `references_model`
363 # against self.new_name.
364 return super(ModelOperation, self).reduce(
365 operation, package_label
366 ) or not operation.references_model(self.new_name, package_label)
367
368
369class ModelOptionOperation(ModelOperation):
370 def reduce(
371 self, operation: Operation, package_label: str
372 ) -> bool | list[Operation]:
373 # Use tuple syntax because self.__class__ is not compatible with union syntax in isinstance
374 if isinstance(operation, (self.__class__, DeleteModel)) and ( # noqa: UP038
375 self.name_lower == operation.name_lower
376 ):
377 return [operation]
378 return super().reduce(operation, package_label)
379
380
381class AlterModelTable(ModelOptionOperation):
382 """Rename a model's table."""
383
384 def __init__(self, name: str, table: str | None) -> None:
385 self.table = table
386 super().__init__(name)
387
388 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
389 kwargs: dict[str, Any] = {
390 "name": self.name,
391 "table": self.table,
392 }
393 return (self.__class__.__qualname__, (), kwargs)
394
395 def state_forwards(self, package_label: str, state: ProjectState) -> None:
396 state.alter_model_options(
397 package_label, self.name_lower, {"db_table": self.table}
398 )
399
400 def database_forwards(
401 self,
402 package_label: str,
403 schema_editor: DatabaseSchemaEditor,
404 from_state: ProjectState,
405 to_state: ProjectState,
406 ) -> None:
407 new_model = to_state.models_registry.get_model(package_label, self.name)
408 old_model = from_state.models_registry.get_model(package_label, self.name)
409 schema_editor.alter_db_table(
410 new_model,
411 old_model.model_options.db_table,
412 new_model.model_options.db_table,
413 )
414
415 def describe(self) -> str:
416 return "Rename table for {} to {}".format(
417 self.name,
418 self.table if self.table is not None else "(default)",
419 )
420
421 @property
422 def migration_name_fragment(self) -> str:
423 return f"alter_{self.name_lower}_table"
424
425
426class AlterModelOptions(ModelOptionOperation):
427 """
428 Set new model options that don't directly affect the database schema
429 (like ordering). Python code in migrations
430 may still need them.
431 """
432
433 # Model options we want to compare and preserve in an AlterModelOptions op
434 ALTER_OPTION_KEYS = [
435 "ordering",
436 ]
437
438 def __init__(self, name: str, options: dict[str, Any]) -> None:
439 self.options = options
440 super().__init__(name)
441
442 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
443 kwargs: dict[str, Any] = {
444 "name": self.name,
445 "options": self.options,
446 }
447 return (self.__class__.__qualname__, (), kwargs)
448
449 def state_forwards(self, package_label: str, state: ProjectState) -> None:
450 state.alter_model_options(
451 package_label,
452 self.name_lower,
453 self.options,
454 self.ALTER_OPTION_KEYS,
455 )
456
457 def database_forwards(
458 self,
459 package_label: str,
460 schema_editor: DatabaseSchemaEditor,
461 from_state: ProjectState,
462 to_state: ProjectState,
463 ) -> None:
464 pass
465
466 def describe(self) -> str:
467 return f"Change Meta options on {self.name}"
468
469 @property
470 def migration_name_fragment(self) -> str:
471 return f"alter_{self.name_lower}_options"
472
473
474class IndexOperation(Operation):
475 option_name = "indexes"
476 model_name: str # Set by subclasses
477
478 @cached_property
479 def model_name_lower(self) -> str:
480 return self.model_name.lower()
481
482
483class AddIndex(IndexOperation):
484 """Add an index on a model."""
485
486 def __init__(self, model_name: str, index: Any) -> None:
487 self.model_name = model_name
488 if not index.name:
489 raise ValueError(
490 "Indexes passed to AddIndex operations require a name "
491 f"argument. {index!r} doesn't have one."
492 )
493 self.index = index
494
495 def state_forwards(self, package_label: str, state: ProjectState) -> None:
496 state.add_index(package_label, self.model_name_lower, self.index)
497
498 def database_forwards(
499 self,
500 package_label: str,
501 schema_editor: DatabaseSchemaEditor,
502 from_state: ProjectState,
503 to_state: ProjectState,
504 ) -> None:
505 model = to_state.models_registry.get_model(package_label, self.model_name)
506 schema_editor.add_index(model, self.index)
507
508 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
509 kwargs: dict[str, Any] = {
510 "model_name": self.model_name,
511 "index": self.index,
512 }
513 return (
514 self.__class__.__qualname__,
515 (),
516 kwargs,
517 )
518
519 def describe(self) -> str:
520 if self.index.expressions:
521 return "Create index {} on {} on model {}".format(
522 self.index.name,
523 ", ".join([str(expression) for expression in self.index.expressions]),
524 self.model_name,
525 )
526 return "Create index {} on field(s) {} of model {}".format(
527 self.index.name,
528 ", ".join(self.index.fields),
529 self.model_name,
530 )
531
532 @property
533 def migration_name_fragment(self) -> str:
534 return f"{self.model_name_lower}_{self.index.name.lower()}"
535
536
537class RemoveIndex(IndexOperation):
538 """Remove an index from a model."""
539
540 def __init__(self, model_name: str, name: str) -> None:
541 self.model_name = model_name
542 self.name = name
543
544 def state_forwards(self, package_label: str, state: ProjectState) -> None:
545 state.remove_index(package_label, self.model_name_lower, self.name)
546
547 def database_forwards(
548 self,
549 package_label: str,
550 schema_editor: DatabaseSchemaEditor,
551 from_state: ProjectState,
552 to_state: ProjectState,
553 ) -> None:
554 model = from_state.models_registry.get_model(package_label, self.model_name)
555 from_model_state = from_state.models[package_label, self.model_name_lower]
556 index = from_model_state.get_index_by_name(self.name)
557 schema_editor.remove_index(model, index)
558
559 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
560 kwargs: dict[str, Any] = {
561 "model_name": self.model_name,
562 "name": self.name,
563 }
564 return (
565 self.__class__.__qualname__,
566 (),
567 kwargs,
568 )
569
570 def describe(self) -> str:
571 return f"Remove index {self.name} from {self.model_name}"
572
573 @property
574 def migration_name_fragment(self) -> str:
575 return f"remove_{self.model_name_lower}_{self.name.lower()}"
576
577
578class RenameIndex(IndexOperation):
579 """Rename an index."""
580
581 def __init__(
582 self,
583 model_name: str,
584 new_name: str,
585 old_name: str | None = None,
586 old_fields: list[str] | tuple[str, ...] | None = None,
587 ) -> None:
588 if not old_name and not old_fields:
589 raise ValueError(
590 "RenameIndex requires one of old_name and old_fields arguments to be "
591 "set."
592 )
593 if old_name and old_fields:
594 raise ValueError(
595 "RenameIndex.old_name and old_fields are mutually exclusive."
596 )
597 self.model_name = model_name
598 self.new_name = new_name
599 self.old_name = old_name
600 self.old_fields = old_fields
601
602 @cached_property
603 def old_name_lower(self) -> str:
604 assert self.old_name is not None, "old_name is set during initialization"
605 return self.old_name.lower()
606
607 @cached_property
608 def new_name_lower(self) -> str:
609 return self.new_name.lower()
610
611 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
612 kwargs: dict[str, Any] = {
613 "model_name": self.model_name,
614 "new_name": self.new_name,
615 }
616 if self.old_name:
617 kwargs["old_name"] = self.old_name
618 if self.old_fields:
619 kwargs["old_fields"] = self.old_fields
620 return (self.__class__.__qualname__, (), kwargs)
621
622 def state_forwards(self, package_label: str, state: ProjectState) -> None:
623 if self.old_fields:
624 state.add_index(
625 package_label,
626 self.model_name_lower,
627 models.Index(fields=self.old_fields, name=self.new_name),
628 )
629 else:
630 assert self.old_name is not None
631 state.rename_index(
632 package_label,
633 self.model_name_lower,
634 self.old_name,
635 self.new_name,
636 )
637
638 def database_forwards(
639 self,
640 package_label: str,
641 schema_editor: DatabaseSchemaEditor,
642 from_state: ProjectState,
643 to_state: ProjectState,
644 ) -> None:
645 model = to_state.models_registry.get_model(package_label, self.model_name)
646 if self.old_fields:
647 from_model = from_state.models_registry.get_model(
648 package_label, self.model_name
649 )
650 columns = [
651 from_model._model_meta.get_forward_field(field).column
652 for field in self.old_fields
653 ]
654 matching_index_name = schema_editor._constraint_names(
655 from_model, column_names=columns, index=True
656 )
657 if len(matching_index_name) != 1:
658 raise ValueError(
659 "Found wrong number ({}) of indexes for {}({}).".format(
660 len(matching_index_name),
661 from_model.model_options.db_table,
662 ", ".join(columns),
663 )
664 )
665 old_index = models.Index(
666 fields=self.old_fields,
667 name=matching_index_name[0],
668 )
669 else:
670 from_model_state = from_state.models[package_label, self.model_name_lower]
671 assert self.old_name is not None
672 old_index = from_model_state.get_index_by_name(self.old_name)
673 # Don't alter when the index name is not changed.
674 if old_index.name == self.new_name:
675 return None
676
677 to_model_state = to_state.models[package_label, self.model_name_lower]
678 new_index = to_model_state.get_index_by_name(self.new_name)
679 schema_editor.rename_index(model, old_index, new_index)
680 return None
681
682 def describe(self) -> str:
683 if self.old_name:
684 return (
685 f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
686 )
687 return (
688 f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
689 f"{self.new_name}"
690 )
691
692 @property
693 def migration_name_fragment(self) -> str:
694 if self.old_name:
695 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
696 assert self.old_fields is not None, "old_fields is set when old_name is None"
697 return "rename_{}_{}_{}".format(
698 self.model_name_lower,
699 "_".join(self.old_fields),
700 self.new_name_lower,
701 )
702
703 def reduce(
704 self, operation: Operation, package_label: str
705 ) -> bool | list[Operation]:
706 if (
707 isinstance(operation, RenameIndex)
708 and self.model_name_lower == operation.model_name_lower
709 and operation.old_name
710 and self.new_name_lower == operation.old_name_lower
711 ):
712 return [
713 RenameIndex(
714 self.model_name,
715 new_name=operation.new_name,
716 old_name=self.old_name,
717 old_fields=self.old_fields,
718 )
719 ]
720 return super().reduce(operation, package_label)
721
722
723class AddConstraint(IndexOperation):
724 option_name = "constraints"
725
726 def __init__(self, model_name: str, constraint: Any) -> None:
727 self.model_name = model_name
728 self.constraint = constraint
729
730 def state_forwards(self, package_label: str, state: ProjectState) -> None:
731 state.add_constraint(package_label, self.model_name_lower, self.constraint)
732
733 def database_forwards(
734 self,
735 package_label: str,
736 schema_editor: DatabaseSchemaEditor,
737 from_state: ProjectState,
738 to_state: ProjectState,
739 ) -> None:
740 model = to_state.models_registry.get_model(package_label, self.model_name)
741 schema_editor.add_constraint(model, self.constraint)
742
743 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
744 return (
745 self.__class__.__name__,
746 (),
747 {
748 "model_name": self.model_name,
749 "constraint": self.constraint,
750 },
751 )
752
753 def describe(self) -> str:
754 return f"Create constraint {self.constraint.name} on model {self.model_name}"
755
756 @property
757 def migration_name_fragment(self) -> str:
758 return f"{self.model_name_lower}_{self.constraint.name.lower()}"
759
760
761class RemoveConstraint(IndexOperation):
762 option_name = "constraints"
763
764 def __init__(self, model_name: str, name: str) -> None:
765 self.model_name = model_name
766 self.name = name
767
768 def state_forwards(self, package_label: str, state: ProjectState) -> None:
769 state.remove_constraint(package_label, self.model_name_lower, self.name)
770
771 def database_forwards(
772 self,
773 package_label: str,
774 schema_editor: DatabaseSchemaEditor,
775 from_state: ProjectState,
776 to_state: ProjectState,
777 ) -> None:
778 model = to_state.models_registry.get_model(package_label, self.model_name)
779 from_model_state = from_state.models[package_label, self.model_name_lower]
780 constraint = from_model_state.get_constraint_by_name(self.name)
781 schema_editor.remove_constraint(model, constraint)
782
783 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
784 return (
785 self.__class__.__name__,
786 (),
787 {
788 "model_name": self.model_name,
789 "name": self.name,
790 },
791 )
792
793 def describe(self) -> str:
794 return f"Remove constraint {self.name} from model {self.model_name}"
795
796 @property
797 def migration_name_fragment(self) -> str:
798 return f"remove_{self.model_name_lower}_{self.name.lower()}"