1from functools import cached_property
2
3from plain import models
4from plain.models.migrations.operations.base import Operation
5from plain.models.migrations.state import ModelState
6from plain.models.migrations.utils import field_references, resolve_relation
7
8from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
9
10
11def _check_for_duplicates(arg_name, objs):
12 used_vals = set()
13 for val in objs:
14 if val in used_vals:
15 raise ValueError(
16 f"Found duplicate value {val} in CreateModel {arg_name} argument."
17 )
18 used_vals.add(val)
19
20
21class ModelOperation(Operation):
22 def __init__(self, name):
23 self.name = name
24
25 @cached_property
26 def name_lower(self):
27 return self.name.lower()
28
29 def references_model(self, name, package_label):
30 return name.lower() == self.name_lower
31
32 def reduce(self, operation, package_label):
33 return super().reduce(operation, package_label) or self.can_reduce_through(
34 operation, package_label
35 )
36
37 def can_reduce_through(self, operation, package_label):
38 return not operation.references_model(self.name, package_label)
39
40
41class CreateModel(ModelOperation):
42 """Create a model's table."""
43
44 serialization_expand_args = ["fields", "options"]
45
46 def __init__(self, name, fields, options=None, bases=None):
47 self.fields = fields
48 self.options = options or {}
49 self.bases = bases or (models.Model,)
50 super().__init__(name)
51 # Sanity-check that there are no duplicated field names or bases
52 _check_for_duplicates("fields", (name for name, _ in self.fields))
53 _check_for_duplicates(
54 "bases",
55 (
56 base._meta.label_lower
57 if hasattr(base, "_meta")
58 else base.lower()
59 if isinstance(base, str)
60 else base
61 for base in self.bases
62 ),
63 )
64
65 def deconstruct(self):
66 kwargs = {
67 "name": self.name,
68 "fields": self.fields,
69 }
70 if self.options:
71 kwargs["options"] = self.options
72 if self.bases and self.bases != (models.Model,):
73 kwargs["bases"] = self.bases
74 return (self.__class__.__qualname__, [], kwargs)
75
76 def state_forwards(self, package_label, state):
77 state.add_model(
78 ModelState(
79 package_label,
80 self.name,
81 list(self.fields),
82 dict(self.options),
83 tuple(self.bases),
84 )
85 )
86
87 def database_forwards(self, package_label, schema_editor, from_state, to_state):
88 model = to_state.models_registry.get_model(package_label, self.name)
89 if self.allow_migrate_model(schema_editor.connection, model):
90 schema_editor.create_model(model)
91
92 def describe(self):
93 return f"Create model {self.name}"
94
95 @property
96 def migration_name_fragment(self):
97 return self.name_lower
98
99 def references_model(self, name, package_label):
100 name_lower = name.lower()
101 if name_lower == self.name_lower:
102 return True
103
104 # Check we didn't inherit from the model
105 reference_model_tuple = (package_label, name_lower)
106 for base in self.bases:
107 if (
108 base is not models.Model
109 and isinstance(base, models.base.ModelBase | str)
110 and resolve_relation(base, package_label) == reference_model_tuple
111 ):
112 return True
113
114 # Check we have no FKs/M2Ms with it
115 for _name, field in self.fields:
116 if field_references(
117 (package_label, self.name_lower), field, reference_model_tuple
118 ):
119 return True
120 return False
121
122 def reduce(self, operation, package_label):
123 if (
124 isinstance(operation, DeleteModel)
125 and self.name_lower == operation.name_lower
126 ):
127 return []
128 elif (
129 isinstance(operation, RenameModel)
130 and self.name_lower == operation.old_name_lower
131 ):
132 return [
133 CreateModel(
134 operation.new_name,
135 fields=self.fields,
136 options=self.options,
137 bases=self.bases,
138 ),
139 ]
140 elif (
141 isinstance(operation, AlterModelOptions)
142 and self.name_lower == operation.name_lower
143 ):
144 options = {**self.options, **operation.options}
145 for key in operation.ALTER_OPTION_KEYS:
146 if key not in operation.options:
147 options.pop(key, None)
148 return [
149 CreateModel(
150 self.name,
151 fields=self.fields,
152 options=options,
153 bases=self.bases,
154 ),
155 ]
156 elif (
157 isinstance(operation, FieldOperation)
158 and self.name_lower == operation.model_name_lower
159 ):
160 if isinstance(operation, AddField):
161 return [
162 CreateModel(
163 self.name,
164 fields=self.fields + [(operation.name, operation.field)],
165 options=self.options,
166 bases=self.bases,
167 ),
168 ]
169 elif isinstance(operation, AlterField):
170 return [
171 CreateModel(
172 self.name,
173 fields=[
174 (n, operation.field if n == operation.name else v)
175 for n, v in self.fields
176 ],
177 options=self.options,
178 bases=self.bases,
179 ),
180 ]
181 elif isinstance(operation, RemoveField):
182 options = self.options.copy()
183
184 return [
185 CreateModel(
186 self.name,
187 fields=[
188 (n, v)
189 for n, v in self.fields
190 if n.lower() != operation.name_lower
191 ],
192 options=options,
193 bases=self.bases,
194 ),
195 ]
196 elif isinstance(operation, RenameField):
197 options = self.options.copy()
198
199 return [
200 CreateModel(
201 self.name,
202 fields=[
203 (operation.new_name if n == operation.old_name else n, v)
204 for n, v in self.fields
205 ],
206 options=options,
207 bases=self.bases,
208 ),
209 ]
210 return super().reduce(operation, package_label)
211
212
213class DeleteModel(ModelOperation):
214 """Drop a model's table."""
215
216 def deconstruct(self):
217 kwargs = {
218 "name": self.name,
219 }
220 return (self.__class__.__qualname__, [], kwargs)
221
222 def state_forwards(self, package_label, state):
223 state.remove_model(package_label, self.name_lower)
224
225 def database_forwards(self, package_label, schema_editor, from_state, to_state):
226 model = from_state.models_registry.get_model(package_label, self.name)
227 if self.allow_migrate_model(schema_editor.connection, model):
228 schema_editor.delete_model(model)
229
230 def references_model(self, name, package_label):
231 # The deleted model could be referencing the specified model through
232 # related fields.
233 return True
234
235 def describe(self):
236 return f"Delete model {self.name}"
237
238 @property
239 def migration_name_fragment(self):
240 return f"delete_{self.name_lower}"
241
242
243class RenameModel(ModelOperation):
244 """Rename a model."""
245
246 def __init__(self, old_name, new_name):
247 self.old_name = old_name
248 self.new_name = new_name
249 super().__init__(old_name)
250
251 @cached_property
252 def old_name_lower(self):
253 return self.old_name.lower()
254
255 @cached_property
256 def new_name_lower(self):
257 return self.new_name.lower()
258
259 def deconstruct(self):
260 kwargs = {
261 "old_name": self.old_name,
262 "new_name": self.new_name,
263 }
264 return (self.__class__.__qualname__, [], kwargs)
265
266 def state_forwards(self, package_label, state):
267 state.rename_model(package_label, self.old_name, self.new_name)
268
269 def database_forwards(self, package_label, schema_editor, from_state, to_state):
270 new_model = to_state.models_registry.get_model(package_label, self.new_name)
271 if self.allow_migrate_model(schema_editor.connection, new_model):
272 old_model = from_state.models_registry.get_model(
273 package_label, self.old_name
274 )
275 # Move the main table
276 schema_editor.alter_db_table(
277 new_model,
278 old_model._meta.db_table,
279 new_model._meta.db_table,
280 )
281 # Alter the fields pointing to us
282 for related_object in old_model._meta.related_objects:
283 if related_object.related_model == old_model:
284 model = new_model
285 related_key = (package_label, self.new_name_lower)
286 else:
287 model = related_object.related_model
288 related_key = (
289 related_object.related_model._meta.package_label,
290 related_object.related_model._meta.model_name,
291 )
292 to_field = to_state.models_registry.get_model(
293 *related_key
294 )._meta.get_field(related_object.field.name)
295 schema_editor.alter_field(
296 model,
297 related_object.field,
298 to_field,
299 )
300
301 def references_model(self, name, package_label):
302 return (
303 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
304 )
305
306 def describe(self):
307 return f"Rename model {self.old_name} to {self.new_name}"
308
309 @property
310 def migration_name_fragment(self):
311 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
312
313 def reduce(self, operation, package_label):
314 if (
315 isinstance(operation, RenameModel)
316 and self.new_name_lower == operation.old_name_lower
317 ):
318 return [
319 RenameModel(
320 self.old_name,
321 operation.new_name,
322 ),
323 ]
324 # Skip `ModelOperation.reduce` as we want to run `references_model`
325 # against self.new_name.
326 return super(ModelOperation, self).reduce(
327 operation, package_label
328 ) or not operation.references_model(self.new_name, package_label)
329
330
331class ModelOptionOperation(ModelOperation):
332 def reduce(self, operation, package_label):
333 if (
334 isinstance(operation, self.__class__ | DeleteModel)
335 and self.name_lower == operation.name_lower
336 ):
337 return [operation]
338 return super().reduce(operation, package_label)
339
340
341class AlterModelTable(ModelOptionOperation):
342 """Rename a model's table."""
343
344 def __init__(self, name, table):
345 self.table = table
346 super().__init__(name)
347
348 def deconstruct(self):
349 kwargs = {
350 "name": self.name,
351 "table": self.table,
352 }
353 return (self.__class__.__qualname__, [], kwargs)
354
355 def state_forwards(self, package_label, state):
356 state.alter_model_options(
357 package_label, self.name_lower, {"db_table": self.table}
358 )
359
360 def database_forwards(self, package_label, schema_editor, from_state, to_state):
361 new_model = to_state.models_registry.get_model(package_label, self.name)
362 if self.allow_migrate_model(schema_editor.connection, new_model):
363 old_model = from_state.models_registry.get_model(package_label, self.name)
364 schema_editor.alter_db_table(
365 new_model,
366 old_model._meta.db_table,
367 new_model._meta.db_table,
368 )
369
370 def describe(self):
371 return "Rename table for {} to {}".format(
372 self.name,
373 self.table if self.table is not None else "(default)",
374 )
375
376 @property
377 def migration_name_fragment(self):
378 return f"alter_{self.name_lower}_table"
379
380
381class AlterModelTableComment(ModelOptionOperation):
382 def __init__(self, name, table_comment):
383 self.table_comment = table_comment
384 super().__init__(name)
385
386 def deconstruct(self):
387 kwargs = {
388 "name": self.name,
389 "table_comment": self.table_comment,
390 }
391 return (self.__class__.__qualname__, [], kwargs)
392
393 def state_forwards(self, package_label, state):
394 state.alter_model_options(
395 package_label, self.name_lower, {"db_table_comment": self.table_comment}
396 )
397
398 def database_forwards(self, package_label, schema_editor, from_state, to_state):
399 new_model = to_state.models_registry.get_model(package_label, self.name)
400 if self.allow_migrate_model(schema_editor.connection, new_model):
401 old_model = from_state.models_registry.get_model(package_label, self.name)
402 schema_editor.alter_db_table_comment(
403 new_model,
404 old_model._meta.db_table_comment,
405 new_model._meta.db_table_comment,
406 )
407
408 def describe(self):
409 return f"Alter {self.name} table comment"
410
411 @property
412 def migration_name_fragment(self):
413 return f"alter_{self.name_lower}_table_comment"
414
415
416class AlterModelOptions(ModelOptionOperation):
417 """
418 Set new model options that don't directly affect the database schema
419 (like ordering). Python code in migrations
420 may still need them.
421 """
422
423 # Model options we want to compare and preserve in an AlterModelOptions op
424 ALTER_OPTION_KEYS = [
425 "ordering",
426 ]
427
428 def __init__(self, name, options):
429 self.options = options
430 super().__init__(name)
431
432 def deconstruct(self):
433 kwargs = {
434 "name": self.name,
435 "options": self.options,
436 }
437 return (self.__class__.__qualname__, [], kwargs)
438
439 def state_forwards(self, package_label, state):
440 state.alter_model_options(
441 package_label,
442 self.name_lower,
443 self.options,
444 self.ALTER_OPTION_KEYS,
445 )
446
447 def database_forwards(self, package_label, schema_editor, from_state, to_state):
448 pass
449
450 def describe(self):
451 return f"Change Meta options on {self.name}"
452
453 @property
454 def migration_name_fragment(self):
455 return f"alter_{self.name_lower}_options"
456
457
458class IndexOperation(Operation):
459 option_name = "indexes"
460
461 @cached_property
462 def model_name_lower(self):
463 return self.model_name.lower()
464
465
466class AddIndex(IndexOperation):
467 """Add an index on a model."""
468
469 def __init__(self, model_name, index):
470 self.model_name = model_name
471 if not index.name:
472 raise ValueError(
473 "Indexes passed to AddIndex operations require a name "
474 f"argument. {index!r} doesn't have one."
475 )
476 self.index = index
477
478 def state_forwards(self, package_label, state):
479 state.add_index(package_label, self.model_name_lower, self.index)
480
481 def database_forwards(self, package_label, schema_editor, from_state, to_state):
482 model = to_state.models_registry.get_model(package_label, self.model_name)
483 if self.allow_migrate_model(schema_editor.connection, model):
484 schema_editor.add_index(model, self.index)
485
486 def deconstruct(self):
487 kwargs = {
488 "model_name": self.model_name,
489 "index": self.index,
490 }
491 return (
492 self.__class__.__qualname__,
493 [],
494 kwargs,
495 )
496
497 def describe(self):
498 if self.index.expressions:
499 return "Create index {} on {} on model {}".format(
500 self.index.name,
501 ", ".join([str(expression) for expression in self.index.expressions]),
502 self.model_name,
503 )
504 return "Create index {} on field(s) {} of model {}".format(
505 self.index.name,
506 ", ".join(self.index.fields),
507 self.model_name,
508 )
509
510 @property
511 def migration_name_fragment(self):
512 return f"{self.model_name_lower}_{self.index.name.lower()}"
513
514
515class RemoveIndex(IndexOperation):
516 """Remove an index from a model."""
517
518 def __init__(self, model_name, name):
519 self.model_name = model_name
520 self.name = name
521
522 def state_forwards(self, package_label, state):
523 state.remove_index(package_label, self.model_name_lower, self.name)
524
525 def database_forwards(self, package_label, schema_editor, from_state, to_state):
526 model = from_state.models_registry.get_model(package_label, self.model_name)
527 if self.allow_migrate_model(schema_editor.connection, model):
528 from_model_state = from_state.models[package_label, self.model_name_lower]
529 index = from_model_state.get_index_by_name(self.name)
530 schema_editor.remove_index(model, index)
531
532 def deconstruct(self):
533 kwargs = {
534 "model_name": self.model_name,
535 "name": self.name,
536 }
537 return (
538 self.__class__.__qualname__,
539 [],
540 kwargs,
541 )
542
543 def describe(self):
544 return f"Remove index {self.name} from {self.model_name}"
545
546 @property
547 def migration_name_fragment(self):
548 return f"remove_{self.model_name_lower}_{self.name.lower()}"
549
550
551class RenameIndex(IndexOperation):
552 """Rename an index."""
553
554 def __init__(self, model_name, new_name, old_name=None, old_fields=None):
555 if not old_name and not old_fields:
556 raise ValueError(
557 "RenameIndex requires one of old_name and old_fields arguments to be "
558 "set."
559 )
560 if old_name and old_fields:
561 raise ValueError(
562 "RenameIndex.old_name and old_fields are mutually exclusive."
563 )
564 self.model_name = model_name
565 self.new_name = new_name
566 self.old_name = old_name
567 self.old_fields = old_fields
568
569 @cached_property
570 def old_name_lower(self):
571 return self.old_name.lower()
572
573 @cached_property
574 def new_name_lower(self):
575 return self.new_name.lower()
576
577 def deconstruct(self):
578 kwargs = {
579 "model_name": self.model_name,
580 "new_name": self.new_name,
581 }
582 if self.old_name:
583 kwargs["old_name"] = self.old_name
584 if self.old_fields:
585 kwargs["old_fields"] = self.old_fields
586 return (self.__class__.__qualname__, [], kwargs)
587
588 def state_forwards(self, package_label, state):
589 if self.old_fields:
590 state.add_index(
591 package_label,
592 self.model_name_lower,
593 models.Index(fields=self.old_fields, name=self.new_name),
594 )
595 else:
596 state.rename_index(
597 package_label, self.model_name_lower, self.old_name, self.new_name
598 )
599
600 def database_forwards(self, package_label, schema_editor, from_state, to_state):
601 model = to_state.models_registry.get_model(package_label, self.model_name)
602 if not self.allow_migrate_model(schema_editor.connection, model):
603 return
604
605 if self.old_fields:
606 from_model = from_state.models_registry.get_model(
607 package_label, self.model_name
608 )
609 columns = [
610 from_model._meta.get_field(field).column for field in self.old_fields
611 ]
612 matching_index_name = schema_editor._constraint_names(
613 from_model, column_names=columns, index=True
614 )
615 if len(matching_index_name) != 1:
616 raise ValueError(
617 "Found wrong number ({}) of indexes for {}({}).".format(
618 len(matching_index_name),
619 from_model._meta.db_table,
620 ", ".join(columns),
621 )
622 )
623 old_index = models.Index(
624 fields=self.old_fields,
625 name=matching_index_name[0],
626 )
627 else:
628 from_model_state = from_state.models[package_label, self.model_name_lower]
629 old_index = from_model_state.get_index_by_name(self.old_name)
630 # Don't alter when the index name is not changed.
631 if old_index.name == self.new_name:
632 return
633
634 to_model_state = to_state.models[package_label, self.model_name_lower]
635 new_index = to_model_state.get_index_by_name(self.new_name)
636 schema_editor.rename_index(model, old_index, new_index)
637
638 def describe(self):
639 if self.old_name:
640 return (
641 f"Rename index {self.old_name} on {self.model_name} to {self.new_name}"
642 )
643 return (
644 f"Rename unnamed index for {self.old_fields} on {self.model_name} to "
645 f"{self.new_name}"
646 )
647
648 @property
649 def migration_name_fragment(self):
650 if self.old_name:
651 return f"rename_{self.old_name_lower}_{self.new_name_lower}"
652 return "rename_{}_{}_{}".format(
653 self.model_name_lower,
654 "_".join(self.old_fields),
655 self.new_name_lower,
656 )
657
658 def reduce(self, operation, package_label):
659 if (
660 isinstance(operation, RenameIndex)
661 and self.model_name_lower == operation.model_name_lower
662 and operation.old_name
663 and self.new_name_lower == operation.old_name_lower
664 ):
665 return [
666 RenameIndex(
667 self.model_name,
668 new_name=operation.new_name,
669 old_name=self.old_name,
670 old_fields=self.old_fields,
671 )
672 ]
673 return super().reduce(operation, package_label)
674
675
676class AddConstraint(IndexOperation):
677 option_name = "constraints"
678
679 def __init__(self, model_name, constraint):
680 self.model_name = model_name
681 self.constraint = constraint
682
683 def state_forwards(self, package_label, state):
684 state.add_constraint(package_label, self.model_name_lower, self.constraint)
685
686 def database_forwards(self, package_label, schema_editor, from_state, to_state):
687 model = to_state.models_registry.get_model(package_label, self.model_name)
688 if self.allow_migrate_model(schema_editor.connection, model):
689 schema_editor.add_constraint(model, self.constraint)
690
691 def deconstruct(self):
692 return (
693 self.__class__.__name__,
694 [],
695 {
696 "model_name": self.model_name,
697 "constraint": self.constraint,
698 },
699 )
700
701 def describe(self):
702 return f"Create constraint {self.constraint.name} on model {self.model_name}"
703
704 @property
705 def migration_name_fragment(self):
706 return f"{self.model_name_lower}_{self.constraint.name.lower()}"
707
708
709class RemoveConstraint(IndexOperation):
710 option_name = "constraints"
711
712 def __init__(self, model_name, name):
713 self.model_name = model_name
714 self.name = name
715
716 def state_forwards(self, package_label, state):
717 state.remove_constraint(package_label, self.model_name_lower, self.name)
718
719 def database_forwards(self, package_label, schema_editor, from_state, to_state):
720 model = to_state.models_registry.get_model(package_label, self.model_name)
721 if self.allow_migrate_model(schema_editor.connection, model):
722 from_model_state = from_state.models[package_label, self.model_name_lower]
723 constraint = from_model_state.get_constraint_by_name(self.name)
724 schema_editor.remove_constraint(model, constraint)
725
726 def deconstruct(self):
727 return (
728 self.__class__.__name__,
729 [],
730 {
731 "model_name": self.model_name,
732 "name": self.name,
733 },
734 )
735
736 def describe(self):
737 return f"Remove constraint {self.name} from model {self.model_name}"
738
739 @property
740 def migration_name_fragment(self):
741 return f"remove_{self.model_name_lower}_{self.name.lower()}"