1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.lookups import (
6 Exact,
7 GreaterThan,
8 GreaterThanOrEqual,
9 In,
10 IsNull,
11 LessThan,
12 LessThanOrEqual,
13 Lookup,
14)
15
16if TYPE_CHECKING:
17 from plain.models.backends.base.base import BaseDatabaseWrapper
18 from plain.models.sql.compiler import SQLCompiler
19
20
21class MultiColSource:
22 contains_aggregate = False
23 contains_over_clause = False
24
25 def __init__(self, alias: str, targets: Any, sources: Any, field: Any) -> None:
26 self.targets, self.sources, self.field, self.alias = (
27 targets,
28 sources,
29 field,
30 alias,
31 )
32 self.output_field = self.field
33
34 def __repr__(self) -> str:
35 return f"{self.__class__.__name__}({self.alias}, {self.field})"
36
37 def relabeled_clone(self, relabels: dict[str, str]) -> MultiColSource:
38 return self.__class__(
39 relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
40 )
41
42 def get_lookup(self, lookup: str) -> type[Lookup] | None:
43 return self.output_field.get_lookup(lookup)
44
45 def resolve_expression(self, *args: Any, **kwargs: Any) -> MultiColSource:
46 return self
47
48
49def get_normalized_value(value: Any, lhs: Any) -> tuple[Any, ...]:
50 from plain.models import Model
51 from plain.models.fields.related import RelatedField
52
53 if isinstance(value, Model):
54 if value.id is None:
55 raise ValueError("Model instances passed to related filters must be saved.")
56 value_list = []
57 sources = lhs.output_field.path_infos[-1].target_fields
58 for source in sources:
59 while not isinstance(value, source.model) and isinstance(
60 source, RelatedField
61 ):
62 source = source.remote_field.model._model_meta.get_field(
63 source.remote_field.field_name
64 )
65 try:
66 value_list.append(getattr(value, source.attname))
67 except AttributeError:
68 # A case like Restaurant.query.filter(place=restaurant_instance),
69 # where place is a OneToOneField and the primary key of Restaurant.
70 return (value.id,)
71 return tuple(value_list)
72 if not isinstance(value, tuple):
73 return (value,)
74 return value
75
76
77class RelatedIn(In):
78 def get_prep_lookup(self) -> list[Any]:
79 if not isinstance(self.lhs, MultiColSource):
80 if self.rhs_is_direct_value():
81 # If we get here, we are dealing with single-column relations.
82 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
83 # We need to run the related field's get_prep_value(). Consider
84 # case ForeignKeyField to IntegerField given value 'abc'. The
85 # ForeignKeyField itself doesn't have validation for non-integers,
86 # so we must run validation using the target field.
87 if hasattr(self.lhs.output_field, "path_infos"):
88 # Run the target field's get_prep_value. We can safely
89 # assume there is only one as we don't get to the direct
90 # value branch otherwise.
91 target_field = self.lhs.output_field.path_infos[-1].target_fields[
92 -1
93 ]
94 self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
95 elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
96 self.lhs.field.target_field, "primary_key", False
97 ):
98 if (
99 getattr(self.lhs.output_field, "primary_key", False)
100 and self.lhs.output_field.model == self.rhs.model
101 ):
102 # A case like
103 # Restaurant.query.filter(place__in=restaurant_qs), where
104 # place is a OneToOneField and the primary key of
105 # Restaurant.
106 target_field = self.lhs.field.name
107 else:
108 target_field = self.lhs.field.target_field.name
109 self.rhs.set_values([target_field])
110 return super().get_prep_lookup()
111
112 def as_sql(
113 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
114 ) -> tuple[str, list[Any]]:
115 if isinstance(self.lhs, MultiColSource):
116 # For multicolumn lookups we need to build a multicolumn where clause.
117 # This clause is either a SubqueryConstraint (for values that need
118 # to be compiled to SQL) or an OR-combined list of
119 # (col1 = val1 AND col2 = val2 AND ...) clauses.
120 from plain.models.sql.where import (
121 AND,
122 OR,
123 SubqueryConstraint,
124 WhereNode,
125 )
126
127 root_constraint = WhereNode(connector=OR)
128 if self.rhs_is_direct_value():
129 values = [get_normalized_value(value, self.lhs) for value in self.rhs]
130 for value in values:
131 value_constraint = WhereNode()
132 for source, target, val in zip(
133 self.lhs.sources, self.lhs.targets, value
134 ):
135 lookup_class = target.get_lookup("exact")
136 lookup = lookup_class(
137 target.get_col(self.lhs.alias, source), val
138 )
139 value_constraint.add(lookup, AND)
140 root_constraint.add(value_constraint, OR)
141 else:
142 root_constraint.add(
143 SubqueryConstraint(
144 self.lhs.alias,
145 [target.column for target in self.lhs.targets],
146 [source.name for source in self.lhs.sources],
147 self.rhs,
148 ),
149 AND,
150 )
151 return root_constraint.as_sql(compiler, connection)
152 return super().as_sql(compiler, connection)
153
154
155class RelatedLookupMixin(Lookup):
156 # Type hints for attributes/methods expected from Lookup base class
157 lhs: Any
158 rhs: Any
159 prepare_rhs: bool
160 lookup_name: str | None
161
162 def get_prep_lookup(self) -> Any:
163 if not isinstance(self.lhs, MultiColSource) and not hasattr(
164 self.rhs, "resolve_expression"
165 ):
166 # If we get here, we are dealing with single-column relations.
167 self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
168 # We need to run the related field's get_prep_value(). Consider case
169 # ForeignKeyField to IntegerField given value 'abc'. The ForeignKeyField itself
170 # doesn't have validation for non-integers, so we must run validation
171 # using the target field.
172 if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
173 # Get the target field. We can safely assume there is only one
174 # as we don't get to the direct value branch otherwise.
175 target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
176 self.rhs = target_field.get_prep_value(self.rhs)
177
178 return super().get_prep_lookup()
179
180 def as_sql(
181 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
182 ) -> tuple[str, list[Any]]:
183 if isinstance(self.lhs, MultiColSource):
184 assert self.rhs_is_direct_value()
185 self.rhs = get_normalized_value(self.rhs, self.lhs)
186 from plain.models.sql.where import AND, WhereNode
187
188 root_constraint = WhereNode()
189 for target, source, val in zip(
190 self.lhs.targets, self.lhs.sources, self.rhs
191 ):
192 lookup_class = target.get_lookup(self.lookup_name)
193 root_constraint.add(
194 lookup_class(target.get_col(self.lhs.alias, source), val), AND
195 )
196 sql, params = root_constraint.as_sql(compiler, connection)
197 return sql, list(params)
198 sql, params = super().as_sql(compiler, connection)
199 return sql, list(params)
200
201
202class RelatedExact(RelatedLookupMixin, Exact):
203 pass
204
205
206class RelatedLessThan(RelatedLookupMixin, LessThan):
207 pass
208
209
210class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
211 pass
212
213
214class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
215 pass
216
217
218class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
219 pass
220
221
222class RelatedIsNull(RelatedLookupMixin, IsNull):
223 pass