Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Code to manage the creation and SQL rendering of 'where' constraints.
  3"""
  4import operator
  5from functools import reduce
  6
  7from plain.exceptions import EmptyResultSet, FullResultSet
  8from plain.models.expressions import Case, When
  9from plain.models.lookups import Exact
 10from plain.utils import tree
 11from plain.utils.functional import cached_property
 12
 13# Connection types
 14AND = "AND"
 15OR = "OR"
 16XOR = "XOR"
 17
 18
 19class WhereNode(tree.Node):
 20    """
 21    An SQL WHERE clause.
 22
 23    The class is tied to the Query class that created it (in order to create
 24    the correct SQL).
 25
 26    A child is usually an expression producing boolean values. Most likely the
 27    expression is a Lookup instance.
 28
 29    However, a child could also be any class with as_sql() and either
 30    relabeled_clone() method or relabel_aliases() and clone() methods and
 31    contains_aggregate attribute.
 32    """
 33
 34    default = AND
 35    resolved = False
 36    conditional = True
 37
 38    def split_having_qualify(self, negated=False, must_group_by=False):
 39        """
 40        Return three possibly None nodes: one for those parts of self that
 41        should be included in the WHERE clause, one for those parts of self
 42        that must be included in the HAVING clause, and one for those parts
 43        that refer to window functions.
 44        """
 45        if not self.contains_aggregate and not self.contains_over_clause:
 46            return self, None, None
 47        in_negated = negated ^ self.negated
 48        # Whether or not children must be connected in the same filtering
 49        # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
 50        must_remain_connected = (
 51            (in_negated and self.connector == AND)
 52            or (not in_negated and self.connector == OR)
 53            or self.connector == XOR
 54        )
 55        if (
 56            must_remain_connected
 57            and self.contains_aggregate
 58            and not self.contains_over_clause
 59        ):
 60            # It's must cheaper to short-circuit and stash everything in the
 61            # HAVING clause than split children if possible.
 62            return None, self, None
 63        where_parts = []
 64        having_parts = []
 65        qualify_parts = []
 66        for c in self.children:
 67            if hasattr(c, "split_having_qualify"):
 68                where_part, having_part, qualify_part = c.split_having_qualify(
 69                    in_negated, must_group_by
 70                )
 71                if where_part is not None:
 72                    where_parts.append(where_part)
 73                if having_part is not None:
 74                    having_parts.append(having_part)
 75                if qualify_part is not None:
 76                    qualify_parts.append(qualify_part)
 77            elif c.contains_over_clause:
 78                qualify_parts.append(c)
 79            elif c.contains_aggregate:
 80                having_parts.append(c)
 81            else:
 82                where_parts.append(c)
 83        if must_remain_connected and qualify_parts:
 84            # Disjunctive heterogeneous predicates can be pushed down to
 85            # qualify as long as no conditional aggregation is involved.
 86            if not where_parts or (where_parts and not must_group_by):
 87                return None, None, self
 88            elif where_parts:
 89                # In theory this should only be enforced when dealing with
 90                # where_parts containing predicates against multi-valued
 91                # relationships that could affect aggregation results but this
 92                # is complex to infer properly.
 93                raise NotImplementedError(
 94                    "Heterogeneous disjunctive predicates against window functions are "
 95                    "not implemented when performing conditional aggregation."
 96                )
 97        where_node = (
 98            self.create(where_parts, self.connector, self.negated)
 99            if where_parts
100            else None
101        )
102        having_node = (
103            self.create(having_parts, self.connector, self.negated)
104            if having_parts
105            else None
106        )
107        qualify_node = (
108            self.create(qualify_parts, self.connector, self.negated)
109            if qualify_parts
110            else None
111        )
112        return where_node, having_node, qualify_node
113
114    def as_sql(self, compiler, connection):
115        """
116        Return the SQL version of the where clause and the value to be
117        substituted in. Return '', [] if this node matches everything,
118        None, [] if this node is empty, and raise EmptyResultSet if this
119        node can't match anything.
120        """
121        result = []
122        result_params = []
123        if self.connector == AND:
124            full_needed, empty_needed = len(self.children), 1
125        else:
126            full_needed, empty_needed = 1, len(self.children)
127
128        if self.connector == XOR and not connection.features.supports_logical_xor:
129            # Convert if the database doesn't support XOR:
130            #   a XOR b XOR c XOR ...
131            # to:
132            #   (a OR b OR c OR ...) AND (a + b + c + ...) == 1
133            lhs = self.__class__(self.children, OR)
134            rhs_sum = reduce(
135                operator.add,
136                (Case(When(c, then=1), default=0) for c in self.children),
137            )
138            rhs = Exact(1, rhs_sum)
139            return self.__class__([lhs, rhs], AND, self.negated).as_sql(
140                compiler, connection
141            )
142
143        for child in self.children:
144            try:
145                sql, params = compiler.compile(child)
146            except EmptyResultSet:
147                empty_needed -= 1
148            except FullResultSet:
149                full_needed -= 1
150            else:
151                if sql:
152                    result.append(sql)
153                    result_params.extend(params)
154                else:
155                    full_needed -= 1
156            # Check if this node matches nothing or everything.
157            # First check the amount of full nodes and empty nodes
158            # to make this node empty/full.
159            # Now, check if this node is full/empty using the
160            # counts.
161            if empty_needed == 0:
162                if self.negated:
163                    raise FullResultSet
164                else:
165                    raise EmptyResultSet
166            if full_needed == 0:
167                if self.negated:
168                    raise EmptyResultSet
169                else:
170                    raise FullResultSet
171        conn = " %s " % self.connector
172        sql_string = conn.join(result)
173        if not sql_string:
174            raise FullResultSet
175        if self.negated:
176            # Some backends (Oracle at least) need parentheses around the inner
177            # SQL in the negated case, even if the inner SQL contains just a
178            # single expression.
179            sql_string = "NOT (%s)" % sql_string
180        elif len(result) > 1 or self.resolved:
181            sql_string = "(%s)" % sql_string
182        return sql_string, result_params
183
184    def get_group_by_cols(self):
185        cols = []
186        for child in self.children:
187            cols.extend(child.get_group_by_cols())
188        return cols
189
190    def get_source_expressions(self):
191        return self.children[:]
192
193    def set_source_expressions(self, children):
194        assert len(children) == len(self.children)
195        self.children = children
196
197    def relabel_aliases(self, change_map):
198        """
199        Relabel the alias values of any children. 'change_map' is a dictionary
200        mapping old (current) alias values to the new values.
201        """
202        for pos, child in enumerate(self.children):
203            if hasattr(child, "relabel_aliases"):
204                # For example another WhereNode
205                child.relabel_aliases(change_map)
206            elif hasattr(child, "relabeled_clone"):
207                self.children[pos] = child.relabeled_clone(change_map)
208
209    def clone(self):
210        clone = self.create(connector=self.connector, negated=self.negated)
211        for child in self.children:
212            if hasattr(child, "clone"):
213                child = child.clone()
214            clone.children.append(child)
215        return clone
216
217    def relabeled_clone(self, change_map):
218        clone = self.clone()
219        clone.relabel_aliases(change_map)
220        return clone
221
222    def replace_expressions(self, replacements):
223        if replacement := replacements.get(self):
224            return replacement
225        clone = self.create(connector=self.connector, negated=self.negated)
226        for child in self.children:
227            clone.children.append(child.replace_expressions(replacements))
228        return clone
229
230    def get_refs(self):
231        refs = set()
232        for child in self.children:
233            refs |= child.get_refs()
234        return refs
235
236    @classmethod
237    def _contains_aggregate(cls, obj):
238        if isinstance(obj, tree.Node):
239            return any(cls._contains_aggregate(c) for c in obj.children)
240        return obj.contains_aggregate
241
242    @cached_property
243    def contains_aggregate(self):
244        return self._contains_aggregate(self)
245
246    @classmethod
247    def _contains_over_clause(cls, obj):
248        if isinstance(obj, tree.Node):
249            return any(cls._contains_over_clause(c) for c in obj.children)
250        return obj.contains_over_clause
251
252    @cached_property
253    def contains_over_clause(self):
254        return self._contains_over_clause(self)
255
256    @property
257    def is_summary(self):
258        return any(child.is_summary for child in self.children)
259
260    @staticmethod
261    def _resolve_leaf(expr, query, *args, **kwargs):
262        if hasattr(expr, "resolve_expression"):
263            expr = expr.resolve_expression(query, *args, **kwargs)
264        return expr
265
266    @classmethod
267    def _resolve_node(cls, node, query, *args, **kwargs):
268        if hasattr(node, "children"):
269            for child in node.children:
270                cls._resolve_node(child, query, *args, **kwargs)
271        if hasattr(node, "lhs"):
272            node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
273        if hasattr(node, "rhs"):
274            node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
275
276    def resolve_expression(self, *args, **kwargs):
277        clone = self.clone()
278        clone._resolve_node(clone, *args, **kwargs)
279        clone.resolved = True
280        return clone
281
282    @cached_property
283    def output_field(self):
284        from plain.models.fields import BooleanField
285
286        return BooleanField()
287
288    @property
289    def _output_field_or_none(self):
290        return self.output_field
291
292    def select_format(self, compiler, sql, params):
293        # Wrap filters with a CASE WHEN expression if a database backend
294        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
295        # BY list.
296        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
297            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
298        return sql, params
299
300    def get_db_converters(self, connection):
301        return self.output_field.get_db_converters(connection)
302
303    def get_lookup(self, lookup):
304        return self.output_field.get_lookup(lookup)
305
306    def leaves(self):
307        for child in self.children:
308            if isinstance(child, WhereNode):
309                yield from child.leaves()
310            else:
311                yield child
312
313
314class NothingNode:
315    """A node that matches nothing."""
316
317    contains_aggregate = False
318    contains_over_clause = False
319
320    def as_sql(self, compiler=None, connection=None):
321        raise EmptyResultSet
322
323
324class ExtraWhere:
325    # The contents are a black box - assume no aggregates or windows are used.
326    contains_aggregate = False
327    contains_over_clause = False
328
329    def __init__(self, sqls, params):
330        self.sqls = sqls
331        self.params = params
332
333    def as_sql(self, compiler=None, connection=None):
334        sqls = ["(%s)" % sql for sql in self.sqls]
335        return " AND ".join(sqls), list(self.params or ())
336
337
338class SubqueryConstraint:
339    # Even if aggregates or windows would be used in a subquery,
340    # the outer query isn't interested about those.
341    contains_aggregate = False
342    contains_over_clause = False
343
344    def __init__(self, alias, columns, targets, query_object):
345        self.alias = alias
346        self.columns = columns
347        self.targets = targets
348        query_object.clear_ordering(clear_default=True)
349        self.query_object = query_object
350
351    def as_sql(self, compiler, connection):
352        query = self.query_object
353        query.set_values(self.targets)
354        query_compiler = query.get_compiler(connection=connection)
355        return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)