Plain is headed towards 1.0! Subscribe for development updates →

  1from plain.models.fields import NOT_PROVIDED
  2from plain.models.migrations.utils import field_references
  3from plain.utils.functional import cached_property
  4
  5from .base import Operation
  6
  7
  8class FieldOperation(Operation):
  9    def __init__(self, model_name, name, field=None):
 10        self.model_name = model_name
 11        self.name = name
 12        self.field = field
 13
 14    @cached_property
 15    def model_name_lower(self):
 16        return self.model_name.lower()
 17
 18    @cached_property
 19    def name_lower(self):
 20        return self.name.lower()
 21
 22    def is_same_model_operation(self, operation):
 23        return self.model_name_lower == operation.model_name_lower
 24
 25    def is_same_field_operation(self, operation):
 26        return (
 27            self.is_same_model_operation(operation)
 28            and self.name_lower == operation.name_lower
 29        )
 30
 31    def references_model(self, name, package_label):
 32        name_lower = name.lower()
 33        if name_lower == self.model_name_lower:
 34            return True
 35        if self.field:
 36            return bool(
 37                field_references(
 38                    (package_label, self.model_name_lower),
 39                    self.field,
 40                    (package_label, name_lower),
 41                )
 42            )
 43        return False
 44
 45    def references_field(self, model_name, name, package_label):
 46        model_name_lower = model_name.lower()
 47        # Check if this operation locally references the field.
 48        if model_name_lower == self.model_name_lower:
 49            if name == self.name:
 50                return True
 51            elif (
 52                self.field
 53                and hasattr(self.field, "from_fields")
 54                and name in self.field.from_fields
 55            ):
 56                return True
 57        # Check if this operation remotely references the field.
 58        if self.field is None:
 59            return False
 60        return bool(
 61            field_references(
 62                (package_label, self.model_name_lower),
 63                self.field,
 64                (package_label, model_name_lower),
 65                name,
 66            )
 67        )
 68
 69    def reduce(self, operation, package_label):
 70        return super().reduce(
 71            operation, package_label
 72        ) or not operation.references_field(self.model_name, self.name, package_label)
 73
 74
 75class AddField(FieldOperation):
 76    """Add a field to a model."""
 77
 78    def __init__(self, model_name, name, field, preserve_default=True):
 79        self.preserve_default = preserve_default
 80        super().__init__(model_name, name, field)
 81
 82    def deconstruct(self):
 83        kwargs = {
 84            "model_name": self.model_name,
 85            "name": self.name,
 86            "field": self.field,
 87        }
 88        if self.preserve_default is not True:
 89            kwargs["preserve_default"] = self.preserve_default
 90        return (self.__class__.__name__, [], kwargs)
 91
 92    def state_forwards(self, package_label, state):
 93        state.add_field(
 94            package_label,
 95            self.model_name_lower,
 96            self.name,
 97            self.field,
 98            self.preserve_default,
 99        )
100
101    def database_forwards(self, package_label, schema_editor, from_state, to_state):
102        to_model = to_state.packages.get_model(package_label, self.model_name)
103        if self.allow_migrate_model(schema_editor.connection.alias, to_model):
104            from_model = from_state.packages.get_model(package_label, self.model_name)
105            field = to_model._meta.get_field(self.name)
106            if not self.preserve_default:
107                field.default = self.field.default
108            schema_editor.add_field(
109                from_model,
110                field,
111            )
112            if not self.preserve_default:
113                field.default = NOT_PROVIDED
114
115    def database_backwards(self, package_label, schema_editor, from_state, to_state):
116        from_model = from_state.packages.get_model(package_label, self.model_name)
117        if self.allow_migrate_model(schema_editor.connection.alias, from_model):
118            schema_editor.remove_field(
119                from_model, from_model._meta.get_field(self.name)
120            )
121
122    def describe(self):
123        return f"Add field {self.name} to {self.model_name}"
124
125    @property
126    def migration_name_fragment(self):
127        return f"{self.model_name_lower}_{self.name_lower}"
128
129    def reduce(self, operation, package_label):
130        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
131            operation
132        ):
133            if isinstance(operation, AlterField):
134                return [
135                    AddField(
136                        model_name=self.model_name,
137                        name=operation.name,
138                        field=operation.field,
139                    ),
140                ]
141            elif isinstance(operation, RemoveField):
142                return []
143            elif isinstance(operation, RenameField):
144                return [
145                    AddField(
146                        model_name=self.model_name,
147                        name=operation.new_name,
148                        field=self.field,
149                    ),
150                ]
151        return super().reduce(operation, package_label)
152
153
154class RemoveField(FieldOperation):
155    """Remove a field from a model."""
156
157    def deconstruct(self):
158        kwargs = {
159            "model_name": self.model_name,
160            "name": self.name,
161        }
162        return (self.__class__.__name__, [], kwargs)
163
164    def state_forwards(self, package_label, state):
165        state.remove_field(package_label, self.model_name_lower, self.name)
166
167    def database_forwards(self, package_label, schema_editor, from_state, to_state):
168        from_model = from_state.packages.get_model(package_label, self.model_name)
169        if self.allow_migrate_model(schema_editor.connection.alias, from_model):
170            schema_editor.remove_field(
171                from_model, from_model._meta.get_field(self.name)
172            )
173
174    def database_backwards(self, package_label, schema_editor, from_state, to_state):
175        to_model = to_state.packages.get_model(package_label, self.model_name)
176        if self.allow_migrate_model(schema_editor.connection.alias, to_model):
177            from_model = from_state.packages.get_model(package_label, self.model_name)
178            schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
179
180    def describe(self):
181        return f"Remove field {self.name} from {self.model_name}"
182
183    @property
184    def migration_name_fragment(self):
185        return f"remove_{self.model_name_lower}_{self.name_lower}"
186
187    def reduce(self, operation, package_label):
188        from .models import DeleteModel
189
190        if (
191            isinstance(operation, DeleteModel)
192            and operation.name_lower == self.model_name_lower
193        ):
194            return [operation]
195        return super().reduce(operation, package_label)
196
197
198class AlterField(FieldOperation):
199    """
200    Alter a field's database column (e.g. null, max_length) to the provided
201    new field.
202    """
203
204    def __init__(self, model_name, name, field, preserve_default=True):
205        self.preserve_default = preserve_default
206        super().__init__(model_name, name, field)
207
208    def deconstruct(self):
209        kwargs = {
210            "model_name": self.model_name,
211            "name": self.name,
212            "field": self.field,
213        }
214        if self.preserve_default is not True:
215            kwargs["preserve_default"] = self.preserve_default
216        return (self.__class__.__name__, [], kwargs)
217
218    def state_forwards(self, package_label, state):
219        state.alter_field(
220            package_label,
221            self.model_name_lower,
222            self.name,
223            self.field,
224            self.preserve_default,
225        )
226
227    def database_forwards(self, package_label, schema_editor, from_state, to_state):
228        to_model = to_state.packages.get_model(package_label, self.model_name)
229        if self.allow_migrate_model(schema_editor.connection.alias, to_model):
230            from_model = from_state.packages.get_model(package_label, self.model_name)
231            from_field = from_model._meta.get_field(self.name)
232            to_field = to_model._meta.get_field(self.name)
233            if not self.preserve_default:
234                to_field.default = self.field.default
235            schema_editor.alter_field(from_model, from_field, to_field)
236            if not self.preserve_default:
237                to_field.default = NOT_PROVIDED
238
239    def database_backwards(self, package_label, schema_editor, from_state, to_state):
240        self.database_forwards(package_label, schema_editor, from_state, to_state)
241
242    def describe(self):
243        return f"Alter field {self.name} on {self.model_name}"
244
245    @property
246    def migration_name_fragment(self):
247        return f"alter_{self.model_name_lower}_{self.name_lower}"
248
249    def reduce(self, operation, package_label):
250        if isinstance(
251            operation, AlterField | RemoveField
252        ) and self.is_same_field_operation(operation):
253            return [operation]
254        elif (
255            isinstance(operation, RenameField)
256            and self.is_same_field_operation(operation)
257            and self.field.db_column is None
258        ):
259            return [
260                operation,
261                AlterField(
262                    model_name=self.model_name,
263                    name=operation.new_name,
264                    field=self.field,
265                ),
266            ]
267        return super().reduce(operation, package_label)
268
269
270class RenameField(FieldOperation):
271    """Rename a field on the model. Might affect db_column too."""
272
273    def __init__(self, model_name, old_name, new_name):
274        self.old_name = old_name
275        self.new_name = new_name
276        super().__init__(model_name, old_name)
277
278    @cached_property
279    def old_name_lower(self):
280        return self.old_name.lower()
281
282    @cached_property
283    def new_name_lower(self):
284        return self.new_name.lower()
285
286    def deconstruct(self):
287        kwargs = {
288            "model_name": self.model_name,
289            "old_name": self.old_name,
290            "new_name": self.new_name,
291        }
292        return (self.__class__.__name__, [], kwargs)
293
294    def state_forwards(self, package_label, state):
295        state.rename_field(
296            package_label, self.model_name_lower, self.old_name, self.new_name
297        )
298
299    def database_forwards(self, package_label, schema_editor, from_state, to_state):
300        to_model = to_state.packages.get_model(package_label, self.model_name)
301        if self.allow_migrate_model(schema_editor.connection.alias, to_model):
302            from_model = from_state.packages.get_model(package_label, self.model_name)
303            schema_editor.alter_field(
304                from_model,
305                from_model._meta.get_field(self.old_name),
306                to_model._meta.get_field(self.new_name),
307            )
308
309    def database_backwards(self, package_label, schema_editor, from_state, to_state):
310        to_model = to_state.packages.get_model(package_label, self.model_name)
311        if self.allow_migrate_model(schema_editor.connection.alias, to_model):
312            from_model = from_state.packages.get_model(package_label, self.model_name)
313            schema_editor.alter_field(
314                from_model,
315                from_model._meta.get_field(self.new_name),
316                to_model._meta.get_field(self.old_name),
317            )
318
319    def describe(self):
320        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
321
322    @property
323    def migration_name_fragment(self):
324        return "rename_{}_{}_{}".format(
325            self.old_name_lower,
326            self.model_name_lower,
327            self.new_name_lower,
328        )
329
330    def references_field(self, model_name, name, package_label):
331        return self.references_model(model_name, package_label) and (
332            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
333        )
334
335    def reduce(self, operation, package_label):
336        if (
337            isinstance(operation, RenameField)
338            and self.is_same_model_operation(operation)
339            and self.new_name_lower == operation.old_name_lower
340        ):
341            return [
342                RenameField(
343                    self.model_name,
344                    self.old_name,
345                    operation.new_name,
346                ),
347            ]
348        # Skip `FieldOperation.reduce` as we want to run `references_field`
349        # against self.old_name and self.new_name.
350        return super(FieldOperation, self).reduce(operation, package_label) or not (
351            operation.references_field(self.model_name, self.old_name, package_label)
352            or operation.references_field(self.model_name, self.new_name, package_label)
353        )