1from __future__ import annotations
2
3from functools import cached_property
4from typing import TYPE_CHECKING, Any
5
6from plain.postgres.fields import NOT_PROVIDED
7from plain.postgres.migrations.utils import field_references
8
9from .base import Operation
10
11if TYPE_CHECKING:
12 from plain.postgres.fields import Field
13 from plain.postgres.migrations.state import ProjectState
14 from plain.postgres.schema import DatabaseSchemaEditor
15
16
17class FieldOperation(Operation):
18 def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
19 self.model_name = model_name
20 self.name = name
21 self.field = field
22
23 @cached_property
24 def model_name_lower(self) -> str:
25 return self.model_name.lower()
26
27 @cached_property
28 def name_lower(self) -> str:
29 return self.name.lower()
30
31 def is_same_model_operation(self, operation: FieldOperation) -> bool:
32 return self.model_name_lower == operation.model_name_lower
33
34 def is_same_field_operation(self, operation: FieldOperation) -> bool:
35 return (
36 self.is_same_model_operation(operation)
37 and self.name_lower == operation.name_lower
38 )
39
40 def references_model(self, name: str, package_label: str) -> bool:
41 name_lower = name.lower()
42 if name_lower == self.model_name_lower:
43 return True
44 if self.field:
45 return bool(
46 field_references(
47 (package_label, self.model_name_lower),
48 self.field,
49 (package_label, name_lower),
50 )
51 )
52 return False
53
54 def references_field(self, model_name: str, name: str, package_label: str) -> bool:
55 model_name_lower = model_name.lower()
56 # Check if this operation locally references the field.
57 if model_name_lower == self.model_name_lower:
58 if name == self.name:
59 return True
60 # Check if this operation remotely references the field.
61 if self.field is None:
62 return False
63 return bool(
64 field_references(
65 (package_label, self.model_name_lower),
66 self.field,
67 (package_label, model_name_lower),
68 name,
69 )
70 )
71
72 def reduce(
73 self, operation: Operation, package_label: str
74 ) -> list[Operation] | bool:
75 return super().reduce(
76 operation, package_label
77 ) or not operation.references_field(self.model_name, self.name, package_label)
78
79
80class AddField(FieldOperation):
81 """Add a field to a model."""
82
83 def __init__(
84 self, model_name: str, name: str, field: Field, preserve_default: bool = True
85 ) -> None:
86 self.preserve_default = preserve_default
87 super().__init__(model_name, name, field)
88
89 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
90 kwargs: dict[str, Any] = {
91 "model_name": self.model_name,
92 "name": self.name,
93 "field": self.field,
94 }
95 if self.preserve_default is not True:
96 kwargs["preserve_default"] = self.preserve_default
97 return (self.__class__.__name__, (), kwargs)
98
99 def state_forwards(self, package_label: str, state: Any) -> None:
100 state.add_field(
101 package_label,
102 self.model_name_lower,
103 self.name,
104 self.field,
105 self.preserve_default,
106 )
107
108 def database_forwards(
109 self,
110 package_label: str,
111 schema_editor: DatabaseSchemaEditor,
112 from_state: ProjectState,
113 to_state: ProjectState,
114 ) -> None:
115 to_model = to_state.models_registry.get_model(package_label, self.model_name)
116 from_model = from_state.models_registry.get_model(
117 package_label, self.model_name
118 )
119 field = to_model._model_meta.get_forward_field(self.name)
120 assert self.field is not None
121 if not self.preserve_default:
122 field.default = self.field.default
123 schema_editor.add_field(
124 from_model,
125 field,
126 )
127 if not self.preserve_default:
128 field.default = NOT_PROVIDED
129
130 def describe(self) -> str:
131 return f"Add field {self.name} to {self.model_name}"
132
133 @property
134 def migration_name_fragment(self) -> str:
135 return f"{self.model_name_lower}_{self.name_lower}"
136
137 def reduce(
138 self, operation: Operation, package_label: str
139 ) -> list[Operation] | bool:
140 if isinstance(operation, FieldOperation) and self.is_same_field_operation(
141 operation
142 ):
143 if isinstance(operation, AlterField):
144 assert operation.field is not None
145 return [
146 AddField(
147 model_name=self.model_name,
148 name=operation.name,
149 field=operation.field,
150 ),
151 ]
152 elif isinstance(operation, RemoveField):
153 return []
154 elif isinstance(operation, RenameField):
155 assert self.field is not None # AddField always has a field
156 return [
157 AddField(
158 model_name=self.model_name,
159 name=operation.new_name,
160 field=self.field,
161 ),
162 ]
163 return super().reduce(operation, package_label)
164
165
166class RemoveField(FieldOperation):
167 """Remove a field from a model."""
168
169 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
170 kwargs: dict[str, Any] = {
171 "model_name": self.model_name,
172 "name": self.name,
173 }
174 return (self.__class__.__name__, (), kwargs)
175
176 def state_forwards(self, package_label: str, state: Any) -> None:
177 state.remove_field(package_label, self.model_name_lower, self.name)
178
179 def database_forwards(
180 self,
181 package_label: str,
182 schema_editor: DatabaseSchemaEditor,
183 from_state: ProjectState,
184 to_state: ProjectState,
185 ) -> None:
186 from_model = from_state.models_registry.get_model(
187 package_label, self.model_name
188 )
189 schema_editor.remove_field(
190 from_model,
191 from_model._model_meta.get_forward_field(self.name),
192 )
193
194 def describe(self) -> str:
195 return f"Remove field {self.name} from {self.model_name}"
196
197 @property
198 def migration_name_fragment(self) -> str:
199 return f"remove_{self.model_name_lower}_{self.name_lower}"
200
201 def reduce(
202 self, operation: Operation, package_label: str
203 ) -> list[Operation] | bool:
204 from .models import DeleteModel
205
206 if (
207 isinstance(operation, DeleteModel)
208 and operation.name_lower == self.model_name_lower
209 ):
210 return [operation]
211 return super().reduce(operation, package_label)
212
213
214class AlterField(FieldOperation):
215 """
216 Alter a field's database column (e.g. null, max_length) to the provided
217 new field.
218 """
219
220 def __init__(
221 self, model_name: str, name: str, field: Field, preserve_default: bool = True
222 ) -> None:
223 self.preserve_default = preserve_default
224 super().__init__(model_name, name, field)
225
226 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
227 kwargs: dict[str, Any] = {
228 "model_name": self.model_name,
229 "name": self.name,
230 "field": self.field,
231 }
232 if self.preserve_default is not True:
233 kwargs["preserve_default"] = self.preserve_default
234 return (self.__class__.__name__, (), kwargs)
235
236 def state_forwards(self, package_label: str, state: Any) -> None:
237 state.alter_field(
238 package_label,
239 self.model_name_lower,
240 self.name,
241 self.field,
242 self.preserve_default,
243 )
244
245 def database_forwards(
246 self,
247 package_label: str,
248 schema_editor: DatabaseSchemaEditor,
249 from_state: ProjectState,
250 to_state: ProjectState,
251 ) -> None:
252 to_model = to_state.models_registry.get_model(package_label, self.model_name)
253 from_model = from_state.models_registry.get_model(
254 package_label, self.model_name
255 )
256 from_field = from_model._model_meta.get_forward_field(self.name)
257 to_field = to_model._model_meta.get_forward_field(self.name)
258 assert self.field is not None
259 if not self.preserve_default:
260 to_field.default = self.field.default
261 schema_editor.alter_field(from_model, from_field, to_field)
262 if not self.preserve_default:
263 to_field.default = NOT_PROVIDED
264
265 def describe(self) -> str:
266 return f"Alter field {self.name} on {self.model_name}"
267
268 @property
269 def migration_name_fragment(self) -> str:
270 return f"alter_{self.model_name_lower}_{self.name_lower}"
271
272 def reduce(
273 self, operation: Operation, package_label: str
274 ) -> list[Operation] | bool:
275 if isinstance(
276 operation, AlterField | RemoveField
277 ) and self.is_same_field_operation(operation):
278 return [operation]
279 elif (
280 isinstance(operation, RenameField)
281 and self.is_same_field_operation(operation)
282 and self.field is not None
283 ):
284 return [
285 operation,
286 AlterField(
287 model_name=self.model_name,
288 name=operation.new_name,
289 field=self.field,
290 ),
291 ]
292 return super().reduce(operation, package_label)
293
294
295class RenameField(FieldOperation):
296 """Rename a field on the model."""
297
298 def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
299 self.old_name = old_name
300 self.new_name = new_name
301 super().__init__(model_name, old_name)
302
303 @cached_property
304 def old_name_lower(self) -> str:
305 return self.old_name.lower()
306
307 @cached_property
308 def new_name_lower(self) -> str:
309 return self.new_name.lower()
310
311 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
312 kwargs: dict[str, Any] = {
313 "model_name": self.model_name,
314 "old_name": self.old_name,
315 "new_name": self.new_name,
316 }
317 return (self.__class__.__name__, (), kwargs)
318
319 def state_forwards(self, package_label: str, state: Any) -> None:
320 state.rename_field(
321 package_label, self.model_name_lower, self.old_name, self.new_name
322 )
323
324 def database_forwards(
325 self,
326 package_label: str,
327 schema_editor: DatabaseSchemaEditor,
328 from_state: ProjectState,
329 to_state: ProjectState,
330 ) -> None:
331 to_model = to_state.models_registry.get_model(package_label, self.model_name)
332 from_model = from_state.models_registry.get_model(
333 package_label, self.model_name
334 )
335 schema_editor.alter_field(
336 from_model,
337 from_model._model_meta.get_forward_field(self.old_name),
338 to_model._model_meta.get_forward_field(self.new_name),
339 )
340
341 def describe(self) -> str:
342 return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
343
344 @property
345 def migration_name_fragment(self) -> str:
346 return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
347
348 def references_field(self, model_name: str, name: str, package_label: str) -> bool:
349 return self.references_model(model_name, package_label) and (
350 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
351 )
352
353 def reduce(
354 self, operation: Operation, package_label: str
355 ) -> list[Operation] | bool:
356 if (
357 isinstance(operation, RenameField)
358 and self.is_same_model_operation(operation)
359 and self.new_name_lower == operation.old_name_lower
360 ):
361 return [
362 RenameField(
363 self.model_name,
364 self.old_name,
365 operation.new_name,
366 ),
367 ]
368 # Skip `FieldOperation.reduce` as we want to run `references_field`
369 # against self.old_name and self.new_name.
370 return super(FieldOperation, self).reduce(operation, package_label) or not (
371 operation.references_field(self.model_name, self.old_name, package_label)
372 or operation.references_field(self.model_name, self.new_name, package_label)
373 )