Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Code to manage the creation and SQL rendering of 'where' constraints.
  3"""
  4
  5from __future__ import annotations
  6
  7import operator
  8from functools import cached_property, reduce
  9from typing import TYPE_CHECKING, Any
 10
 11from plain.models.exceptions import EmptyResultSet, FullResultSet
 12from plain.models.expressions import Case, When
 13from plain.models.lookups import Exact
 14from plain.utils import tree
 15
 16if TYPE_CHECKING:
 17    from plain.models.backends.base.base import BaseDatabaseWrapper
 18    from plain.models.sql.compiler import SQLCompiler
 19
 20# Connection types
 21AND = "AND"
 22OR = "OR"
 23XOR = "XOR"
 24
 25
 26class WhereNode(tree.Node):
 27    """
 28    An SQL WHERE clause.
 29
 30    The class is tied to the Query class that created it (in order to create
 31    the correct SQL).
 32
 33    A child is usually an expression producing boolean values. Most likely the
 34    expression is a Lookup instance.
 35
 36    However, a child could also be any class with as_sql() and either
 37    relabeled_clone() method or relabel_aliases() and clone() methods and
 38    contains_aggregate attribute.
 39    """
 40
 41    default = AND
 42    resolved = False
 43    conditional = True
 44
 45    def split_having_qualify(
 46        self, negated: bool = False, must_group_by: bool = False
 47    ) -> tuple[WhereNode | None, WhereNode | None, WhereNode | None]:
 48        """
 49        Return three possibly None nodes: one for those parts of self that
 50        should be included in the WHERE clause, one for those parts of self
 51        that must be included in the HAVING clause, and one for those parts
 52        that refer to window functions.
 53        """
 54        if not self.contains_aggregate and not self.contains_over_clause:
 55            return self, None, None
 56        in_negated = negated ^ self.negated
 57        # Whether or not children must be connected in the same filtering
 58        # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
 59        must_remain_connected = (
 60            (in_negated and self.connector == AND)
 61            or (not in_negated and self.connector == OR)
 62            or self.connector == XOR
 63        )
 64        if (
 65            must_remain_connected
 66            and self.contains_aggregate
 67            and not self.contains_over_clause
 68        ):
 69            # It's must cheaper to short-circuit and stash everything in the
 70            # HAVING clause than split children if possible.
 71            return None, self, None
 72        where_parts = []
 73        having_parts = []
 74        qualify_parts = []
 75        for c in self.children:
 76            if hasattr(c, "split_having_qualify"):
 77                where_part, having_part, qualify_part = c.split_having_qualify(
 78                    in_negated, must_group_by
 79                )
 80                if where_part is not None:
 81                    where_parts.append(where_part)
 82                if having_part is not None:
 83                    having_parts.append(having_part)
 84                if qualify_part is not None:
 85                    qualify_parts.append(qualify_part)
 86            elif c.contains_over_clause:
 87                qualify_parts.append(c)
 88            elif c.contains_aggregate:
 89                having_parts.append(c)
 90            else:
 91                where_parts.append(c)
 92        if must_remain_connected and qualify_parts:
 93            # Disjunctive heterogeneous predicates can be pushed down to
 94            # qualify as long as no conditional aggregation is involved.
 95            if not where_parts or (where_parts and not must_group_by):
 96                return None, None, self
 97            elif where_parts:
 98                # In theory this should only be enforced when dealing with
 99                # where_parts containing predicates against multi-valued
100                # relationships that could affect aggregation results but this
101                # is complex to infer properly.
102                raise NotImplementedError(
103                    "Heterogeneous disjunctive predicates against window functions are "
104                    "not implemented when performing conditional aggregation."
105                )
106        where_node = (
107            self.create(where_parts, self.connector, self.negated)
108            if where_parts
109            else None
110        )
111        having_node = (
112            self.create(having_parts, self.connector, self.negated)
113            if having_parts
114            else None
115        )
116        qualify_node = (
117            self.create(qualify_parts, self.connector, self.negated)
118            if qualify_parts
119            else None
120        )
121        return where_node, having_node, qualify_node
122
123    def as_sql(
124        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
125    ) -> tuple[str, list[Any]]:
126        """
127        Return the SQL version of the where clause and the value to be
128        substituted in. Return '', [] if this node matches everything,
129        None, [] if this node is empty, and raise EmptyResultSet if this
130        node can't match anything.
131        """
132        result = []
133        result_params = []
134        if self.connector == AND:
135            full_needed, empty_needed = len(self.children), 1
136        else:
137            full_needed, empty_needed = 1, len(self.children)
138
139        if self.connector == XOR and not connection.features.supports_logical_xor:
140            # Convert if the database doesn't support XOR:
141            #   a XOR b XOR c XOR ...
142            # to:
143            #   (a OR b OR c OR ...) AND (a + b + c + ...) == 1
144            lhs = self.__class__(self.children, OR)
145            rhs_sum = reduce(
146                operator.add,
147                (Case(When(c, then=1), default=0) for c in self.children),
148            )
149            rhs = Exact(1, rhs_sum)
150            return self.__class__([lhs, rhs], AND, self.negated).as_sql(
151                compiler, connection
152            )
153
154        for child in self.children:
155            try:
156                sql, params = compiler.compile(child)
157            except EmptyResultSet:
158                empty_needed -= 1
159            except FullResultSet:
160                full_needed -= 1
161            else:
162                if sql:
163                    result.append(sql)
164                    result_params.extend(params)
165                else:
166                    full_needed -= 1
167            # Check if this node matches nothing or everything.
168            # First check the amount of full nodes and empty nodes
169            # to make this node empty/full.
170            # Now, check if this node is full/empty using the
171            # counts.
172            if empty_needed == 0:
173                if self.negated:
174                    raise FullResultSet
175                else:
176                    raise EmptyResultSet
177            if full_needed == 0:
178                if self.negated:
179                    raise EmptyResultSet
180                else:
181                    raise FullResultSet
182        conn = f" {self.connector} "
183        sql_string = conn.join(result)
184        if not sql_string:
185            raise FullResultSet
186        if self.negated:
187            # Some backends (Oracle at least) need parentheses around the inner
188            # SQL in the negated case, even if the inner SQL contains just a
189            # single expression.
190            sql_string = f"NOT ({sql_string})"
191        elif len(result) > 1 or self.resolved:
192            sql_string = f"({sql_string})"
193        return sql_string, result_params
194
195    def get_group_by_cols(self) -> list[Any]:
196        cols = []
197        for child in self.children:
198            cols.extend(child.get_group_by_cols())
199        return cols
200
201    def get_source_expressions(self) -> list[Any]:
202        return self.children[:]
203
204    def set_source_expressions(self, children: list[Any]) -> None:
205        assert len(children) == len(self.children)
206        self.children = children
207
208    def relabel_aliases(self, change_map: dict[str, str]) -> None:
209        """
210        Relabel the alias values of any children. 'change_map' is a dictionary
211        mapping old (current) alias values to the new values.
212        """
213        for pos, child in enumerate(self.children):
214            if hasattr(child, "relabel_aliases"):
215                # For example another WhereNode
216                child.relabel_aliases(change_map)
217            elif hasattr(child, "relabeled_clone"):
218                self.children[pos] = child.relabeled_clone(change_map)
219
220    def clone(self) -> WhereNode:
221        clone = self.create(connector=self.connector, negated=self.negated)
222        for child in self.children:
223            if hasattr(child, "clone"):
224                child = child.clone()
225            clone.children.append(child)
226        return clone
227
228    def relabeled_clone(self, change_map: dict[str, str]) -> WhereNode:
229        clone = self.clone()
230        clone.relabel_aliases(change_map)
231        return clone
232
233    def replace_expressions(self, replacements: dict[Any, Any]) -> WhereNode:
234        if replacement := replacements.get(self):
235            return replacement
236        clone = self.create(connector=self.connector, negated=self.negated)
237        for child in self.children:
238            clone.children.append(child.replace_expressions(replacements))
239        return clone
240
241    def get_refs(self) -> set[Any]:
242        refs = set()
243        for child in self.children:
244            refs |= child.get_refs()
245        return refs
246
247    @classmethod
248    def _contains_aggregate(cls, obj: Any) -> bool:
249        if isinstance(obj, tree.Node):
250            return any(cls._contains_aggregate(c) for c in obj.children)
251        return obj.contains_aggregate
252
253    @cached_property
254    def contains_aggregate(self) -> bool:
255        return self._contains_aggregate(self)
256
257    @classmethod
258    def _contains_over_clause(cls, obj: Any) -> bool:
259        if isinstance(obj, tree.Node):
260            return any(cls._contains_over_clause(c) for c in obj.children)
261        return obj.contains_over_clause
262
263    @cached_property
264    def contains_over_clause(self) -> bool:
265        return self._contains_over_clause(self)
266
267    @property
268    def is_summary(self) -> bool:
269        return any(child.is_summary for child in self.children)
270
271    @staticmethod
272    def _resolve_leaf(expr: Any, query: Any, *args: Any, **kwargs: Any) -> Any:
273        if hasattr(expr, "resolve_expression"):
274            expr = expr.resolve_expression(query, *args, **kwargs)
275        return expr
276
277    @classmethod
278    def _resolve_node(cls, node: Any, query: Any, *args: Any, **kwargs: Any) -> None:
279        if hasattr(node, "children"):
280            for child in node.children:
281                cls._resolve_node(child, query, *args, **kwargs)
282        if hasattr(node, "lhs"):
283            node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
284        if hasattr(node, "rhs"):
285            node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
286
287    def resolve_expression(self, *args: Any, **kwargs: Any) -> WhereNode:
288        clone = self.clone()
289        clone._resolve_node(clone, *args, **kwargs)
290        clone.resolved = True
291        return clone
292
293    @cached_property
294    def output_field(self) -> Any:
295        from plain.models.fields import BooleanField
296
297        return BooleanField()
298
299    @property
300    def _output_field_or_none(self) -> Any:
301        return self.output_field
302
303    def select_format(
304        self, compiler: SQLCompiler, sql: str, params: list[Any]
305    ) -> tuple[str, list[Any]]:
306        # Wrap filters with a CASE WHEN expression if a database backend
307        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
308        # BY list.
309        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
310            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
311        return sql, params
312
313    def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Any]:
314        return self.output_field.get_db_converters(connection)
315
316    def get_lookup(self, lookup: str) -> Any:
317        return self.output_field.get_lookup(lookup)
318
319    def leaves(self) -> Any:
320        for child in self.children:
321            if isinstance(child, WhereNode):
322                yield from child.leaves()
323            else:
324                yield child
325
326
327class NothingNode:
328    """A node that matches nothing."""
329
330    contains_aggregate = False
331    contains_over_clause = False
332
333    def as_sql(
334        self,
335        compiler: SQLCompiler | None = None,
336        connection: BaseDatabaseWrapper | None = None,
337    ) -> tuple[str, list[Any]]:
338        raise EmptyResultSet
339
340
341class ExtraWhere:
342    # The contents are a black box - assume no aggregates or windows are used.
343    contains_aggregate = False
344    contains_over_clause = False
345
346    def __init__(self, sqls: list[str], params: list[Any] | None):
347        self.sqls = sqls
348        self.params = params
349
350    def as_sql(
351        self,
352        compiler: SQLCompiler | None = None,
353        connection: BaseDatabaseWrapper | None = None,
354    ) -> tuple[str, list[Any]]:
355        sqls = [f"({sql})" for sql in self.sqls]
356        return " AND ".join(sqls), list(self.params or ())
357
358
359class SubqueryConstraint:
360    # Even if aggregates or windows would be used in a subquery,
361    # the outer query isn't interested about those.
362    contains_aggregate = False
363    contains_over_clause = False
364
365    def __init__(
366        self, alias: str, columns: list[str], targets: list[Any], query_object: Any
367    ):
368        self.alias = alias
369        self.columns = columns
370        self.targets = targets
371        query_object.clear_ordering(clear_default=True)
372        self.query_object = query_object
373
374    def as_sql(
375        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
376    ) -> tuple[str, list[Any]]:
377        query = self.query_object
378        query.set_values(self.targets)
379        query_compiler = query.get_compiler()
380        return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)