Plain is headed towards 1.0! Subscribe for development updates →

  1from functools import cached_property
  2
  3from plain import models
  4from plain.models.migrations.operations.base import Operation
  5from plain.models.migrations.state import ModelState
  6from plain.models.migrations.utils import field_references, resolve_relation
  7
  8from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
  9
 10
 11def _check_for_duplicates(arg_name, objs):
 12    used_vals = set()
 13    for val in objs:
 14        if val in used_vals:
 15            raise ValueError(
 16                f"Found duplicate value {val} in CreateModel {arg_name} argument."
 17            )
 18        used_vals.add(val)
 19
 20
 21class ModelOperation(Operation):
 22    def __init__(self, name):
 23        self.name = name
 24
 25    @cached_property
 26    def name_lower(self):
 27        return self.name.lower()
 28
 29    def references_model(self, name, package_label):
 30        return name.lower() == self.name_lower
 31
 32    def reduce(self, operation, package_label):
 33        return super().reduce(operation, package_label) or self.can_reduce_through(
 34            operation, package_label
 35        )
 36
 37    def can_reduce_through(self, operation, package_label):
 38        return not operation.references_model(self.name, package_label)
 39
 40
 41class CreateModel(ModelOperation):
 42    """Create a model's table."""
 43
 44    serialization_expand_args = ["fields", "options"]
 45
 46    def __init__(self, name, fields, options=None, bases=None):
 47        self.fields = fields
 48        self.options = options or {}
 49        self.bases = bases or (models.Model,)
 50        super().__init__(name)
 51        # Sanity-check that there are no duplicated field names or bases
 52        _check_for_duplicates("fields", (name for name, _ in self.fields))
 53        _check_for_duplicates(
 54            "bases",
 55            (
 56                base._meta.label_lower
 57                if hasattr(base, "_meta")
 58                else base.lower()
 59                if isinstance(base, str)
 60                else base
 61                for base in self.bases
 62            ),
 63        )
 64
 65    def deconstruct(self):
 66        kwargs = {
 67            "name": self.name,
 68            "fields": self.fields,
 69        }
 70        if self.options:
 71            kwargs["options"] = self.options
 72        if self.bases and self.bases != (models.Model,):
 73            kwargs["bases"] = self.bases
 74        return (self.__class__.__qualname__, [], kwargs)
 75
 76    def state_forwards(self, package_label, state):
 77        state.add_model(
 78            ModelState(
 79                package_label,
 80                self.name,
 81                list(self.fields),
 82                dict(self.options),
 83                tuple(self.bases),
 84            )
 85        )
 86
 87    def database_forwards(self, package_label, schema_editor, from_state, to_state):
 88        model = to_state.models_registry.get_model(package_label, self.name)
 89        if self.allow_migrate_model(schema_editor.connection, model):
 90            schema_editor.create_model(model)
 91
 92    def describe(self):
 93        return f"Create model {self.name}"
 94
 95    @property
 96    def migration_name_fragment(self):
 97        return self.name_lower
 98
 99    def references_model(self, name, package_label):
100        name_lower = name.lower()
101        if name_lower == self.name_lower:
102            return True
103
104        # Check we didn't inherit from the model
105        reference_model_tuple = (package_label, name_lower)
106        for base in self.bases:
107            if (
108                base is not models.Model
109                and isinstance(base, models.base.ModelBase | str)
110                and resolve_relation(base, package_label) == reference_model_tuple
111            ):
112                return True
113
114        # Check we have no FKs/M2Ms with it
115        for _name, field in self.fields:
116            if field_references(
117                (package_label, self.name_lower), field, reference_model_tuple
118            ):
119                return True
120        return False
121
122    def reduce(self, operation, package_label):
123        if (
124            isinstance(operation, DeleteModel)
125            and self.name_lower == operation.name_lower
126        ):
127            return []
128        elif (
129            isinstance(operation, RenameModel)
130            and self.name_lower == operation.old_name_lower
131        ):
132            return [
133                CreateModel(
134                    operation.new_name,
135                    fields=self.fields,
136                    options=self.options,
137                    bases=self.bases,
138                ),
139            ]
140        elif (
141            isinstance(operation, AlterModelOptions)
142            and self.name_lower == operation.name_lower
143        ):
144            options = {**self.options, **operation.options}
145            for key in operation.ALTER_OPTION_KEYS:
146                if key not in operation.options:
147                    options.pop(key, None)
148            return [
149                CreateModel(
150                    self.name,
151                    fields=self.fields,
152                    options=options,
153                    bases=self.bases,
154                ),
155            ]
156        elif (
157            isinstance(operation, FieldOperation)
158            and self.name_lower == operation.model_name_lower
159        ):
160            if isinstance(operation, AddField):
161                return [
162                    CreateModel(
163                        self.name,
164                        fields=self.fields + [(operation.name, operation.field)],
165                        options=self.options,
166                        bases=self.bases,
167                    ),
168                ]
169            elif isinstance(operation, AlterField):
170                return [
171                    CreateModel(
172                        self.name,
173                        fields=[
174                            (n, operation.field if n == operation.name else v)
175                            for n, v in self.fields
176                        ],
177                        options=self.options,
178                        bases=self.bases,
179                    ),
180                ]
181            elif isinstance(operation, RemoveField):
182                options = self.options.copy()
183
184                return [
185                    CreateModel(
186                        self.name,
187                        fields=[
188                            (n, v)
189                            for n, v in self.fields
190                            if n.lower() != operation.name_lower
191                        ],
192                        options=options,
193                        bases=self.bases,
194                    ),
195                ]
196            elif isinstance(operation, RenameField):
197                options = self.options.copy()
198
199                return [
200                    CreateModel(
201                        self.name,
202                        fields=[
203                            (operation.new_name if n == operation.old_name else n, v)
204                            for n, v in self.fields
205                        ],
206                        options=options,
207                        bases=self.bases,
208                    ),
209                ]
210        return super().reduce(operation, package_label)
211
212
213class DeleteModel(ModelOperation):
214    """Drop a model's table."""
215
216    def deconstruct(self):
217        kwargs = {
218            "name": self.name,
219        }
220        return (self.__class__.__qualname__, [], kwargs)
221
222    def state_forwards(self, package_label, state):
223        state.remove_model(package_label, self.name_lower)
224
225    def database_forwards(self, package_label, schema_editor, from_state, to_state):
226        model = from_state.models_registry.get_model(package_label, self.name)
227        if self.allow_migrate_model(schema_editor.connection, model):
228            schema_editor.delete_model(model)
229
230    def references_model(self, name, package_label):
231        # The deleted model could be referencing the specified model through
232        # related fields.
233        return True
234
235    def describe(self):
236        return f"Delete model {self.name}"
237
238    @property
239    def migration_name_fragment(self):
240        return f"delete_{self.name_lower}"
241
242
243class RenameModel(ModelOperation):
244    """Rename a model."""
245
246    def __init__(self, old_name, new_name):
247        self.old_name = old_name
248        self.new_name = new_name
249        super().__init__(old_name)
250
251    @cached_property
252    def old_name_lower(self):
253        return self.old_name.lower()
254
255    @cached_property
256    def new_name_lower(self):
257        return self.new_name.lower()
258
259    def deconstruct(self):
260        kwargs = {
261            "old_name": self.old_name,
262            "new_name": self.new_name,
263        }
264        return (self.__class__.__qualname__, [], kwargs)
265
266    def state_forwards(self, package_label, state):
267        state.rename_model(package_label, self.old_name, self.new_name)
268
269    def database_forwards(self, package_label, schema_editor, from_state, to_state):
270        new_model = to_state.models_registry.get_model(package_label, self.new_name)
271        if self.allow_migrate_model(schema_editor.connection, new_model):
272            old_model = from_state.models_registry.get_model(
273                package_label, self.old_name
274            )
275            # Move the main table
276            schema_editor.alter_db_table(
277                new_model,
278                old_model._meta.db_table,
279                new_model._meta.db_table,
280            )
281            # Alter the fields pointing to us
282            for related_object in old_model._meta.related_objects:
283                if related_object.related_model == old_model:
284                    model = new_model
285                    related_key = (package_label, self.new_name_lower)
286                else:
287                    model = related_object.related_model
288                    related_key = (
289                        related_object.related_model._meta.package_label,
290                        related_object.related_model._meta.model_name,
291                    )
292                to_field = to_state.models_registry.get_model(
293                    *related_key
294                )._meta.get_field(related_object.field.name)
295                schema_editor.alter_field(
296                    model,
297                    related_object.field,
298                    to_field,
299                )
300
301    def references_model(self, name, package_label):
302        return (
303            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
304        )
305
306    def describe(self):
307        return f"Rename model {self.old_name} to {self.new_name}"
308
309    @property
310    def migration_name_fragment(self):
311        return f"rename_{self.old_name_lower}_{self.new_name_lower}"
312
313    def reduce(self, operation, package_label):
314        if (
315            isinstance(operation, RenameModel)
316            and self.new_name_lower == operation.old_name_lower
317        ):
318            return [
319                RenameModel(
320                    self.old_name,
321                    operation.new_name,
322                ),
323            ]
324        # Skip `ModelOperation.reduce` as we want to run `references_model`
325        # against self.new_name.
326        return super(ModelOperation, self).reduce(
327            operation, package_label
328        ) or not operation.references_model(self.new_name, package_label)
329
330
331class ModelOptionOperation(ModelOperation):
332    def reduce(self, operation, package_label):
333        if (
334            isinstance(operation, self.__class__ | DeleteModel)
335            and self.name_lower == operation.name_lower
336        ):
337            return [operation]
338        return super().reduce(operation, package_label)
339
340
341class AlterModelTable(ModelOptionOperation):
342    """Rename a model's table."""
343
344    def __init__(self, name, table):
345        self.table = table
346        super().__init__(name)
347
348    def deconstruct(self):
349        kwargs = {
350            "name": self.name,
351            "table": self.table,
352        }
353        return (self.__class__.__qualname__, [], kwargs)
354
355    def state_forwards(self, package_label, state):
356        state.alter_model_options(
357            package_label, self.name_lower, {"db_table": self.table}
358        )
359
360    def database_forwards(self, package_label, schema_editor, from_state, to_state):
361        new_model = to_state.models_registry.get_model(package_label, self.name)
362        if self.allow_migrate_model(schema_editor.connection, new_model):
363            old_model = from_state.models_registry.get_model(package_label, self.name)
364            schema_editor.alter_db_table(
365                new_model,
366                old_model._meta.db_table,
367                new_model._meta.db_table,
368            )
369
370    def describe(self):
371        return "Rename table for {} to {}".format(
372            self.name,
373            self.table if self.table is not None else "(default)",
374        )
375
376    @property
377    def migration_name_fragment(self):
378        return f"alter_{self.name_lower}_table"
379
380
381class AlterModelTableComment(ModelOptionOperation):
382    def __init__(self, name, table_comment):
383        self.table_comment = table_comment
384        super().__init__(name)
385
386    def deconstruct(self):
387        kwargs = {
388            "name": self.name,
389            "table_comment": self.table_comment,
390        }
391        return (self.__class__.__qualname__, [], kwargs)
392
393    def state_forwards(self, package_label, state):
394        state.alter_model_options(
395            package_label, self.name_lower, {"db_table_comment": self.table_comment}
396        )
397
398    def database_forwards(self, package_label, schema_editor, from_state, to_state):
399        new_model = to_state.models_registry.get_model(package_label, self.name)
400        if self.allow_migrate_model(schema_editor.connection, new_model):
401            old_model = from_state.models_registry.get_model(package_label, self.name)
402            schema_editor.alter_db_table_comment(
403                new_model,
404                old_model._meta.db_table_comment,
405                new_model._meta.db_table_comment,
406            )
407
408    def describe(self):
409        return f"Alter {self.name} table comment"
410
411    @property
412    def migration_name_fragment(self):
413        return f"alter_{self.name_lower}_table_comment"
414
415
416class AlterModelOptions(ModelOptionOperation):
417    """
418    Set new model options that don't directly affect the database schema
419    (like ordering). Python code in migrations
420    may still need them.
421    """
422
423    # Model options we want to compare and preserve in an AlterModelOptions op
424    ALTER_OPTION_KEYS = [
425        "ordering",
426    ]
427
428    def __init__(self, name, options):
429        self.options = options
430        super().__init__(name)
431
432    def deconstruct(self):
433        kwargs = {
434            "name": self.name,
435            "options": self.options,
436        }
437        return (self.__class__.__qualname__, [], kwargs)
438
439    def state_forwards(self, package_label, state):
440        state.alter_model_options(
441            package_label,
442            self.name_lower,
443            self.options,
444            self.ALTER_OPTION_KEYS,
445        )
446
447    def database_forwards(self, package_label, schema_editor, from_state, to_state):
448        pass
449
450    def describe(self):
451        return f"Change Meta options on {self.name}"
452
453    @property
454    def migration_name_fragment(self):
455        return f"alter_{self.name_lower}_options"
456
457
458class IndexOperation(Operation):
459    option_name = "indexes"
460
461    @cached_property
462    def model_name_lower(self):
463        return self.model_name.lower()
464
465
466class AddIndex(IndexOperation):
467    """Add an index on a model."""
468
469    def __init__(self, model_name, index):
470        self.model_name = model_name
471        if not index.name:
472            raise ValueError(
473                "Indexes passed to AddIndex operations require a name "
474                f"argument. {index!r} doesn't have one."
475            )
476        self.index = index
477
478    def state_forwards(self, package_label, state):
479        state.add_index(package_label, self.model_name_lower, self.index)
480
481    def database_forwards(self, package_label, schema_editor, from_state, to_state):
482        model = to_state.models_registry.get_model(package_label, self.model_name)
483        if self.allow_migrate_model(schema_editor.connection, model):
484            schema_editor.add_index(model, self.index)
485
486    def deconstruct(self):
487        kwargs = {
488            "model_name": self.model_name,
489            "index": self.index,
490        }
491        return (
492            self.__class__.__qualname__,
493            [],
494            kwargs,
495        )
496
497    def describe(self):
498        if self.index.expressions:
499            return "Create index {} on {} on model {}".format(
500                self.index.name,
501                ", ".join([str(expression) for expression in self.index.expressions]),
502                self.model_name,
503            )
504        return "Create index {} on field(s) {} of model {}".format(
505            self.index.name,
506            ", ".join(self.index.fields),
507            self.model_name,
508        )
509
510    @property
511    def migration_name_fragment(self):
512        return f"{self.model_name_lower}_{self.index.name.lower()}"
513
514
515class RemoveIndex(IndexOperation):
516    """Remove an index from a model."""
517
518    def __init__(self, model_name, name):
519        self.model_name = model_name
520        self.name = name
521
522    def state_forwards(self, package_label, state):
523        state.remove_index(package_label, self.model_name_lower, self.name)
524
525    def database_forwards(self, package_label, schema_editor, from_state, to_state):
526        model = from_state.models_registry.get_model(package_label, self.model_name)
527        if self.allow_migrate_model(schema_editor.connection, model):
528            from_model_state = from_state.models[package_label, self.model_name_lower]
529            index = from_model_state.get_index_by_name(self.name)
530            schema_editor.remove_index(model, index)
531
532    def deconstruct(self):
533        kwargs = {
534            "model_name": self.model_name,
535            "name": self.name,
536        }
537        return (
538            self.__class__.__qualname__,
539            [],
540            kwargs,
541        )
542
543    def describe(self):
544        return f"Remove index {self.name} from {self.model_name}"
545
546    @property
547    def migration_name_fragment(self):
548        return f"remove_{self.model_name_lower}_{self.name.lower()}"
549
550
551class RenameIndex(IndexOperation):
552    """Rename an index."""
553
554    def __init__(self, model_name, new_name, old_name=None, old_fields=None):
555        if not old_name and not old_fields:
556            raise ValueError(
557                "RenameIndex requires one of old_name and old_fields arguments to be "
558                "set."
559            )
560        if old_name and old_fields:
561            raise ValueError(
562                "RenameIndex.old_name and old_fields are mutually exclusive."
563            )
564        self.model_name = model_name
565        self.new_name = new_name
566        self.old_name = old_name
567        self.old_fields = old_fields
568
569    @cached_property
570    def old_name_lower(self):
571        return self.old_name.lower()
572
573    @cached_property
574    def new_name_lower(self):
575        return self.new_name.lower()
576
577    def deconstruct(self):
578        kwargs = {
579            "model_name": self.model_name,
580            "new_name": self.new_name,
581        }
582        if self.old_name:
583            kwargs["old_name"] = self.old_name
584        if self.old_fields:
585            kwargs["old_fields"] = self.old_fields
586        return (self.__class__.__qualname__, [], kwargs)
587
588    def state_forwards(self, package_label, state):
589        if self.old_fields:
590            state.add_index(
591                package_label,
592                self.model_name_lower,
593                models.Index(fields=self.old_fields, name=self.new_name),
594            )
595        else:
596            state.rename_index(
597                package_label, self.model_name_lower, self.old_name, self.new_name
598            )
599
600    def database_forwards(self, package_label, schema_editor, from_state, to_state):
601        model = to_state.models_registry.get_model(package_label, self.model_name)
602        if not self.allow_migrate_model(schema_editor.connection, model):
603            return
604
605        if self.old_fields:
606            from_model = from_state.models_registry.get_model(
607                package_label, self.model_name
608            )
609            columns = [
610                from_model._meta.get_field(field).column for field in self.old_fields
611            ]
612            matching_index_name = schema_editor._constraint_names(
613                from_model, column_names=columns, index=True
614            )
615            if len(matching_index_name) != 1:
616                raise ValueError(
617                    "Found wrong number ({}) of indexes for {}({}).".format(
618                        len(matching_index_name),
619                        from_model._meta.db_table,
620                        ", ".join(columns),
621                    )
622                )
623            old_index = models.Index(
624                fields=self.old_fields,
625                name=matching_index_name[0],
626            )
627        else:
628            from_model_state = from_state.models[package_label, self.model_name_lower]
629            old_index = from_model_state.get_index_by_name(self.old_name)
630        # Don't alter when the index name is not changed.
631        if old_index.name == self.new_name:
632            return
633
634        to_model_state = to_state.models[package_label, self.model_name_lower]
635        new_index = to_model_state.get_index_by_name(self.new_name)
636        schema_editor.rename_index(model, old_index, new_index)
637
638    def describe(self):
639        if self.old_name:
640            return (
641                f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
642            )
643        return (
644            f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
645            f"{self.new_name}"
646        )
647
648    @property
649    def migration_name_fragment(self):
650        if self.old_name:
651            return f"rename_{self.old_name_lower}_{self.new_name_lower}"
652        return "rename_{}_{}_{}".format(
653            self.model_name_lower,
654            "_".join(self.old_fields),
655            self.new_name_lower,
656        )
657
658    def reduce(self, operation, package_label):
659        if (
660            isinstance(operation, RenameIndex)
661            and self.model_name_lower == operation.model_name_lower
662            and operation.old_name
663            and self.new_name_lower == operation.old_name_lower
664        ):
665            return [
666                RenameIndex(
667                    self.model_name,
668                    new_name=operation.new_name,
669                    old_name=self.old_name,
670                    old_fields=self.old_fields,
671                )
672            ]
673        return super().reduce(operation, package_label)
674
675
676class AddConstraint(IndexOperation):
677    option_name = "constraints"
678
679    def __init__(self, model_name, constraint):
680        self.model_name = model_name
681        self.constraint = constraint
682
683    def state_forwards(self, package_label, state):
684        state.add_constraint(package_label, self.model_name_lower, self.constraint)
685
686    def database_forwards(self, package_label, schema_editor, from_state, to_state):
687        model = to_state.models_registry.get_model(package_label, self.model_name)
688        if self.allow_migrate_model(schema_editor.connection, model):
689            schema_editor.add_constraint(model, self.constraint)
690
691    def deconstruct(self):
692        return (
693            self.__class__.__name__,
694            [],
695            {
696                "model_name": self.model_name,
697                "constraint": self.constraint,
698            },
699        )
700
701    def describe(self):
702        return f"Create constraint {self.constraint.name} on model {self.model_name}"
703
704    @property
705    def migration_name_fragment(self):
706        return f"{self.model_name_lower}_{self.constraint.name.lower()}"
707
708
709class RemoveConstraint(IndexOperation):
710    option_name = "constraints"
711
712    def __init__(self, model_name, name):
713        self.model_name = model_name
714        self.name = name
715
716    def state_forwards(self, package_label, state):
717        state.remove_constraint(package_label, self.model_name_lower, self.name)
718
719    def database_forwards(self, package_label, schema_editor, from_state, to_state):
720        model = to_state.models_registry.get_model(package_label, self.model_name)
721        if self.allow_migrate_model(schema_editor.connection, model):
722            from_model_state = from_state.models[package_label, self.model_name_lower]
723            constraint = from_model_state.get_constraint_by_name(self.name)
724            schema_editor.remove_constraint(model, constraint)
725
726    def deconstruct(self):
727        return (
728            self.__class__.__name__,
729            [],
730            {
731                "model_name": self.model_name,
732                "name": self.name,
733            },
734        )
735
736    def describe(self):
737        return f"Remove constraint {self.name} from model {self.model_name}"
738
739    @property
740    def migration_name_fragment(self):
741        return f"remove_{self.model_name_lower}_{self.name.lower()}"