1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.postgres.fields import NOT_PROVIDED
  7from plain.postgres.migrations.utils import field_references
  8
  9from .base import Operation
 10
 11if TYPE_CHECKING:
 12    from plain.postgres.fields import Field
 13    from plain.postgres.migrations.state import ProjectState
 14    from plain.postgres.schema import DatabaseSchemaEditor
 15
 16
 17class FieldOperation(Operation):
 18    def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
 19        self.model_name = model_name
 20        self.name = name
 21        self.field = field
 22
 23    @cached_property
 24    def model_name_lower(self) -> str:
 25        return self.model_name.lower()
 26
 27    @cached_property
 28    def name_lower(self) -> str:
 29        return self.name.lower()
 30
 31    def is_same_model_operation(self, operation: FieldOperation) -> bool:
 32        return self.model_name_lower == operation.model_name_lower
 33
 34    def is_same_field_operation(self, operation: FieldOperation) -> bool:
 35        return (
 36            self.is_same_model_operation(operation)
 37            and self.name_lower == operation.name_lower
 38        )
 39
 40    def references_model(self, name: str, package_label: str) -> bool:
 41        name_lower = name.lower()
 42        if name_lower == self.model_name_lower:
 43            return True
 44        if self.field:
 45            return bool(
 46                field_references(
 47                    (package_label, self.model_name_lower),
 48                    self.field,
 49                    (package_label, name_lower),
 50                )
 51            )
 52        return False
 53
 54    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
 55        model_name_lower = model_name.lower()
 56        # Check if this operation locally references the field.
 57        if model_name_lower == self.model_name_lower:
 58            if name == self.name:
 59                return True
 60        # Check if this operation remotely references the field.
 61        if self.field is None:
 62            return False
 63        return bool(
 64            field_references(
 65                (package_label, self.model_name_lower),
 66                self.field,
 67                (package_label, model_name_lower),
 68                name,
 69            )
 70        )
 71
 72    def reduce(
 73        self, operation: Operation, package_label: str
 74    ) -> list[Operation] | bool:
 75        return super().reduce(
 76            operation, package_label
 77        ) or not operation.references_field(self.model_name, self.name, package_label)
 78
 79
 80class AddField(FieldOperation):
 81    """Add a field to a model."""
 82
 83    def __init__(
 84        self, model_name: str, name: str, field: Field, preserve_default: bool = True
 85    ) -> None:
 86        self.preserve_default = preserve_default
 87        super().__init__(model_name, name, field)
 88
 89    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 90        kwargs: dict[str, Any] = {
 91            "model_name": self.model_name,
 92            "name": self.name,
 93            "field": self.field,
 94        }
 95        if self.preserve_default is not True:
 96            kwargs["preserve_default"] = self.preserve_default
 97        return (self.__class__.__name__, (), kwargs)
 98
 99    def state_forwards(self, package_label: str, state: Any) -> None:
100        state.add_field(
101            package_label,
102            self.model_name_lower,
103            self.name,
104            self.field,
105            self.preserve_default,
106        )
107
108    def database_forwards(
109        self,
110        package_label: str,
111        schema_editor: DatabaseSchemaEditor,
112        from_state: ProjectState,
113        to_state: ProjectState,
114    ) -> None:
115        to_model = to_state.models_registry.get_model(package_label, self.model_name)
116        from_model = from_state.models_registry.get_model(
117            package_label, self.model_name
118        )
119        field = to_model._model_meta.get_forward_field(self.name)
120        assert self.field is not None
121        if not self.preserve_default:
122            field.default = self.field.default
123        schema_editor.add_field(
124            from_model,
125            field,
126        )
127        if not self.preserve_default:
128            field.default = NOT_PROVIDED
129
130    def describe(self) -> str:
131        return f"Add field {self.name} to {self.model_name}"
132
133    @property
134    def migration_name_fragment(self) -> str:
135        return f"{self.model_name_lower}_{self.name_lower}"
136
137    def reduce(
138        self, operation: Operation, package_label: str
139    ) -> list[Operation] | bool:
140        if isinstance(operation, FieldOperation) and self.is_same_field_operation(
141            operation
142        ):
143            if isinstance(operation, AlterField):
144                assert operation.field is not None
145                return [
146                    AddField(
147                        model_name=self.model_name,
148                        name=operation.name,
149                        field=operation.field,
150                    ),
151                ]
152            elif isinstance(operation, RemoveField):
153                return []
154            elif isinstance(operation, RenameField):
155                assert self.field is not None  # AddField always has a field
156                return [
157                    AddField(
158                        model_name=self.model_name,
159                        name=operation.new_name,
160                        field=self.field,
161                    ),
162                ]
163        return super().reduce(operation, package_label)
164
165
166class RemoveField(FieldOperation):
167    """Remove a field from a model."""
168
169    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
170        kwargs: dict[str, Any] = {
171            "model_name": self.model_name,
172            "name": self.name,
173        }
174        return (self.__class__.__name__, (), kwargs)
175
176    def state_forwards(self, package_label: str, state: Any) -> None:
177        state.remove_field(package_label, self.model_name_lower, self.name)
178
179    def database_forwards(
180        self,
181        package_label: str,
182        schema_editor: DatabaseSchemaEditor,
183        from_state: ProjectState,
184        to_state: ProjectState,
185    ) -> None:
186        from_model = from_state.models_registry.get_model(
187            package_label, self.model_name
188        )
189        schema_editor.remove_field(
190            from_model,
191            from_model._model_meta.get_forward_field(self.name),
192        )
193
194    def describe(self) -> str:
195        return f"Remove field {self.name} from {self.model_name}"
196
197    @property
198    def migration_name_fragment(self) -> str:
199        return f"remove_{self.model_name_lower}_{self.name_lower}"
200
201    def reduce(
202        self, operation: Operation, package_label: str
203    ) -> list[Operation] | bool:
204        from .models import DeleteModel
205
206        if (
207            isinstance(operation, DeleteModel)
208            and operation.name_lower == self.model_name_lower
209        ):
210            return [operation]
211        return super().reduce(operation, package_label)
212
213
214class AlterField(FieldOperation):
215    """
216    Alter a field's database column (e.g. null, max_length) to the provided
217    new field.
218    """
219
220    def __init__(
221        self, model_name: str, name: str, field: Field, preserve_default: bool = True
222    ) -> None:
223        self.preserve_default = preserve_default
224        super().__init__(model_name, name, field)
225
226    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
227        kwargs: dict[str, Any] = {
228            "model_name": self.model_name,
229            "name": self.name,
230            "field": self.field,
231        }
232        if self.preserve_default is not True:
233            kwargs["preserve_default"] = self.preserve_default
234        return (self.__class__.__name__, (), kwargs)
235
236    def state_forwards(self, package_label: str, state: Any) -> None:
237        state.alter_field(
238            package_label,
239            self.model_name_lower,
240            self.name,
241            self.field,
242            self.preserve_default,
243        )
244
245    def database_forwards(
246        self,
247        package_label: str,
248        schema_editor: DatabaseSchemaEditor,
249        from_state: ProjectState,
250        to_state: ProjectState,
251    ) -> None:
252        to_model = to_state.models_registry.get_model(package_label, self.model_name)
253        from_model = from_state.models_registry.get_model(
254            package_label, self.model_name
255        )
256        from_field = from_model._model_meta.get_forward_field(self.name)
257        to_field = to_model._model_meta.get_forward_field(self.name)
258        assert self.field is not None
259        if not self.preserve_default:
260            to_field.default = self.field.default
261        schema_editor.alter_field(from_model, from_field, to_field)
262        if not self.preserve_default:
263            to_field.default = NOT_PROVIDED
264
265    def describe(self) -> str:
266        return f"Alter field {self.name} on {self.model_name}"
267
268    @property
269    def migration_name_fragment(self) -> str:
270        return f"alter_{self.model_name_lower}_{self.name_lower}"
271
272    def reduce(
273        self, operation: Operation, package_label: str
274    ) -> list[Operation] | bool:
275        if isinstance(
276            operation, AlterField | RemoveField
277        ) and self.is_same_field_operation(operation):
278            return [operation]
279        elif (
280            isinstance(operation, RenameField)
281            and self.is_same_field_operation(operation)
282            and self.field is not None
283        ):
284            return [
285                operation,
286                AlterField(
287                    model_name=self.model_name,
288                    name=operation.new_name,
289                    field=self.field,
290                ),
291            ]
292        return super().reduce(operation, package_label)
293
294
295class RenameField(FieldOperation):
296    """Rename a field on the model."""
297
298    def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
299        self.old_name = old_name
300        self.new_name = new_name
301        super().__init__(model_name, old_name)
302
303    @cached_property
304    def old_name_lower(self) -> str:
305        return self.old_name.lower()
306
307    @cached_property
308    def new_name_lower(self) -> str:
309        return self.new_name.lower()
310
311    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
312        kwargs: dict[str, Any] = {
313            "model_name": self.model_name,
314            "old_name": self.old_name,
315            "new_name": self.new_name,
316        }
317        return (self.__class__.__name__, (), kwargs)
318
319    def state_forwards(self, package_label: str, state: Any) -> None:
320        state.rename_field(
321            package_label, self.model_name_lower, self.old_name, self.new_name
322        )
323
324    def database_forwards(
325        self,
326        package_label: str,
327        schema_editor: DatabaseSchemaEditor,
328        from_state: ProjectState,
329        to_state: ProjectState,
330    ) -> None:
331        to_model = to_state.models_registry.get_model(package_label, self.model_name)
332        from_model = from_state.models_registry.get_model(
333            package_label, self.model_name
334        )
335        schema_editor.alter_field(
336            from_model,
337            from_model._model_meta.get_forward_field(self.old_name),
338            to_model._model_meta.get_forward_field(self.new_name),
339        )
340
341    def describe(self) -> str:
342        return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
343
344    @property
345    def migration_name_fragment(self) -> str:
346        return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
347
348    def references_field(self, model_name: str, name: str, package_label: str) -> bool:
349        return self.references_model(model_name, package_label) and (
350            name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
351        )
352
353    def reduce(
354        self, operation: Operation, package_label: str
355    ) -> list[Operation] | bool:
356        if (
357            isinstance(operation, RenameField)
358            and self.is_same_model_operation(operation)
359            and self.new_name_lower == operation.old_name_lower
360        ):
361            return [
362                RenameField(
363                    self.model_name,
364                    self.old_name,
365                    operation.new_name,
366                ),
367            ]
368        # Skip `FieldOperation.reduce` as we want to run `references_field`
369        # against self.old_name and self.new_name.
370        return super(FieldOperation, self).reduce(operation, package_label) or not (
371            operation.references_field(self.model_name, self.old_name, package_label)
372            or operation.references_field(self.model_name, self.new_name, package_label)
373        )