Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.lookups import (
  6    Exact,
  7    GreaterThan,
  8    GreaterThanOrEqual,
  9    In,
 10    IsNull,
 11    LessThan,
 12    LessThanOrEqual,
 13)
 14
 15if TYPE_CHECKING:
 16    from plain.models.backends.base.base import BaseDatabaseWrapper
 17    from plain.models.sql.compiler import SQLCompiler
 18
 19
 20class MultiColSource:
 21    contains_aggregate = False
 22    contains_over_clause = False
 23
 24    def __init__(self, alias: str, targets: Any, sources: Any, field: Any) -> None:
 25        self.targets, self.sources, self.field, self.alias = (
 26            targets,
 27            sources,
 28            field,
 29            alias,
 30        )
 31        self.output_field = self.field
 32
 33    def __repr__(self) -> str:
 34        return f"{self.__class__.__name__}({self.alias}, {self.field})"
 35
 36    def relabeled_clone(self, relabels: dict[str, str]) -> MultiColSource:
 37        return self.__class__(
 38            relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
 39        )
 40
 41    def get_lookup(self, lookup: str) -> Any:
 42        return self.output_field.get_lookup(lookup)
 43
 44    def resolve_expression(self, *args: Any, **kwargs: Any) -> MultiColSource:
 45        return self
 46
 47
 48def get_normalized_value(value: Any, lhs: Any) -> tuple[Any, ...]:
 49    from plain.models import Model
 50
 51    if isinstance(value, Model):
 52        if value.id is None:
 53            raise ValueError("Model instances passed to related filters must be saved.")
 54        value_list = []
 55        sources = lhs.output_field.path_infos[-1].target_fields
 56        for source in sources:
 57            while not isinstance(value, source.model) and source.remote_field:
 58                source = source.remote_field.model._model_meta.get_field(
 59                    source.remote_field.field_name
 60                )
 61            try:
 62                value_list.append(getattr(value, source.attname))
 63            except AttributeError:
 64                # A case like Restaurant.query.filter(place=restaurant_instance),
 65                # where place is a OneToOneField and the primary key of Restaurant.
 66                return (value.id,)
 67        return tuple(value_list)
 68    if not isinstance(value, tuple):
 69        return (value,)
 70    return value
 71
 72
 73class RelatedIn(In):
 74    def get_prep_lookup(self) -> list[Any]:  # type: ignore[misc]
 75        if not isinstance(self.lhs, MultiColSource):
 76            if self.rhs_is_direct_value():
 77                # If we get here, we are dealing with single-column relations.
 78                self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
 79                # We need to run the related field's get_prep_value(). Consider
 80                # case ForeignKey to IntegerField given value 'abc'. The
 81                # ForeignKey itself doesn't have validation for non-integers,
 82                # so we must run validation using the target field.
 83                if hasattr(self.lhs.output_field, "path_infos"):
 84                    # Run the target field's get_prep_value. We can safely
 85                    # assume there is only one as we don't get to the direct
 86                    # value branch otherwise.
 87                    target_field = self.lhs.output_field.path_infos[-1].target_fields[
 88                        -1
 89                    ]
 90                    self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
 91            elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
 92                self.lhs.field.target_field, "primary_key", False
 93            ):
 94                if (
 95                    getattr(self.lhs.output_field, "primary_key", False)
 96                    and self.lhs.output_field.model == self.rhs.model
 97                ):
 98                    # A case like
 99                    # Restaurant.query.filter(place__in=restaurant_qs), where
100                    # place is a OneToOneField and the primary key of
101                    # Restaurant.
102                    target_field = self.lhs.field.name
103                else:
104                    target_field = self.lhs.field.target_field.name
105                self.rhs.set_values([target_field])
106        return super().get_prep_lookup()
107
108    def as_sql(
109        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
110    ) -> tuple[str, list[Any]]:  # type: ignore[misc]
111        if isinstance(self.lhs, MultiColSource):
112            # For multicolumn lookups we need to build a multicolumn where clause.
113            # This clause is either a SubqueryConstraint (for values that need
114            # to be compiled to SQL) or an OR-combined list of
115            # (col1 = val1 AND col2 = val2 AND ...) clauses.
116            from plain.models.sql.where import (
117                AND,
118                OR,
119                SubqueryConstraint,
120                WhereNode,
121            )
122
123            root_constraint = WhereNode(connector=OR)
124            if self.rhs_is_direct_value():
125                values = [get_normalized_value(value, self.lhs) for value in self.rhs]
126                for value in values:
127                    value_constraint = WhereNode()
128                    for source, target, val in zip(
129                        self.lhs.sources, self.lhs.targets, value
130                    ):
131                        lookup_class = target.get_lookup("exact")
132                        lookup = lookup_class(
133                            target.get_col(self.lhs.alias, source), val
134                        )
135                        value_constraint.add(lookup, AND)
136                    root_constraint.add(value_constraint, OR)
137            else:
138                root_constraint.add(
139                    SubqueryConstraint(
140                        self.lhs.alias,
141                        [target.column for target in self.lhs.targets],
142                        [source.name for source in self.lhs.sources],
143                        self.rhs,
144                    ),
145                    AND,
146                )
147            return root_constraint.as_sql(compiler, connection)
148        return super().as_sql(compiler, connection)
149
150
151class RelatedLookupMixin:
152    def get_prep_lookup(self) -> Any:  # type: ignore[misc]
153        if not isinstance(self.lhs, MultiColSource) and not hasattr(
154            self.rhs, "resolve_expression"
155        ):
156            # If we get here, we are dealing with single-column relations.
157            self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
158            # We need to run the related field's get_prep_value(). Consider case
159            # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
160            # doesn't have validation for non-integers, so we must run validation
161            # using the target field.
162            if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
163                # Get the target field. We can safely assume there is only one
164                # as we don't get to the direct value branch otherwise.
165                target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
166                self.rhs = target_field.get_prep_value(self.rhs)
167
168        return super().get_prep_lookup()  # type: ignore[misc]
169
170    def as_sql(
171        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
172    ) -> tuple[str, list[Any]]:  # type: ignore[misc]
173        if isinstance(self.lhs, MultiColSource):
174            assert self.rhs_is_direct_value()
175            self.rhs = get_normalized_value(self.rhs, self.lhs)
176            from plain.models.sql.where import AND, WhereNode
177
178            root_constraint = WhereNode()
179            for target, source, val in zip(
180                self.lhs.targets, self.lhs.sources, self.rhs
181            ):
182                lookup_class = target.get_lookup(self.lookup_name)
183                root_constraint.add(
184                    lookup_class(target.get_col(self.lhs.alias, source), val), AND
185                )
186            return root_constraint.as_sql(compiler, connection)
187        return super().as_sql(compiler, connection)
188
189
190class RelatedExact(RelatedLookupMixin, Exact):
191    pass
192
193
194class RelatedLessThan(RelatedLookupMixin, LessThan):
195    pass
196
197
198class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
199    pass
200
201
202class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
203    pass
204
205
206class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
207    pass
208
209
210class RelatedIsNull(RelatedLookupMixin, IsNull):
211    pass