1"""
2Code to manage the creation and SQL rendering of 'where' constraints.
3"""
4
5from __future__ import annotations
6
7import operator
8from functools import cached_property, reduce
9from typing import TYPE_CHECKING, Any
10
11from plain.models.exceptions import EmptyResultSet, FullResultSet
12from plain.models.expressions import Case, ResolvableExpression, When
13from plain.models.lookups import Exact
14from plain.utils import tree
15
16if TYPE_CHECKING:
17 from plain.models.lookups import Lookup
18 from plain.models.postgres.wrapper import DatabaseWrapper
19 from plain.models.sql.compiler import SQLCompiler
20
21# Connection types
22AND = "AND"
23OR = "OR"
24XOR = "XOR"
25
26
27class WhereNode(tree.Node):
28 """
29 An SQL WHERE clause.
30
31 The class is tied to the Query class that created it (in order to create
32 the correct SQL).
33
34 A child is usually an expression producing boolean values. Most likely the
35 expression is a Lookup instance.
36
37 However, a child could also be any class with as_sql() and either
38 relabeled_clone() method or relabel_aliases() and clone() methods and
39 contains_aggregate attribute.
40 """
41
42 default = AND
43 resolved = False
44 conditional = True
45
46 def split_having_qualify(
47 self, negated: bool = False, must_group_by: bool = False
48 ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
49 """
50 Return three possibly None nodes: one for those parts of self that
51 should be included in the WHERE clause, one for those parts of self
52 that must be included in the HAVING clause, and one for those parts
53 that refer to window functions.
54 """
55 if not self.contains_aggregate and not self.contains_over_clause:
56 return self, None, None
57 in_negated = negated ^ self.negated
58 # Whether or not children must be connected in the same filtering
59 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
60 must_remain_connected = (
61 (in_negated and self.connector == AND)
62 or (not in_negated and self.connector == OR)
63 or self.connector == XOR
64 )
65 if (
66 must_remain_connected
67 and self.contains_aggregate
68 and not self.contains_over_clause
69 ):
70 # It's must cheaper to short-circuit and stash everything in the
71 # HAVING clause than split children if possible.
72 return None, self, None
73 where_parts = []
74 having_parts = []
75 qualify_parts = []
76 for c in self.children:
77 if hasattr(c, "split_having_qualify"):
78 where_part, having_part, qualify_part = c.split_having_qualify(
79 in_negated, must_group_by
80 )
81 if where_part is not None:
82 where_parts.append(where_part)
83 if having_part is not None:
84 having_parts.append(having_part)
85 if qualify_part is not None:
86 qualify_parts.append(qualify_part)
87 elif c.contains_over_clause:
88 qualify_parts.append(c)
89 elif c.contains_aggregate:
90 having_parts.append(c)
91 else:
92 where_parts.append(c)
93 if must_remain_connected and qualify_parts:
94 # Disjunctive heterogeneous predicates can be pushed down to
95 # qualify as long as no conditional aggregation is involved.
96 if not where_parts or (where_parts and not must_group_by):
97 return None, None, self
98 elif where_parts:
99 # In theory this should only be enforced when dealing with
100 # where_parts containing predicates against multi-valued
101 # relationships that could affect aggregation results but this
102 # is complex to infer properly.
103 raise NotImplementedError(
104 "Heterogeneous disjunctive predicates against window functions are "
105 "not implemented when performing conditional aggregation."
106 )
107 where_node = (
108 self.create(where_parts, self.connector, self.negated)
109 if where_parts
110 else None
111 )
112 having_node = (
113 self.create(having_parts, self.connector, self.negated)
114 if having_parts
115 else None
116 )
117 qualify_node = (
118 self.create(qualify_parts, self.connector, self.negated)
119 if qualify_parts
120 else None
121 )
122 return where_node, having_node, qualify_node
123
124 def as_sql(
125 self, compiler: SQLCompiler, connection: DatabaseWrapper
126 ) -> tuple[str, list[Any]]:
127 """
128 Return the SQL version of the where clause and the value to be
129 substituted in. Return '', [] if this node matches everything,
130 None, [] if this node is empty, and raise EmptyResultSet if this
131 node can't match anything.
132 """
133 result = []
134 result_params = []
135 if self.connector == AND:
136 full_needed, empty_needed = len(self.children), 1
137 else:
138 full_needed, empty_needed = 1, len(self.children)
139
140 if self.connector == XOR:
141 # PostgreSQL doesn't have a native XOR operator, so convert:
142 # a XOR b XOR c XOR ...
143 # to:
144 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1
145 lhs = self.__class__(self.children, OR)
146 rhs_sum = reduce(
147 operator.add,
148 (Case(When(c, then=1), default=0) for c in self.children),
149 )
150 rhs = Exact(1, rhs_sum)
151 return self.__class__([lhs, rhs], AND, self.negated).as_sql(
152 compiler, connection
153 )
154
155 for child in self.children:
156 try:
157 sql, params = compiler.compile(child)
158 except EmptyResultSet:
159 empty_needed -= 1
160 except FullResultSet:
161 full_needed -= 1
162 else:
163 if sql:
164 result.append(sql)
165 result_params.extend(params)
166 else:
167 full_needed -= 1
168 # Check if this node matches nothing or everything.
169 # First check the amount of full nodes and empty nodes
170 # to make this node empty/full.
171 # Now, check if this node is full/empty using the
172 # counts.
173 if empty_needed == 0:
174 if self.negated:
175 raise FullResultSet
176 else:
177 raise EmptyResultSet
178 if full_needed == 0:
179 if self.negated:
180 raise EmptyResultSet
181 else:
182 raise FullResultSet
183 conn = f" {self.connector} "
184 sql_string = conn.join(result)
185 if not sql_string:
186 raise FullResultSet
187 if self.negated:
188 sql_string = f"NOT ({sql_string})"
189 elif len(result) > 1 or self.resolved:
190 sql_string = f"({sql_string})"
191 return sql_string, result_params
192
193 def get_group_by_cols(self) -> list[Any]:
194 cols = []
195 for child in self.children:
196 cols.extend(child.get_group_by_cols())
197 return cols
198
199 def get_source_expressions(self) -> list[Any]:
200 return self.children[:]
201
202 def set_source_expressions(self, children: list[Any]) -> None:
203 assert len(children) == len(self.children)
204 self.children = children
205
206 def relabel_aliases(self, change_map: dict[str, str]) -> None:
207 """
208 Relabel the alias values of any children. 'change_map' is a dictionary
209 mapping old (current) alias values to the new values.
210 """
211 for pos, child in enumerate(self.children):
212 if hasattr(child, "relabel_aliases"):
213 # For example another WhereNode
214 child.relabel_aliases(change_map)
215 elif hasattr(child, "relabeled_clone"):
216 self.children[pos] = child.relabeled_clone(change_map)
217
218 def clone(self) -> WhereNode:
219 clone = self.create(connector=self.connector, negated=self.negated)
220 for child in self.children:
221 if hasattr(child, "clone"):
222 child = child.clone()
223 clone.children.append(child)
224 return clone
225
226 def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
227 clone = self.clone()
228 clone.relabel_aliases(change_map)
229 return clone
230
231 def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
232 if replacement := replacements.get(self):
233 return replacement
234 clone = self.create(connector=self.connector, negated=self.negated)
235 for child in self.children:
236 clone.children.append(child.replace_expressions(replacements))
237 return clone
238
239 def get_refs(self) -> set[Any]:
240 refs = set()
241 for child in self.children:
242 refs |= child.get_refs()
243 return refs
244
245 @classmethod
246 def _contains_aggregate(cls, obj: Any) -> bool:
247 if isinstance(obj, tree.Node):
248 return any(cls._contains_aggregate(c) for c in obj.children)
249 return obj.contains_aggregate
250
251 @cached_property
252 def contains_aggregate(self) -> bool:
253 return self._contains_aggregate(self)
254
255 @classmethod
256 def _contains_over_clause(cls, obj: Any) -> bool:
257 if isinstance(obj, tree.Node):
258 return any(cls._contains_over_clause(c) for c in obj.children)
259 return obj.contains_over_clause
260
261 @cached_property
262 def contains_over_clause(self) -> bool:
263 return self._contains_over_clause(self)
264
265 @property
266 def is_summary(self) -> bool:
267 return any(child.is_summary for child in self.children)
268
269 @staticmethod
270 def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
271 if isinstance(expr, ResolvableExpression):
272 expr = expr.resolve_expression(query, *args, **kwargs)
273 return expr
274
275 @classmethod
276 def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
277 if hasattr(node, "children"):
278 for child in node.children:
279 cls._resolve_node(child, query, *args, **kwargs)
280 if hasattr(node, "lhs"):
281 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
282 if hasattr(node, "rhs"):
283 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
284
285 def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
286 clone = self.clone()
287 clone._resolve_node(clone, *args, **kwargs)
288 clone.resolved = True
289 return clone
290
291 @cached_property
292 def output_field(self) -> Any:
293 from plain.models.fields import BooleanField
294
295 return BooleanField()
296
297 @property
298 def _output_field_or_none(self) -> Any:
299 return self.output_field
300
301 def select_format(
302 self, compiler: SQLCompiler, sql: str, params: list[Any]
303 ) -> tuple[str, list[Any]]:
304 # Boolean expressions work directly in SELECT
305 return sql, params
306
307 def get_db_converters(self, connection: DatabaseWrapper) -> list[Any]:
308 return self.output_field.get_db_converters(connection)
309
310 def get_lookup(self, lookup: str) -> type[Lookup] | None:
311 return self.output_field.get_lookup(lookup)
312
313 def leaves(self) -> Any:
314 for child in self.children:
315 if isinstance(child, WhereNode):
316 yield from child.leaves()
317 else:
318 yield child
319
320
321class NothingNode:
322 """A node that matches nothing."""
323
324 contains_aggregate = False
325 contains_over_clause = False
326
327 def as_sql(
328 self,
329 compiler: SQLCompiler | None = None,
330 connection: DatabaseWrapper | None = None,
331 ) -> tuple[str, list[Any]]:
332 raise EmptyResultSet
333
334
335class ExtraWhere:
336 # The contents are a black box - assume no aggregates or windows are used.
337 contains_aggregate = False
338 contains_over_clause = False
339
340 def __init__(self, sqls: list[str], params: list[Any] | None):
341 self.sqls = sqls
342 self.params = params
343
344 def as_sql(
345 self,
346 compiler: SQLCompiler | None = None,
347 connection: DatabaseWrapper | None = None,
348 ) -> tuple[str, list[Any]]:
349 sqls = [f"({sql})" for sql in self.sqls]
350 return " AND ".join(sqls), list(self.params or ())
351
352
353class SubqueryConstraint:
354 # Even if aggregates or windows would be used in a subquery,
355 # the outer query isn't interested about those.
356 contains_aggregate = False
357 contains_over_clause = False
358
359 def __init__(
360 self, alias: str, columns: list[str], targets: list[Any], query_object: Any
361 ):
362 self.alias = alias
363 self.columns = columns
364 self.targets = targets
365 query_object.clear_ordering(clear_default=True)
366 self.query_object = query_object
367
368 def as_sql(
369 self, compiler: SQLCompiler, connection: DatabaseWrapper
370 ) -> tuple[str, list[Any]]:
371 query = self.query_object
372 query.set_values(self.targets)
373 query_compiler = query.get_compiler()
374 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)