1import copy
2from collections import defaultdict
3from contextlib import contextmanager
4from functools import partial
5
6from plain import models
7from plain.exceptions import FieldDoesNotExist
8from plain.models.fields import NOT_PROVIDED
9from plain.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
10from plain.models.migrations.utils import field_is_referenced, get_references
11from plain.models.options import DEFAULT_NAMES
12from plain.models.utils import make_model_tuple
13from plain.packages import PackageConfig
14from plain.packages.registry import Packages
15from plain.packages.registry import packages as global_packages
16from plain.runtime import settings
17from plain.utils.functional import cached_property
18from plain.utils.module_loading import import_string
19
20from .exceptions import InvalidBasesError
21from .utils import resolve_relation
22
23
24def _get_package_label_and_model_name(model, package_label=""):
25 if isinstance(model, str):
26 split = model.split(".", 1)
27 return tuple(split) if len(split) == 2 else (package_label, split[0])
28 else:
29 return model._meta.package_label, model._meta.model_name
30
31
32def _get_related_models(m):
33 """Return all models that have a direct relationship to the given model."""
34 related_models = [
35 subclass
36 for subclass in m.__subclasses__()
37 if issubclass(subclass, models.Model)
38 ]
39 related_fields_models = set()
40 for f in m._meta.get_fields(include_parents=True, include_hidden=True):
41 if (
42 f.is_relation
43 and f.related_model is not None
44 and not isinstance(f.related_model, str)
45 ):
46 related_fields_models.add(f.model)
47 related_models.append(f.related_model)
48 return related_models
49
50
51def get_related_models_tuples(model):
52 """
53 Return a list of typical (package_label, model_name) tuples for all related
54 models for the given model.
55 """
56 return {
57 (rel_mod._meta.package_label, rel_mod._meta.model_name)
58 for rel_mod in _get_related_models(model)
59 }
60
61
62def get_related_models_recursive(model):
63 """
64 Return all models that have a direct or indirect relationship
65 to the given model.
66
67 Relationships are either defined by explicit relational fields, like
68 ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another
69 model (a superclass is related to its subclasses, but not vice versa). Note,
70 however, that a model inheriting from a concrete model is also related to
71 its superclass through the implicit *_ptr OneToOneField on the subclass.
72 """
73 seen = set()
74 queue = _get_related_models(model)
75 for rel_mod in queue:
76 rel_package_label, rel_model_name = (
77 rel_mod._meta.package_label,
78 rel_mod._meta.model_name,
79 )
80 if (rel_package_label, rel_model_name) in seen:
81 continue
82 seen.add((rel_package_label, rel_model_name))
83 queue.extend(_get_related_models(rel_mod))
84 return seen - {(model._meta.package_label, model._meta.model_name)}
85
86
87class ProjectState:
88 """
89 Represent the entire project's overall state. This is the item that is
90 passed around - do it here rather than at the app level so that cross-app
91 FKs/etc. resolve properly.
92 """
93
94 def __init__(self, models=None, real_packages=None):
95 self.models = models or {}
96 # Packages to include from main registry, usually unmigrated ones
97 if real_packages is None:
98 real_packages = set()
99 else:
100 assert isinstance(real_packages, set)
101 self.real_packages = real_packages
102 self.is_delayed = False
103 # {remote_model_key: {model_key: {field_name: field}}}
104 self._relations = None
105
106 @property
107 def relations(self):
108 if self._relations is None:
109 self.resolve_fields_and_relations()
110 return self._relations
111
112 def add_model(self, model_state):
113 model_key = model_state.package_label, model_state.name_lower
114 self.models[model_key] = model_state
115 if self._relations is not None:
116 self.resolve_model_relations(model_key)
117 if "packages" in self.__dict__: # hasattr would cache the property
118 self.reload_model(*model_key)
119
120 def remove_model(self, package_label, model_name):
121 model_key = package_label, model_name
122 del self.models[model_key]
123 if self._relations is not None:
124 self._relations.pop(model_key, None)
125 # Call list() since _relations can change size during iteration.
126 for related_model_key, model_relations in list(self._relations.items()):
127 model_relations.pop(model_key, None)
128 if not model_relations:
129 del self._relations[related_model_key]
130 if "packages" in self.__dict__: # hasattr would cache the property
131 self.packages.unregister_model(*model_key)
132 # Need to do this explicitly since unregister_model() doesn't clear
133 # the cache automatically (#24513)
134 self.packages.clear_cache()
135
136 def rename_model(self, package_label, old_name, new_name):
137 # Add a new model.
138 old_name_lower = old_name.lower()
139 new_name_lower = new_name.lower()
140 renamed_model = self.models[package_label, old_name_lower].clone()
141 renamed_model.name = new_name
142 self.models[package_label, new_name_lower] = renamed_model
143 # Repoint all fields pointing to the old model to the new one.
144 old_model_tuple = (package_label, old_name_lower)
145 new_remote_model = f"{package_label}.{new_name}"
146 to_reload = set()
147 for model_state, name, field, reference in get_references(
148 self, old_model_tuple
149 ):
150 changed_field = None
151 if reference.to:
152 changed_field = field.clone()
153 changed_field.remote_field.model = new_remote_model
154 if reference.through:
155 if changed_field is None:
156 changed_field = field.clone()
157 changed_field.remote_field.through = new_remote_model
158 if changed_field:
159 model_state.fields[name] = changed_field
160 to_reload.add((model_state.package_label, model_state.name_lower))
161 if self._relations is not None:
162 old_name_key = package_label, old_name_lower
163 new_name_key = package_label, new_name_lower
164 if old_name_key in self._relations:
165 self._relations[new_name_key] = self._relations.pop(old_name_key)
166 for model_relations in self._relations.values():
167 if old_name_key in model_relations:
168 model_relations[new_name_key] = model_relations.pop(old_name_key)
169 # Reload models related to old model before removing the old model.
170 self.reload_models(to_reload, delay=True)
171 # Remove the old model.
172 self.remove_model(package_label, old_name_lower)
173 self.reload_model(package_label, new_name_lower, delay=True)
174
175 def alter_model_options(self, package_label, model_name, options, option_keys=None):
176 model_state = self.models[package_label, model_name]
177 model_state.options = {**model_state.options, **options}
178 if option_keys:
179 for key in option_keys:
180 if key not in options:
181 model_state.options.pop(key, False)
182 self.reload_model(package_label, model_name, delay=True)
183
184 def remove_model_options(
185 self, package_label, model_name, option_name, value_to_remove
186 ):
187 model_state = self.models[package_label, model_name]
188 if objs := model_state.options.get(option_name):
189 model_state.options[option_name] = [
190 obj for obj in objs if tuple(obj) != tuple(value_to_remove)
191 ]
192 self.reload_model(package_label, model_name, delay=True)
193
194 def alter_model_managers(self, package_label, model_name, managers):
195 model_state = self.models[package_label, model_name]
196 model_state.managers = list(managers)
197 self.reload_model(package_label, model_name, delay=True)
198
199 def _append_option(self, package_label, model_name, option_name, obj):
200 model_state = self.models[package_label, model_name]
201 model_state.options[option_name] = [*model_state.options[option_name], obj]
202 self.reload_model(package_label, model_name, delay=True)
203
204 def _remove_option(self, package_label, model_name, option_name, obj_name):
205 model_state = self.models[package_label, model_name]
206 objs = model_state.options[option_name]
207 model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
208 self.reload_model(package_label, model_name, delay=True)
209
210 def add_index(self, package_label, model_name, index):
211 self._append_option(package_label, model_name, "indexes", index)
212
213 def remove_index(self, package_label, model_name, index_name):
214 self._remove_option(package_label, model_name, "indexes", index_name)
215
216 def rename_index(self, package_label, model_name, old_index_name, new_index_name):
217 model_state = self.models[package_label, model_name]
218 objs = model_state.options["indexes"]
219
220 new_indexes = []
221 for obj in objs:
222 if obj.name == old_index_name:
223 obj = obj.clone()
224 obj.name = new_index_name
225 new_indexes.append(obj)
226
227 model_state.options["indexes"] = new_indexes
228 self.reload_model(package_label, model_name, delay=True)
229
230 def add_constraint(self, package_label, model_name, constraint):
231 self._append_option(package_label, model_name, "constraints", constraint)
232
233 def remove_constraint(self, package_label, model_name, constraint_name):
234 self._remove_option(package_label, model_name, "constraints", constraint_name)
235
236 def add_field(self, package_label, model_name, name, field, preserve_default):
237 # If preserve default is off, don't use the default for future state.
238 if not preserve_default:
239 field = field.clone()
240 field.default = NOT_PROVIDED
241 else:
242 field = field
243 model_key = package_label, model_name
244 self.models[model_key].fields[name] = field
245 if self._relations is not None:
246 self.resolve_model_field_relations(model_key, name, field)
247 # Delay rendering of relationships if it's not a relational field.
248 delay = not field.is_relation
249 self.reload_model(*model_key, delay=delay)
250
251 def remove_field(self, package_label, model_name, name):
252 model_key = package_label, model_name
253 model_state = self.models[model_key]
254 old_field = model_state.fields.pop(name)
255 if self._relations is not None:
256 self.resolve_model_field_relations(model_key, name, old_field)
257 # Delay rendering of relationships if it's not a relational field.
258 delay = not old_field.is_relation
259 self.reload_model(*model_key, delay=delay)
260
261 def alter_field(self, package_label, model_name, name, field, preserve_default):
262 if not preserve_default:
263 field = field.clone()
264 field.default = NOT_PROVIDED
265 else:
266 field = field
267 model_key = package_label, model_name
268 fields = self.models[model_key].fields
269 if self._relations is not None:
270 old_field = fields.pop(name)
271 if old_field.is_relation:
272 self.resolve_model_field_relations(model_key, name, old_field)
273 fields[name] = field
274 if field.is_relation:
275 self.resolve_model_field_relations(model_key, name, field)
276 else:
277 fields[name] = field
278 # TODO: investigate if old relational fields must be reloaded or if
279 # it's sufficient if the new field is (#27737).
280 # Delay rendering of relationships if it's not a relational field and
281 # not referenced by a foreign key.
282 delay = not field.is_relation and not field_is_referenced(
283 self, model_key, (name, field)
284 )
285 self.reload_model(*model_key, delay=delay)
286
287 def rename_field(self, package_label, model_name, old_name, new_name):
288 model_key = package_label, model_name
289 model_state = self.models[model_key]
290 # Rename the field.
291 fields = model_state.fields
292 try:
293 found = fields.pop(old_name)
294 except KeyError:
295 raise FieldDoesNotExist(
296 f"{package_label}.{model_name} has no field named '{old_name}'"
297 )
298 fields[new_name] = found
299 for field in fields.values():
300 # Fix from_fields to refer to the new field.
301 from_fields = getattr(field, "from_fields", None)
302 if from_fields:
303 field.from_fields = tuple(
304 [
305 new_name if from_field_name == old_name else from_field_name
306 for from_field_name in from_fields
307 ]
308 )
309
310 # Fix to_fields to refer to the new field.
311 delay = True
312 references = get_references(self, model_key, (old_name, found))
313 for *_, field, reference in references:
314 delay = False
315 if reference.to:
316 remote_field, to_fields = reference.to
317 if getattr(remote_field, "field_name", None) == old_name:
318 remote_field.field_name = new_name
319 if to_fields:
320 field.to_fields = tuple(
321 [
322 new_name if to_field_name == old_name else to_field_name
323 for to_field_name in to_fields
324 ]
325 )
326 if self._relations is not None:
327 old_name_lower = old_name.lower()
328 new_name_lower = new_name.lower()
329 for to_model in self._relations.values():
330 if old_name_lower in to_model[model_key]:
331 field = to_model[model_key].pop(old_name_lower)
332 field.name = new_name_lower
333 to_model[model_key][new_name_lower] = field
334 self.reload_model(*model_key, delay=delay)
335
336 def _find_reload_model(self, package_label, model_name, delay=False):
337 if delay:
338 self.is_delayed = True
339
340 related_models = set()
341
342 try:
343 old_model = self.packages.get_model(package_label, model_name)
344 except LookupError:
345 pass
346 else:
347 # Get all relations to and from the old model before reloading,
348 # as _meta.packages may change
349 if delay:
350 related_models = get_related_models_tuples(old_model)
351 else:
352 related_models = get_related_models_recursive(old_model)
353
354 # Get all outgoing references from the model to be rendered
355 model_state = self.models[(package_label, model_name)]
356 # Directly related models are the models pointed to by ForeignKeys,
357 # OneToOneFields, and ManyToManyFields.
358 direct_related_models = set()
359 for field in model_state.fields.values():
360 if field.is_relation:
361 if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
362 continue
363 rel_package_label, rel_model_name = _get_package_label_and_model_name(
364 field.related_model, package_label
365 )
366 direct_related_models.add((rel_package_label, rel_model_name.lower()))
367
368 # For all direct related models recursively get all related models.
369 related_models.update(direct_related_models)
370 for rel_package_label, rel_model_name in direct_related_models:
371 try:
372 rel_model = self.packages.get_model(rel_package_label, rel_model_name)
373 except LookupError:
374 pass
375 else:
376 if delay:
377 related_models.update(get_related_models_tuples(rel_model))
378 else:
379 related_models.update(get_related_models_recursive(rel_model))
380
381 # Include the model itself
382 related_models.add((package_label, model_name))
383
384 return related_models
385
386 def reload_model(self, package_label, model_name, delay=False):
387 if "packages" in self.__dict__: # hasattr would cache the property
388 related_models = self._find_reload_model(package_label, model_name, delay)
389 self._reload(related_models)
390
391 def reload_models(self, models, delay=True):
392 if "packages" in self.__dict__: # hasattr would cache the property
393 related_models = set()
394 for package_label, model_name in models:
395 related_models.update(
396 self._find_reload_model(package_label, model_name, delay)
397 )
398 self._reload(related_models)
399
400 def _reload(self, related_models):
401 # Unregister all related models
402 with self.packages.bulk_update():
403 for rel_package_label, rel_model_name in related_models:
404 self.packages.unregister_model(rel_package_label, rel_model_name)
405
406 states_to_be_rendered = []
407 # Gather all models states of those models that will be rerendered.
408 # This includes:
409 # 1. All related models of unmigrated packages
410 for model_state in self.packages.real_models:
411 if (model_state.package_label, model_state.name_lower) in related_models:
412 states_to_be_rendered.append(model_state)
413
414 # 2. All related models of migrated packages
415 for rel_package_label, rel_model_name in related_models:
416 try:
417 model_state = self.models[rel_package_label, rel_model_name]
418 except KeyError:
419 pass
420 else:
421 states_to_be_rendered.append(model_state)
422
423 # Render all models
424 self.packages.render_multiple(states_to_be_rendered)
425
426 def update_model_field_relation(
427 self,
428 model,
429 model_key,
430 field_name,
431 field,
432 concretes,
433 ):
434 remote_model_key = resolve_relation(model, *model_key)
435 if (
436 remote_model_key[0] not in self.real_packages
437 and remote_model_key in concretes
438 ):
439 remote_model_key = concretes[remote_model_key]
440 relations_to_remote_model = self._relations[remote_model_key]
441 if field_name in self.models[model_key].fields:
442 # The assert holds because it's a new relation, or an altered
443 # relation, in which case references have been removed by
444 # alter_field().
445 assert field_name not in relations_to_remote_model[model_key]
446 relations_to_remote_model[model_key][field_name] = field
447 else:
448 del relations_to_remote_model[model_key][field_name]
449 if not relations_to_remote_model[model_key]:
450 del relations_to_remote_model[model_key]
451
452 def resolve_model_field_relations(
453 self,
454 model_key,
455 field_name,
456 field,
457 concretes=None,
458 ):
459 remote_field = field.remote_field
460 if not remote_field:
461 return
462 if concretes is None:
463 concretes = self._get_concrete_models_mapping()
464
465 self.update_model_field_relation(
466 remote_field.model,
467 model_key,
468 field_name,
469 field,
470 concretes,
471 )
472
473 through = getattr(remote_field, "through", None)
474 if not through:
475 return
476 self.update_model_field_relation(
477 through, model_key, field_name, field, concretes
478 )
479
480 def resolve_model_relations(self, model_key, concretes=None):
481 if concretes is None:
482 concretes = self._get_concrete_models_mapping()
483
484 model_state = self.models[model_key]
485 for field_name, field in model_state.fields.items():
486 self.resolve_model_field_relations(model_key, field_name, field, concretes)
487
488 def resolve_fields_and_relations(self):
489 # Resolve fields.
490 for model_state in self.models.values():
491 for field_name, field in model_state.fields.items():
492 field.name = field_name
493 # Resolve relations.
494 # {remote_model_key: {model_key: {field_name: field}}}
495 self._relations = defaultdict(partial(defaultdict, dict))
496 concretes = self._get_concrete_models_mapping()
497
498 for model_key in concretes:
499 self.resolve_model_relations(model_key, concretes)
500
501 def get_concrete_model_key(self, model):
502 (concrete_models_mapping,) = self._get_concrete_models_mapping()
503 model_key = make_model_tuple(model)
504 return concrete_models_mapping[model_key]
505
506 def _get_concrete_models_mapping(self):
507 concrete_models_mapping = {}
508 for model_key, model_state in self.models.items():
509 concrete_models_mapping[model_key] = model_key
510 return concrete_models_mapping
511
512 def clone(self):
513 """Return an exact copy of this ProjectState."""
514 new_state = ProjectState(
515 models={k: v.clone() for k, v in self.models.items()},
516 real_packages=self.real_packages,
517 )
518 if "packages" in self.__dict__:
519 new_state.packages = self.packages.clone()
520 new_state.is_delayed = self.is_delayed
521 return new_state
522
523 def clear_delayed_packages_cache(self):
524 if self.is_delayed and "packages" in self.__dict__:
525 del self.__dict__["packages"]
526
527 @cached_property
528 def packages(self):
529 return StatePackages(self.real_packages, self.models)
530
531 @classmethod
532 def from_packages(cls, packages):
533 """Take an Packages and return a ProjectState matching it."""
534 app_models = {}
535 for model in packages.get_models(include_swapped=True):
536 model_state = ModelState.from_model(model)
537 app_models[
538 (model_state.package_label, model_state.name_lower)
539 ] = model_state
540 return cls(app_models)
541
542 def __eq__(self, other):
543 return self.models == other.models and self.real_packages == other.real_packages
544
545
546class PackageConfigStub(PackageConfig):
547 """Stub of an PackageConfig. Only provides a label and a dict of models."""
548
549 def __init__(self, label):
550 self.packages = None
551 self.models = {}
552 # Package-label and package-name are not the same thing, so technically passing
553 # in the label here is wrong. In practice, migrations don't care about
554 # the package name, but we need something unique, and the label works fine.
555 self.label = label
556 self.name = label
557
558 def import_models(self):
559 self.models = self.packages.all_models[self.label]
560
561
562class StatePackages(Packages):
563 """
564 Subclass of the global Packages registry class to better handle dynamic model
565 additions and removals.
566 """
567
568 def __init__(self, real_packages, models, ignore_swappable=False):
569 # Any packages in self.real_packages should have all their models included
570 # in the render. We don't use the original model instances as there
571 # are some variables that refer to the Packages object.
572 # FKs/M2Ms from real packages are also not included as they just
573 # mess things up with partial states (due to lack of dependencies)
574 self.real_models = []
575 for package_label in real_packages:
576 app = global_packages.get_package_config(package_label)
577 for model in app.get_models():
578 self.real_models.append(ModelState.from_model(model, exclude_rels=True))
579 # Populate the app registry with a stub for each application.
580 package_labels = {model_state.package_label for model_state in models.values()}
581 package_configs = [
582 PackageConfigStub(label)
583 for label in sorted([*real_packages, *package_labels])
584 ]
585 super().__init__(package_configs)
586
587 # These locks get in the way of copying as implemented in clone(),
588 # which is called whenever Plain duplicates a StatePackages before
589 # updating it.
590 self._lock = None
591
592 self.render_multiple([*models.values(), *self.real_models])
593
594 # There shouldn't be any operations pending at this point.
595 from plain.models.preflight import _check_lazy_references
596
597 ignore = (
598 {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
599 )
600 errors = _check_lazy_references(self, ignore=ignore)
601 if errors:
602 raise ValueError("\n".join(error.msg for error in errors))
603
604 @contextmanager
605 def bulk_update(self):
606 # Avoid clearing each model's cache for each change. Instead, clear
607 # all caches when we're finished updating the model instances.
608 ready = self.ready
609 self.ready = False
610 try:
611 yield
612 finally:
613 self.ready = ready
614 self.clear_cache()
615
616 def render_multiple(self, model_states):
617 # We keep trying to render the models in a loop, ignoring invalid
618 # base errors, until the size of the unrendered models doesn't
619 # decrease by at least one, meaning there's a base dependency loop/
620 # missing base.
621 if not model_states:
622 return
623 # Prevent that all model caches are expired for each render.
624 with self.bulk_update():
625 unrendered_models = model_states
626 while unrendered_models:
627 new_unrendered_models = []
628 for model in unrendered_models:
629 try:
630 model.render(self)
631 except InvalidBasesError:
632 new_unrendered_models.append(model)
633 if len(new_unrendered_models) == len(unrendered_models):
634 raise InvalidBasesError(
635 "Cannot resolve bases for %r\nThis can happen if you are "
636 "inheriting models from an app with migrations (e.g. "
637 "contrib.auth)\n in an app with no migrations"
638 % new_unrendered_models
639 )
640 unrendered_models = new_unrendered_models
641
642 def clone(self):
643 """Return a clone of this registry."""
644 clone = StatePackages([], {})
645 clone.all_models = copy.deepcopy(self.all_models)
646
647 for package_label in self.package_configs:
648 package_config = PackageConfigStub(package_label)
649 package_config.packages = clone
650 package_config.import_models()
651 clone.package_configs[package_label] = package_config
652
653 # No need to actually clone them, they'll never change
654 clone.real_models = self.real_models
655 return clone
656
657 def register_model(self, package_label, model):
658 self.all_models[package_label][model._meta.model_name] = model
659 if package_label not in self.package_configs:
660 self.package_configs[package_label] = PackageConfigStub(package_label)
661 self.package_configs[package_label].packages = self
662 self.package_configs[package_label].models[model._meta.model_name] = model
663 self.do_pending_operations(model)
664 self.clear_cache()
665
666 def unregister_model(self, package_label, model_name):
667 try:
668 del self.all_models[package_label][model_name]
669 del self.package_configs[package_label].models[model_name]
670 except KeyError:
671 pass
672
673
674class ModelState:
675 """
676 Represent a Plain Model. Don't use the actual Model class as it's not
677 designed to have its options changed - instead, mutate this one and then
678 render it into a Model as required.
679
680 Note that while you are allowed to mutate .fields, you are not allowed
681 to mutate the Field instances inside there themselves - you must instead
682 assign new ones, as these are not detached during a clone.
683 """
684
685 def __init__(
686 self, package_label, name, fields, options=None, bases=None, managers=None
687 ):
688 self.package_label = package_label
689 self.name = name
690 self.fields = dict(fields)
691 self.options = options or {}
692 self.options.setdefault("indexes", [])
693 self.options.setdefault("constraints", [])
694 self.bases = bases or (models.Model,)
695 self.managers = managers or []
696 for name, field in self.fields.items():
697 # Sanity-check that fields are NOT already bound to a model.
698 if hasattr(field, "model"):
699 raise ValueError(
700 'ModelState.fields cannot be bound to a model - "%s" is.' % name
701 )
702 # Sanity-check that relation fields are NOT referring to a model class.
703 if field.is_relation and hasattr(field.related_model, "_meta"):
704 raise ValueError(
705 'ModelState.fields cannot refer to a model class - "%s.to" does. '
706 "Use a string reference instead." % name
707 )
708 if field.many_to_many and hasattr(field.remote_field.through, "_meta"):
709 raise ValueError(
710 'ModelState.fields cannot refer to a model class - "%s.through" '
711 "does. Use a string reference instead." % name
712 )
713 # Sanity-check that indexes have their name set.
714 for index in self.options["indexes"]:
715 if not index.name:
716 raise ValueError(
717 "Indexes passed to ModelState require a name attribute. "
718 "%r doesn't have one." % index
719 )
720
721 @cached_property
722 def name_lower(self):
723 return self.name.lower()
724
725 def get_field(self, field_name):
726 if field_name == "_order":
727 field_name = self.options.get("order_with_respect_to", field_name)
728 return self.fields[field_name]
729
730 @classmethod
731 def from_model(cls, model, exclude_rels=False):
732 """Given a model, return a ModelState representing it."""
733 # Deconstruct the fields
734 fields = []
735 for field in model._meta.local_fields:
736 if getattr(field, "remote_field", None) and exclude_rels:
737 continue
738 if isinstance(field, models.OrderWrt):
739 continue
740 name = field.name
741 try:
742 fields.append((name, field.clone()))
743 except TypeError as e:
744 raise TypeError(
745 f"Couldn't reconstruct field {name} on {model._meta.label}: {e}"
746 )
747 if not exclude_rels:
748 for field in model._meta.local_many_to_many:
749 name = field.name
750 try:
751 fields.append((name, field.clone()))
752 except TypeError as e:
753 raise TypeError(
754 "Couldn't reconstruct m2m field {} on {}: {}".format(
755 name,
756 model._meta.object_name,
757 e,
758 )
759 )
760 # Extract the options
761 options = {}
762 for name in DEFAULT_NAMES:
763 # Ignore some special options
764 if name in ["packages", "package_label"]:
765 continue
766 elif name in model._meta.original_attrs:
767 if name == "indexes":
768 indexes = [idx.clone() for idx in model._meta.indexes]
769 for index in indexes:
770 if not index.name:
771 index.set_name_with_model(model)
772 options["indexes"] = indexes
773 elif name == "constraints":
774 options["constraints"] = [
775 con.clone() for con in model._meta.constraints
776 ]
777 else:
778 options[name] = model._meta.original_attrs[name]
779 # If we're ignoring relationships, remove all field-listing model
780 # options (that option basically just means "make a stub model")
781 if exclude_rels:
782 if "order_with_respect_to" in options:
783 del options["order_with_respect_to"]
784 # Private fields are ignored, so remove options that refer to them.
785 elif options.get("order_with_respect_to") in {
786 field.name for field in model._meta.private_fields
787 }:
788 del options["order_with_respect_to"]
789
790 def flatten_bases(model):
791 bases = []
792 for base in model.__bases__:
793 if hasattr(base, "_meta") and base._meta.abstract:
794 bases.extend(flatten_bases(base))
795 else:
796 bases.append(base)
797 return bases
798
799 # We can't rely on __mro__ directly because we only want to flatten
800 # abstract models and not the whole tree. However by recursing on
801 # __bases__ we may end up with duplicates and ordering issues, we
802 # therefore discard any duplicates and reorder the bases according
803 # to their index in the MRO.
804 flattened_bases = sorted(
805 set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
806 )
807
808 # Make our record
809 bases = tuple(
810 (base._meta.label_lower if hasattr(base, "_meta") else base)
811 for base in flattened_bases
812 )
813 # Ensure at least one base inherits from models.Model
814 if not any(
815 (isinstance(base, str) or issubclass(base, models.Model)) for base in bases
816 ):
817 bases = (models.Model,)
818
819 managers = []
820 manager_names = set()
821 default_manager_shim = None
822 for manager in model._meta.managers:
823 if manager.name in manager_names:
824 # Skip overridden managers.
825 continue
826 elif manager.use_in_migrations:
827 # Copy managers usable in migrations.
828 new_manager = copy.copy(manager)
829 new_manager._set_creation_counter()
830 elif manager is model._base_manager or manager is model._default_manager:
831 # Shim custom managers used as default and base managers.
832 new_manager = models.Manager()
833 new_manager.model = manager.model
834 new_manager.name = manager.name
835 if manager is model._default_manager:
836 default_manager_shim = new_manager
837 else:
838 continue
839 manager_names.add(manager.name)
840 managers.append((manager.name, new_manager))
841
842 # Ignore a shimmed default manager called objects if it's the only one.
843 if managers == [("objects", default_manager_shim)]:
844 managers = []
845
846 # Construct the new ModelState
847 return cls(
848 model._meta.package_label,
849 model._meta.object_name,
850 fields,
851 options,
852 bases,
853 managers,
854 )
855
856 def construct_managers(self):
857 """Deep-clone the managers using deconstruction."""
858 # Sort all managers by their creation counter
859 sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter)
860 for mgr_name, manager in sorted_managers:
861 as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct()
862 if as_manager:
863 qs_class = import_string(qs_path)
864 yield mgr_name, qs_class.as_manager()
865 else:
866 manager_class = import_string(manager_path)
867 yield mgr_name, manager_class(*args, **kwargs)
868
869 def clone(self):
870 """Return an exact copy of this ModelState."""
871 return self.__class__(
872 package_label=self.package_label,
873 name=self.name,
874 fields=dict(self.fields),
875 # Since options are shallow-copied here, operations such as
876 # AddIndex must replace their option (e.g 'indexes') rather
877 # than mutating it.
878 options=dict(self.options),
879 bases=self.bases,
880 managers=list(self.managers),
881 )
882
883 def render(self, packages):
884 """Create a Model object from our current state into the given packages."""
885 # First, make a Meta object
886 meta_contents = {
887 "package_label": self.package_label,
888 "packages": packages,
889 **self.options,
890 }
891 meta = type("Meta", (), meta_contents)
892 # Then, work out our bases
893 try:
894 bases = tuple(
895 (packages.get_model(base) if isinstance(base, str) else base)
896 for base in self.bases
897 )
898 except LookupError:
899 raise InvalidBasesError(
900 f"Cannot resolve one or more bases from {self.bases!r}"
901 )
902 # Clone fields for the body, add other bits.
903 body = {name: field.clone() for name, field in self.fields.items()}
904 body["Meta"] = meta
905 body["__module__"] = "__fake__"
906
907 # Restore managers
908 body.update(self.construct_managers())
909 # Then, make a Model object (packages.register_model is called in __new__)
910 return type(self.name, bases, body)
911
912 def get_index_by_name(self, name):
913 for index in self.options["indexes"]:
914 if index.name == name:
915 return index
916 raise ValueError(f"No index named {name} on model {self.name}")
917
918 def get_constraint_by_name(self, name):
919 for constraint in self.options["constraints"]:
920 if constraint.name == name:
921 return constraint
922 raise ValueError(f"No constraint named {name} on model {self.name}")
923
924 def __repr__(self):
925 return f"<{self.__class__.__name__}: '{self.package_label}.{self.name}'>"
926
927 def __eq__(self, other):
928 return (
929 (self.package_label == other.package_label)
930 and (self.name == other.name)
931 and (len(self.fields) == len(other.fields))
932 and all(
933 k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
934 for (k1, f1), (k2, f2) in zip(
935 sorted(self.fields.items()),
936 sorted(other.fields.items()),
937 )
938 )
939 and (self.options == other.options)
940 and (self.bases == other.bases)
941 and (self.managers == other.managers)
942 )