1"""
  2Various data structures used in query construction.
  3
  4Factored out from plain.postgres.query to avoid making the main module very
  5large and/or so that they can be used by other modules without getting into
  6circular import difficulties.
  7"""
  8
  9from __future__ import annotations
 10
 11import functools
 12import inspect
 13import logging
 14from collections.abc import Callable, Generator
 15from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Self
 16
 17from plain.postgres.constants import LOOKUP_SEP
 18from plain.postgres.db import DatabaseError
 19from plain.postgres.exceptions import FieldError
 20from plain.utils import tree
 21
 22if TYPE_CHECKING:
 23    from plain.postgres.base import Model
 24    from plain.postgres.connection import DatabaseConnection
 25    from plain.postgres.fields import Field
 26    from plain.postgres.fields.related import ForeignKeyField
 27    from plain.postgres.fields.reverse_related import ForeignObjectRel
 28    from plain.postgres.lookups import Lookup, Transform
 29    from plain.postgres.meta import Meta
 30    from plain.postgres.sql.compiler import SQLCompiler
 31    from plain.postgres.sql.where import WhereNode
 32
 33logger = logging.getLogger("plain.postgres")
 34
 35
 36class PathInfo(NamedTuple):
 37    """Information about a relation path when converting lookups (fk__somecol).
 38
 39    Describes the relation in Model terms (Meta and Fields for both
 40    sides of the relation). The join_field is the field backing the relation.
 41    """
 42
 43    from_meta: Meta
 44    to_meta: Meta
 45    target_fields: tuple[Field, ...]
 46    join_field: ForeignKeyField | ForeignObjectRel
 47    m2m: bool
 48    direct: bool
 49    filtered_relation: FilteredRelation | None
 50
 51
 52def subclasses(cls: type) -> Generator[type]:
 53    yield cls
 54    for subclass in cls.__subclasses__():
 55        yield from subclasses(subclass)
 56
 57
 58class Q(tree.Node):
 59    """
 60    Encapsulate filters as objects that can then be combined logically (using
 61    `&` and `|`).
 62    """
 63
 64    # Connection types
 65    AND = "AND"
 66    OR = "OR"
 67    XOR = "XOR"
 68    default = AND
 69    conditional = True
 70
 71    def __init__(
 72        self,
 73        *args: Any,
 74        _connector: str | None = None,
 75        _negated: bool = False,
 76        **kwargs: Any,
 77    ) -> None:
 78        super().__init__(
 79            children=[*args, *sorted(kwargs.items())],
 80            connector=_connector,
 81            negated=_negated,
 82        )
 83
 84    def _combine(self, other: Any, conn: str) -> Q:
 85        if getattr(other, "conditional", False) is False:
 86            raise TypeError(other)
 87        if not self:
 88            return other.copy()
 89        if not other and isinstance(other, Q):
 90            return self.copy()
 91
 92        obj = self.create(connector=conn)
 93        obj.add(self, conn)
 94        obj.add(other, conn)
 95        return obj
 96
 97    def __or__(self, other: Any) -> Q:
 98        return self._combine(other, self.OR)
 99
100    def __and__(self, other: Any) -> Q:
101        return self._combine(other, self.AND)
102
103    def __xor__(self, other: Any) -> Q:
104        return self._combine(other, self.XOR)
105
106    def __invert__(self) -> Q:
107        obj = self.copy()
108        obj.negate()
109        return obj
110
111    def resolve_expression(
112        self,
113        query: Any = None,
114        allow_joins: bool = True,
115        reuse: Any = None,
116        summarize: bool = False,
117        for_save: bool = False,
118    ) -> WhereNode:
119        # We must promote any new joins to left outer joins so that when Q is
120        # used as an expression, rows aren't filtered due to joins.
121        clause, joins = query._add_q(
122            self,
123            reuse,
124            allow_joins=allow_joins,
125            split_subq=False,
126            check_filterable=False,
127            summarize=summarize,
128        )
129        query.promote_joins(joins)
130        return clause
131
132    def flatten(self) -> Generator[Any]:
133        """
134        Recursively yield this Q object and all subexpressions, in depth-first
135        order.
136        """
137        yield self
138        for child in self.children:
139            if isinstance(child, tuple):
140                # Use the lookup.
141                child = child[1]
142            if hasattr(child, "flatten"):
143                yield from child.flatten()
144            else:
145                yield child
146
147    def check(self, against: dict[str, Any]) -> bool:
148        """
149        Do a database query to check if the expressions of the Q instance
150        matches against the expressions.
151        """
152        # Avoid circular imports.
153        from plain.postgres.expressions import ResolvableExpression, Value
154        from plain.postgres.fields import BooleanField
155        from plain.postgres.functions import Coalesce
156        from plain.postgres.sql import SINGLE, Query
157
158        query = Query(None)
159        for name, value in against.items():
160            if not isinstance(value, ResolvableExpression):
161                value = Value(value)
162            query.add_annotation(value, name, select=False)
163        query.add_annotation(Value(1), "_check")
164        # This will raise a FieldError if a field is missing in "against".
165        query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
166        compiler = query.get_compiler()
167        try:
168            return compiler.execute_sql(SINGLE) is not None
169        except DatabaseError as e:
170            logger.warning("Got a database error calling check() on %r: %s", self, e)
171            return True
172
173    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
174        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
175        if path.startswith("plain.postgres.query_utils"):
176            path = path.replace("plain.postgres.query_utils", "plain.postgres")
177        args = tuple(self.children)
178        kwargs: dict[str, Any] = {}
179        if self.connector != self.default:
180            kwargs["_connector"] = self.connector
181        if self.negated:
182            kwargs["_negated"] = True
183        return path, args, kwargs
184
185
186class class_or_instance_method:
187    """
188    Hook used in RegisterLookupMixin to return partial functions depending on
189    the caller type (instance or class of models.Field).
190    """
191
192    def __init__(self, class_method: Any, instance_method: Any) -> None:
193        self.class_method = class_method
194        self.instance_method = instance_method
195
196    def __get__(self, instance: Any, owner: type) -> Any:
197        if instance is None:
198            return functools.partial(self.class_method, owner)
199        return functools.partial(self.instance_method, instance)
200
201
202class RegisterLookupMixin:
203    class_lookups: ClassVar[dict[str, type[Lookup | Transform]]]
204
205    def _get_lookup(self, lookup_name: str) -> type[Lookup | Transform] | None:
206        return self.get_lookups().get(lookup_name, None)
207
208    @functools.cache
209    def get_class_lookups(cls: type[Self]) -> dict[str, type[Lookup | Transform]]:
210        class_lookups = [
211            parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
212        ]
213        return cls.merge_dicts(class_lookups)
214
215    def get_instance_lookups(self) -> dict[str, type[Lookup | Transform]]:
216        class_lookups = self.get_class_lookups()
217        if instance_lookups := getattr(self, "instance_lookups", None):
218            return {**class_lookups, **instance_lookups}
219        return class_lookups
220
221    get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
222    get_class_lookups: ClassVar[classmethod[Any, ..., Any]] = classmethod(
223        get_class_lookups
224    )
225
226    def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
227        from plain.postgres.lookups import Lookup
228
229        found = self._get_lookup(lookup_name)
230        if found is None:
231            # output_field is a Field which inherits from RegisterLookupMixin
232            if output_field := getattr(self, "output_field", None):
233                return output_field.get_lookup(lookup_name)
234        if found is not None and not issubclass(found, Lookup):
235            return None
236        return found
237
238    def get_transform(
239        self, lookup_name: str
240    ) -> type[Transform] | Callable[..., Any] | None:
241        from plain.postgres.lookups import Transform
242
243        found = self._get_lookup(lookup_name)
244        if found is None:
245            # output_field is a Field which inherits from RegisterLookupMixin
246            if output_field := getattr(self, "output_field", None):
247                return output_field.get_transform(lookup_name)
248        if found is not None and not issubclass(found, Transform):
249            return None
250        return found
251
252    @staticmethod
253    def merge_dicts(
254        dicts: list[dict[str, type[Lookup | Transform]]],
255    ) -> dict[str, type[Lookup | Transform]]:
256        """
257        Merge dicts in reverse to preference the order of the original list. e.g.,
258        merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
259        """
260        merged: dict[str, type[Lookup | Transform]] = {}
261        for d in reversed(dicts):
262            merged.update(d)
263        return merged
264
265    @classmethod
266    def _clear_cached_class_lookups(cls: type[Self]) -> None:
267        for subclass in subclasses(cls):
268            if cached := getattr(subclass, "get_class_lookups", None):
269                cached.cache_clear()
270
271    def register_class_lookup(
272        cls: type[Self],
273        lookup: type[Lookup | Transform],
274        lookup_name: str | None = None,
275    ) -> type[Lookup | Transform]:
276        if lookup_name is None:
277            lookup_name = lookup.lookup_name
278        assert lookup_name is not None, "lookup_name must be set on the lookup class"
279        if "class_lookups" not in cls.__dict__:
280            cls.class_lookups = {}
281        cls.class_lookups[lookup_name] = lookup
282        cls._clear_cached_class_lookups()
283        return lookup
284
285    def register_instance_lookup(
286        self, lookup: type[Lookup | Transform], lookup_name: str | None = None
287    ) -> type[Lookup | Transform]:
288        if lookup_name is None:
289            lookup_name = lookup.lookup_name
290        if "instance_lookups" not in self.__dict__:
291            self.instance_lookups = {}
292        self.instance_lookups[lookup_name] = lookup
293        return lookup
294
295    register_lookup = class_or_instance_method(
296        register_class_lookup, register_instance_lookup
297    )
298    register_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
299        register_class_lookup
300    )
301
302    def _unregister_class_lookup(
303        cls: type[Self],
304        lookup: type[Lookup | Transform],
305        lookup_name: str | None = None,
306    ) -> None:
307        """
308        Remove given lookup from cls lookups. For use in tests only as it's
309        not thread-safe.
310        """
311        if lookup_name is None:
312            lookup_name = lookup.lookup_name
313        assert lookup_name is not None, "lookup_name must be set on the lookup class"
314        del cls.class_lookups[lookup_name]
315        cls._clear_cached_class_lookups()
316
317    def _unregister_instance_lookup(
318        self, lookup: type[Lookup | Transform], lookup_name: str | None = None
319    ) -> None:
320        """
321        Remove given lookup from instance lookups. For use in tests only as
322        it's not thread-safe.
323        """
324        if lookup_name is None:
325            lookup_name = lookup.lookup_name
326        del self.instance_lookups[lookup_name]
327
328    _unregister_lookup = class_or_instance_method(
329        _unregister_class_lookup, _unregister_instance_lookup
330    )
331    _unregister_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
332        _unregister_class_lookup
333    )
334
335
336def select_related_descend(
337    field: Any,
338    restricted: bool | None,
339    requested: dict[str, Any] | None,
340    select_mask: Any,
341    reverse: bool = False,
342) -> bool:
343    """
344    Return True if this field should be used to descend deeper for
345    select_related() purposes. Used by both the query construction code
346    (compiler.get_related_selections()) and the model instance creation code
347    (compiler.klass_info).
348
349    Arguments:
350     * field - the field to be checked
351     * restricted - a boolean field, indicating if the field list has been
352       manually restricted using a requested clause)
353     * requested - The select_related() dictionary.
354     * select_mask - the dictionary of selected fields.
355     * reverse - boolean, True if we are checking a reverse select related
356    """
357    from plain.postgres.fields.related import RelatedField
358
359    if not isinstance(field, RelatedField):
360        return False
361    if restricted:
362        assert requested is not None, "requested must be provided when restricted=True"
363        if reverse and field.related_query_name() not in requested:
364            return False
365        if not reverse and field.name not in requested:
366            return False
367    if not restricted and field.allow_null:
368        return False
369    if (
370        restricted
371        and select_mask
372        and field.name in requested  # type: ignore[operator]
373        and field not in select_mask
374    ):
375        raise FieldError(
376            f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
377            "deferred and traversed using select_related at the same time."
378        )
379    return True
380
381
382def refs_expression(
383    lookup_parts: list[str], annotations: dict[str, Any]
384) -> tuple[str | None, tuple[str, ...]]:
385    """
386    Check if the lookup_parts contains references to the given annotations set.
387    Because the LOOKUP_SEP is contained in the default annotation names, check
388    each prefix of the lookup_parts for a match.
389    """
390    for n in range(1, len(lookup_parts) + 1):
391        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
392        if annotations.get(level_n_lookup):
393            return level_n_lookup, tuple(lookup_parts[n:])
394    return None, ()
395
396
397def check_rel_lookup_compatibility(
398    model: type[Model], target_meta: Meta, field: Field | ForeignObjectRel
399) -> bool:
400    """
401    Check that model is compatible with target_meta. Compatibility
402    is OK if:
403      1) model and meta.model match (where proxy inheritance is removed)
404      2) model is parent of meta's model or the other way around
405    """
406
407    def check(meta: Meta) -> bool:
408        return model == meta.model
409
410    # If the field is a primary key, then doing a query against the field's
411    # model is ok, too. Consider the case:
412    # class Restaurant(models.Model):
413    #     place = OneToOneField(Place, primary_key=True):
414    # Restaurant.query.filter(id__in=Restaurant.query.all()).
415    # If we didn't have the primary key check, then id__in (== place__in) would
416    # give Place's meta as the target meta, but Restaurant isn't compatible
417    # with that. This logic applies only to primary keys, as when doing __in=qs,
418    # we are going to turn this into __in=qs.values('id') later on.
419    return check(target_meta) or (
420        getattr(field, "primary_key", False) and check(field.model._model_meta)
421    )
422
423
424class FilteredRelation:
425    """Specify custom filtering in the ON clause of SQL joins."""
426
427    def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
428        if not relation_name:
429            raise ValueError("relation_name cannot be empty.")
430        self.relation_name = relation_name
431        self.alias: str | None = None
432        if not isinstance(condition, Q):
433            raise ValueError("condition argument must be a Q() instance.")
434        self.condition = condition
435        self.path: list[str] = []
436
437    def __eq__(self, other: object) -> bool:
438        if not isinstance(other, self.__class__):
439            return NotImplemented
440        return (
441            self.relation_name == other.relation_name
442            and self.alias == other.alias
443            and self.condition == other.condition
444        )
445
446    def clone(self) -> FilteredRelation:
447        clone = FilteredRelation(self.relation_name, condition=self.condition)
448        clone.alias = self.alias
449        clone.path = self.path[:]
450        return clone
451
452    def as_sql(self, compiler: SQLCompiler, connection: DatabaseConnection) -> Any:
453        # Resolve the condition in Join.filtered_relation.
454        query = compiler.query
455        where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
456        return compiler.compile(where)