1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.lookups import (
6 Exact,
7 GreaterThan,
8 GreaterThanOrEqual,
9 In,
10 IsNull,
11 LessThan,
12 LessThanOrEqual,
13)
14
15if TYPE_CHECKING:
16 from plain.models.backends.base.base import BaseDatabaseWrapper
17 from plain.models.sql.compiler import SQLCompiler
18
19
20class MultiColSource:
21 contains_aggregate = False
22 contains_over_clause = False
23
24 def __init__(self, alias: str, targets: Any, sources: Any, field: Any) -> None:
25 self.targets, self.sources, self.field, self.alias = (
26 targets,
27 sources,
28 field,
29 alias,
30 )
31 self.output_field = self.field
32
33 def __repr__(self) -> str:
34 return f"{self.__class__.__name__}({self.alias}, {self.field})"
35
36 def relabeled_clone(self, relabels: dict[str, str]) -> MultiColSource:
37 return self.__class__(
38 relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
39 )
40
41 def get_lookup(self, lookup: str) -> Any:
42 return self.output_field.get_lookup(lookup)
43
44 def resolve_expression(self, *args: Any, **kwargs: Any) -> MultiColSource:
45 return self
46
47
48def get_normalized_value(value: Any, lhs: Any) -> tuple[Any, ...]:
49 from plain.models import Model
50
51 if isinstance(value, Model):
52 if value.id is None:
53 raise ValueError("Model instances passed to related filters must be saved.")
54 value_list = []
55 sources = lhs.output_field.path_infos[-1].target_fields
56 for source in sources:
57 while not isinstance(value, source.model) and source.remote_field:
58 source = source.remote_field.model._model_meta.get_field(
59 source.remote_field.field_name
60 )
61 try:
62 value_list.append(getattr(value, source.attname))
63 except AttributeError:
64 # A case like Restaurant.query.filter(place=restaurant_instance),
65 # where place is a OneToOneField and the primary key of Restaurant.
66 return (value.id,)
67 return tuple(value_list)
68 if not isinstance(value, tuple):
69 return (value,)
70 return value
71
72
73class RelatedIn(In):
74 def get_prep_lookup(self) -> list[Any]: # type: ignore[misc]
75 if not isinstance(self.lhs, MultiColSource):
76 if self.rhs_is_direct_value():
77 # If we get here, we are dealing with single-column relations.
78 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
79 # We need to run the related field's get_prep_value(). Consider
80 # case ForeignKey to IntegerField given value 'abc'. The
81 # ForeignKey itself doesn't have validation for non-integers,
82 # so we must run validation using the target field.
83 if hasattr(self.lhs.output_field, "path_infos"):
84 # Run the target field's get_prep_value. We can safely
85 # assume there is only one as we don't get to the direct
86 # value branch otherwise.
87 target_field = self.lhs.output_field.path_infos[-1].target_fields[
88 -1
89 ]
90 self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
91 elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
92 self.lhs.field.target_field, "primary_key", False
93 ):
94 if (
95 getattr(self.lhs.output_field, "primary_key", False)
96 and self.lhs.output_field.model == self.rhs.model
97 ):
98 # A case like
99 # Restaurant.query.filter(place__in=restaurant_qs), where
100 # place is a OneToOneField and the primary key of
101 # Restaurant.
102 target_field = self.lhs.field.name
103 else:
104 target_field = self.lhs.field.target_field.name
105 self.rhs.set_values([target_field])
106 return super().get_prep_lookup()
107
108 def as_sql(
109 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
110 ) -> tuple[str, list[Any]]: # type: ignore[misc]
111 if isinstance(self.lhs, MultiColSource):
112 # For multicolumn lookups we need to build a multicolumn where clause.
113 # This clause is either a SubqueryConstraint (for values that need
114 # to be compiled to SQL) or an OR-combined list of
115 # (col1 = val1 AND col2 = val2 AND ...) clauses.
116 from plain.models.sql.where import (
117 AND,
118 OR,
119 SubqueryConstraint,
120 WhereNode,
121 )
122
123 root_constraint = WhereNode(connector=OR)
124 if self.rhs_is_direct_value():
125 values = [get_normalized_value(value, self.lhs) for value in self.rhs]
126 for value in values:
127 value_constraint = WhereNode()
128 for source, target, val in zip(
129 self.lhs.sources, self.lhs.targets, value
130 ):
131 lookup_class = target.get_lookup("exact")
132 lookup = lookup_class(
133 target.get_col(self.lhs.alias, source), val
134 )
135 value_constraint.add(lookup, AND)
136 root_constraint.add(value_constraint, OR)
137 else:
138 root_constraint.add(
139 SubqueryConstraint(
140 self.lhs.alias,
141 [target.column for target in self.lhs.targets],
142 [source.name for source in self.lhs.sources],
143 self.rhs,
144 ),
145 AND,
146 )
147 return root_constraint.as_sql(compiler, connection)
148 return super().as_sql(compiler, connection)
149
150
151class RelatedLookupMixin:
152 def get_prep_lookup(self) -> Any: # type: ignore[misc]
153 if not isinstance(self.lhs, MultiColSource) and not hasattr(
154 self.rhs, "resolve_expression"
155 ):
156 # If we get here, we are dealing with single-column relations.
157 self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
158 # We need to run the related field's get_prep_value(). Consider case
159 # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
160 # doesn't have validation for non-integers, so we must run validation
161 # using the target field.
162 if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
163 # Get the target field. We can safely assume there is only one
164 # as we don't get to the direct value branch otherwise.
165 target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
166 self.rhs = target_field.get_prep_value(self.rhs)
167
168 return super().get_prep_lookup() # type: ignore[misc]
169
170 def as_sql(
171 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
172 ) -> tuple[str, list[Any]]: # type: ignore[misc]
173 if isinstance(self.lhs, MultiColSource):
174 assert self.rhs_is_direct_value()
175 self.rhs = get_normalized_value(self.rhs, self.lhs)
176 from plain.models.sql.where import AND, WhereNode
177
178 root_constraint = WhereNode()
179 for target, source, val in zip(
180 self.lhs.targets, self.lhs.sources, self.rhs
181 ):
182 lookup_class = target.get_lookup(self.lookup_name)
183 root_constraint.add(
184 lookup_class(target.get_col(self.lhs.alias, source), val), AND
185 )
186 return root_constraint.as_sql(compiler, connection)
187 return super().as_sql(compiler, connection)
188
189
190class RelatedExact(RelatedLookupMixin, Exact):
191 pass
192
193
194class RelatedLessThan(RelatedLookupMixin, LessThan):
195 pass
196
197
198class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
199 pass
200
201
202class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
203 pass
204
205
206class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
207 pass
208
209
210class RelatedIsNull(RelatedLookupMixin, IsNull):
211 pass