1"""
2Code to manage the creation and SQL rendering of 'where' constraints.
3"""
4
5from __future__ import annotations
6
7import operator
8from functools import cached_property, reduce
9from typing import TYPE_CHECKING, Any
10
11from plain.models.exceptions import EmptyResultSet, FullResultSet
12from plain.models.expressions import Case, ResolvableExpression, When
13from plain.models.lookups import Exact
14from plain.utils import tree
15
16if TYPE_CHECKING:
17 from plain.models.backends.base.base import BaseDatabaseWrapper
18 from plain.models.lookups import Lookup
19 from plain.models.sql.compiler import SQLCompiler
20
21# Connection types
22AND = "AND"
23OR = "OR"
24XOR = "XOR"
25
26
27class WhereNode(tree.Node):
28 """
29 An SQL WHERE clause.
30
31 The class is tied to the Query class that created it (in order to create
32 the correct SQL).
33
34 A child is usually an expression producing boolean values. Most likely the
35 expression is a Lookup instance.
36
37 However, a child could also be any class with as_sql() and either
38 relabeled_clone() method or relabel_aliases() and clone() methods and
39 contains_aggregate attribute.
40 """
41
42 default = AND
43 resolved = False
44 conditional = True
45
46 def split_having_qualify(
47 self, negated: bool = False, must_group_by: bool = False
48 ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
49 """
50 Return three possibly None nodes: one for those parts of self that
51 should be included in the WHERE clause, one for those parts of self
52 that must be included in the HAVING clause, and one for those parts
53 that refer to window functions.
54 """
55 if not self.contains_aggregate and not self.contains_over_clause:
56 return self, None, None
57 in_negated = negated ^ self.negated
58 # Whether or not children must be connected in the same filtering
59 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
60 must_remain_connected = (
61 (in_negated and self.connector == AND)
62 or (not in_negated and self.connector == OR)
63 or self.connector == XOR
64 )
65 if (
66 must_remain_connected
67 and self.contains_aggregate
68 and not self.contains_over_clause
69 ):
70 # It's must cheaper to short-circuit and stash everything in the
71 # HAVING clause than split children if possible.
72 return None, self, None
73 where_parts = []
74 having_parts = []
75 qualify_parts = []
76 for c in self.children:
77 if hasattr(c, "split_having_qualify"):
78 where_part, having_part, qualify_part = c.split_having_qualify(
79 in_negated, must_group_by
80 )
81 if where_part is not None:
82 where_parts.append(where_part)
83 if having_part is not None:
84 having_parts.append(having_part)
85 if qualify_part is not None:
86 qualify_parts.append(qualify_part)
87 elif c.contains_over_clause:
88 qualify_parts.append(c)
89 elif c.contains_aggregate:
90 having_parts.append(c)
91 else:
92 where_parts.append(c)
93 if must_remain_connected and qualify_parts:
94 # Disjunctive heterogeneous predicates can be pushed down to
95 # qualify as long as no conditional aggregation is involved.
96 if not where_parts or (where_parts and not must_group_by):
97 return None, None, self
98 elif where_parts:
99 # In theory this should only be enforced when dealing with
100 # where_parts containing predicates against multi-valued
101 # relationships that could affect aggregation results but this
102 # is complex to infer properly.
103 raise NotImplementedError(
104 "Heterogeneous disjunctive predicates against window functions are "
105 "not implemented when performing conditional aggregation."
106 )
107 where_node = (
108 self.create(where_parts, self.connector, self.negated)
109 if where_parts
110 else None
111 )
112 having_node = (
113 self.create(having_parts, self.connector, self.negated)
114 if having_parts
115 else None
116 )
117 qualify_node = (
118 self.create(qualify_parts, self.connector, self.negated)
119 if qualify_parts
120 else None
121 )
122 return where_node, having_node, qualify_node
123
124 def as_sql(
125 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
126 ) -> tuple[str, list[Any]]:
127 """
128 Return the SQL version of the where clause and the value to be
129 substituted in. Return '', [] if this node matches everything,
130 None, [] if this node is empty, and raise EmptyResultSet if this
131 node can't match anything.
132 """
133 result = []
134 result_params = []
135 if self.connector == AND:
136 full_needed, empty_needed = len(self.children), 1
137 else:
138 full_needed, empty_needed = 1, len(self.children)
139
140 if self.connector == XOR and not connection.features.supports_logical_xor:
141 # Convert if the database doesn't support XOR:
142 # a XOR b XOR c XOR ...
143 # to:
144 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1
145 lhs = self.__class__(self.children, OR)
146 rhs_sum = reduce(
147 operator.add,
148 (Case(When(c, then=1), default=0) for c in self.children),
149 )
150 rhs = Exact(1, rhs_sum)
151 return self.__class__([lhs, rhs], AND, self.negated).as_sql(
152 compiler, connection
153 )
154
155 for child in self.children:
156 try:
157 sql, params = compiler.compile(child)
158 except EmptyResultSet:
159 empty_needed -= 1
160 except FullResultSet:
161 full_needed -= 1
162 else:
163 if sql:
164 result.append(sql)
165 result_params.extend(params)
166 else:
167 full_needed -= 1
168 # Check if this node matches nothing or everything.
169 # First check the amount of full nodes and empty nodes
170 # to make this node empty/full.
171 # Now, check if this node is full/empty using the
172 # counts.
173 if empty_needed == 0:
174 if self.negated:
175 raise FullResultSet
176 else:
177 raise EmptyResultSet
178 if full_needed == 0:
179 if self.negated:
180 raise EmptyResultSet
181 else:
182 raise FullResultSet
183 conn = f" {self.connector} "
184 sql_string = conn.join(result)
185 if not sql_string:
186 raise FullResultSet
187 if self.negated:
188 # Some backends (Oracle at least) need parentheses around the inner
189 # SQL in the negated case, even if the inner SQL contains just a
190 # single expression.
191 sql_string = f"NOT ({sql_string})"
192 elif len(result) > 1 or self.resolved:
193 sql_string = f"({sql_string})"
194 return sql_string, result_params
195
196 def get_group_by_cols(self) -> list[Any]:
197 cols = []
198 for child in self.children:
199 cols.extend(child.get_group_by_cols())
200 return cols
201
202 def get_source_expressions(self) -> list[Any]:
203 return self.children[:]
204
205 def set_source_expressions(self, children: list[Any]) -> None:
206 assert len(children) == len(self.children)
207 self.children = children
208
209 def relabel_aliases(self, change_map: dict[str, str]) -> None:
210 """
211 Relabel the alias values of any children. 'change_map' is a dictionary
212 mapping old (current) alias values to the new values.
213 """
214 for pos, child in enumerate(self.children):
215 if hasattr(child, "relabel_aliases"):
216 # For example another WhereNode
217 child.relabel_aliases(change_map)
218 elif hasattr(child, "relabeled_clone"):
219 self.children[pos] = child.relabeled_clone(change_map)
220
221 def clone(self) -> WhereNode:
222 clone = self.create(connector=self.connector, negated=self.negated)
223 for child in self.children:
224 if hasattr(child, "clone"):
225 child = child.clone()
226 clone.children.append(child)
227 return clone
228
229 def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
230 clone = self.clone()
231 clone.relabel_aliases(change_map)
232 return clone
233
234 def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
235 if replacement := replacements.get(self):
236 return replacement
237 clone = self.create(connector=self.connector, negated=self.negated)
238 for child in self.children:
239 clone.children.append(child.replace_expressions(replacements))
240 return clone
241
242 def get_refs(self) -> set[Any]:
243 refs = set()
244 for child in self.children:
245 refs |= child.get_refs()
246 return refs
247
248 @classmethod
249 def _contains_aggregate(cls, obj: Any) -> bool:
250 if isinstance(obj, tree.Node):
251 return any(cls._contains_aggregate(c) for c in obj.children)
252 return obj.contains_aggregate
253
254 @cached_property
255 def contains_aggregate(self) -> bool:
256 return self._contains_aggregate(self)
257
258 @classmethod
259 def _contains_over_clause(cls, obj: Any) -> bool:
260 if isinstance(obj, tree.Node):
261 return any(cls._contains_over_clause(c) for c in obj.children)
262 return obj.contains_over_clause
263
264 @cached_property
265 def contains_over_clause(self) -> bool:
266 return self._contains_over_clause(self)
267
268 @property
269 def is_summary(self) -> bool:
270 return any(child.is_summary for child in self.children)
271
272 @staticmethod
273 def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
274 if isinstance(expr, ResolvableExpression):
275 expr = expr.resolve_expression(query, *args, **kwargs)
276 return expr
277
278 @classmethod
279 def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
280 if hasattr(node, "children"):
281 for child in node.children:
282 cls._resolve_node(child, query, *args, **kwargs)
283 if hasattr(node, "lhs"):
284 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
285 if hasattr(node, "rhs"):
286 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
287
288 def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
289 clone = self.clone()
290 clone._resolve_node(clone, *args, **kwargs)
291 clone.resolved = True
292 return clone
293
294 @cached_property
295 def output_field(self) -> Any:
296 from plain.models.fields import BooleanField
297
298 return BooleanField()
299
300 @property
301 def _output_field_or_none(self) -> Any:
302 return self.output_field
303
304 def select_format(
305 self, compiler: SQLCompiler, sql: str, params: list[Any]
306 ) -> tuple[str, list[Any]]:
307 # Wrap filters with a CASE WHEN expression if a database backend
308 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
309 # BY list.
310 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
311 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
312 return sql, params
313
314 def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Any]:
315 return self.output_field.get_db_converters(connection)
316
317 def get_lookup(self, lookup: str) -> type[Lookup] | None:
318 return self.output_field.get_lookup(lookup)
319
320 def leaves(self) -> Any:
321 for child in self.children:
322 if isinstance(child, WhereNode):
323 yield from child.leaves()
324 else:
325 yield child
326
327
328class NothingNode:
329 """A node that matches nothing."""
330
331 contains_aggregate = False
332 contains_over_clause = False
333
334 def as_sql(
335 self,
336 compiler: SQLCompiler | None = None,
337 connection: BaseDatabaseWrapper | None = None,
338 ) -> tuple[str, list[Any]]:
339 raise EmptyResultSet
340
341
342class ExtraWhere:
343 # The contents are a black box - assume no aggregates or windows are used.
344 contains_aggregate = False
345 contains_over_clause = False
346
347 def __init__(self, sqls: list[str], params: list[Any] | None):
348 self.sqls = sqls
349 self.params = params
350
351 def as_sql(
352 self,
353 compiler: SQLCompiler | None = None,
354 connection: BaseDatabaseWrapper | None = None,
355 ) -> tuple[str, list[Any]]:
356 sqls = [f"({sql})" for sql in self.sqls]
357 return " AND ".join(sqls), list(self.params or ())
358
359
360class SubqueryConstraint:
361 # Even if aggregates or windows would be used in a subquery,
362 # the outer query isn't interested about those.
363 contains_aggregate = False
364 contains_over_clause = False
365
366 def __init__(
367 self, alias: str, columns: list[str], targets: list[Any], query_object: Any
368 ):
369 self.alias = alias
370 self.columns = columns
371 self.targets = targets
372 query_object.clear_ordering(clear_default=True)
373 self.query_object = query_object
374
375 def as_sql(
376 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
377 ) -> tuple[str, list[Any]]:
378 query = self.query_object
379 query.set_values(self.targets)
380 query_compiler = query.get_compiler()
381 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)