v0.146.0
  1"""
  2Various data structures used in query construction.
  3
  4Factored out from plain.postgres.query to avoid making the main module very
  5large and/or so that they can be used by other modules without getting into
  6circular import difficulties.
  7"""
  8
  9from __future__ import annotations
 10
 11import functools
 12import inspect
 13from collections.abc import Callable, Generator
 14from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Self
 15
 16import psycopg
 17
 18from plain.logs import get_framework_logger
 19from plain.postgres.constants import LOOKUP_SEP
 20from plain.postgres.exceptions import FieldError
 21from plain.utils import tree
 22
 23if TYPE_CHECKING:
 24    from plain.postgres.base import Model
 25    from plain.postgres.connection import DatabaseConnection
 26    from plain.postgres.fields import Field
 27    from plain.postgres.fields.related import ForeignKeyField
 28    from plain.postgres.fields.reverse_related import ForeignObjectRel
 29    from plain.postgres.lookups import Lookup, Transform
 30    from plain.postgres.meta import Meta
 31    from plain.postgres.sql.compiler import SQLCompiler
 32    from plain.postgres.sql.where import WhereNode
 33
 34logger = get_framework_logger()
 35
 36
 37class PathInfo(NamedTuple):
 38    """Information about a relation path when converting lookups (fk__somecol).
 39
 40    Describes the relation in Model terms (Meta and Fields for both
 41    sides of the relation). The join_field is the field backing the relation.
 42    """
 43
 44    from_meta: Meta
 45    to_meta: Meta
 46    target_fields: tuple[Field, ...]
 47    join_field: ForeignKeyField | ForeignObjectRel
 48    m2m: bool
 49    direct: bool
 50    filtered_relation: FilteredRelation | None
 51
 52
 53def subclasses(cls: type) -> Generator[type]:
 54    yield cls
 55    for subclass in cls.__subclasses__():
 56        yield from subclasses(subclass)
 57
 58
 59class Q(tree.Node):
 60    """
 61    Encapsulate filters as objects that can then be combined logically (using
 62    `&` and `|`).
 63    """
 64
 65    # Connection types
 66    AND = "AND"
 67    OR = "OR"
 68    XOR = "XOR"
 69    default = AND
 70    conditional = True
 71
 72    def __init__(
 73        self,
 74        *args: Any,
 75        _connector: str | None = None,
 76        _negated: bool = False,
 77        **kwargs: Any,
 78    ) -> None:
 79        super().__init__(
 80            children=[*args, *sorted(kwargs.items())],
 81            connector=_connector,
 82            negated=_negated,
 83        )
 84
 85    def _combine(self, other: Any, conn: str) -> Q:
 86        if getattr(other, "conditional", False) is False:
 87            raise TypeError(other)
 88        if not self:
 89            return other.copy()
 90        if not other and isinstance(other, Q):
 91            return self.copy()
 92
 93        obj = self.create(connector=conn)
 94        obj.add(self, conn)
 95        obj.add(other, conn)
 96        return obj
 97
 98    def __or__(self, other: Any) -> Q:
 99        return self._combine(other, self.OR)
100
101    def __and__(self, other: Any) -> Q:
102        return self._combine(other, self.AND)
103
104    def __xor__(self, other: Any) -> Q:
105        return self._combine(other, self.XOR)
106
107    def __invert__(self) -> Q:
108        obj = self.copy()
109        obj.negate()
110        return obj
111
112    def resolve_expression(
113        self,
114        query: Any = None,
115        allow_joins: bool = True,
116        reuse: Any = None,
117        summarize: bool = False,
118        for_save: bool = False,
119    ) -> WhereNode:
120        # We must promote any new joins to left outer joins so that when Q is
121        # used as an expression, rows aren't filtered due to joins.
122        clause, joins = query._add_q(
123            self,
124            reuse,
125            allow_joins=allow_joins,
126            split_subq=False,
127            check_filterable=False,
128            summarize=summarize,
129        )
130        query.promote_joins(joins)
131        return clause
132
133    def flatten(self) -> Generator[Any]:
134        """
135        Recursively yield this Q object and all subexpressions, in depth-first
136        order.
137        """
138        yield self
139        for child in self.children:
140            if isinstance(child, tuple):
141                # Use the lookup.
142                child = child[1]
143            if hasattr(child, "flatten"):
144                yield from child.flatten()
145            else:
146                yield child
147
148    def check(self, against: dict[str, Any]) -> bool:
149        """
150        Do a database query to check if the expressions of the Q instance
151        matches against the expressions.
152        """
153        # Avoid circular imports.
154        from plain.postgres.expressions import ResolvableExpression, Value
155        from plain.postgres.fields import BooleanField
156        from plain.postgres.functions import Coalesce
157        from plain.postgres.sql import SINGLE, Query
158
159        query = Query(None)
160        for name, value in against.items():
161            if not isinstance(value, ResolvableExpression):
162                value = Value(value)
163            query.add_annotation(value, name, select=False)
164        query.add_annotation(Value(1), "_check")
165        # This will raise a FieldError if a field is missing in "against".
166        query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
167        compiler = query.get_compiler()
168        try:
169            return compiler.execute_sql(SINGLE) is not None
170        except psycopg.DatabaseError as e:
171            logger.warning(
172                "Got a database error calling check()",
173                extra={"expression": repr(self), "error": str(e)},
174            )
175            return True
176
177    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
178        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
179        if path.startswith("plain.postgres.query_utils"):
180            path = path.replace("plain.postgres.query_utils", "plain.postgres")
181        args = tuple(self.children)
182        kwargs: dict[str, Any] = {}
183        if self.connector != self.default:
184            kwargs["_connector"] = self.connector
185        if self.negated:
186            kwargs["_negated"] = True
187        return path, args, kwargs
188
189
190class class_or_instance_method:
191    """
192    Hook used in RegisterLookupMixin to return partial functions depending on
193    the caller type (instance or class of models.Field).
194    """
195
196    def __init__(self, class_method: Any, instance_method: Any) -> None:
197        self.class_method = class_method
198        self.instance_method = instance_method
199
200    def __get__(self, instance: Any, owner: type) -> Any:
201        if instance is None:
202            return functools.partial(self.class_method, owner)
203        return functools.partial(self.instance_method, instance)
204
205
206class RegisterLookupMixin:
207    class_lookups: ClassVar[dict[str, type[Lookup | Transform]]]
208
209    def _get_lookup(self, lookup_name: str) -> type[Lookup | Transform] | None:
210        return self.get_lookups().get(lookup_name, None)
211
212    @functools.cache
213    def get_class_lookups(cls: type[Self]) -> dict[str, type[Lookup | Transform]]:
214        class_lookups = [
215            parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
216        ]
217        return cls.merge_dicts(class_lookups)
218
219    def get_instance_lookups(self) -> dict[str, type[Lookup | Transform]]:
220        class_lookups = self.get_class_lookups()
221        if instance_lookups := getattr(self, "instance_lookups", None):
222            return {**class_lookups, **instance_lookups}
223        return class_lookups
224
225    get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
226    get_class_lookups: ClassVar[classmethod[Any, ..., Any]] = classmethod(
227        get_class_lookups
228    )
229
230    def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
231        from plain.postgres.lookups import Lookup
232
233        found = self._get_lookup(lookup_name)
234        if found is None:
235            # output_field is a Field which inherits from RegisterLookupMixin
236            if output_field := getattr(self, "output_field", None):
237                return output_field.get_lookup(lookup_name)
238        if found is not None and not issubclass(found, Lookup):
239            return None
240        return found
241
242    def get_transform(
243        self, lookup_name: str
244    ) -> type[Transform] | Callable[..., Any] | None:
245        from plain.postgres.lookups import Transform
246
247        found = self._get_lookup(lookup_name)
248        if found is None:
249            # output_field is a Field which inherits from RegisterLookupMixin
250            if output_field := getattr(self, "output_field", None):
251                return output_field.get_transform(lookup_name)
252        if found is not None and not issubclass(found, Transform):
253            return None
254        return found
255
256    @staticmethod
257    def merge_dicts(
258        dicts: list[dict[str, type[Lookup | Transform]]],
259    ) -> dict[str, type[Lookup | Transform]]:
260        """
261        Merge dicts in reverse to preference the order of the original list. e.g.,
262        merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
263        """
264        merged: dict[str, type[Lookup | Transform]] = {}
265        for d in reversed(dicts):
266            merged.update(d)
267        return merged
268
269    @classmethod
270    def _clear_cached_class_lookups(cls: type[Self]) -> None:
271        for subclass in subclasses(cls):
272            if cached := getattr(subclass, "get_class_lookups", None):
273                cached.cache_clear()
274
275    def register_class_lookup(
276        cls: type[Self],
277        lookup: type[Lookup | Transform],
278        lookup_name: str | None = None,
279    ) -> type[Lookup | Transform]:
280        if lookup_name is None:
281            lookup_name = lookup.lookup_name
282        assert lookup_name is not None, "lookup_name must be set on the lookup class"
283        if "class_lookups" not in cls.__dict__:
284            cls.class_lookups = {}
285        cls.class_lookups[lookup_name] = lookup
286        cls._clear_cached_class_lookups()
287        return lookup
288
289    def register_instance_lookup(
290        self, lookup: type[Lookup | Transform], lookup_name: str | None = None
291    ) -> type[Lookup | Transform]:
292        if lookup_name is None:
293            lookup_name = lookup.lookup_name
294        if "instance_lookups" not in self.__dict__:
295            self.instance_lookups = {}
296        self.instance_lookups[lookup_name] = lookup
297        return lookup
298
299    register_lookup = class_or_instance_method(
300        register_class_lookup, register_instance_lookup
301    )
302    register_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
303        register_class_lookup
304    )
305
306    def _unregister_class_lookup(
307        cls: type[Self],
308        lookup: type[Lookup | Transform],
309        lookup_name: str | None = None,
310    ) -> None:
311        """
312        Remove given lookup from cls lookups. For use in tests only as it's
313        not thread-safe.
314        """
315        if lookup_name is None:
316            lookup_name = lookup.lookup_name
317        assert lookup_name is not None, "lookup_name must be set on the lookup class"
318        del cls.class_lookups[lookup_name]
319        cls._clear_cached_class_lookups()
320
321    def _unregister_instance_lookup(
322        self, lookup: type[Lookup | Transform], lookup_name: str | None = None
323    ) -> None:
324        """
325        Remove given lookup from instance lookups. For use in tests only as
326        it's not thread-safe.
327        """
328        if lookup_name is None:
329            lookup_name = lookup.lookup_name
330        del self.instance_lookups[lookup_name]
331
332    _unregister_lookup = class_or_instance_method(
333        _unregister_class_lookup, _unregister_instance_lookup
334    )
335    _unregister_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
336        _unregister_class_lookup
337    )
338
339
340def select_related_descend(
341    field: Any,
342    restricted: bool | None,
343    requested: dict[str, Any] | None,
344    select_mask: Any,
345    reverse: bool = False,
346) -> bool:
347    """
348    Return True if this field should be used to descend deeper for
349    select_related() purposes. Used by both the query construction code
350    (compiler.get_related_selections()) and the model instance creation code
351    (compiler.klass_info).
352
353    Arguments:
354     * field - the field to be checked
355     * restricted - a boolean field, indicating if the field list has been
356       manually restricted using a requested clause)
357     * requested - The select_related() dictionary.
358     * select_mask - the dictionary of selected fields.
359     * reverse - boolean, True if we are checking a reverse select related
360    """
361    from plain.postgres.fields.related import RelatedField
362
363    if not isinstance(field, RelatedField):
364        return False
365    if restricted:
366        assert requested is not None, "requested must be provided when restricted=True"
367        if reverse and field.related_query_name() not in requested:
368            return False
369        if not reverse and field.name not in requested:
370            return False
371    if not restricted and field.allow_null:
372        return False
373    if (
374        restricted
375        and select_mask
376        and field.name in requested  # ty: ignore[unsupported-operator]
377        and field not in select_mask
378    ):
379        raise FieldError(
380            f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
381            "deferred and traversed using select_related at the same time."
382        )
383    return True
384
385
386def refs_expression(
387    lookup_parts: list[str], annotations: dict[str, Any]
388) -> tuple[str | None, tuple[str, ...]]:
389    """
390    Check if the lookup_parts contains references to the given annotations set.
391    Because the LOOKUP_SEP is contained in the default annotation names, check
392    each prefix of the lookup_parts for a match.
393    """
394    for n in range(1, len(lookup_parts) + 1):
395        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
396        if annotations.get(level_n_lookup):
397            return level_n_lookup, tuple(lookup_parts[n:])
398    return None, ()
399
400
401def check_rel_lookup_compatibility(
402    model: type[Model], target_meta: Meta, field: Field | ForeignObjectRel
403) -> bool:
404    """
405    Check that model is compatible with target_meta — i.e. model matches
406    the target's model, or the field is a primary key whose model matches.
407    """
408
409    def check(meta: Meta) -> bool:
410        return model == meta.model
411
412    # Primary-key fields get a second chance: a queryset like
413    # `Model.query.filter(id__in=Model.query.all())` resolves `id__in` through
414    # the PK field, whose target meta is the remote model. Allow the match
415    # against the field's own model so the subquery (later reduced to
416    # `.values("id")`) is accepted.
417    return check(target_meta) or (
418        getattr(field, "primary_key", False) and check(field.model._model_meta)
419    )
420
421
422class FilteredRelation:
423    """Specify custom filtering in the ON clause of SQL joins."""
424
425    def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
426        if not relation_name:
427            raise ValueError("relation_name cannot be empty.")
428        self.relation_name = relation_name
429        self.alias: str | None = None
430        if not isinstance(condition, Q):
431            raise ValueError("condition argument must be a Q() instance.")
432        self.condition = condition
433        self.path: list[str] = []
434
435    def __eq__(self, other: object) -> bool:
436        if not isinstance(other, self.__class__):
437            return NotImplemented
438        return (
439            self.relation_name == other.relation_name
440            and self.alias == other.alias
441            and self.condition == other.condition
442        )
443
444    def clone(self) -> FilteredRelation:
445        clone = FilteredRelation(self.relation_name, condition=self.condition)
446        clone.alias = self.alias
447        clone.path = self.path[:]
448        return clone
449
450    def as_sql(self, compiler: SQLCompiler, connection: DatabaseConnection) -> Any:
451        # Resolve the condition in Join.filtered_relation.
452        query = compiler.query
453        where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
454        return compiler.compile(where)