1from __future__ import annotations
2
3from functools import cached_property
4from typing import TYPE_CHECKING, Any
5
6from plain.models.fields import NOT_PROVIDED
7from plain.models.migrations.utils import field_references
8
9from .base import Operation
10
11if TYPE_CHECKING:
12 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
13 from plain.models.fields import Field
14 from plain.models.migrations.state import ProjectState
15
16
17class FieldOperation(Operation):
18 def __init__(self, model_name: str, name: str, field: Field | None = None) -> None:
19 self.model_name = model_name
20 self.name = name
21 self.field = field
22
23 @cached_property
24 def model_name_lower(self) -> str:
25 return self.model_name.lower()
26
27 @cached_property
28 def name_lower(self) -> str:
29 return self.name.lower()
30
31 def is_same_model_operation(self, operation: FieldOperation) -> bool:
32 return self.model_name_lower == operation.model_name_lower
33
34 def is_same_field_operation(self, operation: FieldOperation) -> bool:
35 return (
36 self.is_same_model_operation(operation)
37 and self.name_lower == operation.name_lower
38 )
39
40 def references_model(self, name: str, package_label: str) -> bool:
41 name_lower = name.lower()
42 if name_lower == self.model_name_lower:
43 return True
44 if self.field:
45 return bool(
46 field_references(
47 (package_label, self.model_name_lower),
48 self.field,
49 (package_label, name_lower),
50 )
51 )
52 return False
53
54 def references_field(self, model_name: str, name: str, package_label: str) -> bool:
55 model_name_lower = model_name.lower()
56 # Check if this operation locally references the field.
57 if model_name_lower == self.model_name_lower:
58 if name == self.name:
59 return True
60 # Check if this operation remotely references the field.
61 if self.field is None:
62 return False
63 return bool(
64 field_references(
65 (package_label, self.model_name_lower),
66 self.field,
67 (package_label, model_name_lower),
68 name,
69 )
70 )
71
72 def reduce(
73 self, operation: Operation, package_label: str
74 ) -> list[Operation] | bool:
75 return super().reduce(
76 operation, package_label
77 ) or not operation.references_field(self.model_name, self.name, package_label)
78
79
80class AddField(FieldOperation):
81 """Add a field to a model."""
82
83 def __init__(
84 self, model_name: str, name: str, field: Field, preserve_default: bool = True
85 ) -> None:
86 self.preserve_default = preserve_default
87 super().__init__(model_name, name, field)
88
89 def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
90 kwargs: dict[str, Any] = {
91 "model_name": self.model_name,
92 "name": self.name,
93 "field": self.field,
94 }
95 if self.preserve_default is not True:
96 kwargs["preserve_default"] = self.preserve_default
97 return (self.__class__.__name__, [], kwargs)
98
99 def state_forwards(self, package_label: str, state: Any) -> None:
100 state.add_field(
101 package_label,
102 self.model_name_lower,
103 self.name,
104 self.field,
105 self.preserve_default,
106 )
107
108 def database_forwards(
109 self,
110 package_label: str,
111 schema_editor: BaseDatabaseSchemaEditor,
112 from_state: ProjectState,
113 to_state: ProjectState,
114 ) -> None:
115 to_model = to_state.models_registry.get_model(package_label, self.model_name)
116 if self.allow_migrate_model(schema_editor.connection, to_model):
117 from_model = from_state.models_registry.get_model(
118 package_label, self.model_name
119 )
120 field = to_model._model_meta.get_field(self.name)
121 assert self.field is not None
122 if not self.preserve_default:
123 field.default = self.field.default
124 schema_editor.add_field(
125 from_model,
126 field,
127 )
128 if not self.preserve_default:
129 field.default = NOT_PROVIDED
130
131 def describe(self) -> str:
132 return f"Add field {self.name} to {self.model_name}"
133
134 @property
135 def migration_name_fragment(self) -> str:
136 return f"{self.model_name_lower}_{self.name_lower}"
137
138 def reduce(
139 self, operation: Operation, package_label: str
140 ) -> list[Operation] | bool:
141 if isinstance(operation, FieldOperation) and self.is_same_field_operation(
142 operation
143 ):
144 if isinstance(operation, AlterField):
145 assert operation.field is not None
146 return [
147 AddField(
148 model_name=self.model_name,
149 name=operation.name,
150 field=operation.field,
151 ),
152 ]
153 elif isinstance(operation, RemoveField):
154 return []
155 elif isinstance(operation, RenameField):
156 assert self.field is not None # AddField always has a field
157 return [
158 AddField(
159 model_name=self.model_name,
160 name=operation.new_name,
161 field=self.field,
162 ),
163 ]
164 return super().reduce(operation, package_label)
165
166
167class RemoveField(FieldOperation):
168 """Remove a field from a model."""
169
170 def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
171 kwargs: dict[str, Any] = {
172 "model_name": self.model_name,
173 "name": self.name,
174 }
175 return (self.__class__.__name__, [], kwargs)
176
177 def state_forwards(self, package_label: str, state: Any) -> None:
178 state.remove_field(package_label, self.model_name_lower, self.name)
179
180 def database_forwards(
181 self,
182 package_label: str,
183 schema_editor: BaseDatabaseSchemaEditor,
184 from_state: ProjectState,
185 to_state: ProjectState,
186 ) -> None:
187 from_model = from_state.models_registry.get_model(
188 package_label, self.model_name
189 )
190 if self.allow_migrate_model(schema_editor.connection, from_model):
191 schema_editor.remove_field(
192 from_model, from_model._model_meta.get_field(self.name)
193 )
194
195 def describe(self) -> str:
196 return f"Remove field {self.name} from {self.model_name}"
197
198 @property
199 def migration_name_fragment(self) -> str:
200 return f"remove_{self.model_name_lower}_{self.name_lower}"
201
202 def reduce(
203 self, operation: Operation, package_label: str
204 ) -> list[Operation] | bool:
205 from .models import DeleteModel
206
207 if (
208 isinstance(operation, DeleteModel)
209 and operation.name_lower == self.model_name_lower
210 ):
211 return [operation]
212 return super().reduce(operation, package_label)
213
214
215class AlterField(FieldOperation):
216 """
217 Alter a field's database column (e.g. null, max_length) to the provided
218 new field.
219 """
220
221 def __init__(
222 self, model_name: str, name: str, field: Field, preserve_default: bool = True
223 ) -> None:
224 self.preserve_default = preserve_default
225 super().__init__(model_name, name, field)
226
227 def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
228 kwargs: dict[str, Any] = {
229 "model_name": self.model_name,
230 "name": self.name,
231 "field": self.field,
232 }
233 if self.preserve_default is not True:
234 kwargs["preserve_default"] = self.preserve_default
235 return (self.__class__.__name__, [], kwargs)
236
237 def state_forwards(self, package_label: str, state: Any) -> None:
238 state.alter_field(
239 package_label,
240 self.model_name_lower,
241 self.name,
242 self.field,
243 self.preserve_default,
244 )
245
246 def database_forwards(
247 self,
248 package_label: str,
249 schema_editor: BaseDatabaseSchemaEditor,
250 from_state: ProjectState,
251 to_state: ProjectState,
252 ) -> None:
253 to_model = to_state.models_registry.get_model(package_label, self.model_name)
254 if self.allow_migrate_model(schema_editor.connection, to_model):
255 from_model = from_state.models_registry.get_model(
256 package_label, self.model_name
257 )
258 from_field = from_model._model_meta.get_field(self.name)
259 to_field = to_model._model_meta.get_field(self.name)
260 assert self.field is not None
261 if not self.preserve_default:
262 to_field.default = self.field.default
263 schema_editor.alter_field(from_model, from_field, to_field)
264 if not self.preserve_default:
265 to_field.default = NOT_PROVIDED
266
267 def describe(self) -> str:
268 return f"Alter field {self.name} on {self.model_name}"
269
270 @property
271 def migration_name_fragment(self) -> str:
272 return f"alter_{self.model_name_lower}_{self.name_lower}"
273
274 def reduce(
275 self, operation: Operation, package_label: str
276 ) -> list[Operation] | bool:
277 if isinstance(
278 operation, AlterField | RemoveField
279 ) and self.is_same_field_operation(operation):
280 return [operation]
281 elif (
282 isinstance(operation, RenameField)
283 and self.is_same_field_operation(operation)
284 and self.field is not None
285 and self.field.db_column is None
286 ):
287 return [
288 operation,
289 AlterField(
290 model_name=self.model_name,
291 name=operation.new_name,
292 field=self.field,
293 ),
294 ]
295 return super().reduce(operation, package_label)
296
297
298class RenameField(FieldOperation):
299 """Rename a field on the model. Might affect db_column too."""
300
301 def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
302 self.old_name = old_name
303 self.new_name = new_name
304 super().__init__(model_name, old_name)
305
306 @cached_property
307 def old_name_lower(self) -> str:
308 return self.old_name.lower()
309
310 @cached_property
311 def new_name_lower(self) -> str:
312 return self.new_name.lower()
313
314 def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
315 kwargs: dict[str, Any] = {
316 "model_name": self.model_name,
317 "old_name": self.old_name,
318 "new_name": self.new_name,
319 }
320 return (self.__class__.__name__, [], kwargs)
321
322 def state_forwards(self, package_label: str, state: Any) -> None:
323 state.rename_field(
324 package_label, self.model_name_lower, self.old_name, self.new_name
325 )
326
327 def database_forwards(
328 self,
329 package_label: str,
330 schema_editor: BaseDatabaseSchemaEditor,
331 from_state: ProjectState,
332 to_state: ProjectState,
333 ) -> None:
334 to_model = to_state.models_registry.get_model(package_label, self.model_name)
335 if self.allow_migrate_model(schema_editor.connection, to_model):
336 from_model = from_state.models_registry.get_model(
337 package_label, self.model_name
338 )
339 schema_editor.alter_field(
340 from_model,
341 from_model._model_meta.get_field(self.old_name),
342 to_model._model_meta.get_field(self.new_name),
343 )
344
345 def describe(self) -> str:
346 return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
347
348 @property
349 def migration_name_fragment(self) -> str:
350 return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
351
352 def references_field(self, model_name: str, name: str, package_label: str) -> bool:
353 return self.references_model(model_name, package_label) and (
354 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
355 )
356
357 def reduce(
358 self, operation: Operation, package_label: str
359 ) -> list[Operation] | bool:
360 if (
361 isinstance(operation, RenameField)
362 and self.is_same_model_operation(operation)
363 and self.new_name_lower == operation.old_name_lower
364 ):
365 return [
366 RenameField(
367 self.model_name,
368 self.old_name,
369 operation.new_name,
370 ),
371 ]
372 # Skip `FieldOperation.reduce` as we want to run `references_field`
373 # against self.old_name and self.new_name.
374 return super(FieldOperation, self).reduce(operation, package_label) or not (
375 operation.references_field(self.model_name, self.old_name, package_label)
376 or operation.references_field(self.model_name, self.new_name, package_label)
377 )