1from __future__ import annotations
2
3import copy
4from collections import defaultdict
5from contextlib import contextmanager
6from functools import cached_property, partial
7from typing import TYPE_CHECKING, Any, cast
8
9from plain import models
10from plain.models.exceptions import FieldDoesNotExist
11from plain.models.fields import NOT_PROVIDED
12from plain.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT, RelatedField
13from plain.models.meta import Meta
14from plain.models.migrations.utils import field_is_referenced, get_references
15from plain.models.registry import ModelsRegistry
16from plain.models.registry import models_registry as global_models
17from plain.packages import packages_registry
18
19from .exceptions import InvalidBasesError
20from .utils import resolve_relation
21
22if TYPE_CHECKING:
23 from collections.abc import Generator, Iterable
24
25 from plain.models.fields import Field
26
27
28def _get_package_label_and_model_name(
29 model: str | type[models.Model], package_label: str = ""
30) -> tuple[str, str]:
31 if isinstance(model, str):
32 split = model.split(".", 1)
33 return (
34 cast(tuple[str, str], tuple(split))
35 if len(split) == 2
36 else (package_label, split[0])
37 )
38 else:
39 return model.model_options.package_label, model.model_options.model_name
40
41
42def _get_related_models(m: type[models.Model]) -> list[type[models.Model]]:
43 """Return all models that have a direct relationship to the given model."""
44 related_models = [
45 subclass
46 for subclass in m.__subclasses__()
47 if issubclass(subclass, models.Model)
48 ]
49 from plain.models.fields.reverse_related import ForeignObjectRel
50
51 related_fields_models = set()
52 for f in m._model_meta.get_fields(include_reverse=True):
53 if (
54 isinstance(f, RelatedField | ForeignObjectRel)
55 and f.related_model is not None
56 and not isinstance(f.related_model, str)
57 ):
58 related_fields_models.add(f.model)
59 related_models.append(f.related_model)
60 return related_models
61
62
63def get_related_models_tuples(model: type[models.Model]) -> set[tuple[str, str]]:
64 """
65 Return a list of typical (package_label, model_name) tuples for all related
66 models for the given model.
67 """
68 return {
69 (rel_mod.model_options.package_label, rel_mod.model_options.model_name)
70 for rel_mod in _get_related_models(model)
71 }
72
73
74def get_related_models_recursive(model: type[models.Model]) -> set[tuple[str, str]]:
75 """
76 Return all models that have a direct or indirect relationship
77 to the given model.
78
79 Relationships are either defined by explicit relational fields, like
80 ForeignKeyField or ManyToManyField, or by inheriting from another
81 model (a superclass is related to its subclasses, but not vice versa).
82 """
83 seen = set()
84 queue = _get_related_models(model)
85 for rel_mod in queue:
86 rel_package_label, rel_model_name = (
87 rel_mod.model_options.package_label,
88 rel_mod.model_options.model_name,
89 )
90 if (rel_package_label, rel_model_name) in seen:
91 continue
92 seen.add((rel_package_label, rel_model_name))
93 queue.extend(_get_related_models(rel_mod))
94 return seen - {(model.model_options.package_label, model.model_options.model_name)}
95
96
97class ProjectState:
98 """
99 Represent the entire project's overall state. This is the item that is
100 passed around - do it here rather than at the app level so that cross-app
101 FKs/etc. resolve properly.
102 """
103
104 def __init__(
105 self,
106 models: dict[tuple[str, str], ModelState] | None = None,
107 real_packages: set[str] | None = None,
108 ):
109 self.models = models or {}
110 # Packages to include from main registry, usually unmigrated ones
111 if real_packages is None:
112 real_packages = set()
113 else:
114 assert isinstance(real_packages, set)
115 self.real_packages = real_packages
116 self.is_delayed = False
117 # {remote_model_key: {model_key: {field_name: field}}}
118 self._relations: (
119 dict[tuple[str, str], dict[tuple[str, str], dict[str, Field]]] | None
120 ) = None
121
122 @property
123 def relations(
124 self,
125 ) -> dict[tuple[str, str], dict[tuple[str, str], dict[str, Field]]]:
126 if self._relations is None:
127 self.resolve_fields_and_relations()
128 assert self._relations is not None
129 return self._relations
130
131 def add_model(self, model_state: ModelState) -> None:
132 model_key = model_state.package_label, model_state.name_lower
133 self.models[model_key] = model_state
134 if self._relations is not None:
135 self.resolve_model_relations(model_key)
136 if "models_registry" in self.__dict__: # hasattr would cache the property
137 self.reload_model(*model_key)
138
139 def remove_model(self, package_label: str, model_name: str) -> None:
140 model_key = package_label, model_name
141 del self.models[model_key]
142 if self._relations is not None:
143 self._relations.pop(model_key, None)
144 # Call list() since _relations can change size during iteration.
145 for related_model_key, model_relations in list(self._relations.items()):
146 model_relations.pop(model_key, None)
147 if not model_relations:
148 del self._relations[related_model_key]
149 if "models_registry" in self.__dict__: # hasattr would cache the property
150 self.models_registry.unregister_model(*model_key)
151 # Need to do this explicitly since unregister_model() doesn't clear
152 # the cache automatically (#24513)
153 self.models_registry.clear_cache()
154
155 def rename_model(self, package_label: str, old_name: str, new_name: str) -> None:
156 # Add a new model.
157 old_name_lower = old_name.lower()
158 new_name_lower = new_name.lower()
159 renamed_model = self.models[package_label, old_name_lower].clone()
160 renamed_model.name = new_name
161 self.models[package_label, new_name_lower] = renamed_model
162 # Repoint all fields pointing to the old model to the new one.
163 old_model_tuple = (package_label, old_name_lower)
164 new_remote_model = f"{package_label}.{new_name}"
165 to_reload = set()
166 for model_state, name, field, reference in get_references(
167 self, old_model_tuple
168 ):
169 if not isinstance(field, RelatedField):
170 continue
171 changed_field = None
172 if reference.to:
173 changed_field = field.clone()
174 assert changed_field.remote_field is not None
175 changed_field.remote_field.model = new_remote_model # type: ignore[assignment]
176 if reference.through:
177 if changed_field is None:
178 changed_field = field.clone()
179 assert changed_field.remote_field is not None
180 changed_field.remote_field.through = new_remote_model # type: ignore[assignment]
181 if changed_field:
182 model_state.fields[name] = changed_field
183 to_reload.add((model_state.package_label, model_state.name_lower))
184 if self._relations is not None:
185 old_name_key = package_label, old_name_lower
186 new_name_key = package_label, new_name_lower
187 if old_name_key in self._relations:
188 self._relations[new_name_key] = self._relations.pop(old_name_key)
189 for model_relations in self._relations.values():
190 if old_name_key in model_relations:
191 model_relations[new_name_key] = model_relations.pop(old_name_key)
192 # Reload models related to old model before removing the old model.
193 self.reload_models(to_reload, delay=True)
194 # Remove the old model.
195 self.remove_model(package_label, old_name_lower)
196 self.reload_model(package_label, new_name_lower, delay=True)
197
198 def alter_model_options(
199 self,
200 package_label: str,
201 model_name: str,
202 options: dict[str, Any],
203 option_keys: Iterable[str] | None = None,
204 ) -> None:
205 model_state = self.models[package_label, model_name]
206 model_state.options = {**model_state.options, **options}
207 if option_keys:
208 for key in option_keys:
209 if key not in options:
210 model_state.options.pop(key, False)
211 self.reload_model(package_label, model_name, delay=True)
212
213 def _append_option(
214 self, package_label: str, model_name: str, option_name: str, obj: Any
215 ) -> None:
216 model_state = self.models[package_label, model_name]
217 model_state.options[option_name] = [*model_state.options[option_name], obj]
218 self.reload_model(package_label, model_name, delay=True)
219
220 def _remove_option(
221 self, package_label: str, model_name: str, option_name: str, obj_name: str
222 ) -> None:
223 model_state = self.models[package_label, model_name]
224 objs = model_state.options[option_name]
225 model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
226 self.reload_model(package_label, model_name, delay=True)
227
228 def add_index(self, package_label: str, model_name: str, index: Any) -> None:
229 self._append_option(package_label, model_name, "indexes", index)
230
231 def remove_index(
232 self, package_label: str, model_name: str, index_name: str
233 ) -> None:
234 self._remove_option(package_label, model_name, "indexes", index_name)
235
236 def rename_index(
237 self,
238 package_label: str,
239 model_name: str,
240 old_index_name: str,
241 new_index_name: str,
242 ) -> None:
243 model_state = self.models[package_label, model_name]
244 objs = model_state.options["indexes"]
245
246 new_indexes = []
247 for obj in objs:
248 if obj.name == old_index_name:
249 obj = obj.clone()
250 obj.name = new_index_name
251 new_indexes.append(obj)
252
253 model_state.options["indexes"] = new_indexes
254 self.reload_model(package_label, model_name, delay=True)
255
256 def add_constraint(
257 self, package_label: str, model_name: str, constraint: Any
258 ) -> None:
259 self._append_option(package_label, model_name, "constraints", constraint)
260
261 def remove_constraint(
262 self, package_label: str, model_name: str, constraint_name: str
263 ) -> None:
264 self._remove_option(package_label, model_name, "constraints", constraint_name)
265
266 def add_field(
267 self,
268 package_label: str,
269 model_name: str,
270 name: str,
271 field: Field,
272 preserve_default: bool,
273 ) -> None:
274 # If preserve default is off, don't use the default for future state.
275 if not preserve_default:
276 field = field.clone()
277 field.default = NOT_PROVIDED
278 else:
279 field = field
280 model_key = package_label, model_name
281 self.models[model_key].fields[name] = field
282 if self._relations is not None:
283 self.resolve_model_field_relations(model_key, name, field)
284 # Delay rendering of relationships if it's not a relational field.
285 delay = not isinstance(field, RelatedField)
286 self.reload_model(*model_key, delay=delay)
287
288 def remove_field(self, package_label: str, model_name: str, name: str) -> None:
289 model_key = package_label, model_name
290 model_state = self.models[model_key]
291 old_field = model_state.fields.pop(name)
292 if self._relations is not None:
293 self.resolve_model_field_relations(model_key, name, old_field)
294 # Delay rendering of relationships if it's not a relational field.
295 delay = not isinstance(old_field, RelatedField)
296 self.reload_model(*model_key, delay=delay)
297
298 def alter_field(
299 self,
300 package_label: str,
301 model_name: str,
302 name: str,
303 field: Field,
304 preserve_default: bool,
305 ) -> None:
306 if not preserve_default:
307 field = field.clone()
308 field.default = NOT_PROVIDED
309 else:
310 field = field
311 model_key = package_label, model_name
312 fields = self.models[model_key].fields
313 if self._relations is not None:
314 old_field = fields.pop(name)
315 if isinstance(old_field, RelatedField):
316 self.resolve_model_field_relations(model_key, name, old_field)
317 fields[name] = field
318 if isinstance(field, RelatedField):
319 self.resolve_model_field_relations(model_key, name, field)
320 else:
321 fields[name] = field
322 # TODO: investigate if old relational fields must be reloaded or if
323 # it's sufficient if the new field is (#27737).
324 # Delay rendering of relationships if it's not a relational field and
325 # not referenced by a foreign key.
326 delay = not isinstance(field, RelatedField) and not field_is_referenced(
327 self, model_key, (name, field)
328 )
329 self.reload_model(*model_key, delay=delay)
330
331 def rename_field(
332 self, package_label: str, model_name: str, old_name: str, new_name: str
333 ) -> None:
334 model_key = package_label, model_name
335 model_state = self.models[model_key]
336 # Rename the field.
337 fields = model_state.fields
338 try:
339 found = fields.pop(old_name)
340 except KeyError:
341 raise FieldDoesNotExist(
342 f"{package_label}.{model_name} has no field named '{old_name}'"
343 )
344 fields[new_name] = found
345 # Check if there are any references to this field
346 references = get_references(self, model_key, (old_name, found))
347 delay = not bool(references)
348 if self._relations is not None:
349 old_name_lower = old_name.lower()
350 new_name_lower = new_name.lower()
351 for to_model in self._relations.values():
352 if old_name_lower in to_model[model_key]:
353 field = to_model[model_key].pop(old_name_lower)
354 field.name = new_name_lower
355 to_model[model_key][new_name_lower] = field
356 self.reload_model(*model_key, delay=delay)
357
358 def _find_reload_model(
359 self, package_label: str, model_name: str, delay: bool = False
360 ) -> set[tuple[str, str]]:
361 if delay:
362 self.is_delayed = True
363
364 related_models: set[tuple[str, str]] = set()
365
366 try:
367 old_model = self.models_registry.get_model(package_label, model_name)
368 except LookupError:
369 pass
370 else:
371 # Get all relations to and from the old model before reloading,
372 # as _model_meta.models_registry may change
373 if delay:
374 related_models = get_related_models_tuples(old_model)
375 else:
376 related_models = get_related_models_recursive(old_model)
377
378 # Get all outgoing references from the model to be rendered
379 model_state = self.models[(package_label, model_name)]
380 # Directly related models are the models pointed to by ForeignKeys and ManyToManyFields.
381 direct_related_models = set()
382 for field in model_state.fields.values():
383 if isinstance(field, RelatedField):
384 if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
385 continue
386 rel_package_label, rel_model_name = _get_package_label_and_model_name(
387 field.related_model,
388 package_label,
389 )
390 direct_related_models.add((rel_package_label, rel_model_name.lower()))
391
392 # For all direct related models recursively get all related models.
393 related_models.update(direct_related_models)
394 for rel_package_label, rel_model_name in direct_related_models:
395 try:
396 rel_model = self.models_registry.get_model(
397 rel_package_label, rel_model_name
398 )
399 except LookupError:
400 pass
401 else:
402 if delay:
403 related_models.update(get_related_models_tuples(rel_model))
404 else:
405 related_models.update(get_related_models_recursive(rel_model))
406
407 # Include the model itself
408 related_models.add((package_label, model_name))
409
410 return related_models
411
412 def reload_model(
413 self, package_label: str, model_name: str, delay: bool = False
414 ) -> None:
415 if "models_registry" in self.__dict__: # hasattr would cache the property
416 related_models = self._find_reload_model(package_label, model_name, delay)
417 self._reload(related_models)
418
419 def reload_models(self, models: set[tuple[str, str]], delay: bool = True) -> None:
420 if "models_registry" in self.__dict__: # hasattr would cache the property
421 related_models = set()
422 for package_label, model_name in models:
423 related_models.update(
424 self._find_reload_model(package_label, model_name, delay)
425 )
426 self._reload(related_models)
427
428 def _reload(self, related_models: set[tuple[str, str]]) -> None:
429 # Unregister all related models
430 with self.models_registry.bulk_update():
431 for rel_package_label, rel_model_name in related_models:
432 self.models_registry.unregister_model(rel_package_label, rel_model_name)
433
434 states_to_be_rendered = []
435 # Gather all models states of those models that will be rerendered.
436 # This includes:
437 # 1. All related models of unmigrated packages
438 for model_state in self.models_registry.real_models:
439 if (model_state.package_label, model_state.name_lower) in related_models:
440 states_to_be_rendered.append(model_state)
441
442 # 2. All related models of migrated packages
443 for rel_package_label, rel_model_name in related_models:
444 try:
445 model_state = self.models[rel_package_label, rel_model_name]
446 except KeyError:
447 pass
448 else:
449 states_to_be_rendered.append(model_state)
450
451 # Render all models
452 self.models_registry.render_multiple(states_to_be_rendered)
453
454 def update_model_field_relation(
455 self,
456 model: str | type[models.Model],
457 model_key: tuple[str, str],
458 field_name: str,
459 field: Field,
460 concretes: dict[tuple[str, str], tuple[str, str]],
461 ) -> None:
462 assert self._relations is not None
463 remote_model_key = resolve_relation(model, *model_key)
464 if (
465 remote_model_key[0] not in self.real_packages
466 and remote_model_key in concretes
467 ):
468 remote_model_key = concretes[remote_model_key]
469 relations_to_remote_model = self._relations[remote_model_key]
470 if field_name in self.models[model_key].fields:
471 # The assert holds because it's a new relation, or an altered
472 # relation, in which case references have been removed by
473 # alter_field().
474 assert field_name not in relations_to_remote_model[model_key]
475 relations_to_remote_model[model_key][field_name] = field
476 else:
477 del relations_to_remote_model[model_key][field_name]
478 if not relations_to_remote_model[model_key]:
479 del relations_to_remote_model[model_key]
480
481 def resolve_model_field_relations(
482 self,
483 model_key: tuple[str, str],
484 field_name: str,
485 field: Field,
486 concretes: dict[tuple[str, str], tuple[str, str]] | None = None,
487 ) -> None:
488 # Only process fields that have relations
489 if not isinstance(field, RelatedField):
490 return None
491 remote_field = field.remote_field
492 if not remote_field:
493 return None
494 if concretes is None:
495 concretes = self._get_concrete_models_mapping()
496
497 self.update_model_field_relation(
498 remote_field.model,
499 model_key,
500 field_name,
501 field,
502 concretes,
503 )
504
505 through = getattr(remote_field, "through", None)
506 if not through:
507 return None
508 self.update_model_field_relation(
509 through, model_key, field_name, field, concretes
510 )
511
512 def resolve_model_relations(
513 self,
514 model_key: tuple[str, str],
515 concretes: dict[tuple[str, str], tuple[str, str]] | None = None,
516 ) -> None:
517 if concretes is None:
518 concretes = self._get_concrete_models_mapping()
519
520 model_state = self.models[model_key]
521 for field_name, field in model_state.fields.items():
522 self.resolve_model_field_relations(model_key, field_name, field, concretes)
523
524 def resolve_fields_and_relations(self) -> None:
525 # Resolve fields.
526 for model_state in self.models.values():
527 for field_name, field in model_state.fields.items():
528 field.name = field_name
529 # Resolve relations.
530 # {remote_model_key: {model_key: {field_name: field}}}
531 self._relations = defaultdict(partial(defaultdict, dict))
532 concretes = self._get_concrete_models_mapping()
533
534 for model_key in concretes:
535 self.resolve_model_relations(model_key, concretes)
536
537 def _get_concrete_models_mapping(self) -> dict[tuple[str, str], tuple[str, str]]:
538 concrete_models_mapping = {}
539 for model_key, model_state in self.models.items():
540 concrete_models_mapping[model_key] = model_key
541 return concrete_models_mapping
542
543 def clone(self) -> ProjectState:
544 """Return an exact copy of this ProjectState."""
545 new_state = ProjectState(
546 models={k: v.clone() for k, v in self.models.items()},
547 real_packages=self.real_packages,
548 )
549 if "models_registry" in self.__dict__:
550 new_state.models_registry = self.models_registry.clone()
551 new_state.is_delayed = self.is_delayed
552 return new_state
553
554 def clear_delayed_models_cache(self) -> None:
555 if self.is_delayed and "models_registry" in self.__dict__:
556 del self.__dict__["models_registry"]
557
558 @cached_property
559 def models_registry(self) -> StateModelsRegistry:
560 return StateModelsRegistry(self.real_packages, self.models)
561
562 @classmethod
563 def from_models_registry(cls, models_registry: ModelsRegistry) -> ProjectState:
564 """Take an Packages and return a ProjectState matching it."""
565 app_models = {}
566 for model in models_registry.get_models():
567 model_state = ModelState.from_model(model)
568 app_models[(model_state.package_label, model_state.name_lower)] = (
569 model_state
570 )
571 return cls(app_models)
572
573 def __eq__(self, other: object) -> bool:
574 if not isinstance(other, ProjectState):
575 return NotImplemented
576 return self.models == other.models and self.real_packages == other.real_packages
577
578
579class StateModelsRegistry(ModelsRegistry):
580 """
581 Subclass of the global Packages registry class to better handle dynamic model
582 additions and removals.
583 """
584
585 def __init__(
586 self,
587 real_packages: set[str],
588 models: dict[tuple[str, str], ModelState],
589 ):
590 # Any packages in self.real_packages should have all their models included
591 # in the render. We don't use the original model instances as there
592 # are some variables that refer to the Packages object.
593 # FKs/M2Ms from real packages are also not included as they just
594 # mess things up with partial states (due to lack of dependencies)
595 self.real_models: list[ModelState] = []
596 for package_label in real_packages:
597 for model in global_models.get_models(package_label=package_label):
598 self.real_models.append(ModelState.from_model(model, exclude_rels=True))
599
600 super().__init__()
601
602 self.render_multiple([*models.values(), *self.real_models])
603
604 self.ready = True
605
606 # There shouldn't be any operations pending at this point.
607 from plain.models.preflight import _check_lazy_references
608
609 if errors := _check_lazy_references(self, packages_registry):
610 raise ValueError("\n".join(error.fix for error in errors))
611
612 @contextmanager
613 def bulk_update(self) -> Generator[None, None, None]:
614 # Avoid clearing each model's cache for each change. Instead, clear
615 # all caches when we're finished updating the model instances.
616 ready = self.ready
617 self.ready = False
618 try:
619 yield
620 finally:
621 self.ready = ready
622 self.clear_cache()
623
624 def render_multiple(self, model_states: list[ModelState]) -> None:
625 # We keep trying to render the models in a loop, ignoring invalid
626 # base errors, until the size of the unrendered models doesn't
627 # decrease by at least one, meaning there's a base dependency loop/
628 # missing base.
629 if not model_states:
630 return None
631 # Prevent that all model caches are expired for each render.
632 with self.bulk_update():
633 unrendered_models = model_states
634 while unrendered_models:
635 new_unrendered_models = []
636 for model in unrendered_models:
637 try:
638 model.render(self)
639 except InvalidBasesError:
640 new_unrendered_models.append(model)
641 if len(new_unrendered_models) == len(unrendered_models):
642 raise InvalidBasesError(
643 f"Cannot resolve bases for {new_unrendered_models!r}\nThis can happen if you are "
644 "inheriting models from an app with migrations (e.g. "
645 "contrib.auth)\n in an app with no migrations"
646 )
647 unrendered_models = new_unrendered_models
648
649 def clone(self) -> StateModelsRegistry:
650 """Return a clone of this registry."""
651 clone = StateModelsRegistry(set(), {})
652 clone.all_models = copy.deepcopy(self.all_models)
653
654 # No need to actually clone them, they'll never change
655 clone.real_models = self.real_models
656 return clone
657
658 def register_model(self, package_label: str, model: type[models.Model]) -> None:
659 self.all_models[package_label][model.model_options.model_name] = model
660 self.do_pending_operations(model)
661 self.clear_cache()
662
663 def unregister_model(self, package_label: str, model_name: str) -> None:
664 try:
665 del self.all_models[package_label][model_name]
666 except KeyError:
667 pass
668
669
670class ModelState:
671 """
672 Represent a Plain Model. Don't use the actual Model class as it's not
673 designed to have its options changed - instead, mutate this one and then
674 render it into a Model as required.
675
676 Note that while you are allowed to mutate .fields, you are not allowed
677 to mutate the Field instances inside there themselves - you must instead
678 assign new ones, as these are not detached during a clone.
679 """
680
681 def __init__(
682 self,
683 package_label: str,
684 name: str,
685 fields: Iterable[tuple[str, Field]],
686 options: dict[str, Any] | None = None,
687 bases: tuple[str | type[models.Model], ...] | None = None,
688 ):
689 self.package_label = package_label
690 self.name = name
691 self.fields: dict[str, Field] = dict(fields)
692 self.options = options or {}
693 self.options.setdefault("indexes", [])
694 self.options.setdefault("constraints", [])
695 self.bases = bases or (models.Model,)
696 for name, field in self.fields.items():
697 # Sanity-check that fields are NOT already bound to a model.
698 if hasattr(field, "model"):
699 raise ValueError(
700 f'ModelState.fields cannot be bound to a model - "{name}" is.'
701 )
702 # Sanity-check that relation fields are NOT referring to a model class.
703 if isinstance(field, RelatedField) and hasattr(
704 field.related_model, "_model_meta"
705 ):
706 raise ValueError(
707 f'ModelState.fields cannot refer to a model class - "{name}.to" does. '
708 "Use a string reference instead."
709 )
710 from plain.models.fields.related import ManyToManyField
711
712 if isinstance(field, ManyToManyField) and hasattr(
713 field.remote_field.through, "_model_meta"
714 ):
715 raise ValueError(
716 f'ModelState.fields cannot refer to a model class - "{name}.through" '
717 "does. Use a string reference instead."
718 )
719 # Sanity-check that indexes have their name set.
720 for index in self.options["indexes"]:
721 if not index.name:
722 raise ValueError(
723 "Indexes passed to ModelState require a name attribute. "
724 f"{index!r} doesn't have one."
725 )
726
727 @cached_property
728 def name_lower(self) -> str:
729 return self.name.lower()
730
731 def get_field(self, field_name: str) -> Field:
732 return self.fields[field_name]
733
734 @classmethod
735 def from_model(
736 cls, model: type[models.Model], exclude_rels: bool = False
737 ) -> ModelState:
738 """Given a model, return a ModelState representing it."""
739 # Deconstruct the fields
740 fields = []
741 for field in model._model_meta.local_fields:
742 if getattr(field, "remote_field", None) and exclude_rels:
743 continue
744 name = field.name
745 try:
746 fields.append((name, field.clone()))
747 except TypeError as e:
748 raise TypeError(
749 f"Couldn't reconstruct field {name} on {model.model_options.label}: {e}"
750 )
751 if not exclude_rels:
752 for field in model._model_meta.local_many_to_many:
753 name = field.name
754 try:
755 fields.append((name, field.clone()))
756 except TypeError as e:
757 raise TypeError(
758 f"Couldn't reconstruct m2m field {name} on {model.model_options.object_name}: {e}"
759 )
760
761 def flatten_bases(model: type[models.Model]) -> list[type[models.Model]]:
762 bases = []
763 for base in model.__bases__:
764 bases.append(base)
765 return bases
766
767 # We can't rely on __mro__ directly because we only want to flatten
768 # abstract models and not the whole tree. However by recursing on
769 # __bases__ we may end up with duplicates and ordering issues, we
770 # therefore discard any duplicates and reorder the bases according
771 # to their index in the MRO.
772 flattened_bases = sorted(
773 set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
774 )
775
776 # Make our record
777 bases = tuple(
778 (
779 base.model_options.label_lower
780 if not isinstance(base, str)
781 and base is not models.Model
782 and hasattr(base, "_model_meta")
783 else base
784 )
785 for base in flattened_bases
786 )
787 # Ensure at least one base inherits from models.Model
788 if not any(
789 (isinstance(base, str) or issubclass(base, models.Model)) for base in bases
790 ):
791 bases = (models.Model,)
792
793 # Construct the new ModelState
794 return cls(
795 model.model_options.package_label,
796 model.model_options.object_name,
797 fields,
798 model.model_options.export_for_migrations(),
799 bases,
800 )
801
802 def clone(self) -> ModelState:
803 """Return an exact copy of this ModelState."""
804 return self.__class__(
805 package_label=self.package_label,
806 name=self.name,
807 fields=dict(self.fields),
808 # Since options are shallow-copied here, operations such as
809 # AddIndex must replace their option (e.g 'indexes') rather
810 # than mutating it.
811 options=dict(self.options),
812 bases=self.bases,
813 )
814
815 def render(self, models_registry: ModelsRegistry) -> type[models.Model]:
816 """Create a Model object from our current state into the given packages."""
817 # Create Options instance with metadata
818 meta_options = models.Options(
819 package_label=self.package_label,
820 **self.options,
821 )
822 # Then, work out our bases
823 try:
824 bases = tuple(
825 (models_registry.get_model(base) if isinstance(base, str) else base)
826 for base in self.bases
827 )
828 except LookupError:
829 raise InvalidBasesError(
830 f"Cannot resolve one or more bases from {self.bases!r}"
831 )
832 # Clone fields for the body, add other bits.
833 body = {name: field.clone() for name, field in self.fields.items()}
834 body["model_options"] = meta_options
835 body["_model_meta"] = Meta(
836 models_registry=models_registry
837 ) # Use custom registry
838 body["__module__"] = "__fake__"
839
840 # Then, make a Model object (models_registry.register_model is called in __new__)
841 model_class = cast(type[models.Model], type(self.name, bases, body))
842 from plain.models import register_model
843
844 # Register it to the models_registry associated with the model meta
845 # (could probably do this directly right here too...)
846 register_model(model_class)
847
848 return model_class
849
850 def get_index_by_name(self, name: str) -> Any:
851 for index in self.options["indexes"]:
852 if index.name == name:
853 return index
854 raise ValueError(f"No index named {name} on model {self.name}")
855
856 def get_constraint_by_name(self, name: str) -> Any:
857 for constraint in self.options["constraints"]:
858 if constraint.name == name:
859 return constraint
860 raise ValueError(f"No constraint named {name} on model {self.name}")
861
862 def __repr__(self) -> str:
863 return f"<{self.__class__.__name__}: '{self.package_label}.{self.name}'>"
864
865 def __eq__(self, other: object) -> bool:
866 if not isinstance(other, ModelState):
867 return NotImplemented
868 return (
869 (self.package_label == other.package_label)
870 and (self.name == other.name)
871 and (len(self.fields) == len(other.fields))
872 and all(
873 k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
874 for (k1, f1), (k2, f2) in zip(
875 sorted(self.fields.items()),
876 sorted(other.fields.items()),
877 )
878 )
879 and (self.options == other.options)
880 and (self.bases == other.bases)
881 )