1from __future__ import annotations
  2
  3import copy
  4import inspect
  5from collections import defaultdict
  6from collections.abc import Iterable
  7from functools import cached_property
  8from typing import TYPE_CHECKING, Any, Literal, overload
  9
 10from plain.postgres.exceptions import FieldDoesNotExist
 11from plain.postgres.query import QuerySet
 12from plain.postgres.registry import models_registry as default_models_registry
 13from plain.utils.datastructures import ImmutableList
 14
 15if TYPE_CHECKING:
 16    from plain.postgres.base import Model
 17    from plain.postgres.fields import Field
 18    from plain.postgres.fields.related import ManyToManyField, RelatedField
 19    from plain.postgres.fields.reverse_related import ForeignObjectRel
 20
 21EMPTY_RELATION_TREE = ()
 22
 23IMMUTABLE_WARNING = (
 24    "The return type of '%s' should never be mutated. If you want to manipulate this "
 25    "list for your own use, make a copy first."
 26)
 27
 28
 29def make_immutable_fields_list[T](name: str, data: Iterable[T]) -> ImmutableList[T]:
 30    return ImmutableList(data, warning=IMMUTABLE_WARNING % name)
 31
 32
 33class Meta:
 34    """
 35    Model metadata descriptor and container.
 36
 37    Acts as both a descriptor (for lazy initialization and access control)
 38    and the actual metadata instance (cached per model class).
 39    """
 40
 41    FORWARD_PROPERTIES = {
 42        "fields",
 43        "many_to_many",
 44        "concrete_fields",
 45        "local_concrete_fields",
 46        "_non_pk_concrete_field_names",
 47        "_forward_fields_map",
 48        "base_queryset",
 49    }
 50    REVERSE_PROPERTIES = {"related_objects", "fields_map", "_relation_tree"}
 51
 52    # Type annotations for attributes set in _create_and_cache
 53    # These exist on cached instances, not on the descriptor itself
 54    model: type[Model]
 55    models_registry: Any
 56    _get_fields_cache: dict[Any, Any]
 57    local_fields: list[Field]
 58    local_many_to_many: list[ManyToManyField]
 59
 60    def __init__(self, models_registry: Any | None = None):
 61        """
 62        Initialize the descriptor with optional configuration.
 63
 64        This is called ONCE when defining the base Model class.
 65        The descriptor then creates cached instances per model subclass.
 66        """
 67        self._models_registry = models_registry
 68        self._cache: dict[type[Model], Meta] = {}
 69
 70    def __get__(self, instance: Any, owner: type[Model]) -> Meta:
 71        """
 72        Descriptor protocol - returns cached Meta instance for the model class.
 73
 74        This is called when accessing Model._model_meta and returns a per-class
 75        cached instance created by _create_and_cache().
 76
 77        Can be accessed from both class and instances:
 78        - MyModel._model_meta (class access)
 79        - my_instance._model_meta (instance access - returns class's metadata)
 80        """
 81        # Allow instance access - just return the class's metadata
 82        if instance is not None:
 83            owner = instance.__class__
 84
 85        # Skip for the base Model class - return descriptor
 86        if owner.__name__ == "Model" and owner.__module__ == "plain.postgres.base":
 87            return self
 88
 89        # Return cached instance or create new one
 90        if owner not in self._cache:
 91            # Create the instance and cache it BEFORE field contribution
 92            # to avoid infinite recursion when fields access cls._model_meta
 93            return self._create_and_cache(owner)
 94
 95        return self._cache[owner]
 96
 97    def _create_and_cache(self, model: type[Model]) -> Meta:
 98        """Create Meta instance and cache it before field contribution."""
 99        # Create instance without calling __init__
100        instance = Meta.__new__(Meta)
101
102        # Initialize basic model-specific state
103        instance.model = model
104        instance.models_registry = self._models_registry or default_models_registry
105        instance._get_fields_cache = {}
106        instance.local_fields = []
107        instance.local_many_to_many = []
108
109        # Cache the instance BEFORE processing fields to prevent recursion
110        self._cache[model] = instance
111
112        # Now process fields - they can safely access cls._model_meta
113        seen_attrs = set()
114        for klass in model.__mro__:
115            for attr_name in list(klass.__dict__.keys()):
116                if attr_name.startswith("_") or attr_name in seen_attrs:
117                    continue
118                seen_attrs.add(attr_name)
119
120                attr_value = klass.__dict__[attr_name]
121
122                if not inspect.isclass(attr_value) and hasattr(
123                    attr_value, "contribute_to_class"
124                ):
125                    if attr_name not in model.__dict__:
126                        field = copy.deepcopy(attr_value)
127                    else:
128                        field = attr_value
129                    field.contribute_to_class(model, attr_name)
130
131        # Sort fields: primary key first, then alphabetically by name
132        instance.local_fields.sort(key=lambda f: (not f.primary_key, f.name))
133        instance.local_many_to_many.sort(key=lambda f: f.name)
134
135        # Set index names now that fields are contributed
136        # Trigger model_options descriptor to ensure it's initialized
137        # (accessing it will cache the instance)
138        for index in model.model_options.indexes:
139            if not index.name:
140                index.set_name_with_model(model)
141
142        return instance
143
144    @property
145    def base_queryset(self) -> QuerySet:
146        """
147        The base queryset is used by Plain's internal operations like cascading
148        deletes, migrations, and related object lookups. It provides access to
149        all objects in the database without any filtering, ensuring Plain can
150        always see the complete dataset when performing framework operations.
151
152        Unlike user-defined querysets which may filter results (e.g. only active
153        objects), the base queryset must never filter out rows to prevent
154        incomplete results in related queries.
155        """
156        return QuerySet.from_model(self.model)
157
158    def add_field(self, field: Field) -> None:
159        from plain.postgres.fields.related import ManyToManyField, RelatedField
160
161        if isinstance(field, ManyToManyField):
162            self.local_many_to_many.append(field)
163        else:
164            self.local_fields.append(field)
165
166        # If the field being added is a relation to another known field,
167        # expire the cache on this field and the forward cache on the field
168        # being referenced, because there will be new relationships in the
169        # cache. Otherwise, expire the cache of references *to* this field.
170        # The mechanism for getting at the related model is slightly odd -
171        # ideally, we'd just ask for field.related_model. However, related_model
172        # is a cached property, and all the models haven't been loaded yet, so
173        # we need to make sure we don't cache a string reference.
174        if isinstance(field, RelatedField) and field.remote_field.model:
175            try:
176                field.remote_field.model._model_meta._expire_cache(forward=False)
177            except AttributeError:
178                pass
179            self._expire_cache()
180        else:
181            self._expire_cache(reverse=False)
182
183    @cached_property
184    def fields(self) -> ImmutableList[Field]:
185        from plain.postgres.fields.related import RelatedField
186
187        """
188        Return a list of all forward fields on the model and its parents,
189        excluding ManyToManyFields.
190
191        Private API intended only to be used by Plain itself; get_fields()
192        combined with filtering of field properties is the public API for
193        obtaining this field list.
194        """
195
196        # For legacy reasons, the fields property should only contain forward
197        # fields that are not private or with a m2m cardinality.
198        def is_not_an_m2m_field(f: Any) -> bool:
199            from plain.postgres.fields.related import ManyToManyField
200
201            return not isinstance(f, ManyToManyField)
202
203        def is_not_a_generic_relation(f: Any) -> bool:
204            from plain.postgres.fields.related import ForeignKeyField, ManyToManyField
205
206            # Only ForeignKeyField and ManyToManyField are valid RelatedFields
207            # Anything else is a generic relation
208            if not isinstance(f, RelatedField):
209                return True
210            return isinstance(f, ForeignKeyField | ManyToManyField)
211
212        return make_immutable_fields_list(
213            "fields",
214            (
215                f
216                for f in self._get_fields(reverse=False)
217                if is_not_an_m2m_field(f) and is_not_a_generic_relation(f)
218            ),
219        )
220
221    @cached_property
222    def concrete_fields(self) -> ImmutableList[Field]:
223        """
224        Return a list of all concrete fields on the model and its parents.
225
226        Private API intended only to be used by Plain itself; get_fields()
227        combined with filtering of field properties is the public API for
228        obtaining this field list.
229        """
230        return make_immutable_fields_list(
231            "concrete_fields", (f for f in self.fields if f.concrete)
232        )
233
234    @cached_property
235    def local_concrete_fields(self) -> ImmutableList[Field]:
236        """
237        Return a list of all concrete fields on the model.
238
239        Private API intended only to be used by Plain itself; get_fields()
240        combined with filtering of field properties is the public API for
241        obtaining this field list.
242        """
243        return make_immutable_fields_list(
244            "local_concrete_fields", (f for f in self.local_fields if f.concrete)
245        )
246
247    @cached_property
248    def many_to_many(self) -> ImmutableList[Field]:
249        """
250        Return a list of all many to many fields on the model and its parents.
251
252        Private API intended only to be used by Plain itself; get_fields()
253        combined with filtering of field properties is the public API for
254        obtaining this list.
255        """
256        from plain.postgres.fields.related import ManyToManyField
257
258        return make_immutable_fields_list(
259            "many_to_many",
260            (
261                f
262                for f in self._get_fields(reverse=False)
263                if isinstance(f, ManyToManyField)
264            ),
265        )
266
267    @cached_property
268    def related_objects(self) -> ImmutableList[ForeignObjectRel]:
269        """
270        Return all related objects pointing to the current model. The related
271        objects can come from a one-to-one, one-to-many, or many-to-many field
272        relation type.
273
274        Private API intended only to be used by Plain itself; get_fields()
275        combined with filtering of field properties is the public API for
276        obtaining this field list.
277        """
278        from plain.postgres.fields.reverse_related import ForeignKeyRel, ManyToManyRel
279
280        all_related_fields = self._get_fields(forward=False, reverse=True)
281        return make_immutable_fields_list(
282            "related_objects",
283            (
284                obj
285                for obj in all_related_fields
286                if isinstance(obj, ManyToManyRel | ForeignKeyRel)
287            ),
288        )
289
290    @cached_property
291    def _forward_fields_map(self) -> dict[str, Field]:
292        res = {}
293        fields = self._get_fields(reverse=False)
294        for field in fields:
295            res[field.name] = field
296            # Due to the way Plain's internals work, get_field() should also
297            # be able to fetch a field by attname. In the case of a concrete
298            # field with relation, includes the *_id name too
299            try:
300                res[field.attname] = field
301            except AttributeError:
302                pass
303        return res
304
305    @cached_property
306    def fields_map(self) -> dict[str, Field | ForeignObjectRel]:
307        res = {}
308        fields = self._get_fields(forward=False, reverse=True)
309        for field in fields:
310            res[field.name] = field
311            # Due to the way Plain's internals work, get_field() should also
312            # be able to fetch a field by attname. In the case of a concrete
313            # field with relation, includes the *_id name too
314            try:
315                res[field.attname] = field
316            except AttributeError:
317                pass
318        return res
319
320    def get_field(self, field_name: str) -> Field | ForeignObjectRel:
321        """
322        Return a field instance given the name of a forward or reverse field.
323        """
324        try:
325            # In order to avoid premature loading of the relation tree
326            # (expensive) we prefer checking if the field is a forward field.
327            return self._forward_fields_map[field_name]
328        except KeyError:
329            # If the app registry is not ready, reverse fields are
330            # unavailable, therefore we throw a FieldDoesNotExist exception.
331            if not self.models_registry.ready:
332                raise FieldDoesNotExist(
333                    f"{self.model} has no field named '{field_name}'. The app cache isn't ready yet, "
334                    "so if this is an auto-created related field, it won't "
335                    "be available yet."
336                )
337
338        try:
339            # Retrieve field instance by name from cached or just-computed
340            # field map.
341            return self.fields_map[field_name]
342        except KeyError:
343            raise FieldDoesNotExist(f"{self.model} has no field named '{field_name}'")
344
345    def get_forward_field(self, field_name: str) -> Field:
346        """
347        Return a forward field instance given the field name.
348
349        Raises FieldDoesNotExist if the field doesn't exist or is a reverse relation.
350        """
351        try:
352            return self._forward_fields_map[field_name]
353        except KeyError:
354            raise FieldDoesNotExist(
355                f"{self.model} has no forward field named '{field_name}'"
356            )
357
358    def get_reverse_relation(self, field_name: str) -> ForeignObjectRel:
359        """
360        Return a reverse relation instance given the field name.
361
362        Raises FieldDoesNotExist if the field doesn't exist or is a forward field.
363        """
364        # If the app registry is not ready, reverse fields are unavailable
365        if not self.models_registry.ready:
366            raise FieldDoesNotExist(
367                f"{self.model} has no reverse relation named '{field_name}'. The app cache isn't ready yet."
368            )
369
370        # Check if it's a forward field first
371        if field_name in self._forward_fields_map:
372            raise FieldDoesNotExist(
373                f"'{field_name}' is a forward field, not a reverse relation"
374            )
375
376        try:
377            return self.fields_map[field_name]  # type: ignore[return-type]
378        except KeyError:
379            raise FieldDoesNotExist(
380                f"{self.model} has no reverse relation named '{field_name}'"
381            )
382
383    def _populate_directed_relation_graph(self) -> list[RelatedField]:
384        from plain.postgres.fields.related import RelatedField
385
386        """
387        This method is used by each model to find its reverse objects. As this
388        method is very expensive and is accessed frequently (it looks up every
389        field in a model, in every app), it is computed on first access and then
390        is set as a property on every model.
391        """
392        related_objects_graph: defaultdict[str, list[Any]] = defaultdict(list)
393
394        all_models = self.models_registry.get_models()
395        for model in all_models:
396            meta = model._model_meta
397
398            fields_with_relations = (
399                f
400                for f in meta._get_fields(reverse=False)
401                if isinstance(f, RelatedField)
402            )
403            for f in fields_with_relations:
404                if not isinstance(f.remote_field.model, str):
405                    remote_label = f.remote_field.model.model_options.label
406                    related_objects_graph[remote_label].append(f)
407
408        for model in all_models:
409            # Set the relation_tree using the internal __dict__. In this way
410            # we avoid calling the cached property. In attribute lookup,
411            # __dict__ takes precedence over a data descriptor (such as
412            # @cached_property). This means that the _model_meta._relation_tree is
413            # only called if related_objects is not in __dict__.
414            related_objects = related_objects_graph[model.model_options.label]
415            model._model_meta.__dict__["_relation_tree"] = related_objects
416        # It seems it is possible that self is not in all_models, so guard
417        # against that with default for get().
418        return self.__dict__.get("_relation_tree", EMPTY_RELATION_TREE)
419
420    @cached_property
421    def _relation_tree(self) -> list[RelatedField]:
422        return self._populate_directed_relation_graph()
423
424    def _expire_cache(self, forward: bool = True, reverse: bool = True) -> None:
425        # This method is usually called by packages.cache_clear(), when the
426        # registry is finalized, or when a new field is added.
427        if forward:
428            for cache_key in self.FORWARD_PROPERTIES:
429                if cache_key in self.__dict__:
430                    delattr(self, cache_key)
431        if reverse:
432            for cache_key in self.REVERSE_PROPERTIES:
433                if cache_key in self.__dict__:
434                    delattr(self, cache_key)
435        self._get_fields_cache = {}
436
437    @overload
438    def get_fields(
439        self, include_reverse: Literal[False] = False
440    ) -> ImmutableList[Field]: ...
441
442    @overload
443    def get_fields(
444        self, include_reverse: Literal[True]
445    ) -> ImmutableList[Field | ForeignObjectRel]: ...
446
447    def get_fields(
448        self, include_reverse: bool = False
449    ) -> ImmutableList[Field | ForeignObjectRel]:
450        """
451        Return a list of fields associated to the model.
452
453        By default, returns only forward fields (fields explicitly defined on
454        this model). Set include_reverse=True to also include reverse relations
455        (fields from other models that point to this model).
456
457        Args:
458            include_reverse: Include reverse relation fields (fields from other
459                           models pointing to this model). Needed for framework
460                           operations like migrations and deletion cascading.
461        """
462        return self._get_fields(reverse=include_reverse)
463
464    @overload
465    def _get_fields(
466        self,
467        *,
468        forward: Literal[True] = True,
469        reverse: Literal[False],
470        seen_models: set[type[Any]] | None = None,
471    ) -> ImmutableList[Field]: ...
472
473    @overload
474    def _get_fields(
475        self,
476        *,
477        forward: Literal[False],
478        reverse: Literal[True] = True,
479        seen_models: set[type[Any]] | None = None,
480    ) -> ImmutableList[ForeignObjectRel]: ...
481
482    @overload
483    def _get_fields(
484        self,
485        *,
486        forward: bool = True,
487        reverse: bool = True,
488        seen_models: set[type[Any]] | None = None,
489    ) -> ImmutableList[Field | ForeignObjectRel]: ...
490
491    def _get_fields(
492        self,
493        *,
494        forward: bool = True,
495        reverse: bool = True,
496        seen_models: set[type[Any]] | None = None,
497    ) -> ImmutableList[Field | ForeignObjectRel]:
498        """
499        Internal helper function to return fields of the model.
500
501        Args:
502            forward: If True, fields defined on this model are returned.
503            reverse: If True, reverse relations (fields from other models
504                    pointing to this model) are returned.
505            seen_models: Track visited models to prevent duplicates in recursion.
506        """
507
508        # This helper function is used to allow recursion in ``get_fields()``
509        # implementation and to provide a fast way for Plain's internals to
510        # access specific subsets of fields.
511
512        # We must keep track of which models we have already seen. Otherwise we
513        # could include the same field multiple times from different models.
514        topmost_call = seen_models is None
515        if seen_models is None:
516            seen_models = set()
517        seen_models.add(self.model)
518
519        # Creates a cache key composed of all arguments
520        cache_key = (forward, reverse, topmost_call)
521
522        try:
523            # In order to avoid list manipulation. Always return a shallow copy
524            # of the results.
525            return self._get_fields_cache[cache_key]
526        except KeyError:
527            pass
528
529        fields = []
530
531        if reverse:
532            # Tree is computed once and cached until the app cache is expired.
533            # It is composed of a list of fields from other models pointing to
534            # the current model (reverse relations).
535            all_fields = self._relation_tree
536            for field in all_fields:
537                fields.append(field.remote_field)
538
539        if forward:
540            fields += self.local_fields
541            fields += self.local_many_to_many
542
543        # In order to avoid list manipulation. Always
544        # return a shallow copy of the results
545        fields = make_immutable_fields_list("get_fields()", fields)
546
547        # Store result into cache for later access
548        self._get_fields_cache[cache_key] = fields
549        return fields
550
551    @cached_property
552    def _property_names(self) -> frozenset[str]:
553        """Return a set of the names of the properties defined on the model."""
554        names = []
555        for name in dir(self.model):
556            attr = inspect.getattr_static(self.model, name)
557            if isinstance(attr, property):
558                names.append(name)
559        return frozenset(names)
560
561    @cached_property
562    def _non_pk_concrete_field_names(self) -> frozenset[str]:
563        """
564        Return a set of the non-primary key concrete field names defined on the model.
565        """
566        names = []
567        for field in self.concrete_fields:
568            if not field.primary_key:
569                names.append(field.name)
570                if field.name != field.attname:
571                    names.append(field.attname)
572        return frozenset(names)
573
574    @cached_property
575    def db_returning_fields(self) -> list[Field]:
576        """
577        Private API intended only to be used by Plain itself.
578        Fields to be returned after a database insert.
579        """
580        return [
581            field
582            for field in self._get_fields(forward=True, reverse=False)
583            if getattr(field, "db_returning", False)
584        ]