1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.fields import NOT_PROVIDED
  7from plain.models.migrations.utils import field_references
  8
  9from .base import Operation
 10
 11if TYPE_CHECKING:
 12    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 13    from plain.models.fields import Field
 14    from plain.models.migrations.state import ProjectState
 15
 16
 17class FieldOperation(Operation):
 18    def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
 19        self.model_name = model_name
 20        self.name = name
 21        self.field = field
 22
 23    @cached_property
 24    def model_name_lower(self) -> str:
 25        return self.model_name.lower()
 26
 27    @cached_property
 28    def name_lower(self) -> str:
 29        return self.name.lower()
 30
 31    def is_same_model_operation(self, operation: FieldOperation) -> bool:
 32        return self.model_name_lower == operation.model_name_lower
 33
 34    def is_same_field_operation(self, operation: FieldOperation) -> bool:
 35        return (
 36            self.is_same_model_operation(operation)
 37            and self.name_lower == operation.name_lower
 38        )
 39
 40    def references_model(self, name: str, package_label: str) -> bool:
 41        name_lower = name.lower()
 42        if name_lower == self.model_name_lower:
 43            return True
 44        if self.field:
 45            return bool(
 46                field_references(
 47                    (package_label, self.model_name_lower),
 48                    self.field,
 49                    (package_label, name_lower),
 50                )
 51            )
 52        return False
 53
 54    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
 55        model_name_lower = model_name.lower()
 56        # Check if this operation locally references the field.
 57        if model_name_lower == self.model_name_lower:
 58            if name == self.name:
 59                return True
 60        # Check if this operation remotely references the field.
 61        if self.field is None:
 62            return False
 63        return bool(
 64            field_references(
 65                (package_label, self.model_name_lower),
 66                self.field,
 67                (package_label, model_name_lower),
 68                name,
 69            )
 70        )
 71
 72    def reduce(
 73        self, operation: Operation, package_label: str
 74    ) -> list[Operation] | bool:
 75        return super().reduce(
 76            operation, package_label
 77        ) or not operation.references_field(self.model_name, self.name, package_label)
 78
 79
 80class AddField(FieldOperation):
 81    """Add a field to a model."""
 82
 83    def __init__(
 84        self, model_name: str, name: str, field: Field, preserve_default: bool = True
 85    ) -> None:
 86        self.preserve_default = preserve_default
 87        super().__init__(model_name, name, field)
 88
 89    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
 90        kwargs: dict[str, Any] = {
 91            "model_name": self.model_name,
 92            "name": self.name,
 93            "field": self.field,
 94        }
 95        if self.preserve_default is not True:
 96            kwargs["preserve_default"] = self.preserve_default
 97        return (self.__class__.__name__, [], kwargs)
 98
 99    def state_forwards(self, package_label: str, state: Any) -> None:
100        state.add_field(
101            package_label,
102            self.model_name_lower,
103            self.name,
104            self.field,
105            self.preserve_default,
106        )
107
108    def database_forwards(
109        self,
110        package_label: str,
111        schema_editor: BaseDatabaseSchemaEditor,
112        from_state: ProjectState,
113        to_state: ProjectState,
114    ) -> None:
115        to_model = to_state.models_registry.get_model(package_label, self.model_name)
116        if self.allow_migrate_model(schema_editor.connection, to_model):
117            from_model = from_state.models_registry.get_model(
118                package_label, self.model_name
119            )
120            field = to_model._model_meta.get_field(self.name)
121            if not self.preserve_default:
122                field.default = self.field.default
123            schema_editor.add_field(
124                from_model,
125                field,
126            )
127            if not self.preserve_default:
128                field.default = NOT_PROVIDED
129
130    def describe(self) -> str:
131        return f"Add field {self.name} to {self.model_name}"
132
133    @property
134    def migration_name_fragment(self) -> str:
135        return f"{self.model_name_lower}_{self.name_lower}"
136
137    def reduce(
138        self, operation: Operation, package_label: str
139    ) -> list[Operation] | bool:
140        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
141            operation
142        ):
143            if isinstance(operation, AlterField):
144                assert operation.field is not None
145                return [
146                    AddField(
147                        model_name=self.model_name,
148                        name=operation.name,
149                        field=operation.field,
150                    ),
151                ]
152            elif isinstance(operation, RemoveField):
153                return []
154            elif isinstance(operation, RenameField):
155                return [
156                    AddField(
157                        model_name=self.model_name,
158                        name=operation.new_name,
159                        field=self.field,
160                    ),
161                ]
162        return super().reduce(operation, package_label)
163
164
165class RemoveField(FieldOperation):
166    """Remove a field from a model."""
167
168    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
169        kwargs: dict[str, Any] = {
170            "model_name": self.model_name,
171            "name": self.name,
172        }
173        return (self.__class__.__name__, [], kwargs)
174
175    def state_forwards(self, package_label: str, state: Any) -> None:
176        state.remove_field(package_label, self.model_name_lower, self.name)
177
178    def database_forwards(
179        self,
180        package_label: str,
181        schema_editor: BaseDatabaseSchemaEditor,
182        from_state: ProjectState,
183        to_state: ProjectState,
184    ) -> None:
185        from_model = from_state.models_registry.get_model(
186            package_label, self.model_name
187        )
188        if self.allow_migrate_model(schema_editor.connection, from_model):
189            schema_editor.remove_field(
190                from_model, from_model._model_meta.get_field(self.name)
191            )
192
193    def describe(self) -> str:
194        return f"Remove field {self.name} from {self.model_name}"
195
196    @property
197    def migration_name_fragment(self) -> str:
198        return f"remove_{self.model_name_lower}_{self.name_lower}"
199
200    def reduce(
201        self, operation: Operation, package_label: str
202    ) -> list[Operation] | bool:
203        from .models import DeleteModel
204
205        if (
206            isinstance(operation, DeleteModel)
207            and operation.name_lower == self.model_name_lower
208        ):
209            return [operation]
210        return super().reduce(operation, package_label)
211
212
213class AlterField(FieldOperation):
214    """
215    Alter a field's database column (e.g. null, max_length) to the provided
216    new field.
217    """
218
219    def __init__(
220        self, model_name: str, name: str, field: Field, preserve_default: bool = True
221    ) -> None:
222        self.preserve_default = preserve_default
223        super().__init__(model_name, name, field)
224
225    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
226        kwargs: dict[str, Any] = {
227            "model_name": self.model_name,
228            "name": self.name,
229            "field": self.field,
230        }
231        if self.preserve_default is not True:
232            kwargs["preserve_default"] = self.preserve_default
233        return (self.__class__.__name__, [], kwargs)
234
235    def state_forwards(self, package_label: str, state: Any) -> None:
236        state.alter_field(
237            package_label,
238            self.model_name_lower,
239            self.name,
240            self.field,
241            self.preserve_default,
242        )
243
244    def database_forwards(
245        self,
246        package_label: str,
247        schema_editor: BaseDatabaseSchemaEditor,
248        from_state: ProjectState,
249        to_state: ProjectState,
250    ) -> None:
251        to_model = to_state.models_registry.get_model(package_label, self.model_name)
252        if self.allow_migrate_model(schema_editor.connection, to_model):
253            from_model = from_state.models_registry.get_model(
254                package_label, self.model_name
255            )
256            from_field = from_model._model_meta.get_field(self.name)
257            to_field = to_model._model_meta.get_field(self.name)
258            if not self.preserve_default:
259                to_field.default = self.field.default
260            schema_editor.alter_field(from_model, from_field, to_field)
261            if not self.preserve_default:
262                to_field.default = NOT_PROVIDED
263
264    def describe(self) -> str:
265        return f"Alter field {self.name} on {self.model_name}"
266
267    @property
268    def migration_name_fragment(self) -> str:
269        return f"alter_{self.model_name_lower}_{self.name_lower}"
270
271    def reduce(
272        self, operation: Operation, package_label: str
273    ) -> list[Operation] | bool:
274        if isinstance(
275            operation, AlterField | RemoveField
276        ) and self.is_same_field_operation(operation):
277            return [operation]
278        elif (
279            isinstance(operation, RenameField)
280            and self.is_same_field_operation(operation)
281            and self.field.db_column is None
282        ):
283            return [
284                operation,
285                AlterField(
286                    model_name=self.model_name,
287                    name=operation.new_name,
288                    field=self.field,
289                ),
290            ]
291        return super().reduce(operation, package_label)
292
293
294class RenameField(FieldOperation):
295    """Rename a field on the model. Might affect db_column too."""
296
297    def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
298        self.old_name = old_name
299        self.new_name = new_name
300        super().__init__(model_name, old_name)
301
302    @cached_property
303    def old_name_lower(self) -> str:
304        return self.old_name.lower()
305
306    @cached_property
307    def new_name_lower(self) -> str:
308        return self.new_name.lower()
309
310    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
311        kwargs: dict[str, Any] = {
312            "model_name": self.model_name,
313            "old_name": self.old_name,
314            "new_name": self.new_name,
315        }
316        return (self.__class__.__name__, [], kwargs)
317
318    def state_forwards(self, package_label: str, state: Any) -> None:
319        state.rename_field(
320            package_label, self.model_name_lower, self.old_name, self.new_name
321        )
322
323    def database_forwards(
324        self,
325        package_label: str,
326        schema_editor: BaseDatabaseSchemaEditor,
327        from_state: ProjectState,
328        to_state: ProjectState,
329    ) -> None:
330        to_model = to_state.models_registry.get_model(package_label, self.model_name)
331        if self.allow_migrate_model(schema_editor.connection, to_model):
332            from_model = from_state.models_registry.get_model(
333                package_label, self.model_name
334            )
335            schema_editor.alter_field(
336                from_model,
337                from_model._model_meta.get_field(self.old_name),
338                to_model._model_meta.get_field(self.new_name),
339            )
340
341    def describe(self) -> str:
342        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
343
344    @property
345    def migration_name_fragment(self) -> str:
346        return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
347
348    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
349        return self.references_model(model_name, package_label) and (
350            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
351        )
352
353    def reduce(
354        self, operation: Operation, package_label: str
355    ) -> list[Operation] | bool:
356        if (
357            isinstance(operation, RenameField)
358            and self.is_same_model_operation(operation)
359            and self.new_name_lower == operation.old_name_lower
360        ):
361            return [
362                RenameField(
363                    self.model_name,
364                    self.old_name,
365                    operation.new_name,
366                ),
367            ]
368        # Skip `FieldOperation.reduce` as we want to run `references_field`
369        # against self.old_name and self.new_name.
370        return super(FieldOperation, self).reduce(operation, package_label) or not (
371            operation.references_field(self.model_name, self.old_name, package_label)
372            or operation.references_field(self.model_name, self.new_name, package_label)
373        )