1from __future__ import annotations
2
3import copy
4from collections import defaultdict
5from contextlib import contextmanager
6from functools import cached_property, partial
7from typing import TYPE_CHECKING, Any
8
9from plain import models
10from plain.models.exceptions import FieldDoesNotExist
11from plain.models.fields import NOT_PROVIDED
12from plain.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
13from plain.models.meta import Meta
14from plain.models.migrations.utils import field_is_referenced, get_references
15from plain.models.registry import ModelsRegistry
16from plain.models.registry import models_registry as global_models
17from plain.packages import packages_registry
18
19from .exceptions import InvalidBasesError
20from .utils import resolve_relation
21
22if TYPE_CHECKING:
23 from collections.abc import Generator, Iterable
24
25 from plain.models.fields import Field
26
27
28def _get_package_label_and_model_name(
29 model: str | type[models.Model], package_label: str = ""
30) -> tuple[str, str]:
31 if isinstance(model, str):
32 split = model.split(".", 1)
33 return tuple(split) if len(split) == 2 else (package_label, split[0]) # type: ignore[return-value]
34 else:
35 return model.model_options.package_label, model.model_options.model_name
36
37
38def _get_related_models(m: type[models.Model]) -> list[type[models.Model]]:
39 """Return all models that have a direct relationship to the given model."""
40 related_models = [
41 subclass
42 for subclass in m.__subclasses__()
43 if issubclass(subclass, models.Model)
44 ]
45 related_fields_models = set()
46 for f in m._model_meta.get_fields(include_hidden=True):
47 if (
48 f.is_relation # type: ignore[attr-defined]
49 and f.related_model is not None # type: ignore[attr-defined]
50 and not isinstance(f.related_model, str) # type: ignore[attr-defined]
51 ):
52 related_fields_models.add(f.model) # type: ignore[attr-defined]
53 related_models.append(f.related_model) # type: ignore[attr-defined]
54 return related_models
55
56
57def get_related_models_tuples(model: type[models.Model]) -> set[tuple[str, str]]:
58 """
59 Return a list of typical (package_label, model_name) tuples for all related
60 models for the given model.
61 """
62 return {
63 (rel_mod.model_options.package_label, rel_mod.model_options.model_name)
64 for rel_mod in _get_related_models(model)
65 }
66
67
68def get_related_models_recursive(model: type[models.Model]) -> set[tuple[str, str]]:
69 """
70 Return all models that have a direct or indirect relationship
71 to the given model.
72
73 Relationships are either defined by explicit relational fields, like
74 ForeignKey or ManyToManyField, or by inheriting from another
75 model (a superclass is related to its subclasses, but not vice versa).
76 """
77 seen = set()
78 queue = _get_related_models(model)
79 for rel_mod in queue:
80 rel_package_label, rel_model_name = (
81 rel_mod.model_options.package_label,
82 rel_mod.model_options.model_name,
83 )
84 if (rel_package_label, rel_model_name) in seen:
85 continue
86 seen.add((rel_package_label, rel_model_name))
87 queue.extend(_get_related_models(rel_mod))
88 return seen - {(model.model_options.package_label, model.model_options.model_name)}
89
90
91class ProjectState:
92 """
93 Represent the entire project's overall state. This is the item that is
94 passed around - do it here rather than at the app level so that cross-app
95 FKs/etc. resolve properly.
96 """
97
98 def __init__(
99 self,
100 models: dict[tuple[str, str], ModelState] | None = None,
101 real_packages: set[str] | None = None,
102 ):
103 self.models = models or {}
104 # Packages to include from main registry, usually unmigrated ones
105 if real_packages is None:
106 real_packages = set()
107 else:
108 assert isinstance(real_packages, set)
109 self.real_packages = real_packages
110 self.is_delayed = False
111 # {remote_model_key: {model_key: {field_name: field}}}
112 self._relations: (
113 dict[tuple[str, str], dict[tuple[str, str], dict[str, Field]]] | None
114 ) = None
115
116 @property
117 def relations(
118 self,
119 ) -> dict[tuple[str, str], dict[tuple[str, str], dict[str, Field]]]:
120 if self._relations is None:
121 self.resolve_fields_and_relations()
122 return self._relations # type: ignore[return-value]
123
124 def add_model(self, model_state: ModelState) -> None:
125 model_key = model_state.package_label, model_state.name_lower
126 self.models[model_key] = model_state
127 if self._relations is not None:
128 self.resolve_model_relations(model_key)
129 if "models_registry" in self.__dict__: # hasattr would cache the property
130 self.reload_model(*model_key)
131
132 def remove_model(self, package_label: str, model_name: str) -> None:
133 model_key = package_label, model_name
134 del self.models[model_key]
135 if self._relations is not None:
136 self._relations.pop(model_key, None)
137 # Call list() since _relations can change size during iteration.
138 for related_model_key, model_relations in list(self._relations.items()):
139 model_relations.pop(model_key, None)
140 if not model_relations:
141 del self._relations[related_model_key]
142 if "models_registry" in self.__dict__: # hasattr would cache the property
143 self.models_registry.unregister_model(*model_key)
144 # Need to do this explicitly since unregister_model() doesn't clear
145 # the cache automatically (#24513)
146 self.models_registry.clear_cache()
147
148 def rename_model(self, package_label: str, old_name: str, new_name: str) -> None:
149 # Add a new model.
150 old_name_lower = old_name.lower()
151 new_name_lower = new_name.lower()
152 renamed_model = self.models[package_label, old_name_lower].clone()
153 renamed_model.name = new_name
154 self.models[package_label, new_name_lower] = renamed_model
155 # Repoint all fields pointing to the old model to the new one.
156 old_model_tuple = (package_label, old_name_lower)
157 new_remote_model = f"{package_label}.{new_name}"
158 to_reload = set()
159 for model_state, name, field, reference in get_references(
160 self, old_model_tuple
161 ):
162 changed_field = None
163 if reference.to:
164 changed_field = field.clone()
165 changed_field.remote_field.model = new_remote_model # type: ignore[attr-defined]
166 if reference.through:
167 if changed_field is None:
168 changed_field = field.clone()
169 changed_field.remote_field.through = new_remote_model # type: ignore[attr-defined]
170 if changed_field:
171 model_state.fields[name] = changed_field
172 to_reload.add((model_state.package_label, model_state.name_lower))
173 if self._relations is not None:
174 old_name_key = package_label, old_name_lower
175 new_name_key = package_label, new_name_lower
176 if old_name_key in self._relations:
177 self._relations[new_name_key] = self._relations.pop(old_name_key)
178 for model_relations in self._relations.values():
179 if old_name_key in model_relations:
180 model_relations[new_name_key] = model_relations.pop(old_name_key)
181 # Reload models related to old model before removing the old model.
182 self.reload_models(to_reload, delay=True)
183 # Remove the old model.
184 self.remove_model(package_label, old_name_lower)
185 self.reload_model(package_label, new_name_lower, delay=True)
186
187 def alter_model_options(
188 self,
189 package_label: str,
190 model_name: str,
191 options: dict[str, Any],
192 option_keys: Iterable[str] | None = None,
193 ) -> None:
194 model_state = self.models[package_label, model_name]
195 model_state.options = {**model_state.options, **options}
196 if option_keys:
197 for key in option_keys:
198 if key not in options:
199 model_state.options.pop(key, False)
200 self.reload_model(package_label, model_name, delay=True)
201
202 def _append_option(
203 self, package_label: str, model_name: str, option_name: str, obj: Any
204 ) -> None:
205 model_state = self.models[package_label, model_name]
206 model_state.options[option_name] = [*model_state.options[option_name], obj]
207 self.reload_model(package_label, model_name, delay=True)
208
209 def _remove_option(
210 self, package_label: str, model_name: str, option_name: str, obj_name: str
211 ) -> None:
212 model_state = self.models[package_label, model_name]
213 objs = model_state.options[option_name]
214 model_state.options[option_name] = [
215 obj
216 for obj in objs
217 if obj.name != obj_name # type: ignore[attr-defined]
218 ]
219 self.reload_model(package_label, model_name, delay=True)
220
221 def add_index(self, package_label: str, model_name: str, index: Any) -> None:
222 self._append_option(package_label, model_name, "indexes", index)
223
224 def remove_index(
225 self, package_label: str, model_name: str, index_name: str
226 ) -> None:
227 self._remove_option(package_label, model_name, "indexes", index_name)
228
229 def rename_index(
230 self,
231 package_label: str,
232 model_name: str,
233 old_index_name: str,
234 new_index_name: str,
235 ) -> None:
236 model_state = self.models[package_label, model_name]
237 objs = model_state.options["indexes"]
238
239 new_indexes = []
240 for obj in objs:
241 if obj.name == old_index_name: # type: ignore[attr-defined]
242 obj = obj.clone() # type: ignore[attr-defined]
243 obj.name = new_index_name # type: ignore[attr-defined]
244 new_indexes.append(obj)
245
246 model_state.options["indexes"] = new_indexes
247 self.reload_model(package_label, model_name, delay=True)
248
249 def add_constraint(
250 self, package_label: str, model_name: str, constraint: Any
251 ) -> None:
252 self._append_option(package_label, model_name, "constraints", constraint)
253
254 def remove_constraint(
255 self, package_label: str, model_name: str, constraint_name: str
256 ) -> None:
257 self._remove_option(package_label, model_name, "constraints", constraint_name)
258
259 def add_field(
260 self,
261 package_label: str,
262 model_name: str,
263 name: str,
264 field: Field,
265 preserve_default: bool,
266 ) -> None:
267 # If preserve default is off, don't use the default for future state.
268 if not preserve_default:
269 field = field.clone()
270 field.default = NOT_PROVIDED # type: ignore[attr-defined]
271 else:
272 field = field
273 model_key = package_label, model_name
274 self.models[model_key].fields[name] = field
275 if self._relations is not None:
276 self.resolve_model_field_relations(model_key, name, field)
277 # Delay rendering of relationships if it's not a relational field.
278 delay = not field.is_relation # type: ignore[attr-defined]
279 self.reload_model(*model_key, delay=delay)
280
281 def remove_field(self, package_label: str, model_name: str, name: str) -> None:
282 model_key = package_label, model_name
283 model_state = self.models[model_key]
284 old_field = model_state.fields.pop(name)
285 if self._relations is not None:
286 self.resolve_model_field_relations(model_key, name, old_field)
287 # Delay rendering of relationships if it's not a relational field.
288 delay = not old_field.is_relation # type: ignore[attr-defined]
289 self.reload_model(*model_key, delay=delay)
290
291 def alter_field(
292 self,
293 package_label: str,
294 model_name: str,
295 name: str,
296 field: Field,
297 preserve_default: bool,
298 ) -> None:
299 if not preserve_default:
300 field = field.clone()
301 field.default = NOT_PROVIDED # type: ignore[attr-defined]
302 else:
303 field = field
304 model_key = package_label, model_name
305 fields = self.models[model_key].fields
306 if self._relations is not None:
307 old_field = fields.pop(name)
308 if old_field.is_relation: # type: ignore[attr-defined]
309 self.resolve_model_field_relations(model_key, name, old_field)
310 fields[name] = field
311 if field.is_relation: # type: ignore[attr-defined]
312 self.resolve_model_field_relations(model_key, name, field)
313 else:
314 fields[name] = field
315 # TODO: investigate if old relational fields must be reloaded or if
316 # it's sufficient if the new field is (#27737).
317 # Delay rendering of relationships if it's not a relational field and
318 # not referenced by a foreign key.
319 delay = not field.is_relation and not field_is_referenced( # type: ignore[attr-defined]
320 self, model_key, (name, field)
321 )
322 self.reload_model(*model_key, delay=delay)
323
324 def rename_field(
325 self, package_label: str, model_name: str, old_name: str, new_name: str
326 ) -> None:
327 model_key = package_label, model_name
328 model_state = self.models[model_key]
329 # Rename the field.
330 fields = model_state.fields
331 try:
332 found = fields.pop(old_name)
333 except KeyError:
334 raise FieldDoesNotExist(
335 f"{package_label}.{model_name} has no field named '{old_name}'"
336 )
337 fields[new_name] = found
338 # Check if there are any references to this field
339 references = get_references(self, model_key, (old_name, found))
340 delay = not bool(references)
341 if self._relations is not None:
342 old_name_lower = old_name.lower()
343 new_name_lower = new_name.lower()
344 for to_model in self._relations.values():
345 if old_name_lower in to_model[model_key]:
346 field = to_model[model_key].pop(old_name_lower)
347 field.name = new_name_lower # type: ignore[attr-defined]
348 to_model[model_key][new_name_lower] = field
349 self.reload_model(*model_key, delay=delay)
350
351 def _find_reload_model(
352 self, package_label: str, model_name: str, delay: bool = False
353 ) -> set[tuple[str, str]]:
354 if delay:
355 self.is_delayed = True
356
357 related_models: set[tuple[str, str]] = set()
358
359 try:
360 old_model = self.models_registry.get_model(package_label, model_name)
361 except LookupError:
362 pass
363 else:
364 # Get all relations to and from the old model before reloading,
365 # as _model_meta.models_registry may change
366 if delay:
367 related_models = get_related_models_tuples(old_model)
368 else:
369 related_models = get_related_models_recursive(old_model)
370
371 # Get all outgoing references from the model to be rendered
372 model_state = self.models[(package_label, model_name)]
373 # Directly related models are the models pointed to by ForeignKeys and ManyToManyFields.
374 direct_related_models = set()
375 for field in model_state.fields.values():
376 if field.is_relation: # type: ignore[attr-defined]
377 if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT: # type: ignore[attr-defined]
378 continue
379 rel_package_label, rel_model_name = _get_package_label_and_model_name(
380 field.related_model,
381 package_label, # type: ignore[attr-defined]
382 )
383 direct_related_models.add((rel_package_label, rel_model_name.lower()))
384
385 # For all direct related models recursively get all related models.
386 related_models.update(direct_related_models)
387 for rel_package_label, rel_model_name in direct_related_models:
388 try:
389 rel_model = self.models_registry.get_model(
390 rel_package_label, rel_model_name
391 )
392 except LookupError:
393 pass
394 else:
395 if delay:
396 related_models.update(get_related_models_tuples(rel_model))
397 else:
398 related_models.update(get_related_models_recursive(rel_model))
399
400 # Include the model itself
401 related_models.add((package_label, model_name))
402
403 return related_models
404
405 def reload_model(
406 self, package_label: str, model_name: str, delay: bool = False
407 ) -> None:
408 if "models_registry" in self.__dict__: # hasattr would cache the property
409 related_models = self._find_reload_model(package_label, model_name, delay)
410 self._reload(related_models)
411
412 def reload_models(self, models: set[tuple[str, str]], delay: bool = True) -> None:
413 if "models_registry" in self.__dict__: # hasattr would cache the property
414 related_models = set()
415 for package_label, model_name in models:
416 related_models.update(
417 self._find_reload_model(package_label, model_name, delay)
418 )
419 self._reload(related_models)
420
421 def _reload(self, related_models: set[tuple[str, str]]) -> None:
422 # Unregister all related models
423 with self.models_registry.bulk_update():
424 for rel_package_label, rel_model_name in related_models:
425 self.models_registry.unregister_model(rel_package_label, rel_model_name)
426
427 states_to_be_rendered = []
428 # Gather all models states of those models that will be rerendered.
429 # This includes:
430 # 1. All related models of unmigrated packages
431 for model_state in self.models_registry.real_models:
432 if (model_state.package_label, model_state.name_lower) in related_models:
433 states_to_be_rendered.append(model_state)
434
435 # 2. All related models of migrated packages
436 for rel_package_label, rel_model_name in related_models:
437 try:
438 model_state = self.models[rel_package_label, rel_model_name]
439 except KeyError:
440 pass
441 else:
442 states_to_be_rendered.append(model_state)
443
444 # Render all models
445 self.models_registry.render_multiple(states_to_be_rendered)
446
447 def update_model_field_relation(
448 self,
449 model: str | type[models.Model],
450 model_key: tuple[str, str],
451 field_name: str,
452 field: Field,
453 concretes: dict[tuple[str, str], tuple[str, str]],
454 ) -> None:
455 remote_model_key = resolve_relation(model, *model_key)
456 if (
457 remote_model_key[0] not in self.real_packages
458 and remote_model_key in concretes
459 ):
460 remote_model_key = concretes[remote_model_key]
461 relations_to_remote_model = self._relations[remote_model_key] # type: ignore[index]
462 if field_name in self.models[model_key].fields:
463 # The assert holds because it's a new relation, or an altered
464 # relation, in which case references have been removed by
465 # alter_field().
466 assert field_name not in relations_to_remote_model[model_key]
467 relations_to_remote_model[model_key][field_name] = field
468 else:
469 del relations_to_remote_model[model_key][field_name]
470 if not relations_to_remote_model[model_key]:
471 del relations_to_remote_model[model_key]
472
473 def resolve_model_field_relations(
474 self,
475 model_key: tuple[str, str],
476 field_name: str,
477 field: Field,
478 concretes: dict[tuple[str, str], tuple[str, str]] | None = None,
479 ) -> None:
480 remote_field = field.remote_field # type: ignore[attr-defined]
481 if not remote_field:
482 return None
483 if concretes is None:
484 concretes = self._get_concrete_models_mapping()
485
486 self.update_model_field_relation(
487 remote_field.model, # type: ignore[attr-defined]
488 model_key,
489 field_name,
490 field,
491 concretes,
492 )
493
494 through = getattr(remote_field, "through", None)
495 if not through:
496 return None
497 self.update_model_field_relation(
498 through, model_key, field_name, field, concretes
499 )
500
501 def resolve_model_relations(
502 self,
503 model_key: tuple[str, str],
504 concretes: dict[tuple[str, str], tuple[str, str]] | None = None,
505 ) -> None:
506 if concretes is None:
507 concretes = self._get_concrete_models_mapping()
508
509 model_state = self.models[model_key]
510 for field_name, field in model_state.fields.items():
511 self.resolve_model_field_relations(model_key, field_name, field, concretes)
512
513 def resolve_fields_and_relations(self) -> None:
514 # Resolve fields.
515 for model_state in self.models.values():
516 for field_name, field in model_state.fields.items():
517 field.name = field_name # type: ignore[attr-defined]
518 # Resolve relations.
519 # {remote_model_key: {model_key: {field_name: field}}}
520 self._relations = defaultdict(partial(defaultdict, dict))
521 concretes = self._get_concrete_models_mapping()
522
523 for model_key in concretes:
524 self.resolve_model_relations(model_key, concretes)
525
526 def _get_concrete_models_mapping(self) -> dict[tuple[str, str], tuple[str, str]]:
527 concrete_models_mapping = {}
528 for model_key, model_state in self.models.items():
529 concrete_models_mapping[model_key] = model_key
530 return concrete_models_mapping
531
532 def clone(self) -> ProjectState:
533 """Return an exact copy of this ProjectState."""
534 new_state = ProjectState(
535 models={k: v.clone() for k, v in self.models.items()},
536 real_packages=self.real_packages,
537 )
538 if "models_registry" in self.__dict__:
539 new_state.models_registry = self.models_registry.clone()
540 new_state.is_delayed = self.is_delayed
541 return new_state
542
543 def clear_delayed_models_cache(self) -> None:
544 if self.is_delayed and "models_registry" in self.__dict__:
545 del self.__dict__["models_registry"]
546
547 @cached_property
548 def models_registry(self) -> StateModelsRegistry:
549 return StateModelsRegistry(self.real_packages, self.models)
550
551 @classmethod
552 def from_models_registry(cls, models_registry: ModelsRegistry) -> ProjectState:
553 """Take an Packages and return a ProjectState matching it."""
554 app_models = {}
555 for model in models_registry.get_models():
556 model_state = ModelState.from_model(model)
557 app_models[(model_state.package_label, model_state.name_lower)] = (
558 model_state
559 )
560 return cls(app_models)
561
562 def __eq__(self, other: object) -> bool:
563 if not isinstance(other, ProjectState):
564 return NotImplemented
565 return self.models == other.models and self.real_packages == other.real_packages
566
567
568class StateModelsRegistry(ModelsRegistry):
569 """
570 Subclass of the global Packages registry class to better handle dynamic model
571 additions and removals.
572 """
573
574 def __init__(
575 self,
576 real_packages: set[str],
577 models: dict[tuple[str, str], ModelState],
578 ):
579 # Any packages in self.real_packages should have all their models included
580 # in the render. We don't use the original model instances as there
581 # are some variables that refer to the Packages object.
582 # FKs/M2Ms from real packages are also not included as they just
583 # mess things up with partial states (due to lack of dependencies)
584 self.real_models: list[ModelState] = []
585 for package_label in real_packages:
586 for model in global_models.get_models(package_label=package_label):
587 self.real_models.append(ModelState.from_model(model, exclude_rels=True))
588
589 super().__init__() # type: ignore[misc]
590
591 self.render_multiple([*models.values(), *self.real_models])
592
593 self.ready = True
594
595 # There shouldn't be any operations pending at this point.
596 from plain.models.preflight import _check_lazy_references
597
598 if errors := _check_lazy_references(self, packages_registry):
599 raise ValueError("\n".join(error.message for error in errors)) # type: ignore[attr-defined]
600
601 @contextmanager
602 def bulk_update(self) -> Generator[None, None, None]:
603 # Avoid clearing each model's cache for each change. Instead, clear
604 # all caches when we're finished updating the model instances.
605 ready = self.ready
606 self.ready = False
607 try:
608 yield
609 finally:
610 self.ready = ready
611 self.clear_cache()
612
613 def render_multiple(self, model_states: list[ModelState]) -> None:
614 # We keep trying to render the models in a loop, ignoring invalid
615 # base errors, until the size of the unrendered models doesn't
616 # decrease by at least one, meaning there's a base dependency loop/
617 # missing base.
618 if not model_states:
619 return None
620 # Prevent that all model caches are expired for each render.
621 with self.bulk_update():
622 unrendered_models = model_states
623 while unrendered_models:
624 new_unrendered_models = []
625 for model in unrendered_models:
626 try:
627 model.render(self)
628 except InvalidBasesError:
629 new_unrendered_models.append(model)
630 if len(new_unrendered_models) == len(unrendered_models):
631 raise InvalidBasesError(
632 f"Cannot resolve bases for {new_unrendered_models!r}\nThis can happen if you are "
633 "inheriting models from an app with migrations (e.g. "
634 "contrib.auth)\n in an app with no migrations"
635 )
636 unrendered_models = new_unrendered_models
637
638 def clone(self) -> StateModelsRegistry:
639 """Return a clone of this registry."""
640 clone = StateModelsRegistry(set(), {})
641 clone.all_models = copy.deepcopy(self.all_models)
642
643 # No need to actually clone them, they'll never change
644 clone.real_models = self.real_models
645 return clone
646
647 def register_model(self, package_label: str, model: type[models.Model]) -> None:
648 self.all_models[package_label][model.model_options.model_name] = model
649 self.do_pending_operations(model)
650 self.clear_cache()
651
652 def unregister_model(self, package_label: str, model_name: str) -> None:
653 try:
654 del self.all_models[package_label][model_name]
655 except KeyError:
656 pass
657
658
659class ModelState:
660 """
661 Represent a Plain Model. Don't use the actual Model class as it's not
662 designed to have its options changed - instead, mutate this one and then
663 render it into a Model as required.
664
665 Note that while you are allowed to mutate .fields, you are not allowed
666 to mutate the Field instances inside there themselves - you must instead
667 assign new ones, as these are not detached during a clone.
668 """
669
670 def __init__(
671 self,
672 package_label: str,
673 name: str,
674 fields: Iterable[tuple[str, Field]],
675 options: dict[str, Any] | None = None,
676 bases: tuple[str | type[models.Model], ...] | None = None,
677 ):
678 self.package_label = package_label
679 self.name = name
680 self.fields: dict[str, Field] = dict(fields)
681 self.options = options or {}
682 self.options.setdefault("indexes", [])
683 self.options.setdefault("constraints", [])
684 self.bases = bases or (models.Model,)
685 for name, field in self.fields.items():
686 # Sanity-check that fields are NOT already bound to a model.
687 if hasattr(field, "model"):
688 raise ValueError(
689 f'ModelState.fields cannot be bound to a model - "{name}" is.'
690 )
691 # Sanity-check that relation fields are NOT referring to a model class.
692 if field.is_relation and hasattr(field.related_model, "_model_meta"):
693 raise ValueError(
694 f'ModelState.fields cannot refer to a model class - "{name}.to" does. '
695 "Use a string reference instead."
696 )
697 if field.many_to_many and hasattr(
698 field.remote_field.through, "_model_meta"
699 ):
700 raise ValueError(
701 f'ModelState.fields cannot refer to a model class - "{name}.through" '
702 "does. Use a string reference instead."
703 )
704 # Sanity-check that indexes have their name set.
705 for index in self.options["indexes"]:
706 if not index.name: # type: ignore[attr-defined]
707 raise ValueError(
708 "Indexes passed to ModelState require a name attribute. "
709 f"{index!r} doesn't have one."
710 )
711
712 @cached_property
713 def name_lower(self) -> str:
714 return self.name.lower()
715
716 def get_field(self, field_name: str) -> Field:
717 return self.fields[field_name]
718
719 @classmethod
720 def from_model(
721 cls, model: type[models.Model], exclude_rels: bool = False
722 ) -> ModelState:
723 """Given a model, return a ModelState representing it."""
724 # Deconstruct the fields
725 fields = []
726 for field in model._model_meta.local_fields:
727 if getattr(field, "remote_field", None) and exclude_rels:
728 continue
729 name = field.name # type: ignore[attr-defined]
730 try:
731 fields.append((name, field.clone())) # type: ignore[attr-defined]
732 except TypeError as e:
733 raise TypeError(
734 f"Couldn't reconstruct field {name} on {model.model_options.label}: {e}"
735 )
736 if not exclude_rels:
737 for field in model._model_meta.local_many_to_many:
738 name = field.name # type: ignore[attr-defined]
739 try:
740 fields.append((name, field.clone())) # type: ignore[attr-defined]
741 except TypeError as e:
742 raise TypeError(
743 f"Couldn't reconstruct m2m field {name} on {model.model_options.object_name}: {e}"
744 )
745
746 def flatten_bases(model: type[models.Model]) -> list[type[models.Model]]:
747 bases = []
748 for base in model.__bases__:
749 bases.append(base) # type: ignore[arg-type]
750 return bases
751
752 # We can't rely on __mro__ directly because we only want to flatten
753 # abstract models and not the whole tree. However by recursing on
754 # __bases__ we may end up with duplicates and ordering issues, we
755 # therefore discard any duplicates and reorder the bases according
756 # to their index in the MRO.
757 flattened_bases = sorted(
758 set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
759 )
760
761 # Make our record
762 bases = tuple(
763 (
764 base.model_options.label_lower
765 if not isinstance(base, str)
766 and base is not models.Model
767 and hasattr(base, "_model_meta")
768 else base
769 )
770 for base in flattened_bases
771 )
772 # Ensure at least one base inherits from models.Model
773 if not any(
774 (isinstance(base, str) or issubclass(base, models.Model)) for base in bases
775 ):
776 bases = (models.Model,)
777
778 # Construct the new ModelState
779 return cls(
780 model.model_options.package_label,
781 model.model_options.object_name,
782 fields,
783 model.model_options.export_for_migrations(),
784 bases,
785 )
786
787 def clone(self) -> ModelState:
788 """Return an exact copy of this ModelState."""
789 return self.__class__(
790 package_label=self.package_label,
791 name=self.name,
792 fields=dict(self.fields),
793 # Since options are shallow-copied here, operations such as
794 # AddIndex must replace their option (e.g 'indexes') rather
795 # than mutating it.
796 options=dict(self.options),
797 bases=self.bases,
798 )
799
800 def render(self, models_registry: ModelsRegistry) -> type[models.Model]:
801 """Create a Model object from our current state into the given packages."""
802 # Create Options instance with metadata
803 meta_options = models.Options(
804 package_label=self.package_label,
805 **self.options,
806 )
807 # Then, work out our bases
808 try:
809 bases = tuple(
810 (models_registry.get_model(base) if isinstance(base, str) else base)
811 for base in self.bases
812 )
813 except LookupError:
814 raise InvalidBasesError(
815 f"Cannot resolve one or more bases from {self.bases!r}"
816 )
817 # Clone fields for the body, add other bits.
818 body = {name: field.clone() for name, field in self.fields.items()} # type: ignore[attr-defined]
819 body["model_options"] = meta_options
820 body["_model_meta"] = Meta(
821 models_registry=models_registry
822 ) # Use custom registry
823 body["__module__"] = "__fake__"
824
825 # Then, make a Model object (models_registry.register_model is called in __new__)
826 model_class = type(self.name, bases, body)
827 from plain.models import register_model
828
829 # Register it to the models_registry associated with the model meta
830 # (could probably do this directly right here too...)
831 register_model(model_class) # type: ignore[arg-type]
832
833 return model_class # type: ignore[return-value]
834
835 def get_index_by_name(self, name: str) -> Any:
836 for index in self.options["indexes"]:
837 if index.name == name: # type: ignore[attr-defined]
838 return index
839 raise ValueError(f"No index named {name} on model {self.name}")
840
841 def get_constraint_by_name(self, name: str) -> Any:
842 for constraint in self.options["constraints"]:
843 if constraint.name == name: # type: ignore[attr-defined]
844 return constraint
845 raise ValueError(f"No constraint named {name} on model {self.name}")
846
847 def __repr__(self) -> str:
848 return f"<{self.__class__.__name__}: '{self.package_label}.{self.name}'>"
849
850 def __eq__(self, other: object) -> bool:
851 if not isinstance(other, ModelState):
852 return NotImplemented
853 return (
854 (self.package_label == other.package_label)
855 and (self.name == other.name)
856 and (len(self.fields) == len(other.fields))
857 and all(
858 k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:] # type: ignore[attr-defined]
859 for (k1, f1), (k2, f2) in zip(
860 sorted(self.fields.items()),
861 sorted(other.fields.items()),
862 )
863 )
864 and (self.options == other.options)
865 and (self.bases == other.bases)
866 )