Plain is headed towards 1.0! Subscribe for development updates →

  1from functools import cached_property
  2
  3from plain.models.fields import NOT_PROVIDED
  4from plain.models.migrations.utils import field_references
  5
  6from .base import Operation
  7
  8
  9class FieldOperation(Operation):
 10    def __init__(self, model_name, name, field=None):
 11        self.model_name = model_name
 12        self.name = name
 13        self.field = field
 14
 15    @cached_property
 16    def model_name_lower(self):
 17        return self.model_name.lower()
 18
 19    @cached_property
 20    def name_lower(self):
 21        return self.name.lower()
 22
 23    def is_same_model_operation(self, operation):
 24        return self.model_name_lower == operation.model_name_lower
 25
 26    def is_same_field_operation(self, operation):
 27        return (
 28            self.is_same_model_operation(operation)
 29            and self.name_lower == operation.name_lower
 30        )
 31
 32    def references_model(self, name, package_label):
 33        name_lower = name.lower()
 34        if name_lower == self.model_name_lower:
 35            return True
 36        if self.field:
 37            return bool(
 38                field_references(
 39                    (package_label, self.model_name_lower),
 40                    self.field,
 41                    (package_label, name_lower),
 42                )
 43            )
 44        return False
 45
 46    def references_field(self, model_name, name, package_label):
 47        model_name_lower = model_name.lower()
 48        # Check if this operation locally references the field.
 49        if model_name_lower == self.model_name_lower:
 50            if name == self.name:
 51                return True
 52            elif (
 53                self.field
 54                and hasattr(self.field, "from_fields")
 55                and name in self.field.from_fields
 56            ):
 57                return True
 58        # Check if this operation remotely references the field.
 59        if self.field is None:
 60            return False
 61        return bool(
 62            field_references(
 63                (package_label, self.model_name_lower),
 64                self.field,
 65                (package_label, model_name_lower),
 66                name,
 67            )
 68        )
 69
 70    def reduce(self, operation, package_label):
 71        return super().reduce(
 72            operation, package_label
 73        ) or not operation.references_field(self.model_name, self.name, package_label)
 74
 75
 76class AddField(FieldOperation):
 77    """Add a field to a model."""
 78
 79    def __init__(self, model_name, name, field, preserve_default=True):
 80        self.preserve_default = preserve_default
 81        super().__init__(model_name, name, field)
 82
 83    def deconstruct(self):
 84        kwargs = {
 85            "model_name": self.model_name,
 86            "name": self.name,
 87            "field": self.field,
 88        }
 89        if self.preserve_default is not True:
 90            kwargs["preserve_default"] = self.preserve_default
 91        return (self.__class__.__name__, [], kwargs)
 92
 93    def state_forwards(self, package_label, state):
 94        state.add_field(
 95            package_label,
 96            self.model_name_lower,
 97            self.name,
 98            self.field,
 99            self.preserve_default,
100        )
101
102    def database_forwards(self, package_label, schema_editor, from_state, to_state):
103        to_model = to_state.models_registry.get_model(package_label, self.model_name)
104        if self.allow_migrate_model(schema_editor.connection, to_model):
105            from_model = from_state.models_registry.get_model(
106                package_label, self.model_name
107            )
108            field = to_model._meta.get_field(self.name)
109            if not self.preserve_default:
110                field.default = self.field.default
111            schema_editor.add_field(
112                from_model,
113                field,
114            )
115            if not self.preserve_default:
116                field.default = NOT_PROVIDED
117
118    def describe(self):
119        return f"Add field {self.name} to {self.model_name}"
120
121    @property
122    def migration_name_fragment(self):
123        return f"{self.model_name_lower}_{self.name_lower}"
124
125    def reduce(self, operation, package_label):
126        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
127            operation
128        ):
129            if isinstance(operation, AlterField):
130                return [
131                    AddField(
132                        model_name=self.model_name,
133                        name=operation.name,
134                        field=operation.field,
135                    ),
136                ]
137            elif isinstance(operation, RemoveField):
138                return []
139            elif isinstance(operation, RenameField):
140                return [
141                    AddField(
142                        model_name=self.model_name,
143                        name=operation.new_name,
144                        field=self.field,
145                    ),
146                ]
147        return super().reduce(operation, package_label)
148
149
150class RemoveField(FieldOperation):
151    """Remove a field from a model."""
152
153    def deconstruct(self):
154        kwargs = {
155            "model_name": self.model_name,
156            "name": self.name,
157        }
158        return (self.__class__.__name__, [], kwargs)
159
160    def state_forwards(self, package_label, state):
161        state.remove_field(package_label, self.model_name_lower, self.name)
162
163    def database_forwards(self, package_label, schema_editor, from_state, to_state):
164        from_model = from_state.models_registry.get_model(
165            package_label, self.model_name
166        )
167        if self.allow_migrate_model(schema_editor.connection, from_model):
168            schema_editor.remove_field(
169                from_model, from_model._meta.get_field(self.name)
170            )
171
172    def describe(self):
173        return f"Remove field {self.name} from {self.model_name}"
174
175    @property
176    def migration_name_fragment(self):
177        return f"remove_{self.model_name_lower}_{self.name_lower}"
178
179    def reduce(self, operation, package_label):
180        from .models import DeleteModel
181
182        if (
183            isinstance(operation, DeleteModel)
184            and operation.name_lower == self.model_name_lower
185        ):
186            return [operation]
187        return super().reduce(operation, package_label)
188
189
190class AlterField(FieldOperation):
191    """
192    Alter a field's database column (e.g. null, max_length) to the provided
193    new field.
194    """
195
196    def __init__(self, model_name, name, field, preserve_default=True):
197        self.preserve_default = preserve_default
198        super().__init__(model_name, name, field)
199
200    def deconstruct(self):
201        kwargs = {
202            "model_name": self.model_name,
203            "name": self.name,
204            "field": self.field,
205        }
206        if self.preserve_default is not True:
207            kwargs["preserve_default"] = self.preserve_default
208        return (self.__class__.__name__, [], kwargs)
209
210    def state_forwards(self, package_label, state):
211        state.alter_field(
212            package_label,
213            self.model_name_lower,
214            self.name,
215            self.field,
216            self.preserve_default,
217        )
218
219    def database_forwards(self, package_label, schema_editor, from_state, to_state):
220        to_model = to_state.models_registry.get_model(package_label, self.model_name)
221        if self.allow_migrate_model(schema_editor.connection, to_model):
222            from_model = from_state.models_registry.get_model(
223                package_label, self.model_name
224            )
225            from_field = from_model._meta.get_field(self.name)
226            to_field = to_model._meta.get_field(self.name)
227            if not self.preserve_default:
228                to_field.default = self.field.default
229            schema_editor.alter_field(from_model, from_field, to_field)
230            if not self.preserve_default:
231                to_field.default = NOT_PROVIDED
232
233    def describe(self):
234        return f"Alter field {self.name} on {self.model_name}"
235
236    @property
237    def migration_name_fragment(self):
238        return f"alter_{self.model_name_lower}_{self.name_lower}"
239
240    def reduce(self, operation, package_label):
241        if isinstance(
242            operation, AlterField | RemoveField
243        ) and self.is_same_field_operation(operation):
244            return [operation]
245        elif (
246            isinstance(operation, RenameField)
247            and self.is_same_field_operation(operation)
248            and self.field.db_column is None
249        ):
250            return [
251                operation,
252                AlterField(
253                    model_name=self.model_name,
254                    name=operation.new_name,
255                    field=self.field,
256                ),
257            ]
258        return super().reduce(operation, package_label)
259
260
261class RenameField(FieldOperation):
262    """Rename a field on the model. Might affect db_column too."""
263
264    def __init__(self, model_name, old_name, new_name):
265        self.old_name = old_name
266        self.new_name = new_name
267        super().__init__(model_name, old_name)
268
269    @cached_property
270    def old_name_lower(self):
271        return self.old_name.lower()
272
273    @cached_property
274    def new_name_lower(self):
275        return self.new_name.lower()
276
277    def deconstruct(self):
278        kwargs = {
279            "model_name": self.model_name,
280            "old_name": self.old_name,
281            "new_name": self.new_name,
282        }
283        return (self.__class__.__name__, [], kwargs)
284
285    def state_forwards(self, package_label, state):
286        state.rename_field(
287            package_label, self.model_name_lower, self.old_name, self.new_name
288        )
289
290    def database_forwards(self, package_label, schema_editor, from_state, to_state):
291        to_model = to_state.models_registry.get_model(package_label, self.model_name)
292        if self.allow_migrate_model(schema_editor.connection, to_model):
293            from_model = from_state.models_registry.get_model(
294                package_label, self.model_name
295            )
296            schema_editor.alter_field(
297                from_model,
298                from_model._meta.get_field(self.old_name),
299                to_model._meta.get_field(self.new_name),
300            )
301
302    def describe(self):
303        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
304
305    @property
306    def migration_name_fragment(self):
307        return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
308
309    def references_field(self, model_name, name, package_label):
310        return self.references_model(model_name, package_label) and (
311            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
312        )
313
314    def reduce(self, operation, package_label):
315        if (
316            isinstance(operation, RenameField)
317            and self.is_same_model_operation(operation)
318            and self.new_name_lower == operation.old_name_lower
319        ):
320            return [
321                RenameField(
322                    self.model_name,
323                    self.old_name,
324                    operation.new_name,
325                ),
326            ]
327        # Skip `FieldOperation.reduce` as we want to run `references_field`
328        # against self.old_name and self.new_name.
329        return super(FieldOperation, self).reduce(operation, package_label) or not (
330            operation.references_field(self.model_name, self.old_name, package_label)
331            or operation.references_field(self.model_name, self.new_name, package_label)
332        )