Plain is headed towards 1.0! Subscribe for development updates →

  1import bisect
  2import copy
  3import inspect
  4from collections import defaultdict
  5from functools import cached_property
  6
  7from plain.exceptions import FieldDoesNotExist
  8from plain.models import models_registry
  9from plain.models.constraints import UniqueConstraint
 10from plain.models.db import db_connection
 11from plain.models.fields import BigAutoField
 12from plain.models.manager import Manager
 13from plain.utils.datastructures import ImmutableList
 14
 15PROXY_PARENTS = object()
 16
 17EMPTY_RELATION_TREE = ()
 18
 19IMMUTABLE_WARNING = (
 20    "The return type of '%s' should never be mutated. If you want to manipulate this "
 21    "list for your own use, make a copy first."
 22)
 23
 24DEFAULT_NAMES = (
 25    "db_table",
 26    "db_table_comment",
 27    "ordering",
 28    "get_latest_by",
 29    "package_label",
 30    "models_registry",
 31    "default_related_name",
 32    "required_db_features",
 33    "required_db_vendor",
 34    "base_manager_name",
 35    "default_manager_name",
 36    "indexes",
 37    "constraints",
 38)
 39
 40
 41def make_immutable_fields_list(name, data):
 42    return ImmutableList(data, warning=IMMUTABLE_WARNING % name)
 43
 44
 45class Options:
 46    FORWARD_PROPERTIES = {
 47        "fields",
 48        "many_to_many",
 49        "concrete_fields",
 50        "local_concrete_fields",
 51        "_non_pk_concrete_field_names",
 52        "_forward_fields_map",
 53        "managers",
 54        "managers_map",
 55        "base_manager",
 56        "default_manager",
 57    }
 58    REVERSE_PROPERTIES = {"related_objects", "fields_map", "_relation_tree"}
 59
 60    default_models_registry = models_registry
 61
 62    def __init__(self, meta, package_label=None):
 63        self._get_fields_cache = {}
 64        self.local_fields = []
 65        self.local_many_to_many = []
 66        self.local_managers = []
 67        self.base_manager_name = None
 68        self.default_manager_name = None
 69        self.model_name = None
 70        self.db_table = ""
 71        self.db_table_comment = ""
 72        self.ordering = []
 73        self.indexes = []
 74        self.constraints = []
 75        self.object_name = None
 76        self.package_label = package_label
 77        self.get_latest_by = None
 78        self.required_db_features = []
 79        self.required_db_vendor = None
 80        self.meta = meta
 81        self.pk = None
 82        self.auto_field = None
 83
 84        # For any non-abstract class, the concrete class is the model
 85        # in the end of the proxy_for_model chain. In particular, for
 86        # concrete models, the concrete_model is always the class itself.
 87        self.concrete_model = None
 88
 89        # List of all lookups defined in ForeignKey 'limit_choices_to' options
 90        # from *other* models. Needed for some admin checks. Internal use only.
 91        self.related_fkey_lookups = []
 92
 93        # A custom app registry to use, if you're making a separate model set.
 94        self.models_registry = self.default_models_registry
 95
 96        self.default_related_name = None
 97
 98    @property
 99    def label(self):
100        return f"{self.package_label}.{self.object_name}"
101
102    @property
103    def label_lower(self):
104        return f"{self.package_label}.{self.model_name}"
105
106    def contribute_to_class(self, cls, name):
107        from plain.models.backends.utils import truncate_name
108
109        cls._meta = self
110        self.model = cls
111        # First, construct the default values for these options.
112        self.object_name = cls.__name__
113        self.model_name = self.object_name.lower()
114
115        # Store the original user-defined values for each option,
116        # for use when serializing the model definition
117        self.original_attrs = {}
118
119        # Next, apply any overridden values from 'class Meta'.
120        if self.meta:
121            meta_attrs = self.meta.__dict__.copy()
122            for name in self.meta.__dict__:
123                # Ignore any private attributes that Plain doesn't care about.
124                # NOTE: We can't modify a dictionary's contents while looping
125                # over it, so we loop over the *original* dictionary instead.
126                if name.startswith("_"):
127                    del meta_attrs[name]
128            for attr_name in DEFAULT_NAMES:
129                if attr_name in meta_attrs:
130                    setattr(self, attr_name, meta_attrs.pop(attr_name))
131                    self.original_attrs[attr_name] = getattr(self, attr_name)
132                elif hasattr(self.meta, attr_name):
133                    setattr(self, attr_name, getattr(self.meta, attr_name))
134                    self.original_attrs[attr_name] = getattr(self, attr_name)
135
136            # Package label/class name interpolation for names of constraints and
137            # indexes.
138            for attr_name in {"constraints", "indexes"}:
139                objs = getattr(self, attr_name, [])
140                setattr(self, attr_name, self._format_names_with_class(cls, objs))
141
142            # Any leftover attributes must be invalid.
143            if meta_attrs != {}:
144                raise TypeError(
145                    "'class Meta' got invalid attribute(s): {}".format(
146                        ",".join(meta_attrs)
147                    )
148                )
149
150        del self.meta
151
152        # If the db_table wasn't provided, use the package_label + model_name.
153        if not self.db_table:
154            self.db_table = f"{self.package_label}_{self.model_name}"
155            self.db_table = truncate_name(
156                self.db_table,
157                db_connection.ops.max_name_length(),
158            )
159
160    def _format_names_with_class(self, cls, objs):
161        """Package label/class name interpolation for object names."""
162        new_objs = []
163        for obj in objs:
164            obj = obj.clone()
165            obj.name = obj.name % {
166                "package_label": cls._meta.package_label.lower(),
167                "class": cls.__name__.lower(),
168            }
169            new_objs.append(obj)
170        return new_objs
171
172    def _prepare(self, model):
173        if self.pk is None:
174            auto = BigAutoField(primary_key=True, auto_created=True)
175            model.add_to_class("id", auto)
176
177    def add_manager(self, manager):
178        self.local_managers.append(manager)
179        self._expire_cache()
180
181    def add_field(self, field, private=False):
182        # Insert the given field in the order in which it was created, using
183        # the "creation_counter" attribute of the field.
184        # Move many-to-many related fields from self.fields into
185        # self.many_to_many.
186        if field.is_relation and field.many_to_many:
187            bisect.insort(self.local_many_to_many, field)
188        else:
189            bisect.insort(self.local_fields, field)
190            self.setup_pk(field)
191
192        # If the field being added is a relation to another known field,
193        # expire the cache on this field and the forward cache on the field
194        # being referenced, because there will be new relationships in the
195        # cache. Otherwise, expire the cache of references *to* this field.
196        # The mechanism for getting at the related model is slightly odd -
197        # ideally, we'd just ask for field.related_model. However, related_model
198        # is a cached property, and all the models haven't been loaded yet, so
199        # we need to make sure we don't cache a string reference.
200        if (
201            field.is_relation
202            and hasattr(field.remote_field, "model")
203            and field.remote_field.model
204        ):
205            try:
206                field.remote_field.model._meta._expire_cache(forward=False)
207            except AttributeError:
208                pass
209            self._expire_cache()
210        else:
211            self._expire_cache(reverse=False)
212
213    def setup_pk(self, field):
214        if not self.pk and field.primary_key:
215            self.pk = field
216
217    def __repr__(self):
218        return f"<Options for {self.object_name}>"
219
220    def __str__(self):
221        return self.label_lower
222
223    def can_migrate(self, connection):
224        """
225        Return True if the model can/should be migrated on the given
226        `connection` object.
227        """
228        if self.required_db_vendor:
229            return self.required_db_vendor == connection.vendor
230        if self.required_db_features:
231            return all(
232                getattr(connection.features, feat, False)
233                for feat in self.required_db_features
234            )
235        return True
236
237    @cached_property
238    def managers(self):
239        managers = []
240        seen_managers = set()
241        bases = (b for b in self.model.mro() if hasattr(b, "_meta"))
242        for depth, base in enumerate(bases):
243            for manager in base._meta.local_managers:
244                if manager.name in seen_managers:
245                    continue
246
247                manager = copy.copy(manager)
248                manager.model = self.model
249                seen_managers.add(manager.name)
250                managers.append((depth, manager.creation_counter, manager))
251
252        return make_immutable_fields_list(
253            "managers",
254            (m[2] for m in sorted(managers)),
255        )
256
257    @cached_property
258    def managers_map(self):
259        return {manager.name: manager for manager in self.managers}
260
261    @cached_property
262    def base_manager(self):
263        base_manager_name = self.base_manager_name
264        if not base_manager_name:
265            # Get the first parent's base_manager_name if there's one.
266            for parent in self.model.mro()[1:]:
267                if hasattr(parent, "_meta"):
268                    if parent._base_manager.name != "_base_manager":
269                        base_manager_name = parent._base_manager.name
270                    break
271
272        if base_manager_name:
273            try:
274                return self.managers_map[base_manager_name]
275            except KeyError:
276                raise ValueError(
277                    f"{self.object_name} has no manager named {base_manager_name!r}"
278                )
279
280        manager = Manager()
281        manager.name = "_base_manager"
282        manager.model = self.model
283        manager.auto_created = True
284        return manager
285
286    @cached_property
287    def default_manager(self):
288        default_manager_name = self.default_manager_name
289        if not default_manager_name and not self.local_managers:
290            # Get the first parent's default_manager_name if there's one.
291            for parent in self.model.mro()[1:]:
292                if hasattr(parent, "_meta"):
293                    default_manager_name = parent._meta.default_manager_name
294                    break
295
296        if default_manager_name:
297            try:
298                return self.managers_map[default_manager_name]
299            except KeyError:
300                raise ValueError(
301                    f"{self.object_name} has no manager named {default_manager_name!r}"
302                )
303
304        if self.managers:
305            return self.managers[0]
306
307    @cached_property
308    def fields(self):
309        """
310        Return a list of all forward fields on the model and its parents,
311        excluding ManyToManyFields.
312
313        Private API intended only to be used by Plain itself; get_fields()
314        combined with filtering of field properties is the public API for
315        obtaining this field list.
316        """
317
318        # For legacy reasons, the fields property should only contain forward
319        # fields that are not private or with a m2m cardinality. Therefore we
320        # pass these three filters as filters to the generator.
321        # The third lambda is a longwinded way of checking f.related_model - we don't
322        # use that property directly because related_model is a cached property,
323        # and all the models may not have been loaded yet; we don't want to cache
324        # the string reference to the related_model.
325        def is_not_an_m2m_field(f):
326            return not (f.is_relation and f.many_to_many)
327
328        def is_not_a_generic_relation(f):
329            return not (f.is_relation and f.one_to_many)
330
331        def is_not_a_generic_foreign_key(f):
332            return not (
333                f.is_relation
334                and f.many_to_one
335                and not (hasattr(f.remote_field, "model") and f.remote_field.model)
336            )
337
338        return make_immutable_fields_list(
339            "fields",
340            (
341                f
342                for f in self._get_fields(reverse=False)
343                if is_not_an_m2m_field(f)
344                and is_not_a_generic_relation(f)
345                and is_not_a_generic_foreign_key(f)
346            ),
347        )
348
349    @cached_property
350    def concrete_fields(self):
351        """
352        Return a list of all concrete fields on the model and its parents.
353
354        Private API intended only to be used by Plain itself; get_fields()
355        combined with filtering of field properties is the public API for
356        obtaining this field list.
357        """
358        return make_immutable_fields_list(
359            "concrete_fields", (f for f in self.fields if f.concrete)
360        )
361
362    @cached_property
363    def local_concrete_fields(self):
364        """
365        Return a list of all concrete fields on the model.
366
367        Private API intended only to be used by Plain itself; get_fields()
368        combined with filtering of field properties is the public API for
369        obtaining this field list.
370        """
371        return make_immutable_fields_list(
372            "local_concrete_fields", (f for f in self.local_fields if f.concrete)
373        )
374
375    @cached_property
376    def many_to_many(self):
377        """
378        Return a list of all many to many fields on the model and its parents.
379
380        Private API intended only to be used by Plain itself; get_fields()
381        combined with filtering of field properties is the public API for
382        obtaining this list.
383        """
384        return make_immutable_fields_list(
385            "many_to_many",
386            (
387                f
388                for f in self._get_fields(reverse=False)
389                if f.is_relation and f.many_to_many
390            ),
391        )
392
393    @cached_property
394    def related_objects(self):
395        """
396        Return all related objects pointing to the current model. The related
397        objects can come from a one-to-one, one-to-many, or many-to-many field
398        relation type.
399
400        Private API intended only to be used by Plain itself; get_fields()
401        combined with filtering of field properties is the public API for
402        obtaining this field list.
403        """
404        all_related_fields = self._get_fields(
405            forward=False, reverse=True, include_hidden=True
406        )
407        return make_immutable_fields_list(
408            "related_objects",
409            (
410                obj
411                for obj in all_related_fields
412                if not obj.hidden or obj.field.many_to_many
413            ),
414        )
415
416    @cached_property
417    def _forward_fields_map(self):
418        res = {}
419        fields = self._get_fields(reverse=False)
420        for field in fields:
421            res[field.name] = field
422            # Due to the way Plain's internals work, get_field() should also
423            # be able to fetch a field by attname. In the case of a concrete
424            # field with relation, includes the *_id name too
425            try:
426                res[field.attname] = field
427            except AttributeError:
428                pass
429        return res
430
431    @cached_property
432    def fields_map(self):
433        res = {}
434        fields = self._get_fields(forward=False, include_hidden=True)
435        for field in fields:
436            res[field.name] = field
437            # Due to the way Plain's internals work, get_field() should also
438            # be able to fetch a field by attname. In the case of a concrete
439            # field with relation, includes the *_id name too
440            try:
441                res[field.attname] = field
442            except AttributeError:
443                pass
444        return res
445
446    def get_field(self, field_name):
447        """
448        Return a field instance given the name of a forward or reverse field.
449        """
450        try:
451            # In order to avoid premature loading of the relation tree
452            # (expensive) we prefer checking if the field is a forward field.
453            return self._forward_fields_map[field_name]
454        except KeyError:
455            # If the app registry is not ready, reverse fields are
456            # unavailable, therefore we throw a FieldDoesNotExist exception.
457            if not self.models_registry.ready:
458                raise FieldDoesNotExist(
459                    f"{self.object_name} has no field named '{field_name}'. The app cache isn't ready yet, "
460                    "so if this is an auto-created related field, it won't "
461                    "be available yet."
462                )
463
464        try:
465            # Retrieve field instance by name from cached or just-computed
466            # field map.
467            return self.fields_map[field_name]
468        except KeyError:
469            raise FieldDoesNotExist(
470                f"{self.object_name} has no field named '{field_name}'"
471            )
472
473    def _populate_directed_relation_graph(self):
474        """
475        This method is used by each model to find its reverse objects. As this
476        method is very expensive and is accessed frequently (it looks up every
477        field in a model, in every app), it is computed on first access and then
478        is set as a property on every model.
479        """
480        related_objects_graph = defaultdict(list)
481
482        all_models = self.models_registry.get_models()
483        for model in all_models:
484            opts = model._meta
485
486            fields_with_relations = (
487                f
488                for f in opts._get_fields(reverse=False)
489                if f.is_relation and f.related_model is not None
490            )
491            for f in fields_with_relations:
492                if not isinstance(f.remote_field.model, str):
493                    remote_label = f.remote_field.model._meta.concrete_model._meta.label
494                    related_objects_graph[remote_label].append(f)
495
496        for model in all_models:
497            # Set the relation_tree using the internal __dict__. In this way
498            # we avoid calling the cached property. In attribute lookup,
499            # __dict__ takes precedence over a data descriptor (such as
500            # @cached_property). This means that the _meta._relation_tree is
501            # only called if related_objects is not in __dict__.
502            related_objects = related_objects_graph[
503                model._meta.concrete_model._meta.label
504            ]
505            model._meta.__dict__["_relation_tree"] = related_objects
506        # It seems it is possible that self is not in all_models, so guard
507        # against that with default for get().
508        return self.__dict__.get("_relation_tree", EMPTY_RELATION_TREE)
509
510    @cached_property
511    def _relation_tree(self):
512        return self._populate_directed_relation_graph()
513
514    def _expire_cache(self, forward=True, reverse=True):
515        # This method is usually called by packages.cache_clear(), when the
516        # registry is finalized, or when a new field is added.
517        if forward:
518            for cache_key in self.FORWARD_PROPERTIES:
519                if cache_key in self.__dict__:
520                    delattr(self, cache_key)
521        if reverse:
522            for cache_key in self.REVERSE_PROPERTIES:
523                if cache_key in self.__dict__:
524                    delattr(self, cache_key)
525        self._get_fields_cache = {}
526
527    def get_fields(self, include_hidden=False):
528        """
529        Return a list of fields associated to the model. By default, include
530        forward and reverse fields, fields derived from inheritance, but not
531        hidden fields. The returned fields can be changed using the parameters:
532
533        - include_hidden:  include fields that have a related_name that
534                           starts with a "+"
535        """
536        return self._get_fields(include_hidden=include_hidden)
537
538    def _get_fields(
539        self,
540        forward=True,
541        reverse=True,
542        include_hidden=False,
543        seen_models=None,
544    ):
545        """
546        Internal helper function to return fields of the model.
547        * If forward=True, then fields defined on this model are returned.
548        * If reverse=True, then relations pointing to this model are returned.
549        * If include_hidden=True, then fields with is_hidden=True are returned.
550        """
551
552        # This helper function is used to allow recursion in ``get_fields()``
553        # implementation and to provide a fast way for Plain's internals to
554        # access specific subsets of fields.
555
556        # We must keep track of which models we have already seen. Otherwise we
557        # could include the same field multiple times from different models.
558        topmost_call = seen_models is None
559        if topmost_call:
560            seen_models = set()
561        seen_models.add(self.model)
562
563        # Creates a cache key composed of all arguments
564        cache_key = (forward, reverse, include_hidden, topmost_call)
565
566        try:
567            # In order to avoid list manipulation. Always return a shallow copy
568            # of the results.
569            return self._get_fields_cache[cache_key]
570        except KeyError:
571            pass
572
573        fields = []
574
575        if reverse:
576            # Tree is computed once and cached until the app cache is expired.
577            # It is composed of a list of fields pointing to the current model
578            # from other models.
579            all_fields = self._relation_tree
580            for field in all_fields:
581                # If hidden fields should be included or the relation is not
582                # intentionally hidden, add to the fields dict.
583                if include_hidden or not field.remote_field.hidden:
584                    fields.append(field.remote_field)
585
586        if forward:
587            fields += self.local_fields
588            fields += self.local_many_to_many
589
590        # In order to avoid list manipulation. Always
591        # return a shallow copy of the results
592        fields = make_immutable_fields_list("get_fields()", fields)
593
594        # Store result into cache for later access
595        self._get_fields_cache[cache_key] = fields
596        return fields
597
598    @cached_property
599    def total_unique_constraints(self):
600        """
601        Return a list of total unique constraints. Useful for determining set
602        of fields guaranteed to be unique for all rows.
603        """
604        return [
605            constraint
606            for constraint in self.constraints
607            if (
608                isinstance(constraint, UniqueConstraint)
609                and constraint.condition is None
610                and not constraint.contains_expressions
611            )
612        ]
613
614    @cached_property
615    def _property_names(self):
616        """Return a set of the names of the properties defined on the model."""
617        names = []
618        for name in dir(self.model):
619            attr = inspect.getattr_static(self.model, name)
620            if isinstance(attr, property):
621                names.append(name)
622        return frozenset(names)
623
624    @cached_property
625    def _non_pk_concrete_field_names(self):
626        """
627        Return a set of the non-pk concrete field names defined on the model.
628        """
629        names = []
630        for field in self.concrete_fields:
631            if not field.primary_key:
632                names.append(field.name)
633                if field.name != field.attname:
634                    names.append(field.attname)
635        return frozenset(names)
636
637    @cached_property
638    def db_returning_fields(self):
639        """
640        Private API intended only to be used by Plain itself.
641        Fields to be returned after a database insert.
642        """
643        return [
644            field
645            for field in self._get_fields(forward=True, reverse=False)
646            if getattr(field, "db_returning", False)
647        ]