1from plain.models.lookups import (
2 Exact,
3 GreaterThan,
4 GreaterThanOrEqual,
5 In,
6 IsNull,
7 LessThan,
8 LessThanOrEqual,
9)
10
11
12class MultiColSource:
13 contains_aggregate = False
14 contains_over_clause = False
15
16 def __init__(self, alias, targets, sources, field):
17 self.targets, self.sources, self.field, self.alias = (
18 targets,
19 sources,
20 field,
21 alias,
22 )
23 self.output_field = self.field
24
25 def __repr__(self):
26 return f"{self.__class__.__name__}({self.alias}, {self.field})"
27
28 def relabeled_clone(self, relabels):
29 return self.__class__(
30 relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
31 )
32
33 def get_lookup(self, lookup):
34 return self.output_field.get_lookup(lookup)
35
36 def resolve_expression(self, *args, **kwargs):
37 return self
38
39
40def get_normalized_value(value, lhs):
41 from plain.models import Model
42
43 if isinstance(value, Model):
44 if value.pk is None:
45 raise ValueError("Model instances passed to related filters must be saved.")
46 value_list = []
47 sources = lhs.output_field.path_infos[-1].target_fields
48 for source in sources:
49 while not isinstance(value, source.model) and source.remote_field:
50 source = source.remote_field.model._meta.get_field(
51 source.remote_field.field_name
52 )
53 try:
54 value_list.append(getattr(value, source.attname))
55 except AttributeError:
56 # A case like Restaurant.objects.filter(place=restaurant_instance),
57 # where place is a OneToOneField and the primary key of Restaurant.
58 return (value.pk,)
59 return tuple(value_list)
60 if not isinstance(value, tuple):
61 return (value,)
62 return value
63
64
65class RelatedIn(In):
66 def get_prep_lookup(self):
67 if not isinstance(self.lhs, MultiColSource):
68 if self.rhs_is_direct_value():
69 # If we get here, we are dealing with single-column relations.
70 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
71 # We need to run the related field's get_prep_value(). Consider
72 # case ForeignKey to IntegerField given value 'abc'. The
73 # ForeignKey itself doesn't have validation for non-integers,
74 # so we must run validation using the target field.
75 if hasattr(self.lhs.output_field, "path_infos"):
76 # Run the target field's get_prep_value. We can safely
77 # assume there is only one as we don't get to the direct
78 # value branch otherwise.
79 target_field = self.lhs.output_field.path_infos[-1].target_fields[
80 -1
81 ]
82 self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
83 elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
84 self.lhs.field.target_field, "primary_key", False
85 ):
86 if (
87 getattr(self.lhs.output_field, "primary_key", False)
88 and self.lhs.output_field.model == self.rhs.model
89 ):
90 # A case like
91 # Restaurant.objects.filter(place__in=restaurant_qs), where
92 # place is a OneToOneField and the primary key of
93 # Restaurant.
94 target_field = self.lhs.field.name
95 else:
96 target_field = self.lhs.field.target_field.name
97 self.rhs.set_values([target_field])
98 return super().get_prep_lookup()
99
100 def as_sql(self, compiler, connection):
101 if isinstance(self.lhs, MultiColSource):
102 # For multicolumn lookups we need to build a multicolumn where clause.
103 # This clause is either a SubqueryConstraint (for values that need
104 # to be compiled to SQL) or an OR-combined list of
105 # (col1 = val1 AND col2 = val2 AND ...) clauses.
106 from plain.models.sql.where import (
107 AND,
108 OR,
109 SubqueryConstraint,
110 WhereNode,
111 )
112
113 root_constraint = WhereNode(connector=OR)
114 if self.rhs_is_direct_value():
115 values = [get_normalized_value(value, self.lhs) for value in self.rhs]
116 for value in values:
117 value_constraint = WhereNode()
118 for source, target, val in zip(
119 self.lhs.sources, self.lhs.targets, value
120 ):
121 lookup_class = target.get_lookup("exact")
122 lookup = lookup_class(
123 target.get_col(self.lhs.alias, source), val
124 )
125 value_constraint.add(lookup, AND)
126 root_constraint.add(value_constraint, OR)
127 else:
128 root_constraint.add(
129 SubqueryConstraint(
130 self.lhs.alias,
131 [target.column for target in self.lhs.targets],
132 [source.name for source in self.lhs.sources],
133 self.rhs,
134 ),
135 AND,
136 )
137 return root_constraint.as_sql(compiler, connection)
138 return super().as_sql(compiler, connection)
139
140
141class RelatedLookupMixin:
142 def get_prep_lookup(self):
143 if not isinstance(self.lhs, MultiColSource) and not hasattr(
144 self.rhs, "resolve_expression"
145 ):
146 # If we get here, we are dealing with single-column relations.
147 self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
148 # We need to run the related field's get_prep_value(). Consider case
149 # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
150 # doesn't have validation for non-integers, so we must run validation
151 # using the target field.
152 if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
153 # Get the target field. We can safely assume there is only one
154 # as we don't get to the direct value branch otherwise.
155 target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
156 self.rhs = target_field.get_prep_value(self.rhs)
157
158 return super().get_prep_lookup()
159
160 def as_sql(self, compiler, connection):
161 if isinstance(self.lhs, MultiColSource):
162 assert self.rhs_is_direct_value()
163 self.rhs = get_normalized_value(self.rhs, self.lhs)
164 from plain.models.sql.where import AND, WhereNode
165
166 root_constraint = WhereNode()
167 for target, source, val in zip(
168 self.lhs.targets, self.lhs.sources, self.rhs
169 ):
170 lookup_class = target.get_lookup(self.lookup_name)
171 root_constraint.add(
172 lookup_class(target.get_col(self.lhs.alias, source), val), AND
173 )
174 return root_constraint.as_sql(compiler, connection)
175 return super().as_sql(compiler, connection)
176
177
178class RelatedExact(RelatedLookupMixin, Exact):
179 pass
180
181
182class RelatedLessThan(RelatedLookupMixin, LessThan):
183 pass
184
185
186class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
187 pass
188
189
190class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
191 pass
192
193
194class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
195 pass
196
197
198class RelatedIsNull(RelatedLookupMixin, IsNull):
199 pass