1"""
2Code to manage the creation and SQL rendering of 'where' constraints.
3"""
4
5from __future__ import annotations
6
7import operator
8from functools import cached_property, reduce
9from typing import TYPE_CHECKING, Any
10
11from plain.models.exceptions import EmptyResultSet, FullResultSet
12from plain.models.expressions import Case, When
13from plain.models.lookups import Exact
14from plain.utils import tree
15
16if TYPE_CHECKING:
17 from plain.models.backends.base.base import BaseDatabaseWrapper
18 from plain.models.sql.compiler import SQLCompiler
19
20# Connection types
21AND = "AND"
22OR = "OR"
23XOR = "XOR"
24
25
26class WhereNode(tree.Node):
27 """
28 An SQL WHERE clause.
29
30 The class is tied to the Query class that created it (in order to create
31 the correct SQL).
32
33 A child is usually an expression producing boolean values. Most likely the
34 expression is a Lookup instance.
35
36 However, a child could also be any class with as_sql() and either
37 relabeled_clone() method or relabel_aliases() and clone() methods and
38 contains_aggregate attribute.
39 """
40
41 default = AND
42 resolved = False
43 conditional = True
44
45 def split_having_qualify(
46 self, negated: bool = False, must_group_by: bool = False
47 ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
48 """
49 Return three possibly None nodes: one for those parts of self that
50 should be included in the WHERE clause, one for those parts of self
51 that must be included in the HAVING clause, and one for those parts
52 that refer to window functions.
53 """
54 if not self.contains_aggregate and not self.contains_over_clause:
55 return self, None, None
56 in_negated = negated ^ self.negated
57 # Whether or not children must be connected in the same filtering
58 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
59 must_remain_connected = (
60 (in_negated and self.connector == AND)
61 or (not in_negated and self.connector == OR)
62 or self.connector == XOR
63 )
64 if (
65 must_remain_connected
66 and self.contains_aggregate
67 and not self.contains_over_clause
68 ):
69 # It's must cheaper to short-circuit and stash everything in the
70 # HAVING clause than split children if possible.
71 return None, self, None
72 where_parts = []
73 having_parts = []
74 qualify_parts = []
75 for c in self.children:
76 if hasattr(c, "split_having_qualify"):
77 where_part, having_part, qualify_part = c.split_having_qualify(
78 in_negated, must_group_by
79 )
80 if where_part is not None:
81 where_parts.append(where_part)
82 if having_part is not None:
83 having_parts.append(having_part)
84 if qualify_part is not None:
85 qualify_parts.append(qualify_part)
86 elif c.contains_over_clause:
87 qualify_parts.append(c)
88 elif c.contains_aggregate:
89 having_parts.append(c)
90 else:
91 where_parts.append(c)
92 if must_remain_connected and qualify_parts:
93 # Disjunctive heterogeneous predicates can be pushed down to
94 # qualify as long as no conditional aggregation is involved.
95 if not where_parts or (where_parts and not must_group_by):
96 return None, None, self
97 elif where_parts:
98 # In theory this should only be enforced when dealing with
99 # where_parts containing predicates against multi-valued
100 # relationships that could affect aggregation results but this
101 # is complex to infer properly.
102 raise NotImplementedError(
103 "Heterogeneous disjunctive predicates against window functions are "
104 "not implemented when performing conditional aggregation."
105 )
106 where_node = (
107 self.create(where_parts, self.connector, self.negated)
108 if where_parts
109 else None
110 )
111 having_node = (
112 self.create(having_parts, self.connector, self.negated)
113 if having_parts
114 else None
115 )
116 qualify_node = (
117 self.create(qualify_parts, self.connector, self.negated)
118 if qualify_parts
119 else None
120 )
121 return where_node, having_node, qualify_node
122
123 def as_sql(
124 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
125 ) -> tuple[str, list[Any]]:
126 """
127 Return the SQL version of the where clause and the value to be
128 substituted in. Return '', [] if this node matches everything,
129 None, [] if this node is empty, and raise EmptyResultSet if this
130 node can't match anything.
131 """
132 result = []
133 result_params = []
134 if self.connector == AND:
135 full_needed, empty_needed = len(self.children), 1
136 else:
137 full_needed, empty_needed = 1, len(self.children)
138
139 if self.connector == XOR and not connection.features.supports_logical_xor:
140 # Convert if the database doesn't support XOR:
141 # a XOR b XOR c XOR ...
142 # to:
143 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1
144 lhs = self.__class__(self.children, OR)
145 rhs_sum = reduce(
146 operator.add,
147 (Case(When(c, then=1), default=0) for c in self.children),
148 )
149 rhs = Exact(1, rhs_sum)
150 return self.__class__([lhs, rhs], AND, self.negated).as_sql(
151 compiler, connection
152 )
153
154 for child in self.children:
155 try:
156 sql, params = compiler.compile(child)
157 except EmptyResultSet:
158 empty_needed -= 1
159 except FullResultSet:
160 full_needed -= 1
161 else:
162 if sql:
163 result.append(sql)
164 result_params.extend(params)
165 else:
166 full_needed -= 1
167 # Check if this node matches nothing or everything.
168 # First check the amount of full nodes and empty nodes
169 # to make this node empty/full.
170 # Now, check if this node is full/empty using the
171 # counts.
172 if empty_needed == 0:
173 if self.negated:
174 raise FullResultSet
175 else:
176 raise EmptyResultSet
177 if full_needed == 0:
178 if self.negated:
179 raise EmptyResultSet
180 else:
181 raise FullResultSet
182 conn = f" {self.connector} "
183 sql_string = conn.join(result)
184 if not sql_string:
185 raise FullResultSet
186 if self.negated:
187 # Some backends (Oracle at least) need parentheses around the inner
188 # SQL in the negated case, even if the inner SQL contains just a
189 # single expression.
190 sql_string = f"NOT ({sql_string})"
191 elif len(result) > 1 or self.resolved:
192 sql_string = f"({sql_string})"
193 return sql_string, result_params
194
195 def get_group_by_cols(self) -> list[Any]:
196 cols = []
197 for child in self.children:
198 cols.extend(child.get_group_by_cols())
199 return cols
200
201 def get_source_expressions(self) -> list[Any]:
202 return self.children[:]
203
204 def set_source_expressions(self, children: list[Any]) -> None:
205 assert len(children) == len(self.children)
206 self.children = children
207
208 def relabel_aliases(self, change_map: dict[str, str]) -> None:
209 """
210 Relabel the alias values of any children. 'change_map' is a dictionary
211 mapping old (current) alias values to the new values.
212 """
213 for pos, child in enumerate(self.children):
214 if hasattr(child, "relabel_aliases"):
215 # For example another WhereNode
216 child.relabel_aliases(change_map)
217 elif hasattr(child, "relabeled_clone"):
218 self.children[pos] = child.relabeled_clone(change_map)
219
220 def clone(self) -> WhereNode:
221 clone = self.create(connector=self.connector, negated=self.negated)
222 for child in self.children:
223 if hasattr(child, "clone"):
224 child = child.clone()
225 clone.children.append(child)
226 return clone
227
228 def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
229 clone = self.clone()
230 clone.relabel_aliases(change_map)
231 return clone
232
233 def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
234 if replacement := replacements.get(self):
235 return replacement
236 clone = self.create(connector=self.connector, negated=self.negated)
237 for child in self.children:
238 clone.children.append(child.replace_expressions(replacements))
239 return clone
240
241 def get_refs(self) -> set[Any]:
242 refs = set()
243 for child in self.children:
244 refs |= child.get_refs()
245 return refs
246
247 @classmethod
248 def _contains_aggregate(cls, obj: Any) -> bool:
249 if isinstance(obj, tree.Node):
250 return any(cls._contains_aggregate(c) for c in obj.children)
251 return obj.contains_aggregate
252
253 @cached_property
254 def contains_aggregate(self) -> bool:
255 return self._contains_aggregate(self)
256
257 @classmethod
258 def _contains_over_clause(cls, obj: Any) -> bool:
259 if isinstance(obj, tree.Node):
260 return any(cls._contains_over_clause(c) for c in obj.children)
261 return obj.contains_over_clause
262
263 @cached_property
264 def contains_over_clause(self) -> bool:
265 return self._contains_over_clause(self)
266
267 @property
268 def is_summary(self) -> bool:
269 return any(child.is_summary for child in self.children)
270
271 @staticmethod
272 def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
273 if hasattr(expr, "resolve_expression"):
274 expr = expr.resolve_expression(query, *args, **kwargs)
275 return expr
276
277 @classmethod
278 def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
279 if hasattr(node, "children"):
280 for child in node.children:
281 cls._resolve_node(child, query, *args, **kwargs)
282 if hasattr(node, "lhs"):
283 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
284 if hasattr(node, "rhs"):
285 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
286
287 def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
288 clone = self.clone()
289 clone._resolve_node(clone, *args, **kwargs)
290 clone.resolved = True
291 return clone
292
293 @cached_property
294 def output_field(self) -> Any:
295 from plain.models.fields import BooleanField
296
297 return BooleanField()
298
299 @property
300 def _output_field_or_none(self) -> Any:
301 return self.output_field
302
303 def select_format(
304 self, compiler: SQLCompiler, sql: str, params: list[Any]
305 ) -> tuple[str, list[Any]]:
306 # Wrap filters with a CASE WHEN expression if a database backend
307 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
308 # BY list.
309 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
310 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
311 return sql, params
312
313 def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Any]:
314 return self.output_field.get_db_converters(connection)
315
316 def get_lookup(self, lookup: str) -> Any:
317 return self.output_field.get_lookup(lookup)
318
319 def leaves(self) -> Any:
320 for child in self.children:
321 if isinstance(child, WhereNode):
322 yield from child.leaves()
323 else:
324 yield child
325
326
327class NothingNode:
328 """A node that matches nothing."""
329
330 contains_aggregate = False
331 contains_over_clause = False
332
333 def as_sql(
334 self,
335 compiler: SQLCompiler | None = None,
336 connection: BaseDatabaseWrapper | None = None,
337 ) -> tuple[str, list[Any]]:
338 raise EmptyResultSet
339
340
341class ExtraWhere:
342 # The contents are a black box - assume no aggregates or windows are used.
343 contains_aggregate = False
344 contains_over_clause = False
345
346 def __init__(self, sqls: list[str], params: list[Any] | None):
347 self.sqls = sqls
348 self.params = params
349
350 def as_sql(
351 self,
352 compiler: SQLCompiler | None = None,
353 connection: BaseDatabaseWrapper | None = None,
354 ) -> tuple[str, list[Any]]:
355 sqls = [f"({sql})" for sql in self.sqls]
356 return " AND ".join(sqls), list(self.params or ())
357
358
359class SubqueryConstraint:
360 # Even if aggregates or windows would be used in a subquery,
361 # the outer query isn't interested about those.
362 contains_aggregate = False
363 contains_over_clause = False
364
365 def __init__(
366 self, alias: str, columns: list[str], targets: list[Any], query_object: Any
367 ):
368 self.alias = alias
369 self.columns = columns
370 self.targets = targets
371 query_object.clear_ordering(clear_default=True)
372 self.query_object = query_object
373
374 def as_sql(
375 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
376 ) -> tuple[str, list[Any]]:
377 query = self.query_object
378 query.set_values(self.targets)
379 query_compiler = query.get_compiler()
380 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)