1"""
  2Code to manage the creation and SQL rendering of 'where' constraints.
  3"""
  4
  5from __future__ import annotations
  6
  7import operator
  8from functools import cached_property, reduce
  9from typing import TYPE_CHECKING, Any
 10
 11from plain.models.exceptions import EmptyResultSet, FullResultSet
 12from plain.models.expressions import Case, ResolvableExpression, When
 13from plain.models.lookups import Exact
 14from plain.utils import tree
 15
 16if TYPE_CHECKING:
 17    from plain.models.backends.base.base import BaseDatabaseWrapper
 18    from plain.models.lookups import Lookup
 19    from plain.models.sql.compiler import SQLCompiler
 20
 21# Connection types
 22AND = "AND"
 23OR = "OR"
 24XOR = "XOR"
 25
 26
 27class WhereNode(tree.Node):
 28    """
 29    An SQL WHERE clause.
 30
 31    The class is tied to the Query class that created it (in order to create
 32    the correct SQL).
 33
 34    A child is usually an expression producing boolean values. Most likely the
 35    expression is a Lookup instance.
 36
 37    However, a child could also be any class with as_sql() and either
 38    relabeled_clone() method or relabel_aliases() and clone() methods and
 39    contains_aggregate attribute.
 40    """
 41
 42    default = AND
 43    resolved = False
 44    conditional = True
 45
 46    def split_having_qualify(
 47        self, negated: bool = False, must_group_by: bool = False
 48    ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
 49        """
 50        Return three possibly None nodes: one for those parts of self that
 51        should be included in the WHERE clause, one for those parts of self
 52        that must be included in the HAVING clause, and one for those parts
 53        that refer to window functions.
 54        """
 55        if not self.contains_aggregate and not self.contains_over_clause:
 56            return self, None, None
 57        in_negated = negated ^ self.negated
 58        # Whether or not children must be connected in the same filtering
 59        # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
 60        must_remain_connected = (
 61            (in_negated and self.connector == AND)
 62            or (not in_negated and self.connector == OR)
 63            or self.connector == XOR
 64        )
 65        if (
 66            must_remain_connected
 67            and self.contains_aggregate
 68            and not self.contains_over_clause
 69        ):
 70            # It's must cheaper to short-circuit and stash everything in the
 71            # HAVING clause than split children if possible.
 72            return None, self, None
 73        where_parts = []
 74        having_parts = []
 75        qualify_parts = []
 76        for c in self.children:
 77            if hasattr(c, "split_having_qualify"):
 78                where_part, having_part, qualify_part = c.split_having_qualify(
 79                    in_negated, must_group_by
 80                )
 81                if where_part is not None:
 82                    where_parts.append(where_part)
 83                if having_part is not None:
 84                    having_parts.append(having_part)
 85                if qualify_part is not None:
 86                    qualify_parts.append(qualify_part)
 87            elif c.contains_over_clause:
 88                qualify_parts.append(c)
 89            elif c.contains_aggregate:
 90                having_parts.append(c)
 91            else:
 92                where_parts.append(c)
 93        if must_remain_connected and qualify_parts:
 94            # Disjunctive heterogeneous predicates can be pushed down to
 95            # qualify as long as no conditional aggregation is involved.
 96            if not where_parts or (where_parts and not must_group_by):
 97                return None, None, self
 98            elif where_parts:
 99                # In theory this should only be enforced when dealing with
100                # where_parts containing predicates against multi-valued
101                # relationships that could affect aggregation results but this
102                # is complex to infer properly.
103                raise NotImplementedError(
104                    "Heterogeneous disjunctive predicates against window functions are "
105                    "not implemented when performing conditional aggregation."
106                )
107        where_node = (
108            self.create(where_parts, self.connector, self.negated)
109            if where_parts
110            else None
111        )
112        having_node = (
113            self.create(having_parts, self.connector, self.negated)
114            if having_parts
115            else None
116        )
117        qualify_node = (
118            self.create(qualify_parts, self.connector, self.negated)
119            if qualify_parts
120            else None
121        )
122        return where_node, having_node, qualify_node
123
124    def as_sql(
125        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
126    ) -> tuple[str, list[Any]]:
127        """
128        Return the SQL version of the where clause and the value to be
129        substituted in. Return '', [] if this node matches everything,
130        None, [] if this node is empty, and raise EmptyResultSet if this
131        node can't match anything.
132        """
133        result = []
134        result_params = []
135        if self.connector == AND:
136            full_needed, empty_needed = len(self.children), 1
137        else:
138            full_needed, empty_needed = 1, len(self.children)
139
140        if self.connector == XOR and not connection.features.supports_logical_xor:
141            # Convert if the database doesn't support XOR:
142            #   a XOR b XOR c XOR ...
143            # to:
144            #   (a OR b OR c OR ...) AND (a + b + c + ...) == 1
145            lhs = self.__class__(self.children, OR)
146            rhs_sum = reduce(
147                operator.add,
148                (Case(When(c, then=1), default=0) for c in self.children),
149            )
150            rhs = Exact(1, rhs_sum)
151            return self.__class__([lhs, rhs], AND, self.negated).as_sql(
152                compiler, connection
153            )
154
155        for child in self.children:
156            try:
157                sql, params = compiler.compile(child)
158            except EmptyResultSet:
159                empty_needed -= 1
160            except FullResultSet:
161                full_needed -= 1
162            else:
163                if sql:
164                    result.append(sql)
165                    result_params.extend(params)
166                else:
167                    full_needed -= 1
168            # Check if this node matches nothing or everything.
169            # First check the amount of full nodes and empty nodes
170            # to make this node empty/full.
171            # Now, check if this node is full/empty using the
172            # counts.
173            if empty_needed == 0:
174                if self.negated:
175                    raise FullResultSet
176                else:
177                    raise EmptyResultSet
178            if full_needed == 0:
179                if self.negated:
180                    raise EmptyResultSet
181                else:
182                    raise FullResultSet
183        conn = f" {self.connector} "
184        sql_string = conn.join(result)
185        if not sql_string:
186            raise FullResultSet
187        if self.negated:
188            # Some backends (Oracle at least) need parentheses around the inner
189            # SQL in the negated case, even if the inner SQL contains just a
190            # single expression.
191            sql_string = f"NOT ({sql_string})"
192        elif len(result) > 1 or self.resolved:
193            sql_string = f"({sql_string})"
194        return sql_string, result_params
195
196    def get_group_by_cols(self) -> list[Any]:
197        cols = []
198        for child in self.children:
199            cols.extend(child.get_group_by_cols())
200        return cols
201
202    def get_source_expressions(self) -> list[Any]:
203        return self.children[:]
204
205    def set_source_expressions(self, children: list[Any]) -> None:
206        assert len(children) == len(self.children)
207        self.children = children
208
209    def relabel_aliases(self, change_map: dict[str, str]) -> None:
210        """
211        Relabel the alias values of any children. 'change_map' is a dictionary
212        mapping old (current) alias values to the new values.
213        """
214        for pos, child in enumerate(self.children):
215            if hasattr(child, "relabel_aliases"):
216                # For example another WhereNode
217                child.relabel_aliases(change_map)
218            elif hasattr(child, "relabeled_clone"):
219                self.children[pos] = child.relabeled_clone(change_map)
220
221    def clone(self) -> WhereNode:
222        clone = self.create(connector=self.connector, negated=self.negated)
223        for child in self.children:
224            if hasattr(child, "clone"):
225                child = child.clone()
226            clone.children.append(child)
227        return clone
228
229    def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
230        clone = self.clone()
231        clone.relabel_aliases(change_map)
232        return clone
233
234    def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
235        if replacement := replacements.get(self):
236            return replacement
237        clone = self.create(connector=self.connector, negated=self.negated)
238        for child in self.children:
239            clone.children.append(child.replace_expressions(replacements))
240        return clone
241
242    def get_refs(self) -> set[Any]:
243        refs = set()
244        for child in self.children:
245            refs |= child.get_refs()
246        return refs
247
248    @classmethod
249    def _contains_aggregate(cls, obj: Any) -> bool:
250        if isinstance(obj, tree.Node):
251            return any(cls._contains_aggregate(c) for c in obj.children)
252        return obj.contains_aggregate
253
254    @cached_property
255    def contains_aggregate(self) -> bool:
256        return self._contains_aggregate(self)
257
258    @classmethod
259    def _contains_over_clause(cls, obj: Any) -> bool:
260        if isinstance(obj, tree.Node):
261            return any(cls._contains_over_clause(c) for c in obj.children)
262        return obj.contains_over_clause
263
264    @cached_property
265    def contains_over_clause(self) -> bool:
266        return self._contains_over_clause(self)
267
268    @property
269    def is_summary(self) -> bool:
270        return any(child.is_summary for child in self.children)
271
272    @staticmethod
273    def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
274        if isinstance(expr, ResolvableExpression):
275            expr = expr.resolve_expression(query, *args, **kwargs)
276        return expr
277
278    @classmethod
279    def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
280        if hasattr(node, "children"):
281            for child in node.children:
282                cls._resolve_node(child, query, *args, **kwargs)
283        if hasattr(node, "lhs"):
284            node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
285        if hasattr(node, "rhs"):
286            node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
287
288    def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
289        clone = self.clone()
290        clone._resolve_node(clone, *args, **kwargs)
291        clone.resolved = True
292        return clone
293
294    @cached_property
295    def output_field(self) -> Any:
296        from plain.models.fields import BooleanField
297
298        return BooleanField()
299
300    @property
301    def _output_field_or_none(self) -> Any:
302        return self.output_field
303
304    def select_format(
305        self, compiler: SQLCompiler, sql: str, params: list[Any]
306    ) -> tuple[str, list[Any]]:
307        # Wrap filters with a CASE WHEN expression if a database backend
308        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
309        # BY list.
310        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
311            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
312        return sql, params
313
314    def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Any]:
315        return self.output_field.get_db_converters(connection)
316
317    def get_lookup(self, lookup: str) -> type[Lookup] | None:
318        return self.output_field.get_lookup(lookup)
319
320    def leaves(self) -> Any:
321        for child in self.children:
322            if isinstance(child, WhereNode):
323                yield from child.leaves()
324            else:
325                yield child
326
327
328class NothingNode:
329    """A node that matches nothing."""
330
331    contains_aggregate = False
332    contains_over_clause = False
333
334    def as_sql(
335        self,
336        compiler: SQLCompiler | None = None,
337        connection: BaseDatabaseWrapper | None = None,
338    ) -> tuple[str, list[Any]]:
339        raise EmptyResultSet
340
341
342class ExtraWhere:
343    # The contents are a black box - assume no aggregates or windows are used.
344    contains_aggregate = False
345    contains_over_clause = False
346
347    def __init__(self, sqls: list[str], params: list[Any] | None):
348        self.sqls = sqls
349        self.params = params
350
351    def as_sql(
352        self,
353        compiler: SQLCompiler | None = None,
354        connection: BaseDatabaseWrapper | None = None,
355    ) -> tuple[str, list[Any]]:
356        sqls = [f"({sql})" for sql in self.sqls]
357        return " AND ".join(sqls), list(self.params or ())
358
359
360class SubqueryConstraint:
361    # Even if aggregates or windows would be used in a subquery,
362    # the outer query isn't interested about those.
363    contains_aggregate = False
364    contains_over_clause = False
365
366    def __init__(
367        self, alias: str, columns: list[str], targets: list[Any], query_object: Any
368    ):
369        self.alias = alias
370        self.columns = columns
371        self.targets = targets
372        query_object.clear_ordering(clear_default=True)
373        self.query_object = query_object
374
375    def as_sql(
376        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
377    ) -> tuple[str, list[Any]]:
378        query = self.query_object
379        query.set_values(self.targets)
380        query_compiler = query.get_compiler()
381        return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)