1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.fields import NOT_PROVIDED
  7from plain.models.migrations.utils import field_references
  8
  9from .base import Operation
 10
 11if TYPE_CHECKING:
 12    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 13    from plain.models.fields import Field
 14    from plain.models.migrations.state import ProjectState
 15
 16
 17class FieldOperation(Operation):
 18    def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
 19        self.model_name = model_name
 20        self.name = name
 21        self.field = field
 22
 23    @cached_property
 24    def model_name_lower(self) -> str:
 25        return self.model_name.lower()
 26
 27    @cached_property
 28    def name_lower(self) -> str:
 29        return self.name.lower()
 30
 31    def is_same_model_operation(self, operation: FieldOperation) -> bool:
 32        return self.model_name_lower == operation.model_name_lower
 33
 34    def is_same_field_operation(self, operation: FieldOperation) -> bool:
 35        return (
 36            self.is_same_model_operation(operation)
 37            and self.name_lower == operation.name_lower
 38        )
 39
 40    def references_model(self, name: str, package_label: str) -> bool:
 41        name_lower = name.lower()
 42        if name_lower == self.model_name_lower:
 43            return True
 44        if self.field:
 45            return bool(
 46                field_references(
 47                    (package_label, self.model_name_lower),
 48                    self.field,
 49                    (package_label, name_lower),
 50                )
 51            )
 52        return False
 53
 54    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
 55        model_name_lower = model_name.lower()
 56        # Check if this operation locally references the field.
 57        if model_name_lower == self.model_name_lower:
 58            if name == self.name:
 59                return True
 60        # Check if this operation remotely references the field.
 61        if self.field is None:
 62            return False
 63        return bool(
 64            field_references(
 65                (package_label, self.model_name_lower),
 66                self.field,
 67                (package_label, model_name_lower),
 68                name,
 69            )
 70        )
 71
 72    def reduce(
 73        self, operation: Operation, package_label: str
 74    ) -> list[Operation] | bool:
 75        return super().reduce(
 76            operation, package_label
 77        ) or not operation.references_field(self.model_name, self.name, package_label)
 78
 79
 80class AddField(FieldOperation):
 81    """Add a field to a model."""
 82
 83    def __init__(
 84        self, model_name: str, name: str, field: Field, preserve_default: bool = True
 85    ) -> None:
 86        self.preserve_default = preserve_default
 87        super().__init__(model_name, name, field)
 88
 89    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 90        kwargs: dict[str, Any] = {
 91            "model_name": self.model_name,
 92            "name": self.name,
 93            "field": self.field,
 94        }
 95        if self.preserve_default is not True:
 96            kwargs["preserve_default"] = self.preserve_default
 97        return (self.__class__.__name__, (), kwargs)
 98
 99    def state_forwards(self, package_label: str, state: Any) -> None:
100        state.add_field(
101            package_label,
102            self.model_name_lower,
103            self.name,
104            self.field,
105            self.preserve_default,
106        )
107
108    def database_forwards(
109        self,
110        package_label: str,
111        schema_editor: BaseDatabaseSchemaEditor,
112        from_state: ProjectState,
113        to_state: ProjectState,
114    ) -> None:
115        to_model = to_state.models_registry.get_model(package_label, self.model_name)
116        if self.allow_migrate_model(schema_editor.connection, to_model):
117            from_model = from_state.models_registry.get_model(
118                package_label, self.model_name
119            )
120            field = to_model._model_meta.get_field(self.name)
121            assert self.field is not None
122            if not self.preserve_default:
123                field.default = self.field.default
124            schema_editor.add_field(
125                from_model,
126                field,
127            )
128            if not self.preserve_default:
129                field.default = NOT_PROVIDED
130
131    def describe(self) -> str:
132        return f"Add field {self.name} to {self.model_name}"
133
134    @property
135    def migration_name_fragment(self) -> str:
136        return f"{self.model_name_lower}_{self.name_lower}"
137
138    def reduce(
139        self, operation: Operation, package_label: str
140    ) -> list[Operation] | bool:
141        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
142            operation
143        ):
144            if isinstance(operation, AlterField):
145                assert operation.field is not None
146                return [
147                    AddField(
148                        model_name=self.model_name,
149                        name=operation.name,
150                        field=operation.field,
151                    ),
152                ]
153            elif isinstance(operation, RemoveField):
154                return []
155            elif isinstance(operation, RenameField):
156                assert self.field is not None  # AddField always has a field
157                return [
158                    AddField(
159                        model_name=self.model_name,
160                        name=operation.new_name,
161                        field=self.field,
162                    ),
163                ]
164        return super().reduce(operation, package_label)
165
166
167class RemoveField(FieldOperation):
168    """Remove a field from a model."""
169
170    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
171        kwargs: dict[str, Any] = {
172            "model_name": self.model_name,
173            "name": self.name,
174        }
175        return (self.__class__.__name__, (), kwargs)
176
177    def state_forwards(self, package_label: str, state: Any) -> None:
178        state.remove_field(package_label, self.model_name_lower, self.name)
179
180    def database_forwards(
181        self,
182        package_label: str,
183        schema_editor: BaseDatabaseSchemaEditor,
184        from_state: ProjectState,
185        to_state: ProjectState,
186    ) -> None:
187        from_model = from_state.models_registry.get_model(
188            package_label, self.model_name
189        )
190        if self.allow_migrate_model(schema_editor.connection, from_model):
191            schema_editor.remove_field(
192                from_model, from_model._model_meta.get_field(self.name)
193            )
194
195    def describe(self) -> str:
196        return f"Remove field {self.name} from {self.model_name}"
197
198    @property
199    def migration_name_fragment(self) -> str:
200        return f"remove_{self.model_name_lower}_{self.name_lower}"
201
202    def reduce(
203        self, operation: Operation, package_label: str
204    ) -> list[Operation] | bool:
205        from .models import DeleteModel
206
207        if (
208            isinstance(operation, DeleteModel)
209            and operation.name_lower == self.model_name_lower
210        ):
211            return [operation]
212        return super().reduce(operation, package_label)
213
214
215class AlterField(FieldOperation):
216    """
217    Alter a field's database column (e.g. null, max_length) to the provided
218    new field.
219    """
220
221    def __init__(
222        self, model_name: str, name: str, field: Field, preserve_default: bool = True
223    ) -> None:
224        self.preserve_default = preserve_default
225        super().__init__(model_name, name, field)
226
227    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
228        kwargs: dict[str, Any] = {
229            "model_name": self.model_name,
230            "name": self.name,
231            "field": self.field,
232        }
233        if self.preserve_default is not True:
234            kwargs["preserve_default"] = self.preserve_default
235        return (self.__class__.__name__, (), kwargs)
236
237    def state_forwards(self, package_label: str, state: Any) -> None:
238        state.alter_field(
239            package_label,
240            self.model_name_lower,
241            self.name,
242            self.field,
243            self.preserve_default,
244        )
245
246    def database_forwards(
247        self,
248        package_label: str,
249        schema_editor: BaseDatabaseSchemaEditor,
250        from_state: ProjectState,
251        to_state: ProjectState,
252    ) -> None:
253        to_model = to_state.models_registry.get_model(package_label, self.model_name)
254        if self.allow_migrate_model(schema_editor.connection, to_model):
255            from_model = from_state.models_registry.get_model(
256                package_label, self.model_name
257            )
258            from_field = from_model._model_meta.get_field(self.name)
259            to_field = to_model._model_meta.get_field(self.name)
260            assert self.field is not None
261            if not self.preserve_default:
262                to_field.default = self.field.default
263            schema_editor.alter_field(from_model, from_field, to_field)
264            if not self.preserve_default:
265                to_field.default = NOT_PROVIDED
266
267    def describe(self) -> str:
268        return f"Alter field {self.name} on {self.model_name}"
269
270    @property
271    def migration_name_fragment(self) -> str:
272        return f"alter_{self.model_name_lower}_{self.name_lower}"
273
274    def reduce(
275        self, operation: Operation, package_label: str
276    ) -> list[Operation] | bool:
277        if isinstance(
278            operation, AlterField | RemoveField
279        ) and self.is_same_field_operation(operation):
280            return [operation]
281        elif (
282            isinstance(operation, RenameField)
283            and self.is_same_field_operation(operation)
284            and self.field is not None
285        ):
286            return [
287                operation,
288                AlterField(
289                    model_name=self.model_name,
290                    name=operation.new_name,
291                    field=self.field,
292                ),
293            ]
294        return super().reduce(operation, package_label)
295
296
297class RenameField(FieldOperation):
298    """Rename a field on the model."""
299
300    def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
301        self.old_name = old_name
302        self.new_name = new_name
303        super().__init__(model_name, old_name)
304
305    @cached_property
306    def old_name_lower(self) -> str:
307        return self.old_name.lower()
308
309    @cached_property
310    def new_name_lower(self) -> str:
311        return self.new_name.lower()
312
313    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
314        kwargs: dict[str, Any] = {
315            "model_name": self.model_name,
316            "old_name": self.old_name,
317            "new_name": self.new_name,
318        }
319        return (self.__class__.__name__, (), kwargs)
320
321    def state_forwards(self, package_label: str, state: Any) -> None:
322        state.rename_field(
323            package_label, self.model_name_lower, self.old_name, self.new_name
324        )
325
326    def database_forwards(
327        self,
328        package_label: str,
329        schema_editor: BaseDatabaseSchemaEditor,
330        from_state: ProjectState,
331        to_state: ProjectState,
332    ) -> None:
333        to_model = to_state.models_registry.get_model(package_label, self.model_name)
334        if self.allow_migrate_model(schema_editor.connection, to_model):
335            from_model = from_state.models_registry.get_model(
336                package_label, self.model_name
337            )
338            schema_editor.alter_field(
339                from_model,
340                from_model._model_meta.get_field(self.old_name),
341                to_model._model_meta.get_field(self.new_name),
342            )
343
344    def describe(self) -> str:
345        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
346
347    @property
348    def migration_name_fragment(self) -> str:
349        return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
350
351    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
352        return self.references_model(model_name, package_label) and (
353            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
354        )
355
356    def reduce(
357        self, operation: Operation, package_label: str
358    ) -> list[Operation] | bool:
359        if (
360            isinstance(operation, RenameField)
361            and self.is_same_model_operation(operation)
362            and self.new_name_lower == operation.old_name_lower
363        ):
364            return [
365                RenameField(
366                    self.model_name,
367                    self.old_name,
368                    operation.new_name,
369                ),
370            ]
371        # Skip `FieldOperation.reduce` as we want to run `references_field`
372        # against self.old_name and self.new_name.
373        return super(FieldOperation, self).reduce(operation, package_label) or not (
374            operation.references_field(self.model_name, self.old_name, package_label)
375            or operation.references_field(self.model_name, self.new_name, package_label)
376        )