1from functools import cached_property
2
3from plain import models
4from plain.models.migrations.operations.base import Operation
5from plain.models.migrations.state import ModelState
6from plain.models.migrations.utils import field_references, resolve_relation
7
8from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
9
10
11def _check_for_duplicates(arg_name, objs):
12 used_vals = set()
13 for val in objs:
14 if val in used_vals:
15 raise ValueError(
16 f"Found duplicate value {val} in CreateModel {arg_name} argument."
17 )
18 used_vals.add(val)
19
20
21class ModelOperation(Operation):
22 def __init__(self, name):
23 self.name = name
24
25 @cached_property
26 def name_lower(self):
27 return self.name.lower()
28
29 def references_model(self, name, package_label):
30 return name.lower() == self.name_lower
31
32 def reduce(self, operation, package_label):
33 return super().reduce(operation, package_label) or self.can_reduce_through(
34 operation, package_label
35 )
36
37 def can_reduce_through(self, operation, package_label):
38 return not operation.references_model(self.name, package_label)
39
40
41class CreateModel(ModelOperation):
42 """Create a model's table."""
43
44 serialization_expand_args = ["fields", "options", "managers"]
45
46 def __init__(self, name, fields, options=None, bases=None, managers=None):
47 self.fields = fields
48 self.options = options or {}
49 self.bases = bases or (models.Model,)
50 self.managers = managers or []
51 super().__init__(name)
52 # Sanity-check that there are no duplicated field names, bases, or
53 # manager names
54 _check_for_duplicates("fields", (name for name, _ in self.fields))
55 _check_for_duplicates(
56 "bases",
57 (
58 base._meta.label_lower
59 if hasattr(base, "_meta")
60 else base.lower()
61 if isinstance(base, str)
62 else base
63 for base in self.bases
64 ),
65 )
66 _check_for_duplicates("managers", (name for name, _ in self.managers))
67
68 def deconstruct(self):
69 kwargs = {
70 "name": self.name,
71 "fields": self.fields,
72 }
73 if self.options:
74 kwargs["options"] = self.options
75 if self.bases and self.bases != (models.Model,):
76 kwargs["bases"] = self.bases
77 if self.managers and self.managers != [("objects", models.Manager())]:
78 kwargs["managers"] = self.managers
79 return (self.__class__.__qualname__, [], kwargs)
80
81 def state_forwards(self, package_label, state):
82 state.add_model(
83 ModelState(
84 package_label,
85 self.name,
86 list(self.fields),
87 dict(self.options),
88 tuple(self.bases),
89 list(self.managers),
90 )
91 )
92
93 def database_forwards(self, package_label, schema_editor, from_state, to_state):
94 model = to_state.models_registry.get_model(package_label, self.name)
95 if self.allow_migrate_model(schema_editor.connection, model):
96 schema_editor.create_model(model)
97
98 def describe(self):
99 return f"Create model {self.name}"
100
101 @property
102 def migration_name_fragment(self):
103 return self.name_lower
104
105 def references_model(self, name, package_label):
106 name_lower = name.lower()
107 if name_lower == self.name_lower:
108 return True
109
110 # Check we didn't inherit from the model
111 reference_model_tuple = (package_label, name_lower)
112 for base in self.bases:
113 if (
114 base is not models.Model
115 and isinstance(base, models.base.ModelBase | str)
116 and resolve_relation(base, package_label) == reference_model_tuple
117 ):
118 return True
119
120 # Check we have no FKs/M2Ms with it
121 for _name, field in self.fields:
122 if field_references(
123 (package_label, self.name_lower), field, reference_model_tuple
124 ):
125 return True
126 return False
127
128 def reduce(self, operation, package_label):
129 if (
130 isinstance(operation, DeleteModel)
131 and self.name_lower == operation.name_lower
132 ):
133 return []
134 elif (
135 isinstance(operation, RenameModel)
136 and self.name_lower == operation.old_name_lower
137 ):
138 return [
139 CreateModel(
140 operation.new_name,
141 fields=self.fields,
142 options=self.options,
143 bases=self.bases,
144 managers=self.managers,
145 ),
146 ]
147 elif (
148 isinstance(operation, AlterModelOptions)
149 and self.name_lower == operation.name_lower
150 ):
151 options = {**self.options, **operation.options}
152 for key in operation.ALTER_OPTION_KEYS:
153 if key not in operation.options:
154 options.pop(key, None)
155 return [
156 CreateModel(
157 self.name,
158 fields=self.fields,
159 options=options,
160 bases=self.bases,
161 managers=self.managers,
162 ),
163 ]
164 elif (
165 isinstance(operation, AlterModelManagers)
166 and self.name_lower == operation.name_lower
167 ):
168 return [
169 CreateModel(
170 self.name,
171 fields=self.fields,
172 options=self.options,
173 bases=self.bases,
174 managers=operation.managers,
175 ),
176 ]
177 elif (
178 isinstance(operation, FieldOperation)
179 and self.name_lower == operation.model_name_lower
180 ):
181 if isinstance(operation, AddField):
182 return [
183 CreateModel(
184 self.name,
185 fields=self.fields + [(operation.name, operation.field)],
186 options=self.options,
187 bases=self.bases,
188 managers=self.managers,
189 ),
190 ]
191 elif isinstance(operation, AlterField):
192 return [
193 CreateModel(
194 self.name,
195 fields=[
196 (n, operation.field if n == operation.name else v)
197 for n, v in self.fields
198 ],
199 options=self.options,
200 bases=self.bases,
201 managers=self.managers,
202 ),
203 ]
204 elif isinstance(operation, RemoveField):
205 options = self.options.copy()
206
207 return [
208 CreateModel(
209 self.name,
210 fields=[
211 (n, v)
212 for n, v in self.fields
213 if n.lower() != operation.name_lower
214 ],
215 options=options,
216 bases=self.bases,
217 managers=self.managers,
218 ),
219 ]
220 elif isinstance(operation, RenameField):
221 options = self.options.copy()
222
223 return [
224 CreateModel(
225 self.name,
226 fields=[
227 (operation.new_name if n == operation.old_name else n, v)
228 for n, v in self.fields
229 ],
230 options=options,
231 bases=self.bases,
232 managers=self.managers,
233 ),
234 ]
235 return super().reduce(operation, package_label)
236
237
238class DeleteModel(ModelOperation):
239 """Drop a model's table."""
240
241 def deconstruct(self):
242 kwargs = {
243 "name": self.name,
244 }
245 return (self.__class__.__qualname__, [], kwargs)
246
247 def state_forwards(self, package_label, state):
248 state.remove_model(package_label, self.name_lower)
249
250 def database_forwards(self, package_label, schema_editor, from_state, to_state):
251 model = from_state.models_registry.get_model(package_label, self.name)
252 if self.allow_migrate_model(schema_editor.connection, model):
253 schema_editor.delete_model(model)
254
255 def references_model(self, name, package_label):
256 # The deleted model could be referencing the specified model through
257 # related fields.
258 return True
259
260 def describe(self):
261 return f"Delete model {self.name}"
262
263 @property
264 def migration_name_fragment(self):
265 return f"delete_{self.name_lower}"
266
267
268class RenameModel(ModelOperation):
269 """Rename a model."""
270
271 def __init__(self, old_name, new_name):
272 self.old_name = old_name
273 self.new_name = new_name
274 super().__init__(old_name)
275
276 @cached_property
277 def old_name_lower(self):
278 return self.old_name.lower()
279
280 @cached_property
281 def new_name_lower(self):
282 return self.new_name.lower()
283
284 def deconstruct(self):
285 kwargs = {
286 "old_name": self.old_name,
287 "new_name": self.new_name,
288 }
289 return (self.__class__.__qualname__, [], kwargs)
290
291 def state_forwards(self, package_label, state):
292 state.rename_model(package_label, self.old_name, self.new_name)
293
294 def database_forwards(self, package_label, schema_editor, from_state, to_state):
295 new_model = to_state.models_registry.get_model(package_label, self.new_name)
296 if self.allow_migrate_model(schema_editor.connection, new_model):
297 old_model = from_state.models_registry.get_model(
298 package_label, self.old_name
299 )
300 # Move the main table
301 schema_editor.alter_db_table(
302 new_model,
303 old_model._meta.db_table,
304 new_model._meta.db_table,
305 )
306 # Alter the fields pointing to us
307 for related_object in old_model._meta.related_objects:
308 if related_object.related_model == old_model:
309 model = new_model
310 related_key = (package_label, self.new_name_lower)
311 else:
312 model = related_object.related_model
313 related_key = (
314 related_object.related_model._meta.package_label,
315 related_object.related_model._meta.model_name,
316 )
317 to_field = to_state.models_registry.get_model(
318 *related_key
319 )._meta.get_field(related_object.field.name)
320 schema_editor.alter_field(
321 model,
322 related_object.field,
323 to_field,
324 )
325
326 def references_model(self, name, package_label):
327 return (
328 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
329 )
330
331 def describe(self):
332 return f"Rename model {self.old_name} to {self.new_name}"
333
334 @property
335 def migration_name_fragment(self):
336 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
337
338 def reduce(self, operation, package_label):
339 if (
340 isinstance(operation, RenameModel)
341 and self.new_name_lower == operation.old_name_lower
342 ):
343 return [
344 RenameModel(
345 self.old_name,
346 operation.new_name,
347 ),
348 ]
349 # Skip `ModelOperation.reduce` as we want to run `references_model`
350 # against self.new_name.
351 return super(ModelOperation, self).reduce(
352 operation, package_label
353 ) or not operation.references_model(self.new_name, package_label)
354
355
356class ModelOptionOperation(ModelOperation):
357 def reduce(self, operation, package_label):
358 if (
359 isinstance(operation, self.__class__ | DeleteModel)
360 and self.name_lower == operation.name_lower
361 ):
362 return [operation]
363 return super().reduce(operation, package_label)
364
365
366class AlterModelTable(ModelOptionOperation):
367 """Rename a model's table."""
368
369 def __init__(self, name, table):
370 self.table = table
371 super().__init__(name)
372
373 def deconstruct(self):
374 kwargs = {
375 "name": self.name,
376 "table": self.table,
377 }
378 return (self.__class__.__qualname__, [], kwargs)
379
380 def state_forwards(self, package_label, state):
381 state.alter_model_options(
382 package_label, self.name_lower, {"db_table": self.table}
383 )
384
385 def database_forwards(self, package_label, schema_editor, from_state, to_state):
386 new_model = to_state.models_registry.get_model(package_label, self.name)
387 if self.allow_migrate_model(schema_editor.connection, new_model):
388 old_model = from_state.models_registry.get_model(package_label, self.name)
389 schema_editor.alter_db_table(
390 new_model,
391 old_model._meta.db_table,
392 new_model._meta.db_table,
393 )
394
395 def describe(self):
396 return "Rename table for {} to {}".format(
397 self.name,
398 self.table if self.table is not None else "(default)",
399 )
400
401 @property
402 def migration_name_fragment(self):
403 return f"alter_{self.name_lower}_table"
404
405
406class AlterModelTableComment(ModelOptionOperation):
407 def __init__(self, name, table_comment):
408 self.table_comment = table_comment
409 super().__init__(name)
410
411 def deconstruct(self):
412 kwargs = {
413 "name": self.name,
414 "table_comment": self.table_comment,
415 }
416 return (self.__class__.__qualname__, [], kwargs)
417
418 def state_forwards(self, package_label, state):
419 state.alter_model_options(
420 package_label, self.name_lower, {"db_table_comment": self.table_comment}
421 )
422
423 def database_forwards(self, package_label, schema_editor, from_state, to_state):
424 new_model = to_state.models_registry.get_model(package_label, self.name)
425 if self.allow_migrate_model(schema_editor.connection, new_model):
426 old_model = from_state.models_registry.get_model(package_label, self.name)
427 schema_editor.alter_db_table_comment(
428 new_model,
429 old_model._meta.db_table_comment,
430 new_model._meta.db_table_comment,
431 )
432
433 def describe(self):
434 return f"Alter {self.name} table comment"
435
436 @property
437 def migration_name_fragment(self):
438 return f"alter_{self.name_lower}_table_comment"
439
440
441class AlterModelOptions(ModelOptionOperation):
442 """
443 Set new model options that don't directly affect the database schema
444 (like ordering). Python code in migrations
445 may still need them.
446 """
447
448 # Model options we want to compare and preserve in an AlterModelOptions op
449 ALTER_OPTION_KEYS = [
450 "base_manager_name",
451 "default_manager_name",
452 "default_related_name",
453 "get_latest_by",
454 "ordering",
455 ]
456
457 def __init__(self, name, options):
458 self.options = options
459 super().__init__(name)
460
461 def deconstruct(self):
462 kwargs = {
463 "name": self.name,
464 "options": self.options,
465 }
466 return (self.__class__.__qualname__, [], kwargs)
467
468 def state_forwards(self, package_label, state):
469 state.alter_model_options(
470 package_label,
471 self.name_lower,
472 self.options,
473 self.ALTER_OPTION_KEYS,
474 )
475
476 def database_forwards(self, package_label, schema_editor, from_state, to_state):
477 pass
478
479 def describe(self):
480 return f"Change Meta options on {self.name}"
481
482 @property
483 def migration_name_fragment(self):
484 return f"alter_{self.name_lower}_options"
485
486
487class AlterModelManagers(ModelOptionOperation):
488 """Alter the model's managers."""
489
490 serialization_expand_args = ["managers"]
491
492 def __init__(self, name, managers):
493 self.managers = managers
494 super().__init__(name)
495
496 def deconstruct(self):
497 return (self.__class__.__qualname__, [self.name, self.managers], {})
498
499 def state_forwards(self, package_label, state):
500 state.alter_model_managers(package_label, self.name_lower, self.managers)
501
502 def database_forwards(self, package_label, schema_editor, from_state, to_state):
503 pass
504
505 def describe(self):
506 return f"Change managers on {self.name}"
507
508 @property
509 def migration_name_fragment(self):
510 return f"alter_{self.name_lower}_managers"
511
512
513class IndexOperation(Operation):
514 option_name = "indexes"
515
516 @cached_property
517 def model_name_lower(self):
518 return self.model_name.lower()
519
520
521class AddIndex(IndexOperation):
522 """Add an index on a model."""
523
524 def __init__(self, model_name, index):
525 self.model_name = model_name
526 if not index.name:
527 raise ValueError(
528 "Indexes passed to AddIndex operations require a name "
529 f"argument. {index!r} doesn't have one."
530 )
531 self.index = index
532
533 def state_forwards(self, package_label, state):
534 state.add_index(package_label, self.model_name_lower, self.index)
535
536 def database_forwards(self, package_label, schema_editor, from_state, to_state):
537 model = to_state.models_registry.get_model(package_label, self.model_name)
538 if self.allow_migrate_model(schema_editor.connection, model):
539 schema_editor.add_index(model, self.index)
540
541 def deconstruct(self):
542 kwargs = {
543 "model_name": self.model_name,
544 "index": self.index,
545 }
546 return (
547 self.__class__.__qualname__,
548 [],
549 kwargs,
550 )
551
552 def describe(self):
553 if self.index.expressions:
554 return "Create index {} on {} on model {}".format(
555 self.index.name,
556 ", ".join([str(expression) for expression in self.index.expressions]),
557 self.model_name,
558 )
559 return "Create index {} on field(s) {} of model {}".format(
560 self.index.name,
561 ", ".join(self.index.fields),
562 self.model_name,
563 )
564
565 @property
566 def migration_name_fragment(self):
567 return f"{self.model_name_lower}_{self.index.name.lower()}"
568
569
570class RemoveIndex(IndexOperation):
571 """Remove an index from a model."""
572
573 def __init__(self, model_name, name):
574 self.model_name = model_name
575 self.name = name
576
577 def state_forwards(self, package_label, state):
578 state.remove_index(package_label, self.model_name_lower, self.name)
579
580 def database_forwards(self, package_label, schema_editor, from_state, to_state):
581 model = from_state.models_registry.get_model(package_label, self.model_name)
582 if self.allow_migrate_model(schema_editor.connection, model):
583 from_model_state = from_state.models[package_label, self.model_name_lower]
584 index = from_model_state.get_index_by_name(self.name)
585 schema_editor.remove_index(model, index)
586
587 def deconstruct(self):
588 kwargs = {
589 "model_name": self.model_name,
590 "name": self.name,
591 }
592 return (
593 self.__class__.__qualname__,
594 [],
595 kwargs,
596 )
597
598 def describe(self):
599 return f"Remove index {self.name} from {self.model_name}"
600
601 @property
602 def migration_name_fragment(self):
603 return f"remove_{self.model_name_lower}_{self.name.lower()}"
604
605
606class RenameIndex(IndexOperation):
607 """Rename an index."""
608
609 def __init__(self, model_name, new_name, old_name=None, old_fields=None):
610 if not old_name and not old_fields:
611 raise ValueError(
612 "RenameIndex requires one of old_name and old_fields arguments to be "
613 "set."
614 )
615 if old_name and old_fields:
616 raise ValueError(
617 "RenameIndex.old_name and old_fields are mutually exclusive."
618 )
619 self.model_name = model_name
620 self.new_name = new_name
621 self.old_name = old_name
622 self.old_fields = old_fields
623
624 @cached_property
625 def old_name_lower(self):
626 return self.old_name.lower()
627
628 @cached_property
629 def new_name_lower(self):
630 return self.new_name.lower()
631
632 def deconstruct(self):
633 kwargs = {
634 "model_name": self.model_name,
635 "new_name": self.new_name,
636 }
637 if self.old_name:
638 kwargs["old_name"] = self.old_name
639 if self.old_fields:
640 kwargs["old_fields"] = self.old_fields
641 return (self.__class__.__qualname__, [], kwargs)
642
643 def state_forwards(self, package_label, state):
644 if self.old_fields:
645 state.add_index(
646 package_label,
647 self.model_name_lower,
648 models.Index(fields=self.old_fields, name=self.new_name),
649 )
650 else:
651 state.rename_index(
652 package_label, self.model_name_lower, self.old_name, self.new_name
653 )
654
655 def database_forwards(self, package_label, schema_editor, from_state, to_state):
656 model = to_state.models_registry.get_model(package_label, self.model_name)
657 if not self.allow_migrate_model(schema_editor.connection, model):
658 return
659
660 if self.old_fields:
661 from_model = from_state.models_registry.get_model(
662 package_label, self.model_name
663 )
664 columns = [
665 from_model._meta.get_field(field).column for field in self.old_fields
666 ]
667 matching_index_name = schema_editor._constraint_names(
668 from_model, column_names=columns, index=True
669 )
670 if len(matching_index_name) != 1:
671 raise ValueError(
672 "Found wrong number ({}) of indexes for {}({}).".format(
673 len(matching_index_name),
674 from_model._meta.db_table,
675 ", ".join(columns),
676 )
677 )
678 old_index = models.Index(
679 fields=self.old_fields,
680 name=matching_index_name[0],
681 )
682 else:
683 from_model_state = from_state.models[package_label, self.model_name_lower]
684 old_index = from_model_state.get_index_by_name(self.old_name)
685 # Don't alter when the index name is not changed.
686 if old_index.name == self.new_name:
687 return
688
689 to_model_state = to_state.models[package_label, self.model_name_lower]
690 new_index = to_model_state.get_index_by_name(self.new_name)
691 schema_editor.rename_index(model, old_index, new_index)
692
693 def describe(self):
694 if self.old_name:
695 return (
696 f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
697 )
698 return (
699 f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
700 f"{self.new_name}"
701 )
702
703 @property
704 def migration_name_fragment(self):
705 if self.old_name:
706 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
707 return "rename_{}_{}_{}".format(
708 self.model_name_lower,
709 "_".join(self.old_fields),
710 self.new_name_lower,
711 )
712
713 def reduce(self, operation, package_label):
714 if (
715 isinstance(operation, RenameIndex)
716 and self.model_name_lower == operation.model_name_lower
717 and operation.old_name
718 and self.new_name_lower == operation.old_name_lower
719 ):
720 return [
721 RenameIndex(
722 self.model_name,
723 new_name=operation.new_name,
724 old_name=self.old_name,
725 old_fields=self.old_fields,
726 )
727 ]
728 return super().reduce(operation, package_label)
729
730
731class AddConstraint(IndexOperation):
732 option_name = "constraints"
733
734 def __init__(self, model_name, constraint):
735 self.model_name = model_name
736 self.constraint = constraint
737
738 def state_forwards(self, package_label, state):
739 state.add_constraint(package_label, self.model_name_lower, self.constraint)
740
741 def database_forwards(self, package_label, schema_editor, from_state, to_state):
742 model = to_state.models_registry.get_model(package_label, self.model_name)
743 if self.allow_migrate_model(schema_editor.connection, model):
744 schema_editor.add_constraint(model, self.constraint)
745
746 def deconstruct(self):
747 return (
748 self.__class__.__name__,
749 [],
750 {
751 "model_name": self.model_name,
752 "constraint": self.constraint,
753 },
754 )
755
756 def describe(self):
757 return f"Create constraint {self.constraint.name} on model {self.model_name}"
758
759 @property
760 def migration_name_fragment(self):
761 return f"{self.model_name_lower}_{self.constraint.name.lower()}"
762
763
764class RemoveConstraint(IndexOperation):
765 option_name = "constraints"
766
767 def __init__(self, model_name, name):
768 self.model_name = model_name
769 self.name = name
770
771 def state_forwards(self, package_label, state):
772 state.remove_constraint(package_label, self.model_name_lower, self.name)
773
774 def database_forwards(self, package_label, schema_editor, from_state, to_state):
775 model = to_state.models_registry.get_model(package_label, self.model_name)
776 if self.allow_migrate_model(schema_editor.connection, model):
777 from_model_state = from_state.models[package_label, self.model_name_lower]
778 constraint = from_model_state.get_constraint_by_name(self.name)
779 schema_editor.remove_constraint(model, constraint)
780
781 def deconstruct(self):
782 return (
783 self.__class__.__name__,
784 [],
785 {
786 "model_name": self.model_name,
787 "name": self.name,
788 },
789 )
790
791 def describe(self):
792 return f"Remove constraint {self.name} from model {self.model_name}"
793
794 @property
795 def migration_name_fragment(self):
796 return f"remove_{self.model_name_lower}_{self.name.lower()}"