Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.lookups import (
  6    Exact,
  7    GreaterThan,
  8    GreaterThanOrEqual,
  9    In,
 10    IsNull,
 11    LessThan,
 12    LessThanOrEqual,
 13    Lookup,
 14)
 15
 16if TYPE_CHECKING:
 17    from plain.models.backends.base.base import BaseDatabaseWrapper
 18    from plain.models.sql.compiler import SQLCompiler
 19
 20
 21class MultiColSource:
 22    contains_aggregate = False
 23    contains_over_clause = False
 24
 25    def __init__(self, alias: str, targets: Any, sources: Any, field: Any) -> None:
 26        self.targets, self.sources, self.field, self.alias = (
 27            targets,
 28            sources,
 29            field,
 30            alias,
 31        )
 32        self.output_field = self.field
 33
 34    def __repr__(self) -> str:
 35        return f"{self.__class__.__name__}({self.alias}, {self.field})"
 36
 37    def relabeled_clone(self, relabels: dict[str, str]) -> MultiColSource:
 38        return self.__class__(
 39            relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
 40        )
 41
 42    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 43        return self.output_field.get_lookup(lookup)
 44
 45    def resolve_expression(self, *args: Any, **kwargs: Any) -> MultiColSource:
 46        return self
 47
 48
 49def get_normalized_value(value: Any, lhs: Any) -> tuple[Any, ...]:
 50    from plain.models import Model
 51    from plain.models.fields.related import RelatedField
 52
 53    if isinstance(value, Model):
 54        if value.id is None:
 55            raise ValueError("Model instances passed to related filters must be saved.")
 56        value_list = []
 57        sources = lhs.output_field.path_infos[-1].target_fields
 58        for source in sources:
 59            while not isinstance(value, source.model) and isinstance(
 60                source, RelatedField
 61            ):
 62                source = source.remote_field.model._model_meta.get_field(
 63                    source.remote_field.field_name
 64                )
 65            try:
 66                value_list.append(getattr(value, source.attname))
 67            except AttributeError:
 68                # A case like Restaurant.query.filter(place=restaurant_instance),
 69                # where place is a OneToOneField and the primary key of Restaurant.
 70                return (value.id,)
 71        return tuple(value_list)
 72    if not isinstance(value, tuple):
 73        return (value,)
 74    return value
 75
 76
 77class RelatedIn(In):
 78    def get_prep_lookup(self) -> list[Any]:
 79        if not isinstance(self.lhs, MultiColSource):
 80            if self.rhs_is_direct_value():
 81                # If we get here, we are dealing with single-column relations.
 82                self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
 83                # We need to run the related field's get_prep_value(). Consider
 84                # case ForeignKeyField to IntegerField given value 'abc'. The
 85                # ForeignKeyField itself doesn't have validation for non-integers,
 86                # so we must run validation using the target field.
 87                if hasattr(self.lhs.output_field, "path_infos"):
 88                    # Run the target field's get_prep_value. We can safely
 89                    # assume there is only one as we don't get to the direct
 90                    # value branch otherwise.
 91                    target_field = self.lhs.output_field.path_infos[-1].target_fields[
 92                        -1
 93                    ]
 94                    self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
 95            elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
 96                self.lhs.field.target_field, "primary_key", False
 97            ):
 98                if (
 99                    getattr(self.lhs.output_field, "primary_key", False)
100                    and self.lhs.output_field.model == self.rhs.model
101                ):
102                    # A case like
103                    # Restaurant.query.filter(place__in=restaurant_qs), where
104                    # place is a OneToOneField and the primary key of
105                    # Restaurant.
106                    target_field = self.lhs.field.name
107                else:
108                    target_field = self.lhs.field.target_field.name
109                self.rhs.set_values([target_field])
110        return super().get_prep_lookup()
111
112    def as_sql(
113        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
114    ) -> tuple[str, list[Any]]:
115        if isinstance(self.lhs, MultiColSource):
116            # For multicolumn lookups we need to build a multicolumn where clause.
117            # This clause is either a SubqueryConstraint (for values that need
118            # to be compiled to SQL) or an OR-combined list of
119            # (col1 = val1 AND col2 = val2 AND ...) clauses.
120            from plain.models.sql.where import (
121                AND,
122                OR,
123                SubqueryConstraint,
124                WhereNode,
125            )
126
127            root_constraint = WhereNode(connector=OR)
128            if self.rhs_is_direct_value():
129                values = [get_normalized_value(value, self.lhs) for value in self.rhs]
130                for value in values:
131                    value_constraint = WhereNode()
132                    for source, target, val in zip(
133                        self.lhs.sources, self.lhs.targets, value
134                    ):
135                        lookup_class = target.get_lookup("exact")
136                        lookup = lookup_class(
137                            target.get_col(self.lhs.alias, source), val
138                        )
139                        value_constraint.add(lookup, AND)
140                    root_constraint.add(value_constraint, OR)
141            else:
142                root_constraint.add(
143                    SubqueryConstraint(
144                        self.lhs.alias,
145                        [target.column for target in self.lhs.targets],
146                        [source.name for source in self.lhs.sources],
147                        self.rhs,
148                    ),
149                    AND,
150                )
151            return root_constraint.as_sql(compiler, connection)
152        return super().as_sql(compiler, connection)
153
154
155class RelatedLookupMixin(Lookup):
156    # Type hints for attributes/methods expected from Lookup base class
157    lhs: Any
158    rhs: Any
159    prepare_rhs: bool
160    lookup_name: str | None
161
162    def get_prep_lookup(self) -> Any:
163        if not isinstance(self.lhs, MultiColSource) and not hasattr(
164            self.rhs, "resolve_expression"
165        ):
166            # If we get here, we are dealing with single-column relations.
167            self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
168            # We need to run the related field's get_prep_value(). Consider case
169            # ForeignKeyField to IntegerField given value 'abc'. The ForeignKeyField itself
170            # doesn't have validation for non-integers, so we must run validation
171            # using the target field.
172            if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
173                # Get the target field. We can safely assume there is only one
174                # as we don't get to the direct value branch otherwise.
175                target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
176                self.rhs = target_field.get_prep_value(self.rhs)
177
178        return super().get_prep_lookup()
179
180    def as_sql(
181        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
182    ) -> tuple[str, list[Any]]:
183        if isinstance(self.lhs, MultiColSource):
184            assert self.rhs_is_direct_value()
185            self.rhs = get_normalized_value(self.rhs, self.lhs)
186            from plain.models.sql.where import AND, WhereNode
187
188            root_constraint = WhereNode()
189            for target, source, val in zip(
190                self.lhs.targets, self.lhs.sources, self.rhs
191            ):
192                lookup_class = target.get_lookup(self.lookup_name)
193                root_constraint.add(
194                    lookup_class(target.get_col(self.lhs.alias, source), val), AND
195                )
196            sql, params = root_constraint.as_sql(compiler, connection)
197            return sql, list(params)
198        sql, params = super().as_sql(compiler, connection)
199        return sql, list(params)
200
201
202class RelatedExact(RelatedLookupMixin, Exact):
203    pass
204
205
206class RelatedLessThan(RelatedLookupMixin, LessThan):
207    pass
208
209
210class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
211    pass
212
213
214class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
215    pass
216
217
218class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
219    pass
220
221
222class RelatedIsNull(RelatedLookupMixin, IsNull):
223    pass