Plain is headed towards 1.0! Subscribe for development updates →

  1from functools import cached_property
  2
  3from plain import models
  4from plain.models.migrations.operations.base import Operation
  5from plain.models.migrations.state import ModelState
  6from plain.models.migrations.utils import field_references, resolve_relation
  7
  8from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
  9
 10
 11def _check_for_duplicates(arg_name, objs):
 12    used_vals = set()
 13    for val in objs:
 14        if val in used_vals:
 15            raise ValueError(
 16                f"Found duplicate value {val} in CreateModel {arg_name} argument."
 17            )
 18        used_vals.add(val)
 19
 20
 21class ModelOperation(Operation):
 22    def __init__(self, name):
 23        self.name = name
 24
 25    @cached_property
 26    def name_lower(self):
 27        return self.name.lower()
 28
 29    def references_model(self, name, package_label):
 30        return name.lower() == self.name_lower
 31
 32    def reduce(self, operation, package_label):
 33        return super().reduce(operation, package_label) or self.can_reduce_through(
 34            operation, package_label
 35        )
 36
 37    def can_reduce_through(self, operation, package_label):
 38        return not operation.references_model(self.name, package_label)
 39
 40
 41class CreateModel(ModelOperation):
 42    """Create a model's table."""
 43
 44    serialization_expand_args = ["fields", "options", "managers"]
 45
 46    def __init__(self, name, fields, options=None, bases=None, managers=None):
 47        self.fields = fields
 48        self.options = options or {}
 49        self.bases = bases or (models.Model,)
 50        self.managers = managers or []
 51        super().__init__(name)
 52        # Sanity-check that there are no duplicated field names, bases, or
 53        # manager names
 54        _check_for_duplicates("fields", (name for name, _ in self.fields))
 55        _check_for_duplicates(
 56            "bases",
 57            (
 58                base._meta.label_lower
 59                if hasattr(base, "_meta")
 60                else base.lower()
 61                if isinstance(base, str)
 62                else base
 63                for base in self.bases
 64            ),
 65        )
 66        _check_for_duplicates("managers", (name for name, _ in self.managers))
 67
 68    def deconstruct(self):
 69        kwargs = {
 70            "name": self.name,
 71            "fields": self.fields,
 72        }
 73        if self.options:
 74            kwargs["options"] = self.options
 75        if self.bases and self.bases != (models.Model,):
 76            kwargs["bases"] = self.bases
 77        if self.managers and self.managers != [("objects", models.Manager())]:
 78            kwargs["managers"] = self.managers
 79        return (self.__class__.__qualname__, [], kwargs)
 80
 81    def state_forwards(self, package_label, state):
 82        state.add_model(
 83            ModelState(
 84                package_label,
 85                self.name,
 86                list(self.fields),
 87                dict(self.options),
 88                tuple(self.bases),
 89                list(self.managers),
 90            )
 91        )
 92
 93    def database_forwards(self, package_label, schema_editor, from_state, to_state):
 94        model = to_state.models_registry.get_model(package_label, self.name)
 95        if self.allow_migrate_model(schema_editor.connection, model):
 96            schema_editor.create_model(model)
 97
 98    def describe(self):
 99        return f"Create model {self.name}"
100
101    @property
102    def migration_name_fragment(self):
103        return self.name_lower
104
105    def references_model(self, name, package_label):
106        name_lower = name.lower()
107        if name_lower == self.name_lower:
108            return True
109
110        # Check we didn't inherit from the model
111        reference_model_tuple = (package_label, name_lower)
112        for base in self.bases:
113            if (
114                base is not models.Model
115                and isinstance(base, models.base.ModelBase | str)
116                and resolve_relation(base, package_label) == reference_model_tuple
117            ):
118                return True
119
120        # Check we have no FKs/M2Ms with it
121        for _name, field in self.fields:
122            if field_references(
123                (package_label, self.name_lower), field, reference_model_tuple
124            ):
125                return True
126        return False
127
128    def reduce(self, operation, package_label):
129        if (
130            isinstance(operation, DeleteModel)
131            and self.name_lower == operation.name_lower
132        ):
133            return []
134        elif (
135            isinstance(operation, RenameModel)
136            and self.name_lower == operation.old_name_lower
137        ):
138            return [
139                CreateModel(
140                    operation.new_name,
141                    fields=self.fields,
142                    options=self.options,
143                    bases=self.bases,
144                    managers=self.managers,
145                ),
146            ]
147        elif (
148            isinstance(operation, AlterModelOptions)
149            and self.name_lower == operation.name_lower
150        ):
151            options = {**self.options, **operation.options}
152            for key in operation.ALTER_OPTION_KEYS:
153                if key not in operation.options:
154                    options.pop(key, None)
155            return [
156                CreateModel(
157                    self.name,
158                    fields=self.fields,
159                    options=options,
160                    bases=self.bases,
161                    managers=self.managers,
162                ),
163            ]
164        elif (
165            isinstance(operation, AlterModelManagers)
166            and self.name_lower == operation.name_lower
167        ):
168            return [
169                CreateModel(
170                    self.name,
171                    fields=self.fields,
172                    options=self.options,
173                    bases=self.bases,
174                    managers=operation.managers,
175                ),
176            ]
177        elif (
178            isinstance(operation, FieldOperation)
179            and self.name_lower == operation.model_name_lower
180        ):
181            if isinstance(operation, AddField):
182                return [
183                    CreateModel(
184                        self.name,
185                        fields=self.fields + [(operation.name, operation.field)],
186                        options=self.options,
187                        bases=self.bases,
188                        managers=self.managers,
189                    ),
190                ]
191            elif isinstance(operation, AlterField):
192                return [
193                    CreateModel(
194                        self.name,
195                        fields=[
196                            (n, operation.field if n == operation.name else v)
197                            for n, v in self.fields
198                        ],
199                        options=self.options,
200                        bases=self.bases,
201                        managers=self.managers,
202                    ),
203                ]
204            elif isinstance(operation, RemoveField):
205                options = self.options.copy()
206
207                return [
208                    CreateModel(
209                        self.name,
210                        fields=[
211                            (n, v)
212                            for n, v in self.fields
213                            if n.lower() != operation.name_lower
214                        ],
215                        options=options,
216                        bases=self.bases,
217                        managers=self.managers,
218                    ),
219                ]
220            elif isinstance(operation, RenameField):
221                options = self.options.copy()
222
223                return [
224                    CreateModel(
225                        self.name,
226                        fields=[
227                            (operation.new_name if n == operation.old_name else n, v)
228                            for n, v in self.fields
229                        ],
230                        options=options,
231                        bases=self.bases,
232                        managers=self.managers,
233                    ),
234                ]
235        return super().reduce(operation, package_label)
236
237
238class DeleteModel(ModelOperation):
239    """Drop a model's table."""
240
241    def deconstruct(self):
242        kwargs = {
243            "name": self.name,
244        }
245        return (self.__class__.__qualname__, [], kwargs)
246
247    def state_forwards(self, package_label, state):
248        state.remove_model(package_label, self.name_lower)
249
250    def database_forwards(self, package_label, schema_editor, from_state, to_state):
251        model = from_state.models_registry.get_model(package_label, self.name)
252        if self.allow_migrate_model(schema_editor.connection, model):
253            schema_editor.delete_model(model)
254
255    def references_model(self, name, package_label):
256        # The deleted model could be referencing the specified model through
257        # related fields.
258        return True
259
260    def describe(self):
261        return f"Delete model {self.name}"
262
263    @property
264    def migration_name_fragment(self):
265        return f"delete_{self.name_lower}"
266
267
268class RenameModel(ModelOperation):
269    """Rename a model."""
270
271    def __init__(self, old_name, new_name):
272        self.old_name = old_name
273        self.new_name = new_name
274        super().__init__(old_name)
275
276    @cached_property
277    def old_name_lower(self):
278        return self.old_name.lower()
279
280    @cached_property
281    def new_name_lower(self):
282        return self.new_name.lower()
283
284    def deconstruct(self):
285        kwargs = {
286            "old_name": self.old_name,
287            "new_name": self.new_name,
288        }
289        return (self.__class__.__qualname__, [], kwargs)
290
291    def state_forwards(self, package_label, state):
292        state.rename_model(package_label, self.old_name, self.new_name)
293
294    def database_forwards(self, package_label, schema_editor, from_state, to_state):
295        new_model = to_state.models_registry.get_model(package_label, self.new_name)
296        if self.allow_migrate_model(schema_editor.connection, new_model):
297            old_model = from_state.models_registry.get_model(
298                package_label, self.old_name
299            )
300            # Move the main table
301            schema_editor.alter_db_table(
302                new_model,
303                old_model._meta.db_table,
304                new_model._meta.db_table,
305            )
306            # Alter the fields pointing to us
307            for related_object in old_model._meta.related_objects:
308                if related_object.related_model == old_model:
309                    model = new_model
310                    related_key = (package_label, self.new_name_lower)
311                else:
312                    model = related_object.related_model
313                    related_key = (
314                        related_object.related_model._meta.package_label,
315                        related_object.related_model._meta.model_name,
316                    )
317                to_field = to_state.models_registry.get_model(
318                    *related_key
319                )._meta.get_field(related_object.field.name)
320                schema_editor.alter_field(
321                    model,
322                    related_object.field,
323                    to_field,
324                )
325
326    def references_model(self, name, package_label):
327        return (
328            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
329        )
330
331    def describe(self):
332        return f"Rename model {self.old_name} to {self.new_name}"
333
334    @property
335    def migration_name_fragment(self):
336        return f"rename_{self.old_name_lower}_{self.new_name_lower}"
337
338    def reduce(self, operation, package_label):
339        if (
340            isinstance(operation, RenameModel)
341            and self.new_name_lower == operation.old_name_lower
342        ):
343            return [
344                RenameModel(
345                    self.old_name,
346                    operation.new_name,
347                ),
348            ]
349        # Skip `ModelOperation.reduce` as we want to run `references_model`
350        # against self.new_name.
351        return super(ModelOperation, self).reduce(
352            operation, package_label
353        ) or not operation.references_model(self.new_name, package_label)
354
355
356class ModelOptionOperation(ModelOperation):
357    def reduce(self, operation, package_label):
358        if (
359            isinstance(operation, self.__class__ | DeleteModel)
360            and self.name_lower == operation.name_lower
361        ):
362            return [operation]
363        return super().reduce(operation, package_label)
364
365
366class AlterModelTable(ModelOptionOperation):
367    """Rename a model's table."""
368
369    def __init__(self, name, table):
370        self.table = table
371        super().__init__(name)
372
373    def deconstruct(self):
374        kwargs = {
375            "name": self.name,
376            "table": self.table,
377        }
378        return (self.__class__.__qualname__, [], kwargs)
379
380    def state_forwards(self, package_label, state):
381        state.alter_model_options(
382            package_label, self.name_lower, {"db_table": self.table}
383        )
384
385    def database_forwards(self, package_label, schema_editor, from_state, to_state):
386        new_model = to_state.models_registry.get_model(package_label, self.name)
387        if self.allow_migrate_model(schema_editor.connection, new_model):
388            old_model = from_state.models_registry.get_model(package_label, self.name)
389            schema_editor.alter_db_table(
390                new_model,
391                old_model._meta.db_table,
392                new_model._meta.db_table,
393            )
394
395    def describe(self):
396        return "Rename table for {} to {}".format(
397            self.name,
398            self.table if self.table is not None else "(default)",
399        )
400
401    @property
402    def migration_name_fragment(self):
403        return f"alter_{self.name_lower}_table"
404
405
406class AlterModelTableComment(ModelOptionOperation):
407    def __init__(self, name, table_comment):
408        self.table_comment = table_comment
409        super().__init__(name)
410
411    def deconstruct(self):
412        kwargs = {
413            "name": self.name,
414            "table_comment": self.table_comment,
415        }
416        return (self.__class__.__qualname__, [], kwargs)
417
418    def state_forwards(self, package_label, state):
419        state.alter_model_options(
420            package_label, self.name_lower, {"db_table_comment": self.table_comment}
421        )
422
423    def database_forwards(self, package_label, schema_editor, from_state, to_state):
424        new_model = to_state.models_registry.get_model(package_label, self.name)
425        if self.allow_migrate_model(schema_editor.connection, new_model):
426            old_model = from_state.models_registry.get_model(package_label, self.name)
427            schema_editor.alter_db_table_comment(
428                new_model,
429                old_model._meta.db_table_comment,
430                new_model._meta.db_table_comment,
431            )
432
433    def describe(self):
434        return f"Alter {self.name} table comment"
435
436    @property
437    def migration_name_fragment(self):
438        return f"alter_{self.name_lower}_table_comment"
439
440
441class AlterModelOptions(ModelOptionOperation):
442    """
443    Set new model options that don't directly affect the database schema
444    (like ordering). Python code in migrations
445    may still need them.
446    """
447
448    # Model options we want to compare and preserve in an AlterModelOptions op
449    ALTER_OPTION_KEYS = [
450        "base_manager_name",
451        "default_manager_name",
452        "default_related_name",
453        "get_latest_by",
454        "ordering",
455    ]
456
457    def __init__(self, name, options):
458        self.options = options
459        super().__init__(name)
460
461    def deconstruct(self):
462        kwargs = {
463            "name": self.name,
464            "options": self.options,
465        }
466        return (self.__class__.__qualname__, [], kwargs)
467
468    def state_forwards(self, package_label, state):
469        state.alter_model_options(
470            package_label,
471            self.name_lower,
472            self.options,
473            self.ALTER_OPTION_KEYS,
474        )
475
476    def database_forwards(self, package_label, schema_editor, from_state, to_state):
477        pass
478
479    def describe(self):
480        return f"Change Meta options on {self.name}"
481
482    @property
483    def migration_name_fragment(self):
484        return f"alter_{self.name_lower}_options"
485
486
487class AlterModelManagers(ModelOptionOperation):
488    """Alter the model's managers."""
489
490    serialization_expand_args = ["managers"]
491
492    def __init__(self, name, managers):
493        self.managers = managers
494        super().__init__(name)
495
496    def deconstruct(self):
497        return (self.__class__.__qualname__, [self.name, self.managers], {})
498
499    def state_forwards(self, package_label, state):
500        state.alter_model_managers(package_label, self.name_lower, self.managers)
501
502    def database_forwards(self, package_label, schema_editor, from_state, to_state):
503        pass
504
505    def describe(self):
506        return f"Change managers on {self.name}"
507
508    @property
509    def migration_name_fragment(self):
510        return f"alter_{self.name_lower}_managers"
511
512
513class IndexOperation(Operation):
514    option_name = "indexes"
515
516    @cached_property
517    def model_name_lower(self):
518        return self.model_name.lower()
519
520
521class AddIndex(IndexOperation):
522    """Add an index on a model."""
523
524    def __init__(self, model_name, index):
525        self.model_name = model_name
526        if not index.name:
527            raise ValueError(
528                "Indexes passed to AddIndex operations require a name "
529                f"argument. {index!r} doesn't have one."
530            )
531        self.index = index
532
533    def state_forwards(self, package_label, state):
534        state.add_index(package_label, self.model_name_lower, self.index)
535
536    def database_forwards(self, package_label, schema_editor, from_state, to_state):
537        model = to_state.models_registry.get_model(package_label, self.model_name)
538        if self.allow_migrate_model(schema_editor.connection, model):
539            schema_editor.add_index(model, self.index)
540
541    def deconstruct(self):
542        kwargs = {
543            "model_name": self.model_name,
544            "index": self.index,
545        }
546        return (
547            self.__class__.__qualname__,
548            [],
549            kwargs,
550        )
551
552    def describe(self):
553        if self.index.expressions:
554            return "Create index {} on {} on model {}".format(
555                self.index.name,
556                ", ".join([str(expression) for expression in self.index.expressions]),
557                self.model_name,
558            )
559        return "Create index {} on field(s) {} of model {}".format(
560            self.index.name,
561            ", ".join(self.index.fields),
562            self.model_name,
563        )
564
565    @property
566    def migration_name_fragment(self):
567        return f"{self.model_name_lower}_{self.index.name.lower()}"
568
569
570class RemoveIndex(IndexOperation):
571    """Remove an index from a model."""
572
573    def __init__(self, model_name, name):
574        self.model_name = model_name
575        self.name = name
576
577    def state_forwards(self, package_label, state):
578        state.remove_index(package_label, self.model_name_lower, self.name)
579
580    def database_forwards(self, package_label, schema_editor, from_state, to_state):
581        model = from_state.models_registry.get_model(package_label, self.model_name)
582        if self.allow_migrate_model(schema_editor.connection, model):
583            from_model_state = from_state.models[package_label, self.model_name_lower]
584            index = from_model_state.get_index_by_name(self.name)
585            schema_editor.remove_index(model, index)
586
587    def deconstruct(self):
588        kwargs = {
589            "model_name": self.model_name,
590            "name": self.name,
591        }
592        return (
593            self.__class__.__qualname__,
594            [],
595            kwargs,
596        )
597
598    def describe(self):
599        return f"Remove index {self.name} from {self.model_name}"
600
601    @property
602    def migration_name_fragment(self):
603        return f"remove_{self.model_name_lower}_{self.name.lower()}"
604
605
606class RenameIndex(IndexOperation):
607    """Rename an index."""
608
609    def __init__(self, model_name, new_name, old_name=None, old_fields=None):
610        if not old_name and not old_fields:
611            raise ValueError(
612                "RenameIndex requires one of old_name and old_fields arguments to be "
613                "set."
614            )
615        if old_name and old_fields:
616            raise ValueError(
617                "RenameIndex.old_name and old_fields are mutually exclusive."
618            )
619        self.model_name = model_name
620        self.new_name = new_name
621        self.old_name = old_name
622        self.old_fields = old_fields
623
624    @cached_property
625    def old_name_lower(self):
626        return self.old_name.lower()
627
628    @cached_property
629    def new_name_lower(self):
630        return self.new_name.lower()
631
632    def deconstruct(self):
633        kwargs = {
634            "model_name": self.model_name,
635            "new_name": self.new_name,
636        }
637        if self.old_name:
638            kwargs["old_name"] = self.old_name
639        if self.old_fields:
640            kwargs["old_fields"] = self.old_fields
641        return (self.__class__.__qualname__, [], kwargs)
642
643    def state_forwards(self, package_label, state):
644        if self.old_fields:
645            state.add_index(
646                package_label,
647                self.model_name_lower,
648                models.Index(fields=self.old_fields, name=self.new_name),
649            )
650        else:
651            state.rename_index(
652                package_label, self.model_name_lower, self.old_name, self.new_name
653            )
654
655    def database_forwards(self, package_label, schema_editor, from_state, to_state):
656        model = to_state.models_registry.get_model(package_label, self.model_name)
657        if not self.allow_migrate_model(schema_editor.connection, model):
658            return
659
660        if self.old_fields:
661            from_model = from_state.models_registry.get_model(
662                package_label, self.model_name
663            )
664            columns = [
665                from_model._meta.get_field(field).column for field in self.old_fields
666            ]
667            matching_index_name = schema_editor._constraint_names(
668                from_model, column_names=columns, index=True
669            )
670            if len(matching_index_name) != 1:
671                raise ValueError(
672                    "Found wrong number ({}) of indexes for {}({}).".format(
673                        len(matching_index_name),
674                        from_model._meta.db_table,
675                        ", ".join(columns),
676                    )
677                )
678            old_index = models.Index(
679                fields=self.old_fields,
680                name=matching_index_name[0],
681            )
682        else:
683            from_model_state = from_state.models[package_label, self.model_name_lower]
684            old_index = from_model_state.get_index_by_name(self.old_name)
685        # Don't alter when the index name is not changed.
686        if old_index.name == self.new_name:
687            return
688
689        to_model_state = to_state.models[package_label, self.model_name_lower]
690        new_index = to_model_state.get_index_by_name(self.new_name)
691        schema_editor.rename_index(model, old_index, new_index)
692
693    def describe(self):
694        if self.old_name:
695            return (
696                f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
697            )
698        return (
699            f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
700            f"{self.new_name}"
701        )
702
703    @property
704    def migration_name_fragment(self):
705        if self.old_name:
706            return f"rename_{self.old_name_lower}_{self.new_name_lower}"
707        return "rename_{}_{}_{}".format(
708            self.model_name_lower,
709            "_".join(self.old_fields),
710            self.new_name_lower,
711        )
712
713    def reduce(self, operation, package_label):
714        if (
715            isinstance(operation, RenameIndex)
716            and self.model_name_lower == operation.model_name_lower
717            and operation.old_name
718            and self.new_name_lower == operation.old_name_lower
719        ):
720            return [
721                RenameIndex(
722                    self.model_name,
723                    new_name=operation.new_name,
724                    old_name=self.old_name,
725                    old_fields=self.old_fields,
726                )
727            ]
728        return super().reduce(operation, package_label)
729
730
731class AddConstraint(IndexOperation):
732    option_name = "constraints"
733
734    def __init__(self, model_name, constraint):
735        self.model_name = model_name
736        self.constraint = constraint
737
738    def state_forwards(self, package_label, state):
739        state.add_constraint(package_label, self.model_name_lower, self.constraint)
740
741    def database_forwards(self, package_label, schema_editor, from_state, to_state):
742        model = to_state.models_registry.get_model(package_label, self.model_name)
743        if self.allow_migrate_model(schema_editor.connection, model):
744            schema_editor.add_constraint(model, self.constraint)
745
746    def deconstruct(self):
747        return (
748            self.__class__.__name__,
749            [],
750            {
751                "model_name": self.model_name,
752                "constraint": self.constraint,
753            },
754        )
755
756    def describe(self):
757        return f"Create constraint {self.constraint.name} on model {self.model_name}"
758
759    @property
760    def migration_name_fragment(self):
761        return f"{self.model_name_lower}_{self.constraint.name.lower()}"
762
763
764class RemoveConstraint(IndexOperation):
765    option_name = "constraints"
766
767    def __init__(self, model_name, name):
768        self.model_name = model_name
769        self.name = name
770
771    def state_forwards(self, package_label, state):
772        state.remove_constraint(package_label, self.model_name_lower, self.name)
773
774    def database_forwards(self, package_label, schema_editor, from_state, to_state):
775        model = to_state.models_registry.get_model(package_label, self.model_name)
776        if self.allow_migrate_model(schema_editor.connection, model):
777            from_model_state = from_state.models[package_label, self.model_name_lower]
778            constraint = from_model_state.get_constraint_by_name(self.name)
779            schema_editor.remove_constraint(model, constraint)
780
781    def deconstruct(self):
782        return (
783            self.__class__.__name__,
784            [],
785            {
786                "model_name": self.model_name,
787                "name": self.name,
788            },
789        )
790
791    def describe(self):
792        return f"Remove constraint {self.name} from model {self.model_name}"
793
794    @property
795    def migration_name_fragment(self):
796        return f"remove_{self.model_name_lower}_{self.name.lower()}"