Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Various data structures used in query construction.
  3
  4Factored out from plain.models.query to avoid making the main module very
  5large and/or so that they can be used by other modules without getting into
  6circular import difficulties.
  7"""
  8
  9from __future__ import annotations
 10
 11import functools
 12import inspect
 13import logging
 14from collections.abc import Callable, Generator
 15from typing import TYPE_CHECKING, Any, NamedTuple, Self
 16
 17from plain.models.constants import LOOKUP_SEP
 18from plain.models.db import DatabaseError, db_connection
 19from plain.models.exceptions import FieldError
 20from plain.utils import tree
 21
 22if TYPE_CHECKING:
 23    from plain.models.backends.base.base import BaseDatabaseWrapper
 24    from plain.models.base import Model
 25    from plain.models.fields import Field
 26    from plain.models.fields.related import ForeignKeyField
 27    from plain.models.fields.reverse_related import ForeignObjectRel
 28    from plain.models.lookups import Lookup, Transform
 29    from plain.models.meta import Meta
 30    from plain.models.sql.compiler import SQLCompiler
 31    from plain.models.sql.where import WhereNode
 32
 33logger = logging.getLogger("plain.models")
 34
 35
 36class PathInfo(NamedTuple):
 37    """Information about a relation path when converting lookups (fk__somecol).
 38
 39    Describes the relation in Model terms (Meta and Fields for both
 40    sides of the relation). The join_field is the field backing the relation.
 41    """
 42
 43    from_meta: Meta
 44    to_meta: Meta
 45    target_fields: tuple[Field, ...]
 46    join_field: ForeignKeyField | ForeignObjectRel
 47    m2m: bool
 48    direct: bool
 49    filtered_relation: FilteredRelation | None
 50
 51
 52def subclasses(cls: type) -> Generator[type, None, None]:
 53    yield cls
 54    for subclass in cls.__subclasses__():
 55        yield from subclasses(subclass)
 56
 57
 58class Q(tree.Node):
 59    """
 60    Encapsulate filters as objects that can then be combined logically (using
 61    `&` and `|`).
 62    """
 63
 64    # Connection types
 65    AND = "AND"
 66    OR = "OR"
 67    XOR = "XOR"
 68    default = AND
 69    conditional = True
 70
 71    def __init__(
 72        self,
 73        *args: Any,
 74        _connector: str | None = None,
 75        _negated: bool = False,
 76        **kwargs: Any,
 77    ) -> None:
 78        super().__init__(
 79            children=[*args, *sorted(kwargs.items())],
 80            connector=_connector,
 81            negated=_negated,
 82        )
 83
 84    def _combine(self, other: Any, conn: str) -> Q:
 85        if getattr(other, "conditional", False) is False:
 86            raise TypeError(other)
 87        if not self:
 88            return other.copy()
 89        if not other and isinstance(other, Q):
 90            return self.copy()
 91
 92        obj = self.create(connector=conn)
 93        obj.add(self, conn)
 94        obj.add(other, conn)
 95        return obj
 96
 97    def __or__(self, other: Any) -> Q:
 98        return self._combine(other, self.OR)
 99
100    def __and__(self, other: Any) -> Q:
101        return self._combine(other, self.AND)
102
103    def __xor__(self, other: Any) -> Q:
104        return self._combine(other, self.XOR)
105
106    def __invert__(self) -> Q:
107        obj = self.copy()
108        obj.negate()
109        return obj
110
111    def resolve_expression(
112        self,
113        query: Any = None,
114        allow_joins: bool = True,
115        reuse: Any = None,
116        summarize: bool = False,
117        for_save: bool = False,
118    ) -> WhereNode:
119        # We must promote any new joins to left outer joins so that when Q is
120        # used as an expression, rows aren't filtered due to joins.
121        clause, joins = query._add_q(
122            self,
123            reuse,
124            allow_joins=allow_joins,
125            split_subq=False,
126            check_filterable=False,
127            summarize=summarize,
128        )
129        query.promote_joins(joins)
130        return clause
131
132    def flatten(self) -> Generator[Any, None, None]:
133        """
134        Recursively yield this Q object and all subexpressions, in depth-first
135        order.
136        """
137        yield self
138        for child in self.children:
139            if isinstance(child, tuple):
140                # Use the lookup.
141                child = child[1]
142            if hasattr(child, "flatten"):
143                yield from child.flatten()
144            else:
145                yield child
146
147    def check(self, against: dict[str, Any]) -> bool:
148        """
149        Do a database query to check if the expressions of the Q instance
150        matches against the expressions.
151        """
152        # Avoid circular imports.
153        from plain.models.expressions import ResolvableExpression, Value
154        from plain.models.fields import BooleanField
155        from plain.models.functions import Coalesce
156        from plain.models.sql import Query
157        from plain.models.sql.constants import SINGLE
158
159        query = Query(None)
160        for name, value in against.items():
161            if not isinstance(value, ResolvableExpression):
162                value = Value(value)
163            query.add_annotation(value, name, select=False)
164        query.add_annotation(Value(1), "_check")
165        # This will raise a FieldError if a field is missing in "against".
166        if db_connection.features.supports_comparing_boolean_expr:
167            query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
168        else:
169            query.add_q(self)
170        compiler = query.get_compiler()
171        try:
172            return compiler.execute_sql(SINGLE) is not None
173        except DatabaseError as e:
174            logger.warning("Got a database error calling check() on %r: %s", self, e)
175            return True
176
177    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
178        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
179        if path.startswith("plain.models.query_utils"):
180            path = path.replace("plain.models.query_utils", "plain.models")
181        args = tuple(self.children)
182        kwargs: dict[str, Any] = {}
183        if self.connector != self.default:
184            kwargs["_connector"] = self.connector
185        if self.negated:
186            kwargs["_negated"] = True
187        return path, args, kwargs
188
189
190class class_or_instance_method:
191    """
192    Hook used in RegisterLookupMixin to return partial functions depending on
193    the caller type (instance or class of models.Field).
194    """
195
196    def __init__(self, class_method: Any, instance_method: Any) -> None:
197        self.class_method = class_method
198        self.instance_method = instance_method
199
200    def __get__(self, instance: Any, owner: type) -> Any:
201        if instance is None:
202            return functools.partial(self.class_method, owner)
203        return functools.partial(self.instance_method, instance)
204
205
206class RegisterLookupMixin:
207    def _get_lookup(self, lookup_name: str) -> type[Lookup | Transform] | None:
208        return self.get_lookups().get(lookup_name, None)
209
210    @functools.cache
211    def get_class_lookups(cls: type[Self]) -> dict[str, type[Lookup | Transform]]:
212        class_lookups = [
213            parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
214        ]
215        return cls.merge_dicts(class_lookups)
216
217    def get_instance_lookups(self) -> dict[str, type[Lookup | Transform]]:
218        class_lookups = self.get_class_lookups()
219        if instance_lookups := getattr(self, "instance_lookups", None):
220            return {**class_lookups, **instance_lookups}
221        return class_lookups
222
223    get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
224    get_class_lookups = classmethod(get_class_lookups)  # type: ignore[assignment]
225
226    def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
227        from plain.models.lookups import Lookup
228
229        found = self._get_lookup(lookup_name)
230        if found is None:
231            # output_field is a Field which inherits from RegisterLookupMixin
232            if output_field := getattr(self, "output_field", None):
233                return output_field.get_lookup(lookup_name)
234        if found is not None and not issubclass(found, Lookup):
235            return None
236        return found
237
238    def get_transform(
239        self, lookup_name: str
240    ) -> type[Transform] | Callable[..., Any] | None:
241        from plain.models.lookups import Transform
242
243        found = self._get_lookup(lookup_name)
244        if found is None:
245            # output_field is a Field which inherits from RegisterLookupMixin
246            if output_field := getattr(self, "output_field", None):
247                return output_field.get_transform(lookup_name)
248        if found is not None and not issubclass(found, Transform):
249            return None
250        return found
251
252    @staticmethod
253    def merge_dicts(
254        dicts: list[dict[str, type[Lookup | Transform]]],
255    ) -> dict[str, type[Lookup | Transform]]:
256        """
257        Merge dicts in reverse to preference the order of the original list. e.g.,
258        merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
259        """
260        merged: dict[str, type[Lookup | Transform]] = {}
261        for d in reversed(dicts):
262            merged.update(d)
263        return merged
264
265    @classmethod
266    def _clear_cached_class_lookups(cls: type[Self]) -> None:
267        for subclass in subclasses(cls):
268            subclass.get_class_lookups.cache_clear()  # type: ignore[attr-defined]
269
270    def register_class_lookup(
271        cls: type[Self],
272        lookup: type[Lookup | Transform],
273        lookup_name: str | None = None,
274    ) -> type[Lookup | Transform]:
275        if lookup_name is None:
276            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
277        if "class_lookups" not in cls.__dict__:
278            cls.class_lookups = {}  # type: ignore[misc]
279        cls.class_lookups[lookup_name] = lookup  # type: ignore[attr-defined]
280        cls._clear_cached_class_lookups()
281        return lookup
282
283    def register_instance_lookup(
284        self, lookup: type[Lookup | Transform], lookup_name: str | None = None
285    ) -> type[Lookup | Transform]:
286        if lookup_name is None:
287            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
288        if "instance_lookups" not in self.__dict__:
289            self.instance_lookups = {}
290        self.instance_lookups[lookup_name] = lookup
291        return lookup
292
293    register_lookup = class_or_instance_method(
294        register_class_lookup, register_instance_lookup
295    )
296    register_class_lookup = classmethod(register_class_lookup)  # type: ignore[assignment]
297
298    def _unregister_class_lookup(
299        cls: type[Self],
300        lookup: type[Lookup | Transform],
301        lookup_name: str | None = None,
302    ) -> None:
303        """
304        Remove given lookup from cls lookups. For use in tests only as it's
305        not thread-safe.
306        """
307        if lookup_name is None:
308            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
309        del cls.class_lookups[lookup_name]  # type: ignore[attr-defined]
310        cls._clear_cached_class_lookups()
311
312    def _unregister_instance_lookup(
313        self, lookup: type[Lookup | Transform], lookup_name: str | None = None
314    ) -> None:
315        """
316        Remove given lookup from instance lookups. For use in tests only as
317        it's not thread-safe.
318        """
319        if lookup_name is None:
320            lookup_name = lookup.lookup_name  # type: ignore[attr-defined]
321        del self.instance_lookups[lookup_name]
322
323    _unregister_lookup = class_or_instance_method(
324        _unregister_class_lookup, _unregister_instance_lookup
325    )
326    _unregister_class_lookup = classmethod(_unregister_class_lookup)  # type: ignore[assignment]
327
328
329def select_related_descend(
330    field: Any,
331    restricted: bool | None,
332    requested: dict[str, Any] | None,
333    select_mask: Any,
334    reverse: bool = False,
335) -> bool:
336    """
337    Return True if this field should be used to descend deeper for
338    select_related() purposes. Used by both the query construction code
339    (compiler.get_related_selections()) and the model instance creation code
340    (compiler.klass_info).
341
342    Arguments:
343     * field - the field to be checked
344     * restricted - a boolean field, indicating if the field list has been
345       manually restricted using a requested clause)
346     * requested - The select_related() dictionary.
347     * select_mask - the dictionary of selected fields.
348     * reverse - boolean, True if we are checking a reverse select related
349    """
350    from plain.models.fields.related import RelatedField
351
352    if not isinstance(field, RelatedField):
353        return False
354    if restricted:
355        assert requested is not None, "requested must be provided when restricted=True"
356        if reverse and field.related_query_name() not in requested:
357            return False
358        if not reverse and field.name not in requested:
359            return False
360    if not restricted and field.allow_null:
361        return False
362    if (
363        restricted
364        and select_mask
365        and field.name in requested  # type: ignore[operator]
366        and field not in select_mask
367    ):
368        raise FieldError(
369            f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
370            "deferred and traversed using select_related at the same time."
371        )
372    return True
373
374
375def refs_expression(
376    lookup_parts: list[str], annotations: dict[str, Any]
377) -> tuple[str | None, tuple[str, ...]]:
378    """
379    Check if the lookup_parts contains references to the given annotations set.
380    Because the LOOKUP_SEP is contained in the default annotation names, check
381    each prefix of the lookup_parts for a match.
382    """
383    for n in range(1, len(lookup_parts) + 1):
384        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
385        if annotations.get(level_n_lookup):
386            return level_n_lookup, tuple(lookup_parts[n:])
387    return None, ()
388
389
390def check_rel_lookup_compatibility(
391    model: type[Model], target_meta: Meta, field: Field | ForeignObjectRel
392) -> bool:
393    """
394    Check that model is compatible with target_meta. Compatibility
395    is OK if:
396      1) model and meta.model match (where proxy inheritance is removed)
397      2) model is parent of meta's model or the other way around
398    """
399
400    def check(meta: Meta) -> bool:
401        return model == meta.model
402
403    # If the field is a primary key, then doing a query against the field's
404    # model is ok, too. Consider the case:
405    # class Restaurant(models.Model):
406    #     place = OneToOneField(Place, primary_key=True):
407    # Restaurant.query.filter(id__in=Restaurant.query.all()).
408    # If we didn't have the primary key check, then id__in (== place__in) would
409    # give Place's meta as the target meta, but Restaurant isn't compatible
410    # with that. This logic applies only to primary keys, as when doing __in=qs,
411    # we are going to turn this into __in=qs.values('id') later on.
412    return check(target_meta) or (
413        getattr(field, "primary_key", False) and check(field.model._model_meta)
414    )
415
416
417class FilteredRelation:
418    """Specify custom filtering in the ON clause of SQL joins."""
419
420    def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
421        if not relation_name:
422            raise ValueError("relation_name cannot be empty.")
423        self.relation_name = relation_name
424        self.alias: str | None = None
425        if not isinstance(condition, Q):
426            raise ValueError("condition argument must be a Q() instance.")
427        self.condition = condition
428        self.path: list[str] = []
429
430    def __eq__(self, other: object) -> bool:
431        if not isinstance(other, self.__class__):
432            return NotImplemented
433        return (
434            self.relation_name == other.relation_name
435            and self.alias == other.alias
436            and self.condition == other.condition
437        )
438
439    def clone(self) -> FilteredRelation:
440        clone = FilteredRelation(self.relation_name, condition=self.condition)
441        clone.alias = self.alias
442        clone.path = self.path[:]
443        return clone
444
445    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
446        """
447        QuerySet.annotate() only accepts expression-like arguments
448        (with a resolve_expression() method).
449        """
450        raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
451
452    def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any:
453        # Resolve the condition in Join.filtered_relation.
454        query = compiler.query
455        where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
456        return compiler.compile(where)