Plain is headed towards 1.0! Subscribe for development updates →

  1from functools import cached_property
  2
  3from plain.models.fields import NOT_PROVIDED
  4from plain.models.migrations.utils import field_references
  5
  6from .base import Operation
  7
  8
  9class FieldOperation(Operation):
 10    def __init__(self, model_name, name, field=None):
 11        self.model_name = model_name
 12        self.name = name
 13        self.field = field
 14
 15    @cached_property
 16    def model_name_lower(self):
 17        return self.model_name.lower()
 18
 19    @cached_property
 20    def name_lower(self):
 21        return self.name.lower()
 22
 23    def is_same_model_operation(self, operation):
 24        return self.model_name_lower == operation.model_name_lower
 25
 26    def is_same_field_operation(self, operation):
 27        return (
 28            self.is_same_model_operation(operation)
 29            and self.name_lower == operation.name_lower
 30        )
 31
 32    def references_model(self, name, package_label):
 33        name_lower = name.lower()
 34        if name_lower == self.model_name_lower:
 35            return True
 36        if self.field:
 37            return bool(
 38                field_references(
 39                    (package_label, self.model_name_lower),
 40                    self.field,
 41                    (package_label, name_lower),
 42                )
 43            )
 44        return False
 45
 46    def references_field(self, model_name, name, package_label):
 47        model_name_lower = model_name.lower()
 48        # Check if this operation locally references the field.
 49        if model_name_lower == self.model_name_lower:
 50            if name == self.name:
 51                return True
 52        # Check if this operation remotely references the field.
 53        if self.field is None:
 54            return False
 55        return bool(
 56            field_references(
 57                (package_label, self.model_name_lower),
 58                self.field,
 59                (package_label, model_name_lower),
 60                name,
 61            )
 62        )
 63
 64    def reduce(self, operation, package_label):
 65        return super().reduce(
 66            operation, package_label
 67        ) or not operation.references_field(self.model_name, self.name, package_label)
 68
 69
 70class AddField(FieldOperation):
 71    """Add a field to a model."""
 72
 73    def __init__(self, model_name, name, field, preserve_default=True):
 74        self.preserve_default = preserve_default
 75        super().__init__(model_name, name, field)
 76
 77    def deconstruct(self):
 78        kwargs = {
 79            "model_name": self.model_name,
 80            "name": self.name,
 81            "field": self.field,
 82        }
 83        if self.preserve_default is not True:
 84            kwargs["preserve_default"] = self.preserve_default
 85        return (self.__class__.__name__, [], kwargs)
 86
 87    def state_forwards(self, package_label, state):
 88        state.add_field(
 89            package_label,
 90            self.model_name_lower,
 91            self.name,
 92            self.field,
 93            self.preserve_default,
 94        )
 95
 96    def database_forwards(self, package_label, schema_editor, from_state, to_state):
 97        to_model = to_state.models_registry.get_model(package_label, self.model_name)
 98        if self.allow_migrate_model(schema_editor.connection, to_model):
 99            from_model = from_state.models_registry.get_model(
100                package_label, self.model_name
101            )
102            field = to_model._meta.get_field(self.name)
103            if not self.preserve_default:
104                field.default = self.field.default
105            schema_editor.add_field(
106                from_model,
107                field,
108            )
109            if not self.preserve_default:
110                field.default = NOT_PROVIDED
111
112    def describe(self):
113        return f"Add field {self.name} to {self.model_name}"
114
115    @property
116    def migration_name_fragment(self):
117        return f"{self.model_name_lower}_{self.name_lower}"
118
119    def reduce(self, operation, package_label):
120        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
121            operation
122        ):
123            if isinstance(operation, AlterField):
124                return [
125                    AddField(
126                        model_name=self.model_name,
127                        name=operation.name,
128                        field=operation.field,
129                    ),
130                ]
131            elif isinstance(operation, RemoveField):
132                return []
133            elif isinstance(operation, RenameField):
134                return [
135                    AddField(
136                        model_name=self.model_name,
137                        name=operation.new_name,
138                        field=self.field,
139                    ),
140                ]
141        return super().reduce(operation, package_label)
142
143
144class RemoveField(FieldOperation):
145    """Remove a field from a model."""
146
147    def deconstruct(self):
148        kwargs = {
149            "model_name": self.model_name,
150            "name": self.name,
151        }
152        return (self.__class__.__name__, [], kwargs)
153
154    def state_forwards(self, package_label, state):
155        state.remove_field(package_label, self.model_name_lower, self.name)
156
157    def database_forwards(self, package_label, schema_editor, from_state, to_state):
158        from_model = from_state.models_registry.get_model(
159            package_label, self.model_name
160        )
161        if self.allow_migrate_model(schema_editor.connection, from_model):
162            schema_editor.remove_field(
163                from_model, from_model._meta.get_field(self.name)
164            )
165
166    def describe(self):
167        return f"Remove field {self.name} from {self.model_name}"
168
169    @property
170    def migration_name_fragment(self):
171        return f"remove_{self.model_name_lower}_{self.name_lower}"
172
173    def reduce(self, operation, package_label):
174        from .models import DeleteModel
175
176        if (
177            isinstance(operation, DeleteModel)
178            and operation.name_lower == self.model_name_lower
179        ):
180            return [operation]
181        return super().reduce(operation, package_label)
182
183
184class AlterField(FieldOperation):
185    """
186    Alter a field's database column (e.g. null, max_length) to the provided
187    new field.
188    """
189
190    def __init__(self, model_name, name, field, preserve_default=True):
191        self.preserve_default = preserve_default
192        super().__init__(model_name, name, field)
193
194    def deconstruct(self):
195        kwargs = {
196            "model_name": self.model_name,
197            "name": self.name,
198            "field": self.field,
199        }
200        if self.preserve_default is not True:
201            kwargs["preserve_default"] = self.preserve_default
202        return (self.__class__.__name__, [], kwargs)
203
204    def state_forwards(self, package_label, state):
205        state.alter_field(
206            package_label,
207            self.model_name_lower,
208            self.name,
209            self.field,
210            self.preserve_default,
211        )
212
213    def database_forwards(self, package_label, schema_editor, from_state, to_state):
214        to_model = to_state.models_registry.get_model(package_label, self.model_name)
215        if self.allow_migrate_model(schema_editor.connection, to_model):
216            from_model = from_state.models_registry.get_model(
217                package_label, self.model_name
218            )
219            from_field = from_model._meta.get_field(self.name)
220            to_field = to_model._meta.get_field(self.name)
221            if not self.preserve_default:
222                to_field.default = self.field.default
223            schema_editor.alter_field(from_model, from_field, to_field)
224            if not self.preserve_default:
225                to_field.default = NOT_PROVIDED
226
227    def describe(self):
228        return f"Alter field {self.name} on {self.model_name}"
229
230    @property
231    def migration_name_fragment(self):
232        return f"alter_{self.model_name_lower}_{self.name_lower}"
233
234    def reduce(self, operation, package_label):
235        if isinstance(
236            operation, AlterField | RemoveField
237        ) and self.is_same_field_operation(operation):
238            return [operation]
239        elif (
240            isinstance(operation, RenameField)
241            and self.is_same_field_operation(operation)
242            and self.field.db_column is None
243        ):
244            return [
245                operation,
246                AlterField(
247                    model_name=self.model_name,
248                    name=operation.new_name,
249                    field=self.field,
250                ),
251            ]
252        return super().reduce(operation, package_label)
253
254
255class RenameField(FieldOperation):
256    """Rename a field on the model. Might affect db_column too."""
257
258    def __init__(self, model_name, old_name, new_name):
259        self.old_name = old_name
260        self.new_name = new_name
261        super().__init__(model_name, old_name)
262
263    @cached_property
264    def old_name_lower(self):
265        return self.old_name.lower()
266
267    @cached_property
268    def new_name_lower(self):
269        return self.new_name.lower()
270
271    def deconstruct(self):
272        kwargs = {
273            "model_name": self.model_name,
274            "old_name": self.old_name,
275            "new_name": self.new_name,
276        }
277        return (self.__class__.__name__, [], kwargs)
278
279    def state_forwards(self, package_label, state):
280        state.rename_field(
281            package_label, self.model_name_lower, self.old_name, self.new_name
282        )
283
284    def database_forwards(self, package_label, schema_editor, from_state, to_state):
285        to_model = to_state.models_registry.get_model(package_label, self.model_name)
286        if self.allow_migrate_model(schema_editor.connection, to_model):
287            from_model = from_state.models_registry.get_model(
288                package_label, self.model_name
289            )
290            schema_editor.alter_field(
291                from_model,
292                from_model._meta.get_field(self.old_name),
293                to_model._meta.get_field(self.new_name),
294            )
295
296    def describe(self):
297        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
298
299    @property
300    def migration_name_fragment(self):
301        return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
302
303    def references_field(self, model_name, name, package_label):
304        return self.references_model(model_name, package_label) and (
305            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
306        )
307
308    def reduce(self, operation, package_label):
309        if (
310            isinstance(operation, RenameField)
311            and self.is_same_model_operation(operation)
312            and self.new_name_lower == operation.old_name_lower
313        ):
314            return [
315                RenameField(
316                    self.model_name,
317                    self.old_name,
318                    operation.new_name,
319                ),
320            ]
321        # Skip `FieldOperation.reduce` as we want to run `references_field`
322        # against self.old_name and self.new_name.
323        return super(FieldOperation, self).reduce(operation, package_label) or not (
324            operation.references_field(self.model_name, self.old_name, package_label)
325            or operation.references_field(self.model_name, self.new_name, package_label)
326        )