Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.fields import NOT_PROVIDED
  7from plain.models.migrations.utils import field_references
  8
  9from .base import Operation
 10
 11if TYPE_CHECKING:
 12    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 13    from plain.models.fields import Field
 14    from plain.models.migrations.state import ProjectState
 15
 16
 17class FieldOperation(Operation):
 18    def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
 19        self.model_name = model_name
 20        self.name = name
 21        self.field = field
 22
 23    @cached_property
 24    def model_name_lower(self) -> str:
 25        return self.model_name.lower()
 26
 27    @cached_property
 28    def name_lower(self) -> str:
 29        return self.name.lower()
 30
 31    def is_same_model_operation(self, operation: FieldOperation) -> bool:
 32        return self.model_name_lower == operation.model_name_lower
 33
 34    def is_same_field_operation(self, operation: FieldOperation) -> bool:
 35        return (
 36            self.is_same_model_operation(operation)
 37            and self.name_lower == operation.name_lower
 38        )
 39
 40    def references_model(self, name: str, package_label: str) -> bool:
 41        name_lower = name.lower()
 42        if name_lower == self.model_name_lower:
 43            return True
 44        if self.field:
 45            return bool(
 46                field_references(
 47                    (package_label, self.model_name_lower),
 48                    self.field,
 49                    (package_label, name_lower),
 50                )
 51            )
 52        return False
 53
 54    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
 55        model_name_lower = model_name.lower()
 56        # Check if this operation locally references the field.
 57        if model_name_lower == self.model_name_lower:
 58            if name == self.name:
 59                return True
 60        # Check if this operation remotely references the field.
 61        if self.field is None:
 62            return False
 63        return bool(
 64            field_references(
 65                (package_label, self.model_name_lower),
 66                self.field,
 67                (package_label, model_name_lower),
 68                name,
 69            )
 70        )
 71
 72    def reduce(
 73        self, operation: Operation, package_label: str
 74    ) -> list[Operation] | bool:
 75        return super().reduce(
 76            operation, package_label
 77        ) or not operation.references_field(self.model_name, self.name, package_label)
 78
 79
 80class AddField(FieldOperation):
 81    """Add a field to a model."""
 82
 83    def __init__(
 84        self, model_name: str, name: str, field: Field, preserve_default: bool = True
 85    ) -> None:
 86        self.preserve_default = preserve_default
 87        super().__init__(model_name, name, field)
 88
 89    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
 90        kwargs: dict[str, Any] = {
 91            "model_name": self.model_name,
 92            "name": self.name,
 93            "field": self.field,
 94        }
 95        if self.preserve_default is not True:
 96            kwargs["preserve_default"] = self.preserve_default
 97        return (self.__class__.__name__, [], kwargs)
 98
 99    def state_forwards(self, package_label: str, state: Any) -> None:
100        state.add_field(
101            package_label,
102            self.model_name_lower,
103            self.name,
104            self.field,
105            self.preserve_default,
106        )
107
108    def database_forwards(
109        self,
110        package_label: str,
111        schema_editor: BaseDatabaseSchemaEditor,
112        from_state: ProjectState,
113        to_state: ProjectState,
114    ) -> None:
115        to_model = to_state.models_registry.get_model(package_label, self.model_name)
116        if self.allow_migrate_model(schema_editor.connection, to_model):
117            from_model = from_state.models_registry.get_model(
118                package_label, self.model_name
119            )
120            field = to_model._model_meta.get_field(self.name)
121            assert self.field is not None
122            if not self.preserve_default:
123                field.default = self.field.default
124            schema_editor.add_field(
125                from_model,
126                field,
127            )
128            if not self.preserve_default:
129                field.default = NOT_PROVIDED
130
131    def describe(self) -> str:
132        return f"Add field {self.name} to {self.model_name}"
133
134    @property
135    def migration_name_fragment(self) -> str:
136        return f"{self.model_name_lower}_{self.name_lower}"
137
138    def reduce(
139        self, operation: Operation, package_label: str
140    ) -> list[Operation] | bool:
141        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
142            operation
143        ):
144            if isinstance(operation, AlterField):
145                assert operation.field is not None
146                return [
147                    AddField(
148                        model_name=self.model_name,
149                        name=operation.name,
150                        field=operation.field,
151                    ),
152                ]
153            elif isinstance(operation, RemoveField):
154                return []
155            elif isinstance(operation, RenameField):
156                assert self.field is not None  # AddField always has a field
157                return [
158                    AddField(
159                        model_name=self.model_name,
160                        name=operation.new_name,
161                        field=self.field,
162                    ),
163                ]
164        return super().reduce(operation, package_label)
165
166
167class RemoveField(FieldOperation):
168    """Remove a field from a model."""
169
170    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
171        kwargs: dict[str, Any] = {
172            "model_name": self.model_name,
173            "name": self.name,
174        }
175        return (self.__class__.__name__, [], kwargs)
176
177    def state_forwards(self, package_label: str, state: Any) -> None:
178        state.remove_field(package_label, self.model_name_lower, self.name)
179
180    def database_forwards(
181        self,
182        package_label: str,
183        schema_editor: BaseDatabaseSchemaEditor,
184        from_state: ProjectState,
185        to_state: ProjectState,
186    ) -> None:
187        from_model = from_state.models_registry.get_model(
188            package_label, self.model_name
189        )
190        if self.allow_migrate_model(schema_editor.connection, from_model):
191            schema_editor.remove_field(
192                from_model, from_model._model_meta.get_field(self.name)
193            )
194
195    def describe(self) -> str:
196        return f"Remove field {self.name} from {self.model_name}"
197
198    @property
199    def migration_name_fragment(self) -> str:
200        return f"remove_{self.model_name_lower}_{self.name_lower}"
201
202    def reduce(
203        self, operation: Operation, package_label: str
204    ) -> list[Operation] | bool:
205        from .models import DeleteModel
206
207        if (
208            isinstance(operation, DeleteModel)
209            and operation.name_lower == self.model_name_lower
210        ):
211            return [operation]
212        return super().reduce(operation, package_label)
213
214
215class AlterField(FieldOperation):
216    """
217    Alter a field's database column (e.g. null, max_length) to the provided
218    new field.
219    """
220
221    def __init__(
222        self, model_name: str, name: str, field: Field, preserve_default: bool = True
223    ) -> None:
224        self.preserve_default = preserve_default
225        super().__init__(model_name, name, field)
226
227    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
228        kwargs: dict[str, Any] = {
229            "model_name": self.model_name,
230            "name": self.name,
231            "field": self.field,
232        }
233        if self.preserve_default is not True:
234            kwargs["preserve_default"] = self.preserve_default
235        return (self.__class__.__name__, [], kwargs)
236
237    def state_forwards(self, package_label: str, state: Any) -> None:
238        state.alter_field(
239            package_label,
240            self.model_name_lower,
241            self.name,
242            self.field,
243            self.preserve_default,
244        )
245
246    def database_forwards(
247        self,
248        package_label: str,
249        schema_editor: BaseDatabaseSchemaEditor,
250        from_state: ProjectState,
251        to_state: ProjectState,
252    ) -> None:
253        to_model = to_state.models_registry.get_model(package_label, self.model_name)
254        if self.allow_migrate_model(schema_editor.connection, to_model):
255            from_model = from_state.models_registry.get_model(
256                package_label, self.model_name
257            )
258            from_field = from_model._model_meta.get_field(self.name)
259            to_field = to_model._model_meta.get_field(self.name)
260            assert self.field is not None
261            if not self.preserve_default:
262                to_field.default = self.field.default
263            schema_editor.alter_field(from_model, from_field, to_field)
264            if not self.preserve_default:
265                to_field.default = NOT_PROVIDED
266
267    def describe(self) -> str:
268        return f"Alter field {self.name} on {self.model_name}"
269
270    @property
271    def migration_name_fragment(self) -> str:
272        return f"alter_{self.model_name_lower}_{self.name_lower}"
273
274    def reduce(
275        self, operation: Operation, package_label: str
276    ) -> list[Operation] | bool:
277        if isinstance(
278            operation, AlterField | RemoveField
279        ) and self.is_same_field_operation(operation):
280            return [operation]
281        elif (
282            isinstance(operation, RenameField)
283            and self.is_same_field_operation(operation)
284            and self.field is not None
285            and self.field.db_column is None
286        ):
287            return [
288                operation,
289                AlterField(
290                    model_name=self.model_name,
291                    name=operation.new_name,
292                    field=self.field,
293                ),
294            ]
295        return super().reduce(operation, package_label)
296
297
298class RenameField(FieldOperation):
299    """Rename a field on the model. Might affect db_column too."""
300
301    def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
302        self.old_name = old_name
303        self.new_name = new_name
304        super().__init__(model_name, old_name)
305
306    @cached_property
307    def old_name_lower(self) -> str:
308        return self.old_name.lower()
309
310    @cached_property
311    def new_name_lower(self) -> str:
312        return self.new_name.lower()
313
314    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
315        kwargs: dict[str, Any] = {
316            "model_name": self.model_name,
317            "old_name": self.old_name,
318            "new_name": self.new_name,
319        }
320        return (self.__class__.__name__, [], kwargs)
321
322    def state_forwards(self, package_label: str, state: Any) -> None:
323        state.rename_field(
324            package_label, self.model_name_lower, self.old_name, self.new_name
325        )
326
327    def database_forwards(
328        self,
329        package_label: str,
330        schema_editor: BaseDatabaseSchemaEditor,
331        from_state: ProjectState,
332        to_state: ProjectState,
333    ) -> None:
334        to_model = to_state.models_registry.get_model(package_label, self.model_name)
335        if self.allow_migrate_model(schema_editor.connection, to_model):
336            from_model = from_state.models_registry.get_model(
337                package_label, self.model_name
338            )
339            schema_editor.alter_field(
340                from_model,
341                from_model._model_meta.get_field(self.old_name),
342                to_model._model_meta.get_field(self.new_name),
343            )
344
345    def describe(self) -> str:
346        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
347
348    @property
349    def migration_name_fragment(self) -> str:
350        return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
351
352    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
353        return self.references_model(model_name, package_label) and (
354            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
355        )
356
357    def reduce(
358        self, operation: Operation, package_label: str
359    ) -> list[Operation] | bool:
360        if (
361            isinstance(operation, RenameField)
362            and self.is_same_model_operation(operation)
363            and self.new_name_lower == operation.old_name_lower
364        ):
365            return [
366                RenameField(
367                    self.model_name,
368                    self.old_name,
369                    operation.new_name,
370                ),
371            ]
372        # Skip `FieldOperation.reduce` as we want to run `references_field`
373        # against self.old_name and self.new_name.
374        return super(FieldOperation, self).reduce(operation, package_label) or not (
375            operation.references_field(self.model_name, self.old_name, package_label)
376            or operation.references_field(self.model_name, self.new_name, package_label)
377        )