1from __future__ import annotations
2
3import copy
4from collections import defaultdict
5from contextlib import contextmanager
6from functools import cached_property, partial
7from typing import TYPE_CHECKING, Any, cast
8
9from plain import models
10from plain.models.exceptions import FieldDoesNotExist
11from plain.models.fields import NOT_PROVIDED
12from plain.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT, RelatedField
13from plain.models.meta import Meta
14from plain.models.migrations.utils import field_is_referenced, get_references
15from plain.models.registry import ModelsRegistry
16from plain.models.registry import models_registry as global_models
17from plain.packages import packages_registry
18
19from .exceptions import InvalidBasesError
20from .utils import resolve_relation
21
22if TYPE_CHECKING:
23 from collections.abc import Generator, Iterable
24
25 from plain.models.fields import Field
26
27
28def _get_package_label_and_model_name(
29 model: str | type[models.Model], package_label: str = ""
30) -> tuple[str, str]:
31 if isinstance(model, str):
32 split = model.split(".", 1)
33 return (
34 cast(tuple[str, str], tuple(split))
35 if len(split) == 2
36 else (package_label, split[0])
37 )
38 else:
39 return model.model_options.package_label, model.model_options.model_name
40
41
42def _get_related_models(m: type[models.Model]) -> list[type[models.Model]]:
43 """Return all models that have a direct relationship to the given model."""
44 related_models = [
45 subclass
46 for subclass in m.__subclasses__()
47 if issubclass(subclass, models.Model)
48 ]
49 from plain.models.fields.reverse_related import ForeignObjectRel
50
51 related_fields_models = set()
52 for f in m._model_meta.get_fields(include_reverse=True):
53 if isinstance(f, RelatedField | ForeignObjectRel) and not isinstance(
54 f.related_model, str
55 ):
56 related_fields_models.add(f.model)
57 related_models.append(f.related_model)
58 return related_models
59
60
61def get_related_models_tuples(model: type[models.Model]) -> set[tuple[str, str]]:
62 """
63 Return a list of typical (package_label, model_name) tuples for all related
64 models for the given model.
65 """
66 return {
67 (rel_mod.model_options.package_label, rel_mod.model_options.model_name)
68 for rel_mod in _get_related_models(model)
69 }
70
71
72def get_related_models_recursive(model: type[models.Model]) -> set[tuple[str, str]]:
73 """
74 Return all models that have a direct or indirect relationship
75 to the given model.
76
77 Relationships are either defined by explicit relational fields, like
78 ForeignKeyField or ManyToManyField, or by inheriting from another
79 model (a superclass is related to its subclasses, but not vice versa).
80 """
81 seen = set()
82 queue = _get_related_models(model)
83 for rel_mod in queue:
84 rel_package_label, rel_model_name = (
85 rel_mod.model_options.package_label,
86 rel_mod.model_options.model_name,
87 )
88 if (rel_package_label, rel_model_name) in seen:
89 continue
90 seen.add((rel_package_label, rel_model_name))
91 queue.extend(_get_related_models(rel_mod))
92 return seen - {(model.model_options.package_label, model.model_options.model_name)}
93
94
95class ProjectState:
96 """
97 Represent the entire project's overall state. This is the item that is
98 passed around - do it here rather than at the app level so that cross-app
99 FKs/etc. resolve properly.
100 """
101
102 def __init__(
103 self,
104 models: dict[tuple[str, str], ModelState] | None = None,
105 real_packages: set[str] | None = None,
106 ):
107 self.models = models or {}
108 # Packages to include from main registry, usually unmigrated ones
109 if real_packages is None:
110 real_packages = set()
111 else:
112 assert isinstance(real_packages, set)
113 self.real_packages = real_packages
114 self.is_delayed = False
115 # {remote_model_key: {model_key: {field_name: field}}}
116 self._relations: (
117 dict[tuple[str, str], dict[tuple[str, str], dict[str, Field]]] | None
118 ) = None
119
120 @property
121 def relations(
122 self,
123 ) -> dict[tuple[str, str], dict[tuple[str, str], dict[str, Field]]]:
124 if self._relations is None:
125 self.resolve_fields_and_relations()
126 assert self._relations is not None
127 return self._relations
128
129 def add_model(self, model_state: ModelState) -> None:
130 model_key = model_state.package_label, model_state.name_lower
131 self.models[model_key] = model_state
132 if self._relations is not None:
133 self.resolve_model_relations(model_key)
134 if "models_registry" in self.__dict__: # hasattr would cache the property
135 self.reload_model(*model_key)
136
137 def remove_model(self, package_label: str, model_name: str) -> None:
138 model_key = package_label, model_name
139 del self.models[model_key]
140 if self._relations is not None:
141 self._relations.pop(model_key, None)
142 # Call list() since _relations can change size during iteration.
143 for related_model_key, model_relations in list(self._relations.items()):
144 model_relations.pop(model_key, None)
145 if not model_relations:
146 del self._relations[related_model_key]
147 if "models_registry" in self.__dict__: # hasattr would cache the property
148 self.models_registry.unregister_model(*model_key)
149 # Need to do this explicitly since unregister_model() doesn't clear
150 # the cache automatically (#24513)
151 self.models_registry.clear_cache()
152
153 def rename_model(self, package_label: str, old_name: str, new_name: str) -> None:
154 # Add a new model.
155 old_name_lower = old_name.lower()
156 new_name_lower = new_name.lower()
157 renamed_model = self.models[package_label, old_name_lower].clone()
158 renamed_model.name = new_name
159 self.models[package_label, new_name_lower] = renamed_model
160 # Repoint all fields pointing to the old model to the new one.
161 old_model_tuple = (package_label, old_name_lower)
162 new_remote_model = f"{package_label}.{new_name}"
163 to_reload = set()
164 for model_state, name, field, reference in get_references(
165 self, old_model_tuple
166 ):
167 if not isinstance(field, RelatedField):
168 continue
169 changed_field = None
170 if reference.to:
171 changed_field = field.clone()
172 assert changed_field.remote_field is not None
173 changed_field.remote_field.model = new_remote_model # type: ignore[assignment]
174 if reference.through:
175 if changed_field is None:
176 changed_field = field.clone()
177 assert changed_field.remote_field is not None
178 changed_field.remote_field.through = new_remote_model # type: ignore[attr-defined]
179 if changed_field:
180 model_state.fields[name] = changed_field
181 to_reload.add((model_state.package_label, model_state.name_lower))
182 if self._relations is not None:
183 old_name_key = package_label, old_name_lower
184 new_name_key = package_label, new_name_lower
185 if old_name_key in self._relations:
186 self._relations[new_name_key] = self._relations.pop(old_name_key)
187 for model_relations in self._relations.values():
188 if old_name_key in model_relations:
189 model_relations[new_name_key] = model_relations.pop(old_name_key)
190 # Reload models related to old model before removing the old model.
191 self.reload_models(to_reload, delay=True)
192 # Remove the old model.
193 self.remove_model(package_label, old_name_lower)
194 self.reload_model(package_label, new_name_lower, delay=True)
195
196 def alter_model_options(
197 self,
198 package_label: str,
199 model_name: str,
200 options: dict[str, Any],
201 option_keys: Iterable[str] | None = None,
202 ) -> None:
203 model_state = self.models[package_label, model_name]
204 model_state.options = {**model_state.options, **options}
205 if option_keys:
206 for key in option_keys:
207 if key not in options:
208 model_state.options.pop(key, False)
209 self.reload_model(package_label, model_name, delay=True)
210
211 def _append_option(
212 self, package_label: str, model_name: str, option_name: str, obj: Any
213 ) -> None:
214 model_state = self.models[package_label, model_name]
215 model_state.options[option_name] = [*model_state.options[option_name], obj]
216 self.reload_model(package_label, model_name, delay=True)
217
218 def _remove_option(
219 self, package_label: str, model_name: str, option_name: str, obj_name: str
220 ) -> None:
221 model_state = self.models[package_label, model_name]
222 objs = model_state.options[option_name]
223 model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
224 self.reload_model(package_label, model_name, delay=True)
225
226 def add_index(self, package_label: str, model_name: str, index: Any) -> None:
227 self._append_option(package_label, model_name, "indexes", index)
228
229 def remove_index(
230 self, package_label: str, model_name: str, index_name: str
231 ) -> None:
232 self._remove_option(package_label, model_name, "indexes", index_name)
233
234 def rename_index(
235 self,
236 package_label: str,
237 model_name: str,
238 old_index_name: str,
239 new_index_name: str,
240 ) -> None:
241 model_state = self.models[package_label, model_name]
242 objs = model_state.options["indexes"]
243
244 new_indexes = []
245 for obj in objs:
246 if obj.name == old_index_name:
247 obj = obj.clone()
248 obj.name = new_index_name
249 new_indexes.append(obj)
250
251 model_state.options["indexes"] = new_indexes
252 self.reload_model(package_label, model_name, delay=True)
253
254 def add_constraint(
255 self, package_label: str, model_name: str, constraint: Any
256 ) -> None:
257 self._append_option(package_label, model_name, "constraints", constraint)
258
259 def remove_constraint(
260 self, package_label: str, model_name: str, constraint_name: str
261 ) -> None:
262 self._remove_option(package_label, model_name, "constraints", constraint_name)
263
264 def add_field(
265 self,
266 package_label: str,
267 model_name: str,
268 name: str,
269 field: Field,
270 preserve_default: bool,
271 ) -> None:
272 # If preserve default is off, don't use the default for future state.
273 if not preserve_default:
274 field = field.clone()
275 field.default = NOT_PROVIDED
276 else:
277 field = field
278 model_key = package_label, model_name
279 self.models[model_key].fields[name] = field
280 if self._relations is not None:
281 self.resolve_model_field_relations(model_key, name, field)
282 # Delay rendering of relationships if it's not a relational field.
283 delay = not isinstance(field, RelatedField)
284 self.reload_model(*model_key, delay=delay)
285
286 def remove_field(self, package_label: str, model_name: str, name: str) -> None:
287 model_key = package_label, model_name
288 model_state = self.models[model_key]
289 old_field = model_state.fields.pop(name)
290 if self._relations is not None:
291 self.resolve_model_field_relations(model_key, name, old_field)
292 # Delay rendering of relationships if it's not a relational field.
293 delay = not isinstance(old_field, RelatedField)
294 self.reload_model(*model_key, delay=delay)
295
296 def alter_field(
297 self,
298 package_label: str,
299 model_name: str,
300 name: str,
301 field: Field,
302 preserve_default: bool,
303 ) -> None:
304 if not preserve_default:
305 field = field.clone()
306 field.default = NOT_PROVIDED
307 else:
308 field = field
309 model_key = package_label, model_name
310 fields = self.models[model_key].fields
311 if self._relations is not None:
312 old_field = fields.pop(name)
313 if isinstance(old_field, RelatedField):
314 self.resolve_model_field_relations(model_key, name, old_field)
315 fields[name] = field
316 if isinstance(field, RelatedField):
317 self.resolve_model_field_relations(model_key, name, field)
318 else:
319 fields[name] = field
320 # TODO: investigate if old relational fields must be reloaded or if
321 # it's sufficient if the new field is (#27737).
322 # Delay rendering of relationships if it's not a relational field and
323 # not referenced by a foreign key.
324 delay = not isinstance(field, RelatedField) and not field_is_referenced(
325 self, model_key, (name, field)
326 )
327 self.reload_model(*model_key, delay=delay)
328
329 def rename_field(
330 self, package_label: str, model_name: str, old_name: str, new_name: str
331 ) -> None:
332 model_key = package_label, model_name
333 model_state = self.models[model_key]
334 # Rename the field.
335 fields = model_state.fields
336 try:
337 found = fields.pop(old_name)
338 except KeyError:
339 raise FieldDoesNotExist(
340 f"{package_label}.{model_name} has no field named '{old_name}'"
341 )
342 fields[new_name] = found
343 # Check if there are any references to this field
344 references = get_references(self, model_key, (old_name, found))
345 delay = not bool(references)
346 if self._relations is not None:
347 old_name_lower = old_name.lower()
348 new_name_lower = new_name.lower()
349 for to_model in self._relations.values():
350 if old_name_lower in to_model[model_key]:
351 field = to_model[model_key].pop(old_name_lower)
352 field.name = new_name_lower
353 to_model[model_key][new_name_lower] = field
354 self.reload_model(*model_key, delay=delay)
355
356 def _find_reload_model(
357 self, package_label: str, model_name: str, delay: bool = False
358 ) -> set[tuple[str, str]]:
359 if delay:
360 self.is_delayed = True
361
362 related_models: set[tuple[str, str]] = set()
363
364 try:
365 old_model = self.models_registry.get_model(package_label, model_name)
366 except LookupError:
367 pass
368 else:
369 # Get all relations to and from the old model before reloading,
370 # as _model_meta.models_registry may change
371 if delay:
372 related_models = get_related_models_tuples(old_model)
373 else:
374 related_models = get_related_models_recursive(old_model)
375
376 # Get all outgoing references from the model to be rendered
377 model_state = self.models[(package_label, model_name)]
378 # Directly related models are the models pointed to by ForeignKeys and ManyToManyFields.
379 direct_related_models = set()
380 for field in model_state.fields.values():
381 if isinstance(field, RelatedField):
382 if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
383 continue
384 rel_package_label, rel_model_name = _get_package_label_and_model_name(
385 field.related_model,
386 package_label,
387 )
388 direct_related_models.add((rel_package_label, rel_model_name.lower()))
389
390 # For all direct related models recursively get all related models.
391 related_models.update(direct_related_models)
392 for rel_package_label, rel_model_name in direct_related_models:
393 try:
394 rel_model = self.models_registry.get_model(
395 rel_package_label, rel_model_name
396 )
397 except LookupError:
398 pass
399 else:
400 if delay:
401 related_models.update(get_related_models_tuples(rel_model))
402 else:
403 related_models.update(get_related_models_recursive(rel_model))
404
405 # Include the model itself
406 related_models.add((package_label, model_name))
407
408 return related_models
409
410 def reload_model(
411 self, package_label: str, model_name: str, delay: bool = False
412 ) -> None:
413 if "models_registry" in self.__dict__: # hasattr would cache the property
414 related_models = self._find_reload_model(package_label, model_name, delay)
415 self._reload(related_models)
416
417 def reload_models(self, models: set[tuple[str, str]], delay: bool = True) -> None:
418 if "models_registry" in self.__dict__: # hasattr would cache the property
419 related_models = set()
420 for package_label, model_name in models:
421 related_models.update(
422 self._find_reload_model(package_label, model_name, delay)
423 )
424 self._reload(related_models)
425
426 def _reload(self, related_models: set[tuple[str, str]]) -> None:
427 # Unregister all related models
428 with self.models_registry.bulk_update():
429 for rel_package_label, rel_model_name in related_models:
430 self.models_registry.unregister_model(rel_package_label, rel_model_name)
431
432 states_to_be_rendered = []
433 # Gather all models states of those models that will be rerendered.
434 # This includes:
435 # 1. All related models of unmigrated packages
436 for model_state in self.models_registry.real_models:
437 if (model_state.package_label, model_state.name_lower) in related_models:
438 states_to_be_rendered.append(model_state)
439
440 # 2. All related models of migrated packages
441 for rel_package_label, rel_model_name in related_models:
442 try:
443 model_state = self.models[rel_package_label, rel_model_name]
444 except KeyError:
445 pass
446 else:
447 states_to_be_rendered.append(model_state)
448
449 # Render all models
450 self.models_registry.render_multiple(states_to_be_rendered)
451
452 def update_model_field_relation(
453 self,
454 model: str | type[models.Model],
455 model_key: tuple[str, str],
456 field_name: str,
457 field: Field,
458 concretes: dict[tuple[str, str], tuple[str, str]],
459 ) -> None:
460 assert self._relations is not None
461 remote_model_key = resolve_relation(model, *model_key)
462 if (
463 remote_model_key[0] not in self.real_packages
464 and remote_model_key in concretes
465 ):
466 remote_model_key = concretes[remote_model_key]
467 relations_to_remote_model = self._relations[remote_model_key]
468 if field_name in self.models[model_key].fields:
469 # The assert holds because it's a new relation, or an altered
470 # relation, in which case references have been removed by
471 # alter_field().
472 assert field_name not in relations_to_remote_model[model_key]
473 relations_to_remote_model[model_key][field_name] = field
474 else:
475 del relations_to_remote_model[model_key][field_name]
476 if not relations_to_remote_model[model_key]:
477 del relations_to_remote_model[model_key]
478
479 def resolve_model_field_relations(
480 self,
481 model_key: tuple[str, str],
482 field_name: str,
483 field: Field,
484 concretes: dict[tuple[str, str], tuple[str, str]] | None = None,
485 ) -> None:
486 # Only process fields that have relations
487 if not isinstance(field, RelatedField):
488 return None
489 remote_field = field.remote_field
490 if not remote_field:
491 return None
492 if concretes is None:
493 concretes = self._get_concrete_models_mapping()
494
495 self.update_model_field_relation(
496 remote_field.model,
497 model_key,
498 field_name,
499 field,
500 concretes,
501 )
502
503 through = getattr(remote_field, "through", None)
504 if not through:
505 return None
506 self.update_model_field_relation(
507 through, model_key, field_name, field, concretes
508 )
509
510 def resolve_model_relations(
511 self,
512 model_key: tuple[str, str],
513 concretes: dict[tuple[str, str], tuple[str, str]] | None = None,
514 ) -> None:
515 if concretes is None:
516 concretes = self._get_concrete_models_mapping()
517
518 model_state = self.models[model_key]
519 for field_name, field in model_state.fields.items():
520 self.resolve_model_field_relations(model_key, field_name, field, concretes)
521
522 def resolve_fields_and_relations(self) -> None:
523 # Resolve fields.
524 for model_state in self.models.values():
525 for field_name, field in model_state.fields.items():
526 field.name = field_name
527 # Resolve relations.
528 # {remote_model_key: {model_key: {field_name: field}}}
529 self._relations = defaultdict(partial(defaultdict, dict))
530 concretes = self._get_concrete_models_mapping()
531
532 for model_key in concretes:
533 self.resolve_model_relations(model_key, concretes)
534
535 def _get_concrete_models_mapping(self) -> dict[tuple[str, str], tuple[str, str]]:
536 concrete_models_mapping = {}
537 for model_key, model_state in self.models.items():
538 concrete_models_mapping[model_key] = model_key
539 return concrete_models_mapping
540
541 def clone(self) -> ProjectState:
542 """Return an exact copy of this ProjectState."""
543 new_state = ProjectState(
544 models={k: v.clone() for k, v in self.models.items()},
545 real_packages=self.real_packages,
546 )
547 if "models_registry" in self.__dict__:
548 new_state.models_registry = self.models_registry.clone()
549 new_state.is_delayed = self.is_delayed
550 return new_state
551
552 def clear_delayed_models_cache(self) -> None:
553 if self.is_delayed and "models_registry" in self.__dict__:
554 del self.__dict__["models_registry"]
555
556 @cached_property
557 def models_registry(self) -> StateModelsRegistry:
558 return StateModelsRegistry(self.real_packages, self.models)
559
560 @classmethod
561 def from_models_registry(cls, models_registry: ModelsRegistry) -> ProjectState:
562 """Take an Packages and return a ProjectState matching it."""
563 app_models = {}
564 for model in models_registry.get_models():
565 model_state = ModelState.from_model(model)
566 app_models[(model_state.package_label, model_state.name_lower)] = (
567 model_state
568 )
569 return cls(app_models)
570
571 def __eq__(self, other: object) -> bool:
572 if not isinstance(other, ProjectState):
573 return NotImplemented
574 return self.models == other.models and self.real_packages == other.real_packages
575
576
577class StateModelsRegistry(ModelsRegistry):
578 """
579 Subclass of the global Packages registry class to better handle dynamic model
580 additions and removals.
581 """
582
583 def __init__(
584 self,
585 real_packages: set[str],
586 models: dict[tuple[str, str], ModelState],
587 ):
588 # Any packages in self.real_packages should have all their models included
589 # in the render. We don't use the original model instances as there
590 # are some variables that refer to the Packages object.
591 # FKs/M2Ms from real packages are also not included as they just
592 # mess things up with partial states (due to lack of dependencies)
593 self.real_models: list[ModelState] = []
594 for package_label in real_packages:
595 for model in global_models.get_models(package_label=package_label):
596 self.real_models.append(ModelState.from_model(model, exclude_rels=True))
597
598 super().__init__()
599
600 self.render_multiple([*models.values(), *self.real_models])
601
602 self.ready = True
603
604 # There shouldn't be any operations pending at this point.
605 from plain.models.preflight import _check_lazy_references
606
607 if errors := _check_lazy_references(self, packages_registry):
608 raise ValueError("\n".join(error.fix for error in errors))
609
610 @contextmanager
611 def bulk_update(self) -> Generator[None, None, None]:
612 # Avoid clearing each model's cache for each change. Instead, clear
613 # all caches when we're finished updating the model instances.
614 ready = self.ready
615 self.ready = False
616 try:
617 yield
618 finally:
619 self.ready = ready
620 self.clear_cache()
621
622 def render_multiple(self, model_states: list[ModelState]) -> None:
623 # We keep trying to render the models in a loop, ignoring invalid
624 # base errors, until the size of the unrendered models doesn't
625 # decrease by at least one, meaning there's a base dependency loop/
626 # missing base.
627 if not model_states:
628 return None
629 # Prevent that all model caches are expired for each render.
630 with self.bulk_update():
631 unrendered_models = model_states
632 while unrendered_models:
633 new_unrendered_models = []
634 for model in unrendered_models:
635 try:
636 model.render(self)
637 except InvalidBasesError:
638 new_unrendered_models.append(model)
639 if len(new_unrendered_models) == len(unrendered_models):
640 raise InvalidBasesError(
641 f"Cannot resolve bases for {new_unrendered_models!r}\nThis can happen if you are "
642 "inheriting models from an app with migrations (e.g. "
643 "contrib.auth)\n in an app with no migrations"
644 )
645 unrendered_models = new_unrendered_models
646
647 def clone(self) -> StateModelsRegistry:
648 """Return a clone of this registry."""
649 clone = StateModelsRegistry(set(), {})
650 clone.all_models = copy.deepcopy(self.all_models)
651
652 # No need to actually clone them, they'll never change
653 clone.real_models = self.real_models
654 return clone
655
656 def register_model(self, package_label: str, model: type[models.Model]) -> None:
657 self.all_models[package_label][model.model_options.model_name] = model
658 self.do_pending_operations(model)
659 self.clear_cache()
660
661 def unregister_model(self, package_label: str, model_name: str) -> None:
662 try:
663 del self.all_models[package_label][model_name]
664 except KeyError:
665 pass
666
667
668class ModelState:
669 """
670 Represent a Plain Model. Don't use the actual Model class as it's not
671 designed to have its options changed - instead, mutate this one and then
672 render it into a Model as required.
673
674 Note that while you are allowed to mutate .fields, you are not allowed
675 to mutate the Field instances inside there themselves - you must instead
676 assign new ones, as these are not detached during a clone.
677 """
678
679 def __init__(
680 self,
681 package_label: str,
682 name: str,
683 fields: Iterable[tuple[str, Field]],
684 options: dict[str, Any] | None = None,
685 bases: tuple[str | type[models.Model], ...] | None = None,
686 ):
687 self.package_label = package_label
688 self.name = name
689 self.fields: dict[str, Field] = dict(fields)
690 self.options = options or {}
691 self.options.setdefault("indexes", [])
692 self.options.setdefault("constraints", [])
693 self.bases = bases or (models.Model,)
694 for name, field in self.fields.items():
695 # Sanity-check that fields are NOT already bound to a model.
696 if hasattr(field, "model"):
697 raise ValueError(
698 f'ModelState.fields cannot be bound to a model - "{name}" is.'
699 )
700 # Sanity-check that relation fields are NOT referring to a model class.
701 if isinstance(field, RelatedField) and hasattr(
702 field.related_model, "_model_meta"
703 ):
704 raise ValueError(
705 f'ModelState.fields cannot refer to a model class - "{name}.to" does. '
706 "Use a string reference instead."
707 )
708 from plain.models.fields.related import ManyToManyField
709
710 if isinstance(field, ManyToManyField) and hasattr(
711 field.remote_field.through, "_model_meta"
712 ):
713 raise ValueError(
714 f'ModelState.fields cannot refer to a model class - "{name}.through" '
715 "does. Use a string reference instead."
716 )
717 # Sanity-check that indexes have their name set.
718 for index in self.options["indexes"]:
719 if not index.name:
720 raise ValueError(
721 "Indexes passed to ModelState require a name attribute. "
722 f"{index!r} doesn't have one."
723 )
724
725 @cached_property
726 def name_lower(self) -> str:
727 return self.name.lower()
728
729 def get_field(self, field_name: str) -> Field:
730 return self.fields[field_name]
731
732 @classmethod
733 def from_model(
734 cls, model: type[models.Model], exclude_rels: bool = False
735 ) -> ModelState:
736 """Given a model, return a ModelState representing it."""
737 # Deconstruct the fields
738 fields = []
739 for field in model._model_meta.local_fields:
740 if getattr(field, "remote_field", None) and exclude_rels:
741 continue
742 name = field.name
743 try:
744 fields.append((name, field.clone()))
745 except TypeError as e:
746 raise TypeError(
747 f"Couldn't reconstruct field {name} on {model.model_options.label}: {e}"
748 )
749 if not exclude_rels:
750 for field in model._model_meta.local_many_to_many:
751 name = field.name
752 try:
753 fields.append((name, field.clone()))
754 except TypeError as e:
755 raise TypeError(
756 f"Couldn't reconstruct m2m field {name} on {model.model_options.object_name}: {e}"
757 )
758
759 def flatten_bases(model: type[models.Model]) -> list[type[models.Model]]:
760 bases = []
761 for base in model.__bases__:
762 bases.append(base)
763 return bases
764
765 # We can't rely on __mro__ directly because we only want to flatten
766 # abstract models and not the whole tree. However by recursing on
767 # __bases__ we may end up with duplicates and ordering issues, we
768 # therefore discard any duplicates and reorder the bases according
769 # to their index in the MRO.
770 flattened_bases = sorted(
771 set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
772 )
773
774 # Make our record
775 bases = tuple(
776 (
777 base.model_options.label_lower
778 if not isinstance(base, str)
779 and base is not models.Model
780 and hasattr(base, "_model_meta")
781 else base
782 )
783 for base in flattened_bases
784 )
785 # Ensure at least one base inherits from models.Model
786 if not any(
787 (isinstance(base, str) or issubclass(base, models.Model)) for base in bases
788 ):
789 bases = (models.Model,)
790
791 # Construct the new ModelState
792 return cls(
793 model.model_options.package_label,
794 model.model_options.object_name,
795 fields,
796 model.model_options.export_for_migrations(),
797 bases,
798 )
799
800 def clone(self) -> ModelState:
801 """Return an exact copy of this ModelState."""
802 return self.__class__(
803 package_label=self.package_label,
804 name=self.name,
805 fields=dict(self.fields),
806 # Since options are shallow-copied here, operations such as
807 # AddIndex must replace their option (e.g 'indexes') rather
808 # than mutating it.
809 options=dict(self.options),
810 bases=self.bases,
811 )
812
813 def render(self, models_registry: ModelsRegistry) -> type[models.Model]:
814 """Create a Model object from our current state into the given packages."""
815 # Create Options instance with metadata
816 meta_options = models.Options(
817 package_label=self.package_label,
818 **self.options,
819 )
820 # Then, work out our bases
821 try:
822 bases = tuple(
823 (models_registry.get_model(base) if isinstance(base, str) else base)
824 for base in self.bases
825 )
826 except LookupError:
827 raise InvalidBasesError(
828 f"Cannot resolve one or more bases from {self.bases!r}"
829 )
830 # Clone fields for the body, add other bits.
831 body = {name: field.clone() for name, field in self.fields.items()}
832 body["model_options"] = meta_options
833 body["_model_meta"] = Meta(
834 models_registry=models_registry
835 ) # Use custom registry
836 body["__module__"] = "__fake__"
837
838 # Then, make a Model object (models_registry.register_model is called in __new__)
839 model_class = cast(type[models.Model], type(self.name, bases, body))
840 from plain.models import register_model
841
842 # Register it to the models_registry associated with the model meta
843 # (could probably do this directly right here too...)
844 register_model(model_class)
845
846 return model_class
847
848 def get_index_by_name(self, name: str) -> Any:
849 for index in self.options["indexes"]:
850 if index.name == name:
851 return index
852 raise ValueError(f"No index named {name} on model {self.name}")
853
854 def get_constraint_by_name(self, name: str) -> Any:
855 for constraint in self.options["constraints"]:
856 if constraint.name == name:
857 return constraint
858 raise ValueError(f"No constraint named {name} on model {self.name}")
859
860 def __repr__(self) -> str:
861 return f"<{self.__class__.__name__}: '{self.package_label}.{self.name}'>"
862
863 def __eq__(self, other: object) -> bool:
864 if not isinstance(other, ModelState):
865 return NotImplemented
866 return (
867 (self.package_label == other.package_label)
868 and (self.name == other.name)
869 and (len(self.fields) == len(other.fields))
870 and all(
871 k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
872 for (k1, f1), (k2, f2) in zip(
873 sorted(self.fields.items()),
874 sorted(other.fields.items()),
875 )
876 )
877 and (self.options == other.options)
878 and (self.bases == other.bases)
879 )