1"""
  2Code to manage the creation and SQL rendering of 'where' constraints.
  3"""
  4
  5from __future__ import annotations
  6
  7import operator
  8from functools import cached_property, reduce
  9from typing import TYPE_CHECKING, Any
 10
 11from plain.models.exceptions import EmptyResultSet, FullResultSet
 12from plain.models.expressions import Case, ResolvableExpression, When
 13from plain.models.lookups import Exact
 14from plain.utils import tree
 15
 16if TYPE_CHECKING:
 17    from plain.models.lookups import Lookup
 18    from plain.models.postgres.wrapper import DatabaseWrapper
 19    from plain.models.sql.compiler import SQLCompiler
 20
 21# Connection types
 22AND = "AND"
 23OR = "OR"
 24XOR = "XOR"
 25
 26
 27class WhereNode(tree.Node):
 28    """
 29    An SQL WHERE clause.
 30
 31    The class is tied to the Query class that created it (in order to create
 32    the correct SQL).
 33
 34    A child is usually an expression producing boolean values. Most likely the
 35    expression is a Lookup instance.
 36
 37    However, a child could also be any class with as_sql() and either
 38    relabeled_clone() method or relabel_aliases() and clone() methods and
 39    contains_aggregate attribute.
 40    """
 41
 42    default = AND
 43    resolved = False
 44    conditional = True
 45
 46    def split_having_qualify(
 47        self, negated: bool = False, must_group_by: bool = False
 48    ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
 49        """
 50        Return three possibly None nodes: one for those parts of self that
 51        should be included in the WHERE clause, one for those parts of self
 52        that must be included in the HAVING clause, and one for those parts
 53        that refer to window functions.
 54        """
 55        if not self.contains_aggregate and not self.contains_over_clause:
 56            return self, None, None
 57        in_negated = negated ^ self.negated
 58        # Whether or not children must be connected in the same filtering
 59        # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
 60        must_remain_connected = (
 61            (in_negated and self.connector == AND)
 62            or (not in_negated and self.connector == OR)
 63            or self.connector == XOR
 64        )
 65        if (
 66            must_remain_connected
 67            and self.contains_aggregate
 68            and not self.contains_over_clause
 69        ):
 70            # It's must cheaper to short-circuit and stash everything in the
 71            # HAVING clause than split children if possible.
 72            return None, self, None
 73        where_parts = []
 74        having_parts = []
 75        qualify_parts = []
 76        for c in self.children:
 77            if hasattr(c, "split_having_qualify"):
 78                where_part, having_part, qualify_part = c.split_having_qualify(
 79                    in_negated, must_group_by
 80                )
 81                if where_part is not None:
 82                    where_parts.append(where_part)
 83                if having_part is not None:
 84                    having_parts.append(having_part)
 85                if qualify_part is not None:
 86                    qualify_parts.append(qualify_part)
 87            elif c.contains_over_clause:
 88                qualify_parts.append(c)
 89            elif c.contains_aggregate:
 90                having_parts.append(c)
 91            else:
 92                where_parts.append(c)
 93        if must_remain_connected and qualify_parts:
 94            # Disjunctive heterogeneous predicates can be pushed down to
 95            # qualify as long as no conditional aggregation is involved.
 96            if not where_parts or (where_parts and not must_group_by):
 97                return None, None, self
 98            elif where_parts:
 99                # In theory this should only be enforced when dealing with
100                # where_parts containing predicates against multi-valued
101                # relationships that could affect aggregation results but this
102                # is complex to infer properly.
103                raise NotImplementedError(
104                    "Heterogeneous disjunctive predicates against window functions are "
105                    "not implemented when performing conditional aggregation."
106                )
107        where_node = (
108            self.create(where_parts, self.connector, self.negated)
109            if where_parts
110            else None
111        )
112        having_node = (
113            self.create(having_parts, self.connector, self.negated)
114            if having_parts
115            else None
116        )
117        qualify_node = (
118            self.create(qualify_parts, self.connector, self.negated)
119            if qualify_parts
120            else None
121        )
122        return where_node, having_node, qualify_node
123
124    def as_sql(
125        self, compiler: SQLCompiler, connection: DatabaseWrapper
126    ) -> tuple[str, list[Any]]:
127        """
128        Return the SQL version of the where clause and the value to be
129        substituted in. Return '', [] if this node matches everything,
130        None, [] if this node is empty, and raise EmptyResultSet if this
131        node can't match anything.
132        """
133        result = []
134        result_params = []
135        if self.connector == AND:
136            full_needed, empty_needed = len(self.children), 1
137        else:
138            full_needed, empty_needed = 1, len(self.children)
139
140        if self.connector == XOR:
141            # PostgreSQL doesn't have a native XOR operator, so convert:
142            #   a XOR b XOR c XOR ...
143            # to:
144            #   (a OR b OR c OR ...) AND (a + b + c + ...) == 1
145            lhs = self.__class__(self.children, OR)
146            rhs_sum = reduce(
147                operator.add,
148                (Case(When(c, then=1), default=0) for c in self.children),
149            )
150            rhs = Exact(1, rhs_sum)
151            return self.__class__([lhs, rhs], AND, self.negated).as_sql(
152                compiler, connection
153            )
154
155        for child in self.children:
156            try:
157                sql, params = compiler.compile(child)
158            except EmptyResultSet:
159                empty_needed -= 1
160            except FullResultSet:
161                full_needed -= 1
162            else:
163                if sql:
164                    result.append(sql)
165                    result_params.extend(params)
166                else:
167                    full_needed -= 1
168            # Check if this node matches nothing or everything.
169            # First check the amount of full nodes and empty nodes
170            # to make this node empty/full.
171            # Now, check if this node is full/empty using the
172            # counts.
173            if empty_needed == 0:
174                if self.negated:
175                    raise FullResultSet
176                else:
177                    raise EmptyResultSet
178            if full_needed == 0:
179                if self.negated:
180                    raise EmptyResultSet
181                else:
182                    raise FullResultSet
183        conn = f" {self.connector} "
184        sql_string = conn.join(result)
185        if not sql_string:
186            raise FullResultSet
187        if self.negated:
188            sql_string = f"NOT ({sql_string})"
189        elif len(result) > 1 or self.resolved:
190            sql_string = f"({sql_string})"
191        return sql_string, result_params
192
193    def get_group_by_cols(self) -> list[Any]:
194        cols = []
195        for child in self.children:
196            cols.extend(child.get_group_by_cols())
197        return cols
198
199    def get_source_expressions(self) -> list[Any]:
200        return self.children[:]
201
202    def set_source_expressions(self, children: list[Any]) -> None:
203        assert len(children) == len(self.children)
204        self.children = children
205
206    def relabel_aliases(self, change_map: dict[str, str]) -> None:
207        """
208        Relabel the alias values of any children. 'change_map' is a dictionary
209        mapping old (current) alias values to the new values.
210        """
211        for pos, child in enumerate(self.children):
212            if hasattr(child, "relabel_aliases"):
213                # For example another WhereNode
214                child.relabel_aliases(change_map)
215            elif hasattr(child, "relabeled_clone"):
216                self.children[pos] = child.relabeled_clone(change_map)
217
218    def clone(self) -> WhereNode:
219        clone = self.create(connector=self.connector, negated=self.negated)
220        for child in self.children:
221            if hasattr(child, "clone"):
222                child = child.clone()
223            clone.children.append(child)
224        return clone
225
226    def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
227        clone = self.clone()
228        clone.relabel_aliases(change_map)
229        return clone
230
231    def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
232        if replacement := replacements.get(self):
233            return replacement
234        clone = self.create(connector=self.connector, negated=self.negated)
235        for child in self.children:
236            clone.children.append(child.replace_expressions(replacements))
237        return clone
238
239    def get_refs(self) -> set[Any]:
240        refs = set()
241        for child in self.children:
242            refs |= child.get_refs()
243        return refs
244
245    @classmethod
246    def _contains_aggregate(cls, obj: Any) -> bool:
247        if isinstance(obj, tree.Node):
248            return any(cls._contains_aggregate(c) for c in obj.children)
249        return obj.contains_aggregate
250
251    @cached_property
252    def contains_aggregate(self) -> bool:
253        return self._contains_aggregate(self)
254
255    @classmethod
256    def _contains_over_clause(cls, obj: Any) -> bool:
257        if isinstance(obj, tree.Node):
258            return any(cls._contains_over_clause(c) for c in obj.children)
259        return obj.contains_over_clause
260
261    @cached_property
262    def contains_over_clause(self) -> bool:
263        return self._contains_over_clause(self)
264
265    @property
266    def is_summary(self) -> bool:
267        return any(child.is_summary for child in self.children)
268
269    @staticmethod
270    def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
271        if isinstance(expr, ResolvableExpression):
272            expr = expr.resolve_expression(query, *args, **kwargs)
273        return expr
274
275    @classmethod
276    def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
277        if hasattr(node, "children"):
278            for child in node.children:
279                cls._resolve_node(child, query, *args, **kwargs)
280        if hasattr(node, "lhs"):
281            node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
282        if hasattr(node, "rhs"):
283            node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
284
285    def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
286        clone = self.clone()
287        clone._resolve_node(clone, *args, **kwargs)
288        clone.resolved = True
289        return clone
290
291    @cached_property
292    def output_field(self) -> Any:
293        from plain.models.fields import BooleanField
294
295        return BooleanField()
296
297    @property
298    def _output_field_or_none(self) -> Any:
299        return self.output_field
300
301    def select_format(
302        self, compiler: SQLCompiler, sql: str, params: list[Any]
303    ) -> tuple[str, list[Any]]:
304        # Boolean expressions work directly in SELECT
305        return sql, params
306
307    def get_db_converters(self, connection: DatabaseWrapper) -> list[Any]:
308        return self.output_field.get_db_converters(connection)
309
310    def get_lookup(self, lookup: str) -> type[Lookup] | None:
311        return self.output_field.get_lookup(lookup)
312
313    def leaves(self) -> Any:
314        for child in self.children:
315            if isinstance(child, WhereNode):
316                yield from child.leaves()
317            else:
318                yield child
319
320
321class NothingNode:
322    """A node that matches nothing."""
323
324    contains_aggregate = False
325    contains_over_clause = False
326
327    def as_sql(
328        self,
329        compiler: SQLCompiler | None = None,
330        connection: DatabaseWrapper | None = None,
331    ) -> tuple[str, list[Any]]:
332        raise EmptyResultSet
333
334
335class ExtraWhere:
336    # The contents are a black box - assume no aggregates or windows are used.
337    contains_aggregate = False
338    contains_over_clause = False
339
340    def __init__(self, sqls: list[str], params: list[Any] | None):
341        self.sqls = sqls
342        self.params = params
343
344    def as_sql(
345        self,
346        compiler: SQLCompiler | None = None,
347        connection: DatabaseWrapper | None = None,
348    ) -> tuple[str, list[Any]]:
349        sqls = [f"({sql})" for sql in self.sqls]
350        return " AND ".join(sqls), list(self.params or ())
351
352
353class SubqueryConstraint:
354    # Even if aggregates or windows would be used in a subquery,
355    # the outer query isn't interested about those.
356    contains_aggregate = False
357    contains_over_clause = False
358
359    def __init__(
360        self, alias: str, columns: list[str], targets: list[Any], query_object: Any
361    ):
362        self.alias = alias
363        self.columns = columns
364        self.targets = targets
365        query_object.clear_ordering(clear_default=True)
366        self.query_object = query_object
367
368    def as_sql(
369        self, compiler: SQLCompiler, connection: DatabaseWrapper
370    ) -> tuple[str, list[Any]]:
371        query = self.query_object
372        query.set_values(self.targets)
373        query_compiler = query.get_compiler()
374        return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)