Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.exceptions import FieldError, FullResultSet
  6from plain.models.expressions import Case, Func, Star, Value, When
  7from plain.models.fields import IntegerField
  8from plain.models.functions.comparison import Coalesce
  9from plain.models.functions.mixins import (
 10    FixDurationInputMixin,
 11    NumericOutputFieldMixin,
 12)
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.base.base import BaseDatabaseWrapper
 16    from plain.models.expressions import Expression
 17    from plain.models.query_utils import Q
 18    from plain.models.sql.compiler import SQLCompiler
 19
 20__all__ = [
 21    "Aggregate",
 22    "Avg",
 23    "Count",
 24    "Max",
 25    "Min",
 26    "StdDev",
 27    "Sum",
 28    "Variance",
 29]
 30
 31
 32class Aggregate(Func):
 33    template = "%(function)s(%(distinct)s%(expressions)s)"
 34    contains_aggregate = True
 35    name = None
 36    filter_template = "%s FILTER (WHERE %%(filter)s)"
 37    window_compatible = True
 38    allow_distinct = False
 39    empty_result_set_value = None
 40
 41    def __init__(
 42        self,
 43        *expressions: Any,
 44        distinct: bool = False,
 45        filter: Q | Expression | None = None,
 46        default: Any = None,
 47        **extra: Any,
 48    ) -> None:
 49        if distinct and not self.allow_distinct:
 50            raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
 51        if default is not None and self.empty_result_set_value is not None:
 52            raise TypeError(f"{self.__class__.__name__} does not allow default.")
 53        self.distinct = distinct
 54        self.filter = filter
 55        self.default = default
 56        super().__init__(*expressions, **extra)
 57
 58    def get_source_fields(self) -> list[Any]:
 59        # Don't return the filter expression since it's not a source field.
 60        return [e._output_field_or_none for e in super().get_source_expressions()]
 61
 62    def get_source_expressions(self) -> list[Expression]:
 63        source_expressions = super().get_source_expressions()
 64        if self.filter:
 65            return source_expressions + [self.filter]
 66        return source_expressions
 67
 68    def set_source_expressions(self, exprs: list[Expression]) -> list[Expression]:
 69        self.filter = self.filter and exprs.pop()
 70        return super().set_source_expressions(exprs)
 71
 72    def resolve_expression(
 73        self,
 74        query: Any = None,
 75        allow_joins: bool = True,
 76        reuse: Any = None,
 77        summarize: bool = False,
 78        for_save: bool = False,
 79    ) -> Expression:
 80        # Aggregates are not allowed in UPDATE queries, so ignore for_save
 81        c = super().resolve_expression(query, allow_joins, reuse, summarize)
 82        c.filter = c.filter and c.filter.resolve_expression(
 83            query, allow_joins, reuse, summarize
 84        )
 85        if not summarize:
 86            # Call Aggregate.get_source_expressions() to avoid
 87            # returning self.filter and including that in this loop.
 88            expressions = super(Aggregate, c).get_source_expressions()
 89            for index, expr in enumerate(expressions):
 90                if expr.contains_aggregate:
 91                    before_resolved = self.get_source_expressions()[index]
 92                    name = (
 93                        before_resolved.name
 94                        if hasattr(before_resolved, "name")
 95                        else repr(before_resolved)
 96                    )
 97                    raise FieldError(
 98                        f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
 99                    )
100        if (default := c.default) is None:
101            return c
102        if hasattr(default, "resolve_expression"):
103            default = default.resolve_expression(query, allow_joins, reuse, summarize)
104            if default._output_field_or_none is None:
105                default.output_field = c._output_field_or_none
106        else:
107            default = Value(default, c._output_field_or_none)
108        c.default = None  # Reset the default argument before wrapping.
109        coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
110        coalesce.is_summary = c.is_summary
111        return coalesce
112
113    @property
114    def default_alias(self) -> str:
115        expressions = self.get_source_expressions()
116        if len(expressions) == 1 and hasattr(expressions[0], "name"):
117            return f"{expressions[0].name}__{self.name.lower()}"
118        raise TypeError("Complex expressions require an alias")
119
120    def get_group_by_cols(self) -> list[Any]:
121        return []
122
123    def as_sql(
124        self,
125        compiler: SQLCompiler,
126        connection: BaseDatabaseWrapper,
127        **extra_context: Any,
128    ) -> tuple[str, tuple[Any, ...]]:
129        extra_context["distinct"] = "DISTINCT " if self.distinct else ""
130        if self.filter:
131            if connection.features.supports_aggregate_filter_clause:
132                try:
133                    filter_sql, filter_params = self.filter.as_sql(compiler, connection)
134                except FullResultSet:
135                    pass
136                else:
137                    template = self.filter_template % extra_context.get(
138                        "template", self.template
139                    )
140                    sql, params = super().as_sql(
141                        compiler,
142                        connection,
143                        template=template,
144                        filter=filter_sql,
145                        **extra_context,
146                    )
147                    return sql, (*params, *filter_params)
148            else:
149                copy = self.copy()
150                copy.filter = None
151                source_expressions = copy.get_source_expressions()
152                condition = When(self.filter, then=source_expressions[0])
153                copy.set_source_expressions([Case(condition)] + source_expressions[1:])
154                return super(Aggregate, copy).as_sql(
155                    compiler, connection, **extra_context
156                )
157        return super().as_sql(compiler, connection, **extra_context)
158
159    def _get_repr_options(self) -> dict[str, Any]:
160        options = super()._get_repr_options()
161        if self.distinct:
162            options["distinct"] = self.distinct
163        if self.filter:
164            options["filter"] = self.filter
165        return options
166
167
168class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
169    function = "AVG"
170    name = "Avg"
171    allow_distinct = True
172
173
174class Count(Aggregate):
175    function = "COUNT"
176    name = "Count"
177    output_field = IntegerField()
178    allow_distinct = True
179    empty_result_set_value = 0
180
181    def __init__(
182        self, expression: Any, filter: Q | Expression | None = None, **extra: Any
183    ) -> None:
184        if expression == "*":
185            expression = Star()
186        if isinstance(expression, Star) and filter is not None:
187            raise ValueError("Star cannot be used with filter. Please specify a field.")
188        super().__init__(expression, filter=filter, **extra)
189
190
191class Max(Aggregate):
192    function = "MAX"
193    name = "Max"
194
195
196class Min(Aggregate):
197    function = "MIN"
198    name = "Min"
199
200
201class StdDev(NumericOutputFieldMixin, Aggregate):
202    name = "StdDev"
203
204    def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
205        self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
206        super().__init__(expression, **extra)
207
208    def _get_repr_options(self) -> dict[str, Any]:
209        return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
210
211
212class Sum(FixDurationInputMixin, Aggregate):
213    function = "SUM"
214    name = "Sum"
215    allow_distinct = True
216
217
218class Variance(NumericOutputFieldMixin, Aggregate):
219    name = "Variance"
220
221    def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
222        self.function = "VAR_SAMP" if sample else "VAR_POP"
223        super().__init__(expression, **extra)
224
225    def _get_repr_options(self) -> dict[str, Any]:
226        return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}