1from __future__ import annotations
  2
  3import functools
  4import warnings
  5from collections import defaultdict
  6from collections.abc import Callable
  7from functools import partial
  8from typing import TYPE_CHECKING
  9
 10if TYPE_CHECKING:
 11    from plain.postgres.base import Model
 12
 13
 14class ModelsRegistryNotReady(Exception):
 15    """The plain.postgres registry is not populated yet"""
 16
 17    pass
 18
 19
 20class ModelsRegistry:
 21    def __init__(self) -> None:
 22        # Mapping of app labels => model names => model classes. Every time a
 23        # model is imported, ModelBase.__new__ calls packages.register_model which
 24        # creates an entry in all_models. All imported models are registered,
 25        # regardless of whether they're defined in an installed application
 26        # and whether the registry has been populated. Since it isn't possible
 27        # to reimport a module safely (it could reexecute initialization code)
 28        # all_models is never overridden or reset.
 29        self.all_models: defaultdict[str, dict[str, type[Model]]] = defaultdict(dict)
 30
 31        # Maps ("package_label", "modelname") tuples to lists of functions to be
 32        # called when the corresponding model is ready. Used by this class's
 33        # `lazy_model_operation()` and `do_pending_operations()` methods.
 34        self._pending_operations: defaultdict[
 35            tuple[str, str], list[Callable[[type[Model]], None]]
 36        ] = defaultdict(list)
 37
 38        self.ready: bool = False
 39
 40    def check_ready(self) -> None:
 41        """Raise an exception if all models haven't been imported yet."""
 42        if not self.ready:
 43            raise ModelsRegistryNotReady("Models aren't loaded yet.")
 44
 45    # This method is performance-critical at least for Plain's test suite.
 46    @functools.cache
 47    def get_models(self, *, package_label: str = "") -> list[type[Model]]:
 48        """
 49        Return a list of all installed models.
 50
 51        By default, the following models aren't included:
 52
 53        - auto-created models for many-to-many relations without
 54          an explicit intermediate table,
 55
 56        Set the corresponding keyword argument to True to include such models.
 57        """
 58
 59        self.check_ready()
 60
 61        models = []
 62
 63        # Get models for a single package
 64        if package_label:
 65            package_models = self.all_models[package_label]
 66            for model in package_models.values():
 67                models.append(model)
 68            return models
 69
 70        # Get models for all packages
 71        for package_models in self.all_models.values():
 72            for model in package_models.values():
 73                models.append(model)
 74
 75        return models
 76
 77    def get_model(
 78        self,
 79        package_label: str,
 80        model_name: str | None = None,
 81        require_ready: bool = True,
 82    ) -> type[Model]:
 83        """
 84        Return the model matching the given package_label and model_name.
 85
 86        As a shortcut, package_label may be in the form <package_label>.<model_name>.
 87
 88        model_name is case-insensitive.
 89
 90        Raise LookupError if no application exists with this label, or no
 91        model exists with this name in the application. Raise ValueError if
 92        called with a single argument that doesn't contain exactly one dot.
 93        """
 94
 95        if require_ready:
 96            self.check_ready()
 97
 98        if model_name is None:
 99            package_label, model_name = package_label.split(".")
100
101        package_models = self.all_models[package_label]
102        return package_models[model_name.lower()]
103
104    def register_model(self, package_label: str, model: type[Model]) -> None:
105        # Since this method is called when models are imported, it cannot
106        # perform imports because of the risk of import loops. It mustn't
107        # call get_package_config().
108        model_name = model.model_options.model_name
109        app_models = self.all_models[package_label]
110        if model_name in app_models:
111            if (
112                model.__name__ == app_models[model_name].__name__
113                and model.__module__ == app_models[model_name].__module__
114            ):
115                warnings.warn(
116                    f"Model '{package_label}.{model_name}' was already registered. Reloading models is not "
117                    "advised as it can lead to inconsistencies, most notably with "
118                    "related models.",
119                    RuntimeWarning,
120                    stacklevel=2,
121                )
122            else:
123                raise RuntimeError(
124                    f"Conflicting '{model_name}' models in application '{package_label}': {app_models[model_name]} and {model}."
125                )
126        app_models[model_name] = model
127        self.do_pending_operations(model)
128        self.clear_cache()
129
130    def _get_registered_model(self, package_label: str, model_name: str) -> type[Model]:
131        """
132        Similar to get_model(), but doesn't require that an app exists with
133        the given package_label.
134
135        It's safe to call this method at import time, even while the registry
136        is being populated.
137        """
138        model = self.all_models[package_label].get(model_name.lower())
139        if model is None:
140            raise LookupError(f"Model '{package_label}.{model_name}' not registered.")
141        return model
142
143    def clear_cache(self) -> None:
144        """
145        Clear all internal caches, for methods that alter the app registry.
146
147        This is mostly used in tests.
148        """
149        # Call expire cache on each model. This will purge
150        # the relation tree and the fields cache.
151        self.get_models.cache_clear()
152        if self.ready:
153            # Circumvent self.get_models() to prevent that the cache is refilled.
154            # This particularly prevents that an empty value is cached while cloning.
155            for package_models in self.all_models.values():
156                for model in package_models.values():
157                    model._model_meta._expire_cache()
158
159    def lazy_model_operation(
160        self, function: Callable[..., None], *model_keys: tuple[str, str]
161    ) -> None:
162        """
163        Take a function and a number of ("package_label", "modelname") tuples, and
164        when all the corresponding models have been imported and registered,
165        call the function with the model classes as its arguments.
166
167        The function passed to this method must accept exactly n models as
168        arguments, where n=len(model_keys).
169        """
170        # Base case: no arguments, just execute the function.
171        if not model_keys:
172            function()
173        # Recursive case: take the head of model_keys, wait for the
174        # corresponding model class to be imported and registered, then apply
175        # that argument to the supplied function. Pass the resulting partial
176        # to lazy_model_operation() along with the remaining model args and
177        # repeat until all models are loaded and all arguments are applied.
178        else:
179            next_model, *more_models = model_keys
180
181            # This will be executed after the class corresponding to next_model
182            # has been imported and registered.
183            def apply_next_model(model: type[Model]) -> None:
184                next_function = partial(function, model)
185                self.lazy_model_operation(next_function, *more_models)
186
187            # If the model has already been imported and registered, partially
188            # apply it to the function now. If not, add it to the list of
189            # pending operations for the model, where it will be executed with
190            # the model class as its sole argument once the model is ready.
191            try:
192                model_class = self._get_registered_model(*next_model)
193            except LookupError:
194                self._pending_operations[next_model].append(apply_next_model)
195            else:
196                apply_next_model(model_class)
197
198    def do_pending_operations(self, model: type[Model]) -> None:
199        """
200        Take a newly-prepared model and pass it to each function waiting for
201        it. This is called at the very end of Models.register_model().
202        """
203        key = model.model_options.package_label, model.model_options.model_name
204        for function in self._pending_operations.pop(key, []):
205            function(model)
206
207
208models_registry = ModelsRegistry()
209
210
211# Decorator to register a model (using the internal registry for the correct state).
212def register_model[M: "Model"](model_class: type[M]) -> type[M]:
213    model_class._model_meta.models_registry.register_model(
214        model_class.model_options.package_label,
215        model_class,
216    )
217    return model_class