1from plain.models.fields import NOT_PROVIDED
2from plain.models.migrations.utils import field_references
3from plain.utils.functional import cached_property
4
5from .base import Operation
6
7
8class FieldOperation(Operation):
9 def __init__(self, model_name, name, field=None):
10 self.model_name = model_name
11 self.name = name
12 self.field = field
13
14 @cached_property
15 def model_name_lower(self):
16 return self.model_name.lower()
17
18 @cached_property
19 def name_lower(self):
20 return self.name.lower()
21
22 def is_same_model_operation(self, operation):
23 return self.model_name_lower == operation.model_name_lower
24
25 def is_same_field_operation(self, operation):
26 return (
27 self.is_same_model_operation(operation)
28 and self.name_lower == operation.name_lower
29 )
30
31 def references_model(self, name, package_label):
32 name_lower = name.lower()
33 if name_lower == self.model_name_lower:
34 return True
35 if self.field:
36 return bool(
37 field_references(
38 (package_label, self.model_name_lower),
39 self.field,
40 (package_label, name_lower),
41 )
42 )
43 return False
44
45 def references_field(self, model_name, name, package_label):
46 model_name_lower = model_name.lower()
47 # Check if this operation locally references the field.
48 if model_name_lower == self.model_name_lower:
49 if name == self.name:
50 return True
51 elif (
52 self.field
53 and hasattr(self.field, "from_fields")
54 and name in self.field.from_fields
55 ):
56 return True
57 # Check if this operation remotely references the field.
58 if self.field is None:
59 return False
60 return bool(
61 field_references(
62 (package_label, self.model_name_lower),
63 self.field,
64 (package_label, model_name_lower),
65 name,
66 )
67 )
68
69 def reduce(self, operation, package_label):
70 return super().reduce(
71 operation, package_label
72 ) or not operation.references_field(self.model_name, self.name, package_label)
73
74
75class AddField(FieldOperation):
76 """Add a field to a model."""
77
78 def __init__(self, model_name, name, field, preserve_default=True):
79 self.preserve_default = preserve_default
80 super().__init__(model_name, name, field)
81
82 def deconstruct(self):
83 kwargs = {
84 "model_name": self.model_name,
85 "name": self.name,
86 "field": self.field,
87 }
88 if self.preserve_default is not True:
89 kwargs["preserve_default"] = self.preserve_default
90 return (self.__class__.__name__, [], kwargs)
91
92 def state_forwards(self, package_label, state):
93 state.add_field(
94 package_label,
95 self.model_name_lower,
96 self.name,
97 self.field,
98 self.preserve_default,
99 )
100
101 def database_forwards(self, package_label, schema_editor, from_state, to_state):
102 to_model = to_state.packages.get_model(package_label, self.model_name)
103 if self.allow_migrate_model(schema_editor.connection.alias, to_model):
104 from_model = from_state.packages.get_model(package_label, self.model_name)
105 field = to_model._meta.get_field(self.name)
106 if not self.preserve_default:
107 field.default = self.field.default
108 schema_editor.add_field(
109 from_model,
110 field,
111 )
112 if not self.preserve_default:
113 field.default = NOT_PROVIDED
114
115 def database_backwards(self, package_label, schema_editor, from_state, to_state):
116 from_model = from_state.packages.get_model(package_label, self.model_name)
117 if self.allow_migrate_model(schema_editor.connection.alias, from_model):
118 schema_editor.remove_field(
119 from_model, from_model._meta.get_field(self.name)
120 )
121
122 def describe(self):
123 return f"Add field {self.name} to {self.model_name}"
124
125 @property
126 def migration_name_fragment(self):
127 return f"{self.model_name_lower}_{self.name_lower}"
128
129 def reduce(self, operation, package_label):
130 if isinstance(operation, FieldOperation) and self.is_same_field_operation(
131 operation
132 ):
133 if isinstance(operation, AlterField):
134 return [
135 AddField(
136 model_name=self.model_name,
137 name=operation.name,
138 field=operation.field,
139 ),
140 ]
141 elif isinstance(operation, RemoveField):
142 return []
143 elif isinstance(operation, RenameField):
144 return [
145 AddField(
146 model_name=self.model_name,
147 name=operation.new_name,
148 field=self.field,
149 ),
150 ]
151 return super().reduce(operation, package_label)
152
153
154class RemoveField(FieldOperation):
155 """Remove a field from a model."""
156
157 def deconstruct(self):
158 kwargs = {
159 "model_name": self.model_name,
160 "name": self.name,
161 }
162 return (self.__class__.__name__, [], kwargs)
163
164 def state_forwards(self, package_label, state):
165 state.remove_field(package_label, self.model_name_lower, self.name)
166
167 def database_forwards(self, package_label, schema_editor, from_state, to_state):
168 from_model = from_state.packages.get_model(package_label, self.model_name)
169 if self.allow_migrate_model(schema_editor.connection.alias, from_model):
170 schema_editor.remove_field(
171 from_model, from_model._meta.get_field(self.name)
172 )
173
174 def database_backwards(self, package_label, schema_editor, from_state, to_state):
175 to_model = to_state.packages.get_model(package_label, self.model_name)
176 if self.allow_migrate_model(schema_editor.connection.alias, to_model):
177 from_model = from_state.packages.get_model(package_label, self.model_name)
178 schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
179
180 def describe(self):
181 return f"Remove field {self.name} from {self.model_name}"
182
183 @property
184 def migration_name_fragment(self):
185 return f"remove_{self.model_name_lower}_{self.name_lower}"
186
187 def reduce(self, operation, package_label):
188 from .models import DeleteModel
189
190 if (
191 isinstance(operation, DeleteModel)
192 and operation.name_lower == self.model_name_lower
193 ):
194 return [operation]
195 return super().reduce(operation, package_label)
196
197
198class AlterField(FieldOperation):
199 """
200 Alter a field's database column (e.g. null, max_length) to the provided
201 new field.
202 """
203
204 def __init__(self, model_name, name, field, preserve_default=True):
205 self.preserve_default = preserve_default
206 super().__init__(model_name, name, field)
207
208 def deconstruct(self):
209 kwargs = {
210 "model_name": self.model_name,
211 "name": self.name,
212 "field": self.field,
213 }
214 if self.preserve_default is not True:
215 kwargs["preserve_default"] = self.preserve_default
216 return (self.__class__.__name__, [], kwargs)
217
218 def state_forwards(self, package_label, state):
219 state.alter_field(
220 package_label,
221 self.model_name_lower,
222 self.name,
223 self.field,
224 self.preserve_default,
225 )
226
227 def database_forwards(self, package_label, schema_editor, from_state, to_state):
228 to_model = to_state.packages.get_model(package_label, self.model_name)
229 if self.allow_migrate_model(schema_editor.connection.alias, to_model):
230 from_model = from_state.packages.get_model(package_label, self.model_name)
231 from_field = from_model._meta.get_field(self.name)
232 to_field = to_model._meta.get_field(self.name)
233 if not self.preserve_default:
234 to_field.default = self.field.default
235 schema_editor.alter_field(from_model, from_field, to_field)
236 if not self.preserve_default:
237 to_field.default = NOT_PROVIDED
238
239 def database_backwards(self, package_label, schema_editor, from_state, to_state):
240 self.database_forwards(package_label, schema_editor, from_state, to_state)
241
242 def describe(self):
243 return f"Alter field {self.name} on {self.model_name}"
244
245 @property
246 def migration_name_fragment(self):
247 return f"alter_{self.model_name_lower}_{self.name_lower}"
248
249 def reduce(self, operation, package_label):
250 if isinstance(
251 operation, AlterField | RemoveField
252 ) and self.is_same_field_operation(operation):
253 return [operation]
254 elif (
255 isinstance(operation, RenameField)
256 and self.is_same_field_operation(operation)
257 and self.field.db_column is None
258 ):
259 return [
260 operation,
261 AlterField(
262 model_name=self.model_name,
263 name=operation.new_name,
264 field=self.field,
265 ),
266 ]
267 return super().reduce(operation, package_label)
268
269
270class RenameField(FieldOperation):
271 """Rename a field on the model. Might affect db_column too."""
272
273 def __init__(self, model_name, old_name, new_name):
274 self.old_name = old_name
275 self.new_name = new_name
276 super().__init__(model_name, old_name)
277
278 @cached_property
279 def old_name_lower(self):
280 return self.old_name.lower()
281
282 @cached_property
283 def new_name_lower(self):
284 return self.new_name.lower()
285
286 def deconstruct(self):
287 kwargs = {
288 "model_name": self.model_name,
289 "old_name": self.old_name,
290 "new_name": self.new_name,
291 }
292 return (self.__class__.__name__, [], kwargs)
293
294 def state_forwards(self, package_label, state):
295 state.rename_field(
296 package_label, self.model_name_lower, self.old_name, self.new_name
297 )
298
299 def database_forwards(self, package_label, schema_editor, from_state, to_state):
300 to_model = to_state.packages.get_model(package_label, self.model_name)
301 if self.allow_migrate_model(schema_editor.connection.alias, to_model):
302 from_model = from_state.packages.get_model(package_label, self.model_name)
303 schema_editor.alter_field(
304 from_model,
305 from_model._meta.get_field(self.old_name),
306 to_model._meta.get_field(self.new_name),
307 )
308
309 def database_backwards(self, package_label, schema_editor, from_state, to_state):
310 to_model = to_state.packages.get_model(package_label, self.model_name)
311 if self.allow_migrate_model(schema_editor.connection.alias, to_model):
312 from_model = from_state.packages.get_model(package_label, self.model_name)
313 schema_editor.alter_field(
314 from_model,
315 from_model._meta.get_field(self.new_name),
316 to_model._meta.get_field(self.old_name),
317 )
318
319 def describe(self):
320 return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
321
322 @property
323 def migration_name_fragment(self):
324 return "rename_{}_{}_{}".format(
325 self.old_name_lower,
326 self.model_name_lower,
327 self.new_name_lower,
328 )
329
330 def references_field(self, model_name, name, package_label):
331 return self.references_model(model_name, package_label) and (
332 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
333 )
334
335 def reduce(self, operation, package_label):
336 if (
337 isinstance(operation, RenameField)
338 and self.is_same_model_operation(operation)
339 and self.new_name_lower == operation.old_name_lower
340 ):
341 return [
342 RenameField(
343 self.model_name,
344 self.old_name,
345 operation.new_name,
346 ),
347 ]
348 # Skip `FieldOperation.reduce` as we want to run `references_field`
349 # against self.old_name and self.new_name.
350 return super(FieldOperation, self).reduce(operation, package_label) or not (
351 operation.references_field(self.model_name, self.old_name, package_label)
352 or operation.references_field(self.model_name, self.new_name, package_label)
353 )