Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Various data structures used in query construction.
  3
  4Factored out from plain.models.query to avoid making the main module very
  5large and/or so that they can be used by other modules without getting into
  6circular import difficulties.
  7"""
  8
  9from __future__ import annotations
 10
 11import functools
 12import inspect
 13import logging
 14from collections import namedtuple
 15from collections.abc import Generator
 16from typing import TYPE_CHECKING, Any
 17
 18from plain.models.constants import LOOKUP_SEP
 19from plain.models.db import DatabaseError, db_connection
 20from plain.models.exceptions import FieldError
 21from plain.utils import tree
 22
 23if TYPE_CHECKING:
 24    from plain.models.backends.base.base import BaseDatabaseWrapper
 25    from plain.models.base import Model
 26    from plain.models.fields import Field
 27    from plain.models.meta import Meta
 28    from plain.models.sql.compiler import SQLCompiler
 29
 30logger = logging.getLogger("plain.models")
 31
 32# PathInfo is used when converting lookups (fk__somecol). The contents
 33# describe the relation in Model terms (Meta and Fields for both
 34# sides of the relation). The join_field is the field backing the relation.
 35PathInfo = namedtuple(
 36    "PathInfo",
 37    "from_meta to_meta target_fields join_field m2m direct filtered_relation",
 38)
 39
 40
 41def subclasses(cls: type) -> Generator[type, None, None]:
 42    yield cls
 43    for subclass in cls.__subclasses__():
 44        yield from subclasses(subclass)
 45
 46
 47class Q(tree.Node):
 48    """
 49    Encapsulate filters as objects that can then be combined logically (using
 50    `&` and `|`).
 51    """
 52
 53    # Connection types
 54    AND = "AND"
 55    OR = "OR"
 56    XOR = "XOR"
 57    default = AND
 58    conditional = True
 59
 60    def __init__(
 61        self,
 62        *args: Any,
 63        _connector: str | None = None,
 64        _negated: bool = False,
 65        **kwargs: Any,
 66    ) -> None:
 67        super().__init__(
 68            children=[*args, *sorted(kwargs.items())],
 69            connector=_connector,
 70            negated=_negated,
 71        )
 72
 73    def _combine(self, other: Any, conn: str) -> Q:
 74        if getattr(other, "conditional", False) is False:
 75            raise TypeError(other)
 76        if not self:
 77            return other.copy()
 78        if not other and isinstance(other, Q):
 79            return self.copy()
 80
 81        obj = self.create(connector=conn)
 82        obj.add(self, conn)
 83        obj.add(other, conn)
 84        return obj
 85
 86    def __or__(self, other: Any) -> Q:
 87        return self._combine(other, self.OR)
 88
 89    def __and__(self, other: Any) -> Q:
 90        return self._combine(other, self.AND)
 91
 92    def __xor__(self, other: Any) -> Q:
 93        return self._combine(other, self.XOR)
 94
 95    def __invert__(self) -> Q:
 96        obj = self.copy()
 97        obj.negate()
 98        return obj
 99
100    def resolve_expression(
101        self,
102        query: Any = None,
103        allow_joins: bool = True,
104        reuse: Any = None,
105        summarize: bool = False,
106        for_save: bool = False,
107    ) -> Any:
108        # We must promote any new joins to left outer joins so that when Q is
109        # used as an expression, rows aren't filtered due to joins.
110        clause, joins = query._add_q(  # type: ignore[union-attr]
111            self,
112            reuse,
113            allow_joins=allow_joins,
114            split_subq=False,
115            check_filterable=False,
116            summarize=summarize,
117        )
118        query.promote_joins(joins)  # type: ignore[union-attr]
119        return clause
120
121    def flatten(self) -> Generator[Any, None, None]:
122        """
123        Recursively yield this Q object and all subexpressions, in depth-first
124        order.
125        """
126        yield self
127        for child in self.children:
128            if isinstance(child, tuple):
129                # Use the lookup.
130                child = child[1]
131            if hasattr(child, "flatten"):
132                yield from child.flatten()
133            else:
134                yield child
135
136    def check(self, against: dict[str, Any]) -> bool:
137        """
138        Do a database query to check if the expressions of the Q instance
139        matches against the expressions.
140        """
141        # Avoid circular imports.
142        from plain.models.expressions import Value
143        from plain.models.fields import BooleanField
144        from plain.models.functions import Coalesce
145        from plain.models.sql import Query
146        from plain.models.sql.constants import SINGLE
147
148        query = Query(None)  # type: ignore[arg-type]
149        for name, value in against.items():
150            if not hasattr(value, "resolve_expression"):
151                value = Value(value)
152            query.add_annotation(value, name, select=False)
153        query.add_annotation(Value(1), "_check")
154        # This will raise a FieldError if a field is missing in "against".
155        if db_connection.features.supports_comparing_boolean_expr:
156            query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
157        else:
158            query.add_q(self)
159        compiler = query.get_compiler()
160        try:
161            return compiler.execute_sql(SINGLE) is not None
162        except DatabaseError as e:
163            logger.warning("Got a database error calling check() on %r: %s", self, e)
164            return True
165
166    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
167        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
168        if path.startswith("plain.models.query_utils"):
169            path = path.replace("plain.models.query_utils", "plain.models")
170        args = tuple(self.children)
171        kwargs: dict[str, Any] = {}
172        if self.connector != self.default:
173            kwargs["_connector"] = self.connector
174        if self.negated:
175            kwargs["_negated"] = True
176        return path, args, kwargs
177
178
179class DeferredAttribute:
180    """
181    A wrapper for a deferred-loading field. When the value is read from this
182    object the first time, the query is executed.
183    """
184
185    def __init__(self, field: Any) -> None:
186        self.field = field
187
188    def __get__(self, instance: Any, cls: type | None = None) -> Any:
189        """
190        Retrieve and caches the value from the datastore on the first lookup.
191        Return the cached value.
192        """
193        if instance is None:
194            return self
195        data = instance.__dict__
196        field_name = self.field.attname
197        if field_name not in data:
198            instance.refresh_from_db(fields=[field_name])
199        return data[field_name]
200
201
202class class_or_instance_method:
203    """
204    Hook used in RegisterLookupMixin to return partial functions depending on
205    the caller type (instance or class of models.Field).
206    """
207
208    def __init__(self, class_method: Any, instance_method: Any) -> None:
209        self.class_method = class_method
210        self.instance_method = instance_method
211
212    def __get__(self, instance: Any, owner: type) -> Any:
213        if instance is None:
214            return functools.partial(self.class_method, owner)
215        return functools.partial(self.instance_method, instance)
216
217
218class RegisterLookupMixin:
219    def _get_lookup(self, lookup_name: str) -> type | None:
220        return self.get_lookups().get(lookup_name, None)
221
222    @functools.cache
223    def get_class_lookups(cls: type) -> dict[str, type]:
224        class_lookups = [
225            parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
226        ]
227        return cls.merge_dicts(class_lookups)  # type: ignore[attr-defined]
228
229    def get_instance_lookups(self) -> dict[str, type]:
230        class_lookups = self.get_class_lookups()
231        if instance_lookups := getattr(self, "instance_lookups", None):
232            return {**class_lookups, **instance_lookups}
233        return class_lookups
234
235    get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
236    get_class_lookups = classmethod(get_class_lookups)  # type: ignore[assignment]
237
238    def get_lookup(self, lookup_name: str) -> type | None:
239        from plain.models.lookups import Lookup
240
241        found = self._get_lookup(lookup_name)
242        if found is None and hasattr(self, "output_field"):
243            return self.output_field.get_lookup(lookup_name)
244        if found is not None and not issubclass(found, Lookup):
245            return None
246        return found
247
248    def get_transform(self, lookup_name: str) -> type | None:
249        from plain.models.lookups import Transform
250
251        found = self._get_lookup(lookup_name)
252        if found is None and hasattr(self, "output_field"):
253            return self.output_field.get_transform(lookup_name)
254        if found is not None and not issubclass(found, Transform):
255            return None
256        return found
257
258    @staticmethod
259    def merge_dicts(dicts: list[dict[str, type]]) -> dict[str, type]:
260        """
261        Merge dicts in reverse to preference the order of the original list. e.g.,
262        merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
263        """
264        merged: dict[str, type] = {}
265        for d in reversed(dicts):
266            merged.update(d)
267        return merged
268
269    @classmethod
270    def _clear_cached_class_lookups(cls) -> None:
271        for subclass in subclasses(cls):
272            subclass.get_class_lookups.cache_clear()  # type: ignore[attr-defined]
273
274    def register_class_lookup(
275        cls: type, lookup: type, lookup_name: str | None = None
276    ) -> type:
277        if lookup_name is None:
278            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
279        if "class_lookups" not in cls.__dict__:
280            cls.class_lookups = {}  # type: ignore[attr-defined]
281        cls.class_lookups[lookup_name] = lookup  # type: ignore[attr-defined]
282        cls._clear_cached_class_lookups()  # type: ignore[attr-defined]
283        return lookup
284
285    def register_instance_lookup(
286        self, lookup: type, lookup_name: str | None = None
287    ) -> type:
288        if lookup_name is None:
289            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
290        if "instance_lookups" not in self.__dict__:
291            self.instance_lookups = {}
292        self.instance_lookups[lookup_name] = lookup
293        return lookup
294
295    register_lookup = class_or_instance_method(
296        register_class_lookup, register_instance_lookup
297    )
298    register_class_lookup = classmethod(register_class_lookup)  # type: ignore[assignment]
299
300    def _unregister_class_lookup(
301        cls: type, lookup: type, lookup_name: str | None = None
302    ) -> None:
303        """
304        Remove given lookup from cls lookups. For use in tests only as it's
305        not thread-safe.
306        """
307        if lookup_name is None:
308            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
309        del cls.class_lookups[lookup_name]  # type: ignore[attr-defined]
310        cls._clear_cached_class_lookups()  # type: ignore[attr-defined]
311
312    def _unregister_instance_lookup(
313        self, lookup: type, lookup_name: str | None = None
314    ) -> None:
315        """
316        Remove given lookup from instance lookups. For use in tests only as
317        it's not thread-safe.
318        """
319        if lookup_name is None:
320            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
321        del self.instance_lookups[lookup_name]
322
323    _unregister_lookup = class_or_instance_method(
324        _unregister_class_lookup, _unregister_instance_lookup
325    )
326    _unregister_class_lookup = classmethod(_unregister_class_lookup)  # type: ignore[assignment]
327
328
329def select_related_descend(
330    field: Any,
331    restricted: bool,
332    requested: dict[str, Any],
333    select_mask: Any,
334    reverse: bool = False,
335) -> bool:
336    """
337    Return True if this field should be used to descend deeper for
338    select_related() purposes. Used by both the query construction code
339    (compiler.get_related_selections()) and the model instance creation code
340    (compiler.klass_info).
341
342    Arguments:
343     * field - the field to be checked
344     * restricted - a boolean field, indicating if the field list has been
345       manually restricted using a requested clause)
346     * requested - The select_related() dictionary.
347     * select_mask - the dictionary of selected fields.
348     * reverse - boolean, True if we are checking a reverse select related
349    """
350    if not field.remote_field:
351        return False
352    if restricted:
353        if reverse and field.related_query_name() not in requested:
354            return False
355        if not reverse and field.name not in requested:
356            return False
357    if not restricted and field.allow_null:
358        return False
359    if (
360        restricted
361        and select_mask
362        and field.name in requested
363        and field not in select_mask
364    ):
365        raise FieldError(
366            f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
367            "deferred and traversed using select_related at the same time."
368        )
369    return True
370
371
372def refs_expression(
373    lookup_parts: list[str], annotations: dict[str, Any]
374) -> tuple[str | None, tuple[str, ...]]:
375    """
376    Check if the lookup_parts contains references to the given annotations set.
377    Because the LOOKUP_SEP is contained in the default annotation names, check
378    each prefix of the lookup_parts for a match.
379    """
380    for n in range(1, len(lookup_parts) + 1):
381        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
382        if annotations.get(level_n_lookup):
383            return level_n_lookup, tuple(lookup_parts[n:])
384    return None, ()
385
386
387def check_rel_lookup_compatibility(
388    model: type[Model], target_meta: Meta, field: Field
389) -> bool:
390    """
391    Check that model is compatible with target_meta. Compatibility
392    is OK if:
393      1) model and meta.model match (where proxy inheritance is removed)
394      2) model is parent of meta's model or the other way around
395    """
396
397    def check(meta: Meta) -> bool:
398        return model == meta.model
399
400    # If the field is a primary key, then doing a query against the field's
401    # model is ok, too. Consider the case:
402    # class Restaurant(models.Model):
403    #     place = OneToOneField(Place, primary_key=True):
404    # Restaurant.query.filter(id__in=Restaurant.query.all()).
405    # If we didn't have the primary key check, then id__in (== place__in) would
406    # give Place's meta as the target meta, but Restaurant isn't compatible
407    # with that. This logic applies only to primary keys, as when doing __in=qs,
408    # we are going to turn this into __in=qs.values('id') later on.
409    return check(target_meta) or (
410        getattr(field, "primary_key", False) and check(field.model._model_meta)
411    )
412
413
414class FilteredRelation:
415    """Specify custom filtering in the ON clause of SQL joins."""
416
417    def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
418        if not relation_name:
419            raise ValueError("relation_name cannot be empty.")
420        self.relation_name = relation_name
421        self.alias: str | None = None
422        if not isinstance(condition, Q):
423            raise ValueError("condition argument must be a Q() instance.")
424        self.condition = condition
425        self.path: list[str] = []
426
427    def __eq__(self, other: object) -> bool:
428        if not isinstance(other, self.__class__):
429            return NotImplemented
430        return (
431            self.relation_name == other.relation_name
432            and self.alias == other.alias
433            and self.condition == other.condition
434        )
435
436    def clone(self) -> FilteredRelation:
437        clone = FilteredRelation(self.relation_name, condition=self.condition)
438        clone.alias = self.alias
439        clone.path = self.path[:]
440        return clone
441
442    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
443        """
444        QuerySet.annotate() only accepts expression-like arguments
445        (with a resolve_expression() method).
446        """
447        raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
448
449    def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any:
450        # Resolve the condition in Join.filtered_relation.
451        query = compiler.query
452        where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
453        return compiler.compile(where)