1from functools import cached_property
2
3from plain.models.fields import NOT_PROVIDED
4from plain.models.migrations.utils import field_references
5
6from .base import Operation
7
8
9class FieldOperation(Operation):
10 def __init__(self, model_name, name, field=None):
11 self.model_name = model_name
12 self.name = name
13 self.field = field
14
15 @cached_property
16 def model_name_lower(self):
17 return self.model_name.lower()
18
19 @cached_property
20 def name_lower(self):
21 return self.name.lower()
22
23 def is_same_model_operation(self, operation):
24 return self.model_name_lower == operation.model_name_lower
25
26 def is_same_field_operation(self, operation):
27 return (
28 self.is_same_model_operation(operation)
29 and self.name_lower == operation.name_lower
30 )
31
32 def references_model(self, name, package_label):
33 name_lower = name.lower()
34 if name_lower == self.model_name_lower:
35 return True
36 if self.field:
37 return bool(
38 field_references(
39 (package_label, self.model_name_lower),
40 self.field,
41 (package_label, name_lower),
42 )
43 )
44 return False
45
46 def references_field(self, model_name, name, package_label):
47 model_name_lower = model_name.lower()
48 # Check if this operation locally references the field.
49 if model_name_lower == self.model_name_lower:
50 if name == self.name:
51 return True
52 elif (
53 self.field
54 and hasattr(self.field, "from_fields")
55 and name in self.field.from_fields
56 ):
57 return True
58 # Check if this operation remotely references the field.
59 if self.field is None:
60 return False
61 return bool(
62 field_references(
63 (package_label, self.model_name_lower),
64 self.field,
65 (package_label, model_name_lower),
66 name,
67 )
68 )
69
70 def reduce(self, operation, package_label):
71 return super().reduce(
72 operation, package_label
73 ) or not operation.references_field(self.model_name, self.name, package_label)
74
75
76class AddField(FieldOperation):
77 """Add a field to a model."""
78
79 def __init__(self, model_name, name, field, preserve_default=True):
80 self.preserve_default = preserve_default
81 super().__init__(model_name, name, field)
82
83 def deconstruct(self):
84 kwargs = {
85 "model_name": self.model_name,
86 "name": self.name,
87 "field": self.field,
88 }
89 if self.preserve_default is not True:
90 kwargs["preserve_default"] = self.preserve_default
91 return (self.__class__.__name__, [], kwargs)
92
93 def state_forwards(self, package_label, state):
94 state.add_field(
95 package_label,
96 self.model_name_lower,
97 self.name,
98 self.field,
99 self.preserve_default,
100 )
101
102 def database_forwards(self, package_label, schema_editor, from_state, to_state):
103 to_model = to_state.models_registry.get_model(package_label, self.model_name)
104 if self.allow_migrate_model(schema_editor.connection, to_model):
105 from_model = from_state.models_registry.get_model(
106 package_label, self.model_name
107 )
108 field = to_model._meta.get_field(self.name)
109 if not self.preserve_default:
110 field.default = self.field.default
111 schema_editor.add_field(
112 from_model,
113 field,
114 )
115 if not self.preserve_default:
116 field.default = NOT_PROVIDED
117
118 def describe(self):
119 return f"Add field {self.name} to {self.model_name}"
120
121 @property
122 def migration_name_fragment(self):
123 return f"{self.model_name_lower}_{self.name_lower}"
124
125 def reduce(self, operation, package_label):
126 if isinstance(operation, FieldOperation) and self.is_same_field_operation(
127 operation
128 ):
129 if isinstance(operation, AlterField):
130 return [
131 AddField(
132 model_name=self.model_name,
133 name=operation.name,
134 field=operation.field,
135 ),
136 ]
137 elif isinstance(operation, RemoveField):
138 return []
139 elif isinstance(operation, RenameField):
140 return [
141 AddField(
142 model_name=self.model_name,
143 name=operation.new_name,
144 field=self.field,
145 ),
146 ]
147 return super().reduce(operation, package_label)
148
149
150class RemoveField(FieldOperation):
151 """Remove a field from a model."""
152
153 def deconstruct(self):
154 kwargs = {
155 "model_name": self.model_name,
156 "name": self.name,
157 }
158 return (self.__class__.__name__, [], kwargs)
159
160 def state_forwards(self, package_label, state):
161 state.remove_field(package_label, self.model_name_lower, self.name)
162
163 def database_forwards(self, package_label, schema_editor, from_state, to_state):
164 from_model = from_state.models_registry.get_model(
165 package_label, self.model_name
166 )
167 if self.allow_migrate_model(schema_editor.connection, from_model):
168 schema_editor.remove_field(
169 from_model, from_model._meta.get_field(self.name)
170 )
171
172 def describe(self):
173 return f"Remove field {self.name} from {self.model_name}"
174
175 @property
176 def migration_name_fragment(self):
177 return f"remove_{self.model_name_lower}_{self.name_lower}"
178
179 def reduce(self, operation, package_label):
180 from .models import DeleteModel
181
182 if (
183 isinstance(operation, DeleteModel)
184 and operation.name_lower == self.model_name_lower
185 ):
186 return [operation]
187 return super().reduce(operation, package_label)
188
189
190class AlterField(FieldOperation):
191 """
192 Alter a field's database column (e.g. null, max_length) to the provided
193 new field.
194 """
195
196 def __init__(self, model_name, name, field, preserve_default=True):
197 self.preserve_default = preserve_default
198 super().__init__(model_name, name, field)
199
200 def deconstruct(self):
201 kwargs = {
202 "model_name": self.model_name,
203 "name": self.name,
204 "field": self.field,
205 }
206 if self.preserve_default is not True:
207 kwargs["preserve_default"] = self.preserve_default
208 return (self.__class__.__name__, [], kwargs)
209
210 def state_forwards(self, package_label, state):
211 state.alter_field(
212 package_label,
213 self.model_name_lower,
214 self.name,
215 self.field,
216 self.preserve_default,
217 )
218
219 def database_forwards(self, package_label, schema_editor, from_state, to_state):
220 to_model = to_state.models_registry.get_model(package_label, self.model_name)
221 if self.allow_migrate_model(schema_editor.connection, to_model):
222 from_model = from_state.models_registry.get_model(
223 package_label, self.model_name
224 )
225 from_field = from_model._meta.get_field(self.name)
226 to_field = to_model._meta.get_field(self.name)
227 if not self.preserve_default:
228 to_field.default = self.field.default
229 schema_editor.alter_field(from_model, from_field, to_field)
230 if not self.preserve_default:
231 to_field.default = NOT_PROVIDED
232
233 def describe(self):
234 return f"Alter field {self.name} on {self.model_name}"
235
236 @property
237 def migration_name_fragment(self):
238 return f"alter_{self.model_name_lower}_{self.name_lower}"
239
240 def reduce(self, operation, package_label):
241 if isinstance(
242 operation, AlterField | RemoveField
243 ) and self.is_same_field_operation(operation):
244 return [operation]
245 elif (
246 isinstance(operation, RenameField)
247 and self.is_same_field_operation(operation)
248 and self.field.db_column is None
249 ):
250 return [
251 operation,
252 AlterField(
253 model_name=self.model_name,
254 name=operation.new_name,
255 field=self.field,
256 ),
257 ]
258 return super().reduce(operation, package_label)
259
260
261class RenameField(FieldOperation):
262 """Rename a field on the model. Might affect db_column too."""
263
264 def __init__(self, model_name, old_name, new_name):
265 self.old_name = old_name
266 self.new_name = new_name
267 super().__init__(model_name, old_name)
268
269 @cached_property
270 def old_name_lower(self):
271 return self.old_name.lower()
272
273 @cached_property
274 def new_name_lower(self):
275 return self.new_name.lower()
276
277 def deconstruct(self):
278 kwargs = {
279 "model_name": self.model_name,
280 "old_name": self.old_name,
281 "new_name": self.new_name,
282 }
283 return (self.__class__.__name__, [], kwargs)
284
285 def state_forwards(self, package_label, state):
286 state.rename_field(
287 package_label, self.model_name_lower, self.old_name, self.new_name
288 )
289
290 def database_forwards(self, package_label, schema_editor, from_state, to_state):
291 to_model = to_state.models_registry.get_model(package_label, self.model_name)
292 if self.allow_migrate_model(schema_editor.connection, to_model):
293 from_model = from_state.models_registry.get_model(
294 package_label, self.model_name
295 )
296 schema_editor.alter_field(
297 from_model,
298 from_model._meta.get_field(self.old_name),
299 to_model._meta.get_field(self.new_name),
300 )
301
302 def describe(self):
303 return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
304
305 @property
306 def migration_name_fragment(self):
307 return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
308
309 def references_field(self, model_name, name, package_label):
310 return self.references_model(model_name, package_label) and (
311 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
312 )
313
314 def reduce(self, operation, package_label):
315 if (
316 isinstance(operation, RenameField)
317 and self.is_same_model_operation(operation)
318 and self.new_name_lower == operation.old_name_lower
319 ):
320 return [
321 RenameField(
322 self.model_name,
323 self.old_name,
324 operation.new_name,
325 ),
326 ]
327 # Skip `FieldOperation.reduce` as we want to run `references_field`
328 # against self.old_name and self.new_name.
329 return super(FieldOperation, self).reduce(operation, package_label) or not (
330 operation.references_field(self.model_name, self.old_name, package_label)
331 or operation.references_field(self.model_name, self.new_name, package_label)
332 )