1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5
  6from plain import models
  7from plain.models.base import ModelBase
  8from plain.models.migrations.operations.base import Operation
  9from plain.models.migrations.state import ModelState
 10from plain.models.migrations.utils import field_references, resolve_relation
 11
 12from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 16    from plain.models.fields import Field
 17    from plain.models.migrations.state import ProjectState
 18
 19
 20def _check_for_duplicates(arg_name: str, objs: Any) -> None:
 21    used_vals = set()
 22    for val in objs:
 23        if val in used_vals:
 24            raise ValueError(
 25                f"Found duplicate value {val} in CreateModel {arg_name} argument."
 26            )
 27        used_vals.add(val)
 28
 29
 30class ModelOperation(Operation):
 31    def __init__(self, name: str) -> None:
 32        self.name = name
 33
 34    @cached_property
 35    def name_lower(self) -> str:
 36        return self.name.lower()
 37
 38    def references_model(self, name: str, package_label: str) -> bool:
 39        return name.lower() == self.name_lower
 40
 41    def reduce(
 42        self, operation: Operation, package_label: str
 43    ) -> bool | list[Operation]:
 44        return super().reduce(operation, package_label) or self.can_reduce_through(
 45            operation, package_label
 46        )
 47
 48    def can_reduce_through(self, operation: Operation, package_label: str) -> bool:
 49        return not operation.references_model(self.name, package_label)
 50
 51
 52class CreateModel(ModelOperation):
 53    """Create a model's table."""
 54
 55    serialization_expand_args = ["fields", "options"]
 56
 57    def __init__(
 58        self,
 59        name: str,
 60        fields: list[tuple[str, Field]],
 61        options: dict[str, Any] | None = None,
 62        bases: tuple[Any, ...] | None = None,
 63    ) -> None:
 64        self.fields = fields
 65        self.options = options or {}
 66        self.bases = bases or (models.Model,)
 67        super().__init__(name)
 68        # Sanity-check that there are no duplicated field names or bases
 69        _check_for_duplicates("fields", (name for name, _ in self.fields))
 70        _check_for_duplicates(
 71            "bases",
 72            (
 73                base.model_options.label_lower
 74                if not isinstance(base, str)
 75                and base is not models.Model
 76                and hasattr(base, "_model_meta")
 77                else base.lower()
 78                if isinstance(base, str)
 79                else base
 80                for base in self.bases
 81            ),
 82        )
 83
 84    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 85        kwargs: dict[str, Any] = {
 86            "name": self.name,
 87            "fields": self.fields,
 88        }
 89        if self.options:
 90            kwargs["options"] = self.options
 91        if self.bases and self.bases != (models.Model,):
 92            kwargs["bases"] = self.bases
 93        return (self.__class__.__qualname__, (), kwargs)
 94
 95    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 96        state.add_model(
 97            ModelState(
 98                package_label,
 99                self.name,
100                list(self.fields),
101                dict(self.options),
102                tuple(self.bases),
103            )
104        )
105
106    def database_forwards(
107        self,
108        package_label: str,
109        schema_editor: BaseDatabaseSchemaEditor,
110        from_state: ProjectState,
111        to_state: ProjectState,
112    ) -> None:
113        model = to_state.models_registry.get_model(package_label, self.name)
114        if self.allow_migrate_model(schema_editor.connection, model):
115            schema_editor.create_model(model)
116
117    def describe(self) -> str:
118        return f"Create model {self.name}"
119
120    @property
121    def migration_name_fragment(self) -> str:
122        return self.name_lower
123
124    def references_model(self, name: str, package_label: str) -> bool:
125        name_lower = name.lower()
126        if name_lower == self.name_lower:
127            return True
128
129        # Check we didn't inherit from the model
130        reference_model_tuple = (package_label, name_lower)
131        for base in self.bases:
132            if (
133                base is not models.Model
134                and isinstance(base, ModelBase | str)
135                and resolve_relation(base, package_label) == reference_model_tuple
136            ):
137                return True
138
139        # Check we have no FKs/M2Ms with it
140        for _name, field in self.fields:
141            if field_references(
142                (package_label, self.name_lower), field, reference_model_tuple
143            ):
144                return True
145        return False
146
147    def reduce(
148        self, operation: Operation, package_label: str
149    ) -> bool | list[Operation]:
150        if (
151            isinstance(operation, DeleteModel)
152            and self.name_lower == operation.name_lower
153        ):
154            return []
155        elif (
156            isinstance(operation, RenameModel)
157            and self.name_lower == operation.old_name_lower
158        ):
159            return [
160                CreateModel(
161                    operation.new_name,
162                    fields=self.fields,
163                    options=self.options,
164                    bases=self.bases,
165                ),
166            ]
167        elif (
168            isinstance(operation, AlterModelOptions)
169            and self.name_lower == operation.name_lower
170        ):
171            options = {**self.options, **operation.options}
172            for key in operation.ALTER_OPTION_KEYS:
173                if key not in operation.options:
174                    options.pop(key, None)
175            return [
176                CreateModel(
177                    self.name,
178                    fields=self.fields,
179                    options=options,
180                    bases=self.bases,
181                ),
182            ]
183        elif (
184            isinstance(operation, FieldOperation)
185            and self.name_lower == operation.model_name_lower
186        ):
187            if isinstance(operation, AddField):
188                assert operation.field is not None
189                return [
190                    CreateModel(
191                        self.name,
192                        fields=self.fields + [(operation.name, operation.field)],
193                        options=self.options,
194                        bases=self.bases,
195                    ),
196                ]
197            elif isinstance(operation, AlterField):
198                assert operation.field is not None
199                return [
200                    CreateModel(
201                        self.name,
202                        fields=[
203                            (n, operation.field if n == operation.name else v)
204                            for n, v in self.fields
205                        ],
206                        options=self.options,
207                        bases=self.bases,
208                    ),
209                ]
210            elif isinstance(operation, RemoveField):
211                options = self.options.copy()
212
213                return [
214                    CreateModel(
215                        self.name,
216                        fields=[
217                            (n, v)
218                            for n, v in self.fields
219                            if n.lower() != operation.name_lower
220                        ],
221                        options=options,
222                        bases=self.bases,
223                    ),
224                ]
225            elif isinstance(operation, RenameField):
226                options = self.options.copy()
227
228                return [
229                    CreateModel(
230                        self.name,
231                        fields=[
232                            (operation.new_name if n == operation.old_name else n, v)
233                            for n, v in self.fields
234                        ],
235                        options=options,
236                        bases=self.bases,
237                    ),
238                ]
239        return super().reduce(operation, package_label)
240
241
242class DeleteModel(ModelOperation):
243    """Drop a model's table."""
244
245    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
246        kwargs: dict[str, Any] = {
247            "name": self.name,
248        }
249        return (self.__class__.__qualname__, (), kwargs)
250
251    def state_forwards(self, package_label: str, state: ProjectState) -> None:
252        state.remove_model(package_label, self.name_lower)
253
254    def database_forwards(
255        self,
256        package_label: str,
257        schema_editor: BaseDatabaseSchemaEditor,
258        from_state: ProjectState,
259        to_state: ProjectState,
260    ) -> None:
261        model = from_state.models_registry.get_model(package_label, self.name)
262        if self.allow_migrate_model(schema_editor.connection, model):
263            schema_editor.delete_model(model)
264
265    def references_model(self, name: str, package_label: str) -> bool:
266        # The deleted model could be referencing the specified model through
267        # related fields.
268        return True
269
270    def describe(self) -> str:
271        return f"Delete model {self.name}"
272
273    @property
274    def migration_name_fragment(self) -> str:
275        return f"delete_{self.name_lower}"
276
277
278class RenameModel(ModelOperation):
279    """Rename a model."""
280
281    def __init__(self, old_name: str, new_name: str) -> None:
282        self.old_name = old_name
283        self.new_name = new_name
284        super().__init__(old_name)
285
286    @cached_property
287    def old_name_lower(self) -> str:
288        return self.old_name.lower()
289
290    @cached_property
291    def new_name_lower(self) -> str:
292        return self.new_name.lower()
293
294    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
295        kwargs: dict[str, Any] = {
296            "old_name": self.old_name,
297            "new_name": self.new_name,
298        }
299        return (self.__class__.__qualname__, (), kwargs)
300
301    def state_forwards(self, package_label: str, state: ProjectState) -> None:
302        state.rename_model(package_label, self.old_name, self.new_name)
303
304    def database_forwards(
305        self,
306        package_label: str,
307        schema_editor: BaseDatabaseSchemaEditor,
308        from_state: ProjectState,
309        to_state: ProjectState,
310    ) -> None:
311        new_model = to_state.models_registry.get_model(package_label, self.new_name)
312        if self.allow_migrate_model(schema_editor.connection, new_model):
313            old_model = from_state.models_registry.get_model(
314                package_label, self.old_name
315            )
316            # Move the main table
317            schema_editor.alter_db_table(
318                new_model,
319                old_model.model_options.db_table,
320                new_model.model_options.db_table,
321            )
322            # Alter the fields pointing to us
323            for related_object in old_model._model_meta.related_objects:
324                if related_object.related_model == old_model:
325                    model = new_model
326                    related_key = (package_label, self.new_name_lower)
327                else:
328                    model = related_object.related_model
329                    related_key = (
330                        related_object.related_model.model_options.package_label,
331                        related_object.related_model.model_options.model_name,
332                    )
333                to_field = to_state.models_registry.get_model(
334                    *related_key
335                )._model_meta.get_field(related_object.field.name)
336                schema_editor.alter_field(
337                    model,
338                    related_object.field,
339                    to_field,
340                )
341
342    def references_model(self, name: str, package_label: str) -> bool:
343        return (
344            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
345        )
346
347    def describe(self) -> str:
348        return f"Rename model {self.old_name} to {self.new_name}"
349
350    @property
351    def migration_name_fragment(self) -> str:
352        return f"rename_{self.old_name_lower}_{self.new_name_lower}"
353
354    def reduce(
355        self, operation: Operation, package_label: str
356    ) -> bool | list[Operation]:
357        if (
358            isinstance(operation, RenameModel)
359            and self.new_name_lower == operation.old_name_lower
360        ):
361            return [
362                RenameModel(
363                    self.old_name,
364                    operation.new_name,
365                ),
366            ]
367        # Skip `ModelOperation.reduce` as we want to run `references_model`
368        # against self.new_name.
369        return super(ModelOperation, self).reduce(
370            operation, package_label
371        ) or not operation.references_model(self.new_name, package_label)
372
373
374class ModelOptionOperation(ModelOperation):
375    def reduce(
376        self, operation: Operation, package_label: str
377    ) -> bool | list[Operation]:
378        # Use tuple syntax because self.__class__ is not compatible with union syntax in isinstance
379        if isinstance(operation, (self.__class__, DeleteModel)) and (  # noqa: UP038
380            self.name_lower == operation.name_lower
381        ):
382            return [operation]
383        return super().reduce(operation, package_label)
384
385
386class AlterModelTable(ModelOptionOperation):
387    """Rename a model's table."""
388
389    def __init__(self, name: str, table: str | None) -> None:
390        self.table = table
391        super().__init__(name)
392
393    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
394        kwargs: dict[str, Any] = {
395            "name": self.name,
396            "table": self.table,
397        }
398        return (self.__class__.__qualname__, (), kwargs)
399
400    def state_forwards(self, package_label: str, state: ProjectState) -> None:
401        state.alter_model_options(
402            package_label, self.name_lower, {"db_table": self.table}
403        )
404
405    def database_forwards(
406        self,
407        package_label: str,
408        schema_editor: BaseDatabaseSchemaEditor,
409        from_state: ProjectState,
410        to_state: ProjectState,
411    ) -> None:
412        new_model = to_state.models_registry.get_model(package_label, self.name)
413        if self.allow_migrate_model(schema_editor.connection, new_model):
414            old_model = from_state.models_registry.get_model(package_label, self.name)
415            schema_editor.alter_db_table(
416                new_model,
417                old_model.model_options.db_table,
418                new_model.model_options.db_table,
419            )
420
421    def describe(self) -> str:
422        return "Rename table for {} to {}".format(
423            self.name,
424            self.table if self.table is not None else "(default)",
425        )
426
427    @property
428    def migration_name_fragment(self) -> str:
429        return f"alter_{self.name_lower}_table"
430
431
432class AlterModelOptions(ModelOptionOperation):
433    """
434    Set new model options that don't directly affect the database schema
435    (like ordering). Python code in migrations
436    may still need them.
437    """
438
439    # Model options we want to compare and preserve in an AlterModelOptions op
440    ALTER_OPTION_KEYS = [
441        "ordering",
442    ]
443
444    def __init__(self, name: str, options: dict[str, Any]) -> None:
445        self.options = options
446        super().__init__(name)
447
448    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
449        kwargs: dict[str, Any] = {
450            "name": self.name,
451            "options": self.options,
452        }
453        return (self.__class__.__qualname__, (), kwargs)
454
455    def state_forwards(self, package_label: str, state: ProjectState) -> None:
456        state.alter_model_options(
457            package_label,
458            self.name_lower,
459            self.options,
460            self.ALTER_OPTION_KEYS,
461        )
462
463    def database_forwards(
464        self,
465        package_label: str,
466        schema_editor: BaseDatabaseSchemaEditor,
467        from_state: ProjectState,
468        to_state: ProjectState,
469    ) -> None:
470        pass
471
472    def describe(self) -> str:
473        return f"Change Meta options on {self.name}"
474
475    @property
476    def migration_name_fragment(self) -> str:
477        return f"alter_{self.name_lower}_options"
478
479
480class IndexOperation(Operation):
481    option_name = "indexes"
482    model_name: str  # Set by subclasses
483
484    @cached_property
485    def model_name_lower(self) -> str:
486        return self.model_name.lower()
487
488
489class AddIndex(IndexOperation):
490    """Add an index on a model."""
491
492    def __init__(self, model_name: str, index: Any) -> None:
493        self.model_name = model_name
494        if not index.name:
495            raise ValueError(
496                "Indexes passed to AddIndex operations require a name "
497                f"argument. {index!r} doesn't have one."
498            )
499        self.index = index
500
501    def state_forwards(self, package_label: str, state: ProjectState) -> None:
502        state.add_index(package_label, self.model_name_lower, self.index)
503
504    def database_forwards(
505        self,
506        package_label: str,
507        schema_editor: BaseDatabaseSchemaEditor,
508        from_state: ProjectState,
509        to_state: ProjectState,
510    ) -> None:
511        model = to_state.models_registry.get_model(package_label, self.model_name)
512        if self.allow_migrate_model(schema_editor.connection, model):
513            schema_editor.add_index(model, self.index)
514
515    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
516        kwargs: dict[str, Any] = {
517            "model_name": self.model_name,
518            "index": self.index,
519        }
520        return (
521            self.__class__.__qualname__,
522            (),
523            kwargs,
524        )
525
526    def describe(self) -> str:
527        if self.index.expressions:
528            return "Create index {} on {} on model {}".format(
529                self.index.name,
530                ", ".join([str(expression) for expression in self.index.expressions]),
531                self.model_name,
532            )
533        return "Create index {} on field(s) {} of model {}".format(
534            self.index.name,
535            ", ".join(self.index.fields),
536            self.model_name,
537        )
538
539    @property
540    def migration_name_fragment(self) -> str:
541        return f"{self.model_name_lower}_{self.index.name.lower()}"
542
543
544class RemoveIndex(IndexOperation):
545    """Remove an index from a model."""
546
547    def __init__(self, model_name: str, name: str) -> None:
548        self.model_name = model_name
549        self.name = name
550
551    def state_forwards(self, package_label: str, state: ProjectState) -> None:
552        state.remove_index(package_label, self.model_name_lower, self.name)
553
554    def database_forwards(
555        self,
556        package_label: str,
557        schema_editor: BaseDatabaseSchemaEditor,
558        from_state: ProjectState,
559        to_state: ProjectState,
560    ) -> None:
561        model = from_state.models_registry.get_model(package_label, self.model_name)
562        if self.allow_migrate_model(schema_editor.connection, model):
563            from_model_state = from_state.models[package_label, self.model_name_lower]
564            index = from_model_state.get_index_by_name(self.name)
565            schema_editor.remove_index(model, index)
566
567    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
568        kwargs: dict[str, Any] = {
569            "model_name": self.model_name,
570            "name": self.name,
571        }
572        return (
573            self.__class__.__qualname__,
574            (),
575            kwargs,
576        )
577
578    def describe(self) -> str:
579        return f"Remove index {self.name} from {self.model_name}"
580
581    @property
582    def migration_name_fragment(self) -> str:
583        return f"remove_{self.model_name_lower}_{self.name.lower()}"
584
585
586class RenameIndex(IndexOperation):
587    """Rename an index."""
588
589    def __init__(
590        self,
591        model_name: str,
592        new_name: str,
593        old_name: str | None = None,
594        old_fields: list[str] | tuple[str, ...] | None = None,
595    ) -> None:
596        if not old_name and not old_fields:
597            raise ValueError(
598                "RenameIndex requires one of old_name and old_fields arguments to be "
599                "set."
600            )
601        if old_name and old_fields:
602            raise ValueError(
603                "RenameIndex.old_name and old_fields are mutually exclusive."
604            )
605        self.model_name = model_name
606        self.new_name = new_name
607        self.old_name = old_name
608        self.old_fields = old_fields
609
610    @cached_property
611    def old_name_lower(self) -> str:
612        assert self.old_name is not None, "old_name is set during initialization"
613        return self.old_name.lower()
614
615    @cached_property
616    def new_name_lower(self) -> str:
617        return self.new_name.lower()
618
619    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
620        kwargs: dict[str, Any] = {
621            "model_name": self.model_name,
622            "new_name": self.new_name,
623        }
624        if self.old_name:
625            kwargs["old_name"] = self.old_name
626        if self.old_fields:
627            kwargs["old_fields"] = self.old_fields
628        return (self.__class__.__qualname__, (), kwargs)
629
630    def state_forwards(self, package_label: str, state: ProjectState) -> None:
631        if self.old_fields:
632            state.add_index(
633                package_label,
634                self.model_name_lower,
635                models.Index(fields=self.old_fields, name=self.new_name),
636            )
637        else:
638            assert self.old_name is not None
639            state.rename_index(
640                package_label,
641                self.model_name_lower,
642                self.old_name,
643                self.new_name,
644            )
645
646    def database_forwards(
647        self,
648        package_label: str,
649        schema_editor: BaseDatabaseSchemaEditor,
650        from_state: ProjectState,
651        to_state: ProjectState,
652    ) -> None:
653        model = to_state.models_registry.get_model(package_label, self.model_name)
654        if not self.allow_migrate_model(schema_editor.connection, model):
655            return None
656
657        if self.old_fields:
658            from_model = from_state.models_registry.get_model(
659                package_label, self.model_name
660            )
661            columns = [
662                from_model._model_meta.get_forward_field(field).column
663                for field in self.old_fields
664            ]
665            matching_index_name = schema_editor._constraint_names(
666                from_model, column_names=columns, index=True
667            )
668            if len(matching_index_name) != 1:
669                raise ValueError(
670                    "Found wrong number ({}) of indexes for {}({}).".format(
671                        len(matching_index_name),
672                        from_model.model_options.db_table,
673                        ", ".join(columns),
674                    )
675                )
676            old_index = models.Index(
677                fields=self.old_fields,
678                name=matching_index_name[0],
679            )
680        else:
681            from_model_state = from_state.models[package_label, self.model_name_lower]
682            assert self.old_name is not None
683            old_index = from_model_state.get_index_by_name(self.old_name)
684        # Don't alter when the index name is not changed.
685        if old_index.name == self.new_name:
686            return None
687
688        to_model_state = to_state.models[package_label, self.model_name_lower]
689        new_index = to_model_state.get_index_by_name(self.new_name)
690        schema_editor.rename_index(model, old_index, new_index)
691        return None
692
693    def describe(self) -> str:
694        if self.old_name:
695            return (
696                f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
697            )
698        return (
699            f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
700            f"{self.new_name}"
701        )
702
703    @property
704    def migration_name_fragment(self) -> str:
705        if self.old_name:
706            return f"rename_{self.old_name_lower}_{self.new_name_lower}"
707        assert self.old_fields is not None, "old_fields is set when old_name is None"
708        return "rename_{}_{}_{}".format(
709            self.model_name_lower,
710            "_".join(self.old_fields),
711            self.new_name_lower,
712        )
713
714    def reduce(
715        self, operation: Operation, package_label: str
716    ) -> bool | list[Operation]:
717        if (
718            isinstance(operation, RenameIndex)
719            and self.model_name_lower == operation.model_name_lower
720            and operation.old_name
721            and self.new_name_lower == operation.old_name_lower
722        ):
723            return [
724                RenameIndex(
725                    self.model_name,
726                    new_name=operation.new_name,
727                    old_name=self.old_name,
728                    old_fields=self.old_fields,
729                )
730            ]
731        return super().reduce(operation, package_label)
732
733
734class AddConstraint(IndexOperation):
735    option_name = "constraints"
736
737    def __init__(self, model_name: str, constraint: Any) -> None:
738        self.model_name = model_name
739        self.constraint = constraint
740
741    def state_forwards(self, package_label: str, state: ProjectState) -> None:
742        state.add_constraint(package_label, self.model_name_lower, self.constraint)
743
744    def database_forwards(
745        self,
746        package_label: str,
747        schema_editor: BaseDatabaseSchemaEditor,
748        from_state: ProjectState,
749        to_state: ProjectState,
750    ) -> None:
751        model = to_state.models_registry.get_model(package_label, self.model_name)
752        if self.allow_migrate_model(schema_editor.connection, model):
753            schema_editor.add_constraint(model, self.constraint)
754
755    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
756        return (
757            self.__class__.__name__,
758            (),
759            {
760                "model_name": self.model_name,
761                "constraint": self.constraint,
762            },
763        )
764
765    def describe(self) -> str:
766        return f"Create constraint {self.constraint.name} on model {self.model_name}"
767
768    @property
769    def migration_name_fragment(self) -> str:
770        return f"{self.model_name_lower}_{self.constraint.name.lower()}"
771
772
773class RemoveConstraint(IndexOperation):
774    option_name = "constraints"
775
776    def __init__(self, model_name: str, name: str) -> None:
777        self.model_name = model_name
778        self.name = name
779
780    def state_forwards(self, package_label: str, state: ProjectState) -> None:
781        state.remove_constraint(package_label, self.model_name_lower, self.name)
782
783    def database_forwards(
784        self,
785        package_label: str,
786        schema_editor: BaseDatabaseSchemaEditor,
787        from_state: ProjectState,
788        to_state: ProjectState,
789    ) -> None:
790        model = to_state.models_registry.get_model(package_label, self.model_name)
791        if self.allow_migrate_model(schema_editor.connection, model):
792            from_model_state = from_state.models[package_label, self.model_name_lower]
793            constraint = from_model_state.get_constraint_by_name(self.name)
794            schema_editor.remove_constraint(model, constraint)
795
796    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
797        return (
798            self.__class__.__name__,
799            (),
800            {
801                "model_name": self.model_name,
802                "name": self.name,
803            },
804        )
805
806    def describe(self) -> str:
807        return f"Remove constraint {self.name} from model {self.model_name}"
808
809    @property
810    def migration_name_fragment(self) -> str:
811        return f"remove_{self.model_name_lower}_{self.name.lower()}"