1from __future__ import annotations
2
3from functools import cached_property
4from typing import TYPE_CHECKING, Any
5
6from plain.models.fields import NOT_PROVIDED
7from plain.models.migrations.utils import field_references
8
9from .base import Operation
10
11if TYPE_CHECKING:
12 from plain.models.fields import Field
13 from plain.models.migrations.state import ProjectState
14 from plain.models.postgres.schema import DatabaseSchemaEditor
15
16
17class FieldOperation(Operation):
18 def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
19 self.model_name = model_name
20 self.name = name
21 self.field = field
22
23 @cached_property
24 def model_name_lower(self) -> str:
25 return self.model_name.lower()
26
27 @cached_property
28 def name_lower(self) -> str:
29 return self.name.lower()
30
31 def is_same_model_operation(self, operation: FieldOperation) -> bool:
32 return self.model_name_lower == operation.model_name_lower
33
34 def is_same_field_operation(self, operation: FieldOperation) -> bool:
35 return (
36 self.is_same_model_operation(operation)
37 and self.name_lower == operation.name_lower
38 )
39
40 def references_model(self, name: str, package_label: str) -> bool:
41 name_lower = name.lower()
42 if name_lower == self.model_name_lower:
43 return True
44 if self.field:
45 return bool(
46 field_references(
47 (package_label, self.model_name_lower),
48 self.field,
49 (package_label, name_lower),
50 )
51 )
52 return False
53
54 def references_field(self, model_name: str, name: str, package_label: str) -> bool:
55 model_name_lower = model_name.lower()
56 # Check if this operation locally references the field.
57 if model_name_lower == self.model_name_lower:
58 if name == self.name:
59 return True
60 # Check if this operation remotely references the field.
61 if self.field is None:
62 return False
63 return bool(
64 field_references(
65 (package_label, self.model_name_lower),
66 self.field,
67 (package_label, model_name_lower),
68 name,
69 )
70 )
71
72 def reduce(
73 self, operation: Operation, package_label: str
74 ) -> list[Operation] | bool:
75 return super().reduce(
76 operation, package_label
77 ) or not operation.references_field(self.model_name, self.name, package_label)
78
79
80class AddField(FieldOperation):
81 """Add a field to a model."""
82
83 def __init__(
84 self, model_name: str, name: str, field: Field, preserve_default: bool = True
85 ) -> None:
86 self.preserve_default = preserve_default
87 super().__init__(model_name, name, field)
88
89 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
90 kwargs: dict[str, Any] = {
91 "model_name": self.model_name,
92 "name": self.name,
93 "field": self.field,
94 }
95 if self.preserve_default is not True:
96 kwargs["preserve_default"] = self.preserve_default
97 return (self.__class__.__name__, (), kwargs)
98
99 def state_forwards(self, package_label: str, state: Any) -> None:
100 state.add_field(
101 package_label,
102 self.model_name_lower,
103 self.name,
104 self.field,
105 self.preserve_default,
106 )
107
108 def database_forwards(
109 self,
110 package_label: str,
111 schema_editor: DatabaseSchemaEditor,
112 from_state: ProjectState,
113 to_state: ProjectState,
114 ) -> None:
115 to_model = to_state.models_registry.get_model(package_label, self.model_name)
116 from_model = from_state.models_registry.get_model(
117 package_label, self.model_name
118 )
119 field = to_model._model_meta.get_field(self.name)
120 assert self.field is not None
121 if not self.preserve_default:
122 field.default = self.field.default
123 schema_editor.add_field(
124 from_model,
125 field,
126 )
127 if not self.preserve_default:
128 field.default = NOT_PROVIDED
129
130 def describe(self) -> str:
131 return f"Add field {self.name} to {self.model_name}"
132
133 @property
134 def migration_name_fragment(self) -> str:
135 return f"{self.model_name_lower}_{self.name_lower}"
136
137 def reduce(
138 self, operation: Operation, package_label: str
139 ) -> list[Operation] | bool:
140 if isinstance(operation, FieldOperation) and self.is_same_field_operation(
141 operation
142 ):
143 if isinstance(operation, AlterField):
144 assert operation.field is not None
145 return [
146 AddField(
147 model_name=self.model_name,
148 name=operation.name,
149 field=operation.field,
150 ),
151 ]
152 elif isinstance(operation, RemoveField):
153 return []
154 elif isinstance(operation, RenameField):
155 assert self.field is not None # AddField always has a field
156 return [
157 AddField(
158 model_name=self.model_name,
159 name=operation.new_name,
160 field=self.field,
161 ),
162 ]
163 return super().reduce(operation, package_label)
164
165
166class RemoveField(FieldOperation):
167 """Remove a field from a model."""
168
169 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
170 kwargs: dict[str, Any] = {
171 "model_name": self.model_name,
172 "name": self.name,
173 }
174 return (self.__class__.__name__, (), kwargs)
175
176 def state_forwards(self, package_label: str, state: Any) -> None:
177 state.remove_field(package_label, self.model_name_lower, self.name)
178
179 def database_forwards(
180 self,
181 package_label: str,
182 schema_editor: DatabaseSchemaEditor,
183 from_state: ProjectState,
184 to_state: ProjectState,
185 ) -> None:
186 from_model = from_state.models_registry.get_model(
187 package_label, self.model_name
188 )
189 schema_editor.remove_field(
190 from_model, from_model._model_meta.get_field(self.name)
191 )
192
193 def describe(self) -> str:
194 return f"Remove field {self.name} from {self.model_name}"
195
196 @property
197 def migration_name_fragment(self) -> str:
198 return f"remove_{self.model_name_lower}_{self.name_lower}"
199
200 def reduce(
201 self, operation: Operation, package_label: str
202 ) -> list[Operation] | bool:
203 from .models import DeleteModel
204
205 if (
206 isinstance(operation, DeleteModel)
207 and operation.name_lower == self.model_name_lower
208 ):
209 return [operation]
210 return super().reduce(operation, package_label)
211
212
213class AlterField(FieldOperation):
214 """
215 Alter a field's database column (e.g. null, max_length) to the provided
216 new field.
217 """
218
219 def __init__(
220 self, model_name: str, name: str, field: Field, preserve_default: bool = True
221 ) -> None:
222 self.preserve_default = preserve_default
223 super().__init__(model_name, name, field)
224
225 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
226 kwargs: dict[str, Any] = {
227 "model_name": self.model_name,
228 "name": self.name,
229 "field": self.field,
230 }
231 if self.preserve_default is not True:
232 kwargs["preserve_default"] = self.preserve_default
233 return (self.__class__.__name__, (), kwargs)
234
235 def state_forwards(self, package_label: str, state: Any) -> None:
236 state.alter_field(
237 package_label,
238 self.model_name_lower,
239 self.name,
240 self.field,
241 self.preserve_default,
242 )
243
244 def database_forwards(
245 self,
246 package_label: str,
247 schema_editor: DatabaseSchemaEditor,
248 from_state: ProjectState,
249 to_state: ProjectState,
250 ) -> None:
251 to_model = to_state.models_registry.get_model(package_label, self.model_name)
252 from_model = from_state.models_registry.get_model(
253 package_label, self.model_name
254 )
255 from_field = from_model._model_meta.get_field(self.name)
256 to_field = to_model._model_meta.get_field(self.name)
257 assert self.field is not None
258 if not self.preserve_default:
259 to_field.default = self.field.default
260 schema_editor.alter_field(from_model, from_field, to_field)
261 if not self.preserve_default:
262 to_field.default = NOT_PROVIDED
263
264 def describe(self) -> str:
265 return f"Alter field {self.name} on {self.model_name}"
266
267 @property
268 def migration_name_fragment(self) -> str:
269 return f"alter_{self.model_name_lower}_{self.name_lower}"
270
271 def reduce(
272 self, operation: Operation, package_label: str
273 ) -> list[Operation] | bool:
274 if isinstance(
275 operation, AlterField | RemoveField
276 ) and self.is_same_field_operation(operation):
277 return [operation]
278 elif (
279 isinstance(operation, RenameField)
280 and self.is_same_field_operation(operation)
281 and self.field is not None
282 ):
283 return [
284 operation,
285 AlterField(
286 model_name=self.model_name,
287 name=operation.new_name,
288 field=self.field,
289 ),
290 ]
291 return super().reduce(operation, package_label)
292
293
294class RenameField(FieldOperation):
295 """Rename a field on the model."""
296
297 def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
298 self.old_name = old_name
299 self.new_name = new_name
300 super().__init__(model_name, old_name)
301
302 @cached_property
303 def old_name_lower(self) -> str:
304 return self.old_name.lower()
305
306 @cached_property
307 def new_name_lower(self) -> str:
308 return self.new_name.lower()
309
310 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
311 kwargs: dict[str, Any] = {
312 "model_name": self.model_name,
313 "old_name": self.old_name,
314 "new_name": self.new_name,
315 }
316 return (self.__class__.__name__, (), kwargs)
317
318 def state_forwards(self, package_label: str, state: Any) -> None:
319 state.rename_field(
320 package_label, self.model_name_lower, self.old_name, self.new_name
321 )
322
323 def database_forwards(
324 self,
325 package_label: str,
326 schema_editor: DatabaseSchemaEditor,
327 from_state: ProjectState,
328 to_state: ProjectState,
329 ) -> None:
330 to_model = to_state.models_registry.get_model(package_label, self.model_name)
331 from_model = from_state.models_registry.get_model(
332 package_label, self.model_name
333 )
334 schema_editor.alter_field(
335 from_model,
336 from_model._model_meta.get_field(self.old_name),
337 to_model._model_meta.get_field(self.new_name),
338 )
339
340 def describe(self) -> str:
341 return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
342
343 @property
344 def migration_name_fragment(self) -> str:
345 return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
346
347 def references_field(self, model_name: str, name: str, package_label: str) -> bool:
348 return self.references_model(model_name, package_label) and (
349 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
350 )
351
352 def reduce(
353 self, operation: Operation, package_label: str
354 ) -> list[Operation] | bool:
355 if (
356 isinstance(operation, RenameField)
357 and self.is_same_model_operation(operation)
358 and self.new_name_lower == operation.old_name_lower
359 ):
360 return [
361 RenameField(
362 self.model_name,
363 self.old_name,
364 operation.new_name,
365 ),
366 ]
367 # Skip `FieldOperation.reduce` as we want to run `references_field`
368 # against self.old_name and self.new_name.
369 return super(FieldOperation, self).reduce(operation, package_label) or not (
370 operation.references_field(self.model_name, self.old_name, package_label)
371 or operation.references_field(self.model_name, self.new_name, package_label)
372 )