1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5
  6from plain import models
  7from plain.models.base import ModelBase
  8from plain.models.migrations.operations.base import Operation
  9from plain.models.migrations.state import ModelState
 10from plain.models.migrations.utils import field_references, resolve_relation
 11
 12from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
 13
 14if TYPE_CHECKING:
 15    from plain.models.fields import Field
 16    from plain.models.migrations.state import ProjectState
 17    from plain.models.postgres.schema import DatabaseSchemaEditor
 18
 19
 20def _check_for_duplicates(arg_name: str, objs: Any) -> None:
 21    used_vals = set()
 22    for val in objs:
 23        if val in used_vals:
 24            raise ValueError(
 25                f"Found duplicate value {val} in CreateModel {arg_name} argument."
 26            )
 27        used_vals.add(val)
 28
 29
 30class ModelOperation(Operation):
 31    def __init__(self, name: str) -> None:
 32        self.name = name
 33
 34    @cached_property
 35    def name_lower(self) -> str:
 36        return self.name.lower()
 37
 38    def references_model(self, name: str, package_label: str) -> bool:
 39        return name.lower() == self.name_lower
 40
 41    def reduce(
 42        self, operation: Operation, package_label: str
 43    ) -> bool | list[Operation]:
 44        return super().reduce(operation, package_label) or self.can_reduce_through(
 45            operation, package_label
 46        )
 47
 48    def can_reduce_through(self, operation: Operation, package_label: str) -> bool:
 49        return not operation.references_model(self.name, package_label)
 50
 51
 52class CreateModel(ModelOperation):
 53    """Create a model's table."""
 54
 55    serialization_expand_args = ["fields", "options"]
 56
 57    def __init__(
 58        self,
 59        name: str,
 60        fields: list[tuple[str, Field]],
 61        options: dict[str, Any] | None = None,
 62        bases: tuple[Any, ...] | None = None,
 63    ) -> None:
 64        self.fields = fields
 65        self.options = options or {}
 66        self.bases = bases or (models.Model,)
 67        super().__init__(name)
 68        # Sanity-check that there are no duplicated field names or bases
 69        _check_for_duplicates("fields", (name for name, _ in self.fields))
 70        _check_for_duplicates(
 71            "bases",
 72            (
 73                base.model_options.label_lower
 74                if not isinstance(base, str)
 75                and base is not models.Model
 76                and hasattr(base, "_model_meta")
 77                else base.lower()
 78                if isinstance(base, str)
 79                else base
 80                for base in self.bases
 81            ),
 82        )
 83
 84    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 85        kwargs: dict[str, Any] = {
 86            "name": self.name,
 87            "fields": self.fields,
 88        }
 89        if self.options:
 90            kwargs["options"] = self.options
 91        if self.bases and self.bases != (models.Model,):
 92            kwargs["bases"] = self.bases
 93        return (self.__class__.__qualname__, (), kwargs)
 94
 95    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 96        state.add_model(
 97            ModelState(
 98                package_label,
 99                self.name,
100                list(self.fields),
101                dict(self.options),
102                tuple(self.bases),
103            )
104        )
105
106    def database_forwards(
107        self,
108        package_label: str,
109        schema_editor: DatabaseSchemaEditor,
110        from_state: ProjectState,
111        to_state: ProjectState,
112    ) -> None:
113        model = to_state.models_registry.get_model(package_label, self.name)
114        schema_editor.create_model(model)
115
116    def describe(self) -> str:
117        return f"Create model {self.name}"
118
119    @property
120    def migration_name_fragment(self) -> str:
121        return self.name_lower
122
123    def references_model(self, name: str, package_label: str) -> bool:
124        name_lower = name.lower()
125        if name_lower == self.name_lower:
126            return True
127
128        # Check we didn't inherit from the model
129        reference_model_tuple = (package_label, name_lower)
130        for base in self.bases:
131            if (
132                base is not models.Model
133                and isinstance(base, ModelBase | str)
134                and resolve_relation(base, package_label) == reference_model_tuple
135            ):
136                return True
137
138        # Check we have no FKs/M2Ms with it
139        for _name, field in self.fields:
140            if field_references(
141                (package_label, self.name_lower), field, reference_model_tuple
142            ):
143                return True
144        return False
145
146    def reduce(
147        self, operation: Operation, package_label: str
148    ) -> bool | list[Operation]:
149        if (
150            isinstance(operation, DeleteModel)
151            and self.name_lower == operation.name_lower
152        ):
153            return []
154        elif (
155            isinstance(operation, RenameModel)
156            and self.name_lower == operation.old_name_lower
157        ):
158            return [
159                CreateModel(
160                    operation.new_name,
161                    fields=self.fields,
162                    options=self.options,
163                    bases=self.bases,
164                ),
165            ]
166        elif (
167            isinstance(operation, AlterModelOptions)
168            and self.name_lower == operation.name_lower
169        ):
170            options = {**self.options, **operation.options}
171            for key in operation.ALTER_OPTION_KEYS:
172                if key not in operation.options:
173                    options.pop(key, None)
174            return [
175                CreateModel(
176                    self.name,
177                    fields=self.fields,
178                    options=options,
179                    bases=self.bases,
180                ),
181            ]
182        elif (
183            isinstance(operation, FieldOperation)
184            and self.name_lower == operation.model_name_lower
185        ):
186            if isinstance(operation, AddField):
187                assert operation.field is not None
188                return [
189                    CreateModel(
190                        self.name,
191                        fields=self.fields + [(operation.name, operation.field)],
192                        options=self.options,
193                        bases=self.bases,
194                    ),
195                ]
196            elif isinstance(operation, AlterField):
197                assert operation.field is not None
198                return [
199                    CreateModel(
200                        self.name,
201                        fields=[
202                            (n, operation.field if n == operation.name else v)
203                            for n, v in self.fields
204                        ],
205                        options=self.options,
206                        bases=self.bases,
207                    ),
208                ]
209            elif isinstance(operation, RemoveField):
210                options = self.options.copy()
211
212                return [
213                    CreateModel(
214                        self.name,
215                        fields=[
216                            (n, v)
217                            for n, v in self.fields
218                            if n.lower() != operation.name_lower
219                        ],
220                        options=options,
221                        bases=self.bases,
222                    ),
223                ]
224            elif isinstance(operation, RenameField):
225                options = self.options.copy()
226
227                return [
228                    CreateModel(
229                        self.name,
230                        fields=[
231                            (operation.new_name if n == operation.old_name else n, v)
232                            for n, v in self.fields
233                        ],
234                        options=options,
235                        bases=self.bases,
236                    ),
237                ]
238        return super().reduce(operation, package_label)
239
240
241class DeleteModel(ModelOperation):
242    """Drop a model's table."""
243
244    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
245        kwargs: dict[str, Any] = {
246            "name": self.name,
247        }
248        return (self.__class__.__qualname__, (), kwargs)
249
250    def state_forwards(self, package_label: str, state: ProjectState) -> None:
251        state.remove_model(package_label, self.name_lower)
252
253    def database_forwards(
254        self,
255        package_label: str,
256        schema_editor: DatabaseSchemaEditor,
257        from_state: ProjectState,
258        to_state: ProjectState,
259    ) -> None:
260        model = from_state.models_registry.get_model(package_label, self.name)
261        schema_editor.delete_model(model)
262
263    def references_model(self, name: str, package_label: str) -> bool:
264        # The deleted model could be referencing the specified model through
265        # related fields.
266        return True
267
268    def describe(self) -> str:
269        return f"Delete model {self.name}"
270
271    @property
272    def migration_name_fragment(self) -> str:
273        return f"delete_{self.name_lower}"
274
275
276class RenameModel(ModelOperation):
277    """Rename a model."""
278
279    def __init__(self, old_name: str, new_name: str) -> None:
280        self.old_name = old_name
281        self.new_name = new_name
282        super().__init__(old_name)
283
284    @cached_property
285    def old_name_lower(self) -> str:
286        return self.old_name.lower()
287
288    @cached_property
289    def new_name_lower(self) -> str:
290        return self.new_name.lower()
291
292    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
293        kwargs: dict[str, Any] = {
294            "old_name": self.old_name,
295            "new_name": self.new_name,
296        }
297        return (self.__class__.__qualname__, (), kwargs)
298
299    def state_forwards(self, package_label: str, state: ProjectState) -> None:
300        state.rename_model(package_label, self.old_name, self.new_name)
301
302    def database_forwards(
303        self,
304        package_label: str,
305        schema_editor: DatabaseSchemaEditor,
306        from_state: ProjectState,
307        to_state: ProjectState,
308    ) -> None:
309        new_model = to_state.models_registry.get_model(package_label, self.new_name)
310        old_model = from_state.models_registry.get_model(package_label, self.old_name)
311        # Move the main table
312        schema_editor.alter_db_table(
313            new_model,
314            old_model.model_options.db_table,
315            new_model.model_options.db_table,
316        )
317        # Alter the fields pointing to us
318        for related_object in old_model._model_meta.related_objects:
319            if related_object.related_model == old_model:
320                model = new_model
321                related_key = (package_label, self.new_name_lower)
322            else:
323                model = related_object.related_model
324                related_key = (
325                    related_object.related_model.model_options.package_label,
326                    related_object.related_model.model_options.model_name,
327                )
328            to_field = to_state.models_registry.get_model(
329                *related_key
330            )._model_meta.get_field(related_object.field.name)
331            schema_editor.alter_field(
332                model,
333                related_object.field,
334                to_field,
335            )
336
337    def references_model(self, name: str, package_label: str) -> bool:
338        return (
339            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
340        )
341
342    def describe(self) -> str:
343        return f"Rename model {self.old_name} to {self.new_name}"
344
345    @property
346    def migration_name_fragment(self) -> str:
347        return f"rename_{self.old_name_lower}_{self.new_name_lower}"
348
349    def reduce(
350        self, operation: Operation, package_label: str
351    ) -> bool | list[Operation]:
352        if (
353            isinstance(operation, RenameModel)
354            and self.new_name_lower == operation.old_name_lower
355        ):
356            return [
357                RenameModel(
358                    self.old_name,
359                    operation.new_name,
360                ),
361            ]
362        # Skip `ModelOperation.reduce` as we want to run `references_model`
363        # against self.new_name.
364        return super(ModelOperation, self).reduce(
365            operation, package_label
366        ) or not operation.references_model(self.new_name, package_label)
367
368
369class ModelOptionOperation(ModelOperation):
370    def reduce(
371        self, operation: Operation, package_label: str
372    ) -> bool | list[Operation]:
373        # Use tuple syntax because self.__class__ is not compatible with union syntax in isinstance
374        if isinstance(operation, (self.__class__, DeleteModel)) and (  # noqa: UP038
375            self.name_lower == operation.name_lower
376        ):
377            return [operation]
378        return super().reduce(operation, package_label)
379
380
381class AlterModelTable(ModelOptionOperation):
382    """Rename a model's table."""
383
384    def __init__(self, name: str, table: str | None) -> None:
385        self.table = table
386        super().__init__(name)
387
388    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
389        kwargs: dict[str, Any] = {
390            "name": self.name,
391            "table": self.table,
392        }
393        return (self.__class__.__qualname__, (), kwargs)
394
395    def state_forwards(self, package_label: str, state: ProjectState) -> None:
396        state.alter_model_options(
397            package_label, self.name_lower, {"db_table": self.table}
398        )
399
400    def database_forwards(
401        self,
402        package_label: str,
403        schema_editor: DatabaseSchemaEditor,
404        from_state: ProjectState,
405        to_state: ProjectState,
406    ) -> None:
407        new_model = to_state.models_registry.get_model(package_label, self.name)
408        old_model = from_state.models_registry.get_model(package_label, self.name)
409        schema_editor.alter_db_table(
410            new_model,
411            old_model.model_options.db_table,
412            new_model.model_options.db_table,
413        )
414
415    def describe(self) -> str:
416        return "Rename table for {} to {}".format(
417            self.name,
418            self.table if self.table is not None else "(default)",
419        )
420
421    @property
422    def migration_name_fragment(self) -> str:
423        return f"alter_{self.name_lower}_table"
424
425
426class AlterModelOptions(ModelOptionOperation):
427    """
428    Set new model options that don't directly affect the database schema
429    (like ordering). Python code in migrations
430    may still need them.
431    """
432
433    # Model options we want to compare and preserve in an AlterModelOptions op
434    ALTER_OPTION_KEYS = [
435        "ordering",
436    ]
437
438    def __init__(self, name: str, options: dict[str, Any]) -> None:
439        self.options = options
440        super().__init__(name)
441
442    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
443        kwargs: dict[str, Any] = {
444            "name": self.name,
445            "options": self.options,
446        }
447        return (self.__class__.__qualname__, (), kwargs)
448
449    def state_forwards(self, package_label: str, state: ProjectState) -> None:
450        state.alter_model_options(
451            package_label,
452            self.name_lower,
453            self.options,
454            self.ALTER_OPTION_KEYS,
455        )
456
457    def database_forwards(
458        self,
459        package_label: str,
460        schema_editor: DatabaseSchemaEditor,
461        from_state: ProjectState,
462        to_state: ProjectState,
463    ) -> None:
464        pass
465
466    def describe(self) -> str:
467        return f"Change Meta options on {self.name}"
468
469    @property
470    def migration_name_fragment(self) -> str:
471        return f"alter_{self.name_lower}_options"
472
473
474class IndexOperation(Operation):
475    option_name = "indexes"
476    model_name: str  # Set by subclasses
477
478    @cached_property
479    def model_name_lower(self) -> str:
480        return self.model_name.lower()
481
482
483class AddIndex(IndexOperation):
484    """Add an index on a model."""
485
486    def __init__(self, model_name: str, index: Any) -> None:
487        self.model_name = model_name
488        if not index.name:
489            raise ValueError(
490                "Indexes passed to AddIndex operations require a name "
491                f"argument. {index!r} doesn't have one."
492            )
493        self.index = index
494
495    def state_forwards(self, package_label: str, state: ProjectState) -> None:
496        state.add_index(package_label, self.model_name_lower, self.index)
497
498    def database_forwards(
499        self,
500        package_label: str,
501        schema_editor: DatabaseSchemaEditor,
502        from_state: ProjectState,
503        to_state: ProjectState,
504    ) -> None:
505        model = to_state.models_registry.get_model(package_label, self.model_name)
506        schema_editor.add_index(model, self.index)
507
508    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
509        kwargs: dict[str, Any] = {
510            "model_name": self.model_name,
511            "index": self.index,
512        }
513        return (
514            self.__class__.__qualname__,
515            (),
516            kwargs,
517        )
518
519    def describe(self) -> str:
520        if self.index.expressions:
521            return "Create index {} on {} on model {}".format(
522                self.index.name,
523                ", ".join([str(expression) for expression in self.index.expressions]),
524                self.model_name,
525            )
526        return "Create index {} on field(s) {} of model {}".format(
527            self.index.name,
528            ", ".join(self.index.fields),
529            self.model_name,
530        )
531
532    @property
533    def migration_name_fragment(self) -> str:
534        return f"{self.model_name_lower}_{self.index.name.lower()}"
535
536
537class RemoveIndex(IndexOperation):
538    """Remove an index from a model."""
539
540    def __init__(self, model_name: str, name: str) -> None:
541        self.model_name = model_name
542        self.name = name
543
544    def state_forwards(self, package_label: str, state: ProjectState) -> None:
545        state.remove_index(package_label, self.model_name_lower, self.name)
546
547    def database_forwards(
548        self,
549        package_label: str,
550        schema_editor: DatabaseSchemaEditor,
551        from_state: ProjectState,
552        to_state: ProjectState,
553    ) -> None:
554        model = from_state.models_registry.get_model(package_label, self.model_name)
555        from_model_state = from_state.models[package_label, self.model_name_lower]
556        index = from_model_state.get_index_by_name(self.name)
557        schema_editor.remove_index(model, index)
558
559    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
560        kwargs: dict[str, Any] = {
561            "model_name": self.model_name,
562            "name": self.name,
563        }
564        return (
565            self.__class__.__qualname__,
566            (),
567            kwargs,
568        )
569
570    def describe(self) -> str:
571        return f"Remove index {self.name} from {self.model_name}"
572
573    @property
574    def migration_name_fragment(self) -> str:
575        return f"remove_{self.model_name_lower}_{self.name.lower()}"
576
577
578class RenameIndex(IndexOperation):
579    """Rename an index."""
580
581    def __init__(
582        self,
583        model_name: str,
584        new_name: str,
585        old_name: str | None = None,
586        old_fields: list[str] | tuple[str, ...] | None = None,
587    ) -> None:
588        if not old_name and not old_fields:
589            raise ValueError(
590                "RenameIndex requires one of old_name and old_fields arguments to be "
591                "set."
592            )
593        if old_name and old_fields:
594            raise ValueError(
595                "RenameIndex.old_name and old_fields are mutually exclusive."
596            )
597        self.model_name = model_name
598        self.new_name = new_name
599        self.old_name = old_name
600        self.old_fields = old_fields
601
602    @cached_property
603    def old_name_lower(self) -> str:
604        assert self.old_name is not None, "old_name is set during initialization"
605        return self.old_name.lower()
606
607    @cached_property
608    def new_name_lower(self) -> str:
609        return self.new_name.lower()
610
611    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
612        kwargs: dict[str, Any] = {
613            "model_name": self.model_name,
614            "new_name": self.new_name,
615        }
616        if self.old_name:
617            kwargs["old_name"] = self.old_name
618        if self.old_fields:
619            kwargs["old_fields"] = self.old_fields
620        return (self.__class__.__qualname__, (), kwargs)
621
622    def state_forwards(self, package_label: str, state: ProjectState) -> None:
623        if self.old_fields:
624            state.add_index(
625                package_label,
626                self.model_name_lower,
627                models.Index(fields=self.old_fields, name=self.new_name),
628            )
629        else:
630            assert self.old_name is not None
631            state.rename_index(
632                package_label,
633                self.model_name_lower,
634                self.old_name,
635                self.new_name,
636            )
637
638    def database_forwards(
639        self,
640        package_label: str,
641        schema_editor: DatabaseSchemaEditor,
642        from_state: ProjectState,
643        to_state: ProjectState,
644    ) -> None:
645        model = to_state.models_registry.get_model(package_label, self.model_name)
646        if self.old_fields:
647            from_model = from_state.models_registry.get_model(
648                package_label, self.model_name
649            )
650            columns = [
651                from_model._model_meta.get_forward_field(field).column
652                for field in self.old_fields
653            ]
654            matching_index_name = schema_editor._constraint_names(
655                from_model, column_names=columns, index=True
656            )
657            if len(matching_index_name) != 1:
658                raise ValueError(
659                    "Found wrong number ({}) of indexes for {}({}).".format(
660                        len(matching_index_name),
661                        from_model.model_options.db_table,
662                        ", ".join(columns),
663                    )
664                )
665            old_index = models.Index(
666                fields=self.old_fields,
667                name=matching_index_name[0],
668            )
669        else:
670            from_model_state = from_state.models[package_label, self.model_name_lower]
671            assert self.old_name is not None
672            old_index = from_model_state.get_index_by_name(self.old_name)
673        # Don't alter when the index name is not changed.
674        if old_index.name == self.new_name:
675            return None
676
677        to_model_state = to_state.models[package_label, self.model_name_lower]
678        new_index = to_model_state.get_index_by_name(self.new_name)
679        schema_editor.rename_index(model, old_index, new_index)
680        return None
681
682    def describe(self) -> str:
683        if self.old_name:
684            return (
685                f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
686            )
687        return (
688            f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
689            f"{self.new_name}"
690        )
691
692    @property
693    def migration_name_fragment(self) -> str:
694        if self.old_name:
695            return f"rename_{self.old_name_lower}_{self.new_name_lower}"
696        assert self.old_fields is not None, "old_fields is set when old_name is None"
697        return "rename_{}_{}_{}".format(
698            self.model_name_lower,
699            "_".join(self.old_fields),
700            self.new_name_lower,
701        )
702
703    def reduce(
704        self, operation: Operation, package_label: str
705    ) -> bool | list[Operation]:
706        if (
707            isinstance(operation, RenameIndex)
708            and self.model_name_lower == operation.model_name_lower
709            and operation.old_name
710            and self.new_name_lower == operation.old_name_lower
711        ):
712            return [
713                RenameIndex(
714                    self.model_name,
715                    new_name=operation.new_name,
716                    old_name=self.old_name,
717                    old_fields=self.old_fields,
718                )
719            ]
720        return super().reduce(operation, package_label)
721
722
723class AddConstraint(IndexOperation):
724    option_name = "constraints"
725
726    def __init__(self, model_name: str, constraint: Any) -> None:
727        self.model_name = model_name
728        self.constraint = constraint
729
730    def state_forwards(self, package_label: str, state: ProjectState) -> None:
731        state.add_constraint(package_label, self.model_name_lower, self.constraint)
732
733    def database_forwards(
734        self,
735        package_label: str,
736        schema_editor: DatabaseSchemaEditor,
737        from_state: ProjectState,
738        to_state: ProjectState,
739    ) -> None:
740        model = to_state.models_registry.get_model(package_label, self.model_name)
741        schema_editor.add_constraint(model, self.constraint)
742
743    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
744        return (
745            self.__class__.__name__,
746            (),
747            {
748                "model_name": self.model_name,
749                "constraint": self.constraint,
750            },
751        )
752
753    def describe(self) -> str:
754        return f"Create constraint {self.constraint.name} on model {self.model_name}"
755
756    @property
757    def migration_name_fragment(self) -> str:
758        return f"{self.model_name_lower}_{self.constraint.name.lower()}"
759
760
761class RemoveConstraint(IndexOperation):
762    option_name = "constraints"
763
764    def __init__(self, model_name: str, name: str) -> None:
765        self.model_name = model_name
766        self.name = name
767
768    def state_forwards(self, package_label: str, state: ProjectState) -> None:
769        state.remove_constraint(package_label, self.model_name_lower, self.name)
770
771    def database_forwards(
772        self,
773        package_label: str,
774        schema_editor: DatabaseSchemaEditor,
775        from_state: ProjectState,
776        to_state: ProjectState,
777    ) -> None:
778        model = to_state.models_registry.get_model(package_label, self.model_name)
779        from_model_state = from_state.models[package_label, self.model_name_lower]
780        constraint = from_model_state.get_constraint_by_name(self.name)
781        schema_editor.remove_constraint(model, constraint)
782
783    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
784        return (
785            self.__class__.__name__,
786            (),
787            {
788                "model_name": self.model_name,
789                "name": self.name,
790            },
791        )
792
793    def describe(self) -> str:
794        return f"Remove constraint {self.name} from model {self.model_name}"
795
796    @property
797    def migration_name_fragment(self) -> str:
798        return f"remove_{self.model_name_lower}_{self.name.lower()}"