1"""
2Code to manage the creation and SQL rendering of 'where' constraints.
3"""
4import operator
5from functools import reduce
6
7from plain.exceptions import EmptyResultSet, FullResultSet
8from plain.models.expressions import Case, When
9from plain.models.lookups import Exact
10from plain.utils import tree
11from plain.utils.functional import cached_property
12
13# Connection types
14AND = "AND"
15OR = "OR"
16XOR = "XOR"
17
18
19class WhereNode(tree.Node):
20 """
21 An SQL WHERE clause.
22
23 The class is tied to the Query class that created it (in order to create
24 the correct SQL).
25
26 A child is usually an expression producing boolean values. Most likely the
27 expression is a Lookup instance.
28
29 However, a child could also be any class with as_sql() and either
30 relabeled_clone() method or relabel_aliases() and clone() methods and
31 contains_aggregate attribute.
32 """
33
34 default = AND
35 resolved = False
36 conditional = True
37
38 def split_having_qualify(self, negated=False, must_group_by=False):
39 """
40 Return three possibly None nodes: one for those parts of self that
41 should be included in the WHERE clause, one for those parts of self
42 that must be included in the HAVING clause, and one for those parts
43 that refer to window functions.
44 """
45 if not self.contains_aggregate and not self.contains_over_clause:
46 return self, None, None
47 in_negated = negated ^ self.negated
48 # Whether or not children must be connected in the same filtering
49 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
50 must_remain_connected = (
51 (in_negated and self.connector == AND)
52 or (not in_negated and self.connector == OR)
53 or self.connector == XOR
54 )
55 if (
56 must_remain_connected
57 and self.contains_aggregate
58 and not self.contains_over_clause
59 ):
60 # It's must cheaper to short-circuit and stash everything in the
61 # HAVING clause than split children if possible.
62 return None, self, None
63 where_parts = []
64 having_parts = []
65 qualify_parts = []
66 for c in self.children:
67 if hasattr(c, "split_having_qualify"):
68 where_part, having_part, qualify_part = c.split_having_qualify(
69 in_negated, must_group_by
70 )
71 if where_part is not None:
72 where_parts.append(where_part)
73 if having_part is not None:
74 having_parts.append(having_part)
75 if qualify_part is not None:
76 qualify_parts.append(qualify_part)
77 elif c.contains_over_clause:
78 qualify_parts.append(c)
79 elif c.contains_aggregate:
80 having_parts.append(c)
81 else:
82 where_parts.append(c)
83 if must_remain_connected and qualify_parts:
84 # Disjunctive heterogeneous predicates can be pushed down to
85 # qualify as long as no conditional aggregation is involved.
86 if not where_parts or (where_parts and not must_group_by):
87 return None, None, self
88 elif where_parts:
89 # In theory this should only be enforced when dealing with
90 # where_parts containing predicates against multi-valued
91 # relationships that could affect aggregation results but this
92 # is complex to infer properly.
93 raise NotImplementedError(
94 "Heterogeneous disjunctive predicates against window functions are "
95 "not implemented when performing conditional aggregation."
96 )
97 where_node = (
98 self.create(where_parts, self.connector, self.negated)
99 if where_parts
100 else None
101 )
102 having_node = (
103 self.create(having_parts, self.connector, self.negated)
104 if having_parts
105 else None
106 )
107 qualify_node = (
108 self.create(qualify_parts, self.connector, self.negated)
109 if qualify_parts
110 else None
111 )
112 return where_node, having_node, qualify_node
113
114 def as_sql(self, compiler, connection):
115 """
116 Return the SQL version of the where clause and the value to be
117 substituted in. Return '', [] if this node matches everything,
118 None, [] if this node is empty, and raise EmptyResultSet if this
119 node can't match anything.
120 """
121 result = []
122 result_params = []
123 if self.connector == AND:
124 full_needed, empty_needed = len(self.children), 1
125 else:
126 full_needed, empty_needed = 1, len(self.children)
127
128 if self.connector == XOR and not connection.features.supports_logical_xor:
129 # Convert if the database doesn't support XOR:
130 # a XOR b XOR c XOR ...
131 # to:
132 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1
133 lhs = self.__class__(self.children, OR)
134 rhs_sum = reduce(
135 operator.add,
136 (Case(When(c, then=1), default=0) for c in self.children),
137 )
138 rhs = Exact(1, rhs_sum)
139 return self.__class__([lhs, rhs], AND, self.negated).as_sql(
140 compiler, connection
141 )
142
143 for child in self.children:
144 try:
145 sql, params = compiler.compile(child)
146 except EmptyResultSet:
147 empty_needed -= 1
148 except FullResultSet:
149 full_needed -= 1
150 else:
151 if sql:
152 result.append(sql)
153 result_params.extend(params)
154 else:
155 full_needed -= 1
156 # Check if this node matches nothing or everything.
157 # First check the amount of full nodes and empty nodes
158 # to make this node empty/full.
159 # Now, check if this node is full/empty using the
160 # counts.
161 if empty_needed == 0:
162 if self.negated:
163 raise FullResultSet
164 else:
165 raise EmptyResultSet
166 if full_needed == 0:
167 if self.negated:
168 raise EmptyResultSet
169 else:
170 raise FullResultSet
171 conn = " %s " % self.connector
172 sql_string = conn.join(result)
173 if not sql_string:
174 raise FullResultSet
175 if self.negated:
176 # Some backends (Oracle at least) need parentheses around the inner
177 # SQL in the negated case, even if the inner SQL contains just a
178 # single expression.
179 sql_string = "NOT (%s)" % sql_string
180 elif len(result) > 1 or self.resolved:
181 sql_string = "(%s)" % sql_string
182 return sql_string, result_params
183
184 def get_group_by_cols(self):
185 cols = []
186 for child in self.children:
187 cols.extend(child.get_group_by_cols())
188 return cols
189
190 def get_source_expressions(self):
191 return self.children[:]
192
193 def set_source_expressions(self, children):
194 assert len(children) == len(self.children)
195 self.children = children
196
197 def relabel_aliases(self, change_map):
198 """
199 Relabel the alias values of any children. 'change_map' is a dictionary
200 mapping old (current) alias values to the new values.
201 """
202 for pos, child in enumerate(self.children):
203 if hasattr(child, "relabel_aliases"):
204 # For example another WhereNode
205 child.relabel_aliases(change_map)
206 elif hasattr(child, "relabeled_clone"):
207 self.children[pos] = child.relabeled_clone(change_map)
208
209 def clone(self):
210 clone = self.create(connector=self.connector, negated=self.negated)
211 for child in self.children:
212 if hasattr(child, "clone"):
213 child = child.clone()
214 clone.children.append(child)
215 return clone
216
217 def relabeled_clone(self, change_map):
218 clone = self.clone()
219 clone.relabel_aliases(change_map)
220 return clone
221
222 def replace_expressions(self, replacements):
223 if replacement := replacements.get(self):
224 return replacement
225 clone = self.create(connector=self.connector, negated=self.negated)
226 for child in self.children:
227 clone.children.append(child.replace_expressions(replacements))
228 return clone
229
230 def get_refs(self):
231 refs = set()
232 for child in self.children:
233 refs |= child.get_refs()
234 return refs
235
236 @classmethod
237 def _contains_aggregate(cls, obj):
238 if isinstance(obj, tree.Node):
239 return any(cls._contains_aggregate(c) for c in obj.children)
240 return obj.contains_aggregate
241
242 @cached_property
243 def contains_aggregate(self):
244 return self._contains_aggregate(self)
245
246 @classmethod
247 def _contains_over_clause(cls, obj):
248 if isinstance(obj, tree.Node):
249 return any(cls._contains_over_clause(c) for c in obj.children)
250 return obj.contains_over_clause
251
252 @cached_property
253 def contains_over_clause(self):
254 return self._contains_over_clause(self)
255
256 @property
257 def is_summary(self):
258 return any(child.is_summary for child in self.children)
259
260 @staticmethod
261 def _resolve_leaf(expr, query, *args, **kwargs):
262 if hasattr(expr, "resolve_expression"):
263 expr = expr.resolve_expression(query, *args, **kwargs)
264 return expr
265
266 @classmethod
267 def _resolve_node(cls, node, query, *args, **kwargs):
268 if hasattr(node, "children"):
269 for child in node.children:
270 cls._resolve_node(child, query, *args, **kwargs)
271 if hasattr(node, "lhs"):
272 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
273 if hasattr(node, "rhs"):
274 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
275
276 def resolve_expression(self, *args, **kwargs):
277 clone = self.clone()
278 clone._resolve_node(clone, *args, **kwargs)
279 clone.resolved = True
280 return clone
281
282 @cached_property
283 def output_field(self):
284 from plain.models.fields import BooleanField
285
286 return BooleanField()
287
288 @property
289 def _output_field_or_none(self):
290 return self.output_field
291
292 def select_format(self, compiler, sql, params):
293 # Wrap filters with a CASE WHEN expression if a database backend
294 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
295 # BY list.
296 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
297 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
298 return sql, params
299
300 def get_db_converters(self, connection):
301 return self.output_field.get_db_converters(connection)
302
303 def get_lookup(self, lookup):
304 return self.output_field.get_lookup(lookup)
305
306 def leaves(self):
307 for child in self.children:
308 if isinstance(child, WhereNode):
309 yield from child.leaves()
310 else:
311 yield child
312
313
314class NothingNode:
315 """A node that matches nothing."""
316
317 contains_aggregate = False
318 contains_over_clause = False
319
320 def as_sql(self, compiler=None, connection=None):
321 raise EmptyResultSet
322
323
324class ExtraWhere:
325 # The contents are a black box - assume no aggregates or windows are used.
326 contains_aggregate = False
327 contains_over_clause = False
328
329 def __init__(self, sqls, params):
330 self.sqls = sqls
331 self.params = params
332
333 def as_sql(self, compiler=None, connection=None):
334 sqls = ["(%s)" % sql for sql in self.sqls]
335 return " AND ".join(sqls), list(self.params or ())
336
337
338class SubqueryConstraint:
339 # Even if aggregates or windows would be used in a subquery,
340 # the outer query isn't interested about those.
341 contains_aggregate = False
342 contains_over_clause = False
343
344 def __init__(self, alias, columns, targets, query_object):
345 self.alias = alias
346 self.columns = columns
347 self.targets = targets
348 query_object.clear_ordering(clear_default=True)
349 self.query_object = query_object
350
351 def as_sql(self, compiler, connection):
352 query = self.query_object
353 query.set_values(self.targets)
354 query_compiler = query.get_compiler(connection=connection)
355 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)