1from functools import cached_property
2
3from plain.models.fields import NOT_PROVIDED
4from plain.models.migrations.utils import field_references
5
6from .base import Operation
7
8
9class FieldOperation(Operation):
10 def __init__(self, model_name, name, field=None):
11 self.model_name = model_name
12 self.name = name
13 self.field = field
14
15 @cached_property
16 def model_name_lower(self):
17 return self.model_name.lower()
18
19 @cached_property
20 def name_lower(self):
21 return self.name.lower()
22
23 def is_same_model_operation(self, operation):
24 return self.model_name_lower == operation.model_name_lower
25
26 def is_same_field_operation(self, operation):
27 return (
28 self.is_same_model_operation(operation)
29 and self.name_lower == operation.name_lower
30 )
31
32 def references_model(self, name, package_label):
33 name_lower = name.lower()
34 if name_lower == self.model_name_lower:
35 return True
36 if self.field:
37 return bool(
38 field_references(
39 (package_label, self.model_name_lower),
40 self.field,
41 (package_label, name_lower),
42 )
43 )
44 return False
45
46 def references_field(self, model_name, name, package_label):
47 model_name_lower = model_name.lower()
48 # Check if this operation locally references the field.
49 if model_name_lower == self.model_name_lower:
50 if name == self.name:
51 return True
52 # Check if this operation remotely references the field.
53 if self.field is None:
54 return False
55 return bool(
56 field_references(
57 (package_label, self.model_name_lower),
58 self.field,
59 (package_label, model_name_lower),
60 name,
61 )
62 )
63
64 def reduce(self, operation, package_label):
65 return super().reduce(
66 operation, package_label
67 ) or not operation.references_field(self.model_name, self.name, package_label)
68
69
70class AddField(FieldOperation):
71 """Add a field to a model."""
72
73 def __init__(self, model_name, name, field, preserve_default=True):
74 self.preserve_default = preserve_default
75 super().__init__(model_name, name, field)
76
77 def deconstruct(self):
78 kwargs = {
79 "model_name": self.model_name,
80 "name": self.name,
81 "field": self.field,
82 }
83 if self.preserve_default is not True:
84 kwargs["preserve_default"] = self.preserve_default
85 return (self.__class__.__name__, [], kwargs)
86
87 def state_forwards(self, package_label, state):
88 state.add_field(
89 package_label,
90 self.model_name_lower,
91 self.name,
92 self.field,
93 self.preserve_default,
94 )
95
96 def database_forwards(self, package_label, schema_editor, from_state, to_state):
97 to_model = to_state.models_registry.get_model(package_label, self.model_name)
98 if self.allow_migrate_model(schema_editor.connection, to_model):
99 from_model = from_state.models_registry.get_model(
100 package_label, self.model_name
101 )
102 field = to_model._meta.get_field(self.name)
103 if not self.preserve_default:
104 field.default = self.field.default
105 schema_editor.add_field(
106 from_model,
107 field,
108 )
109 if not self.preserve_default:
110 field.default = NOT_PROVIDED
111
112 def describe(self):
113 return f"Add field {self.name} to {self.model_name}"
114
115 @property
116 def migration_name_fragment(self):
117 return f"{self.model_name_lower}_{self.name_lower}"
118
119 def reduce(self, operation, package_label):
120 if isinstance(operation, FieldOperation) and self.is_same_field_operation(
121 operation
122 ):
123 if isinstance(operation, AlterField):
124 return [
125 AddField(
126 model_name=self.model_name,
127 name=operation.name,
128 field=operation.field,
129 ),
130 ]
131 elif isinstance(operation, RemoveField):
132 return []
133 elif isinstance(operation, RenameField):
134 return [
135 AddField(
136 model_name=self.model_name,
137 name=operation.new_name,
138 field=self.field,
139 ),
140 ]
141 return super().reduce(operation, package_label)
142
143
144class RemoveField(FieldOperation):
145 """Remove a field from a model."""
146
147 def deconstruct(self):
148 kwargs = {
149 "model_name": self.model_name,
150 "name": self.name,
151 }
152 return (self.__class__.__name__, [], kwargs)
153
154 def state_forwards(self, package_label, state):
155 state.remove_field(package_label, self.model_name_lower, self.name)
156
157 def database_forwards(self, package_label, schema_editor, from_state, to_state):
158 from_model = from_state.models_registry.get_model(
159 package_label, self.model_name
160 )
161 if self.allow_migrate_model(schema_editor.connection, from_model):
162 schema_editor.remove_field(
163 from_model, from_model._meta.get_field(self.name)
164 )
165
166 def describe(self):
167 return f"Remove field {self.name} from {self.model_name}"
168
169 @property
170 def migration_name_fragment(self):
171 return f"remove_{self.model_name_lower}_{self.name_lower}"
172
173 def reduce(self, operation, package_label):
174 from .models import DeleteModel
175
176 if (
177 isinstance(operation, DeleteModel)
178 and operation.name_lower == self.model_name_lower
179 ):
180 return [operation]
181 return super().reduce(operation, package_label)
182
183
184class AlterField(FieldOperation):
185 """
186 Alter a field's database column (e.g. null, max_length) to the provided
187 new field.
188 """
189
190 def __init__(self, model_name, name, field, preserve_default=True):
191 self.preserve_default = preserve_default
192 super().__init__(model_name, name, field)
193
194 def deconstruct(self):
195 kwargs = {
196 "model_name": self.model_name,
197 "name": self.name,
198 "field": self.field,
199 }
200 if self.preserve_default is not True:
201 kwargs["preserve_default"] = self.preserve_default
202 return (self.__class__.__name__, [], kwargs)
203
204 def state_forwards(self, package_label, state):
205 state.alter_field(
206 package_label,
207 self.model_name_lower,
208 self.name,
209 self.field,
210 self.preserve_default,
211 )
212
213 def database_forwards(self, package_label, schema_editor, from_state, to_state):
214 to_model = to_state.models_registry.get_model(package_label, self.model_name)
215 if self.allow_migrate_model(schema_editor.connection, to_model):
216 from_model = from_state.models_registry.get_model(
217 package_label, self.model_name
218 )
219 from_field = from_model._meta.get_field(self.name)
220 to_field = to_model._meta.get_field(self.name)
221 if not self.preserve_default:
222 to_field.default = self.field.default
223 schema_editor.alter_field(from_model, from_field, to_field)
224 if not self.preserve_default:
225 to_field.default = NOT_PROVIDED
226
227 def describe(self):
228 return f"Alter field {self.name} on {self.model_name}"
229
230 @property
231 def migration_name_fragment(self):
232 return f"alter_{self.model_name_lower}_{self.name_lower}"
233
234 def reduce(self, operation, package_label):
235 if isinstance(
236 operation, AlterField | RemoveField
237 ) and self.is_same_field_operation(operation):
238 return [operation]
239 elif (
240 isinstance(operation, RenameField)
241 and self.is_same_field_operation(operation)
242 and self.field.db_column is None
243 ):
244 return [
245 operation,
246 AlterField(
247 model_name=self.model_name,
248 name=operation.new_name,
249 field=self.field,
250 ),
251 ]
252 return super().reduce(operation, package_label)
253
254
255class RenameField(FieldOperation):
256 """Rename a field on the model. Might affect db_column too."""
257
258 def __init__(self, model_name, old_name, new_name):
259 self.old_name = old_name
260 self.new_name = new_name
261 super().__init__(model_name, old_name)
262
263 @cached_property
264 def old_name_lower(self):
265 return self.old_name.lower()
266
267 @cached_property
268 def new_name_lower(self):
269 return self.new_name.lower()
270
271 def deconstruct(self):
272 kwargs = {
273 "model_name": self.model_name,
274 "old_name": self.old_name,
275 "new_name": self.new_name,
276 }
277 return (self.__class__.__name__, [], kwargs)
278
279 def state_forwards(self, package_label, state):
280 state.rename_field(
281 package_label, self.model_name_lower, self.old_name, self.new_name
282 )
283
284 def database_forwards(self, package_label, schema_editor, from_state, to_state):
285 to_model = to_state.models_registry.get_model(package_label, self.model_name)
286 if self.allow_migrate_model(schema_editor.connection, to_model):
287 from_model = from_state.models_registry.get_model(
288 package_label, self.model_name
289 )
290 schema_editor.alter_field(
291 from_model,
292 from_model._meta.get_field(self.old_name),
293 to_model._meta.get_field(self.new_name),
294 )
295
296 def describe(self):
297 return f"Rename field {self.old_name} on {self.model_name} to {self.new_name}"
298
299 @property
300 def migration_name_fragment(self):
301 return f"rename_{self.old_name_lower}_{self.model_name_lower}_{self.new_name_lower}"
302
303 def references_field(self, model_name, name, package_label):
304 return self.references_model(model_name, package_label) and (
305 name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
306 )
307
308 def reduce(self, operation, package_label):
309 if (
310 isinstance(operation, RenameField)
311 and self.is_same_model_operation(operation)
312 and self.new_name_lower == operation.old_name_lower
313 ):
314 return [
315 RenameField(
316 self.model_name,
317 self.old_name,
318 operation.new_name,
319 ),
320 ]
321 # Skip `FieldOperation.reduce` as we want to run `references_field`
322 # against self.old_name and self.new_name.
323 return super(FieldOperation, self).reduce(operation, package_label) or not (
324 operation.references_field(self.model_name, self.old_name, package_label)
325 or operation.references_field(self.model_name, self.new_name, package_label)
326 )