Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.exceptions import FieldError, FullResultSet
  6from plain.models.expressions import (
  7    Case,
  8    Func,
  9    ResolvableExpression,
 10    Star,
 11    Value,
 12    When,
 13)
 14from plain.models.fields import IntegerField
 15from plain.models.functions.comparison import Coalesce
 16from plain.models.functions.mixins import (
 17    FixDurationInputMixin,
 18    NumericOutputFieldMixin,
 19)
 20
 21if TYPE_CHECKING:
 22    from collections.abc import Sequence
 23
 24    from plain.models.backends.base.base import BaseDatabaseWrapper
 25    from plain.models.expressions import Expression
 26    from plain.models.query_utils import Q
 27    from plain.models.sql.compiler import SQLCompiler
 28
 29
 30__all__ = [
 31    "Aggregate",
 32    "Avg",
 33    "Count",
 34    "Max",
 35    "Min",
 36    "StdDev",
 37    "Sum",
 38    "Variance",
 39]
 40
 41
 42class Aggregate(Func):
 43    template = "%(function)s(%(distinct)s%(expressions)s)"
 44    contains_aggregate = True
 45    name = None
 46    filter_template = "%s FILTER (WHERE %%(filter)s)"
 47    window_compatible = True
 48    allow_distinct = False
 49    empty_result_set_value = None
 50
 51    def __init__(
 52        self,
 53        *expressions: Any,
 54        distinct: bool = False,
 55        filter: Q | Expression | None = None,
 56        default: Any = None,
 57        **extra: Any,
 58    ) -> None:
 59        if distinct and not self.allow_distinct:
 60            raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
 61        if default is not None and self.empty_result_set_value is not None:
 62            raise TypeError(f"{self.__class__.__name__} does not allow default.")
 63        self.distinct = distinct
 64        self.filter = filter
 65        self.default = default
 66        super().__init__(*expressions, **extra)
 67
 68    def get_source_fields(self) -> list[Any]:
 69        # Don't return the filter expression since it's not a source field.
 70        return [e._output_field_or_none for e in super().get_source_expressions()]
 71
 72    def get_source_expressions(self) -> list[Expression]:
 73        source_expressions = super().get_source_expressions()
 74        if self.filter:
 75            return source_expressions + [self.filter]
 76        return source_expressions
 77
 78    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 79        exprs_list = list(exprs)
 80        self.filter = self.filter and exprs_list.pop()
 81        super().set_source_expressions(exprs_list)
 82
 83    def resolve_expression(  # type: ignore[override]
 84        self,
 85        query: Any = None,
 86        allow_joins: bool = True,
 87        reuse: Any = None,
 88        summarize: bool = False,
 89        for_save: bool = False,
 90    ) -> Expression:
 91        # Aggregates are not allowed in UPDATE queries, so ignore for_save
 92        c = super().resolve_expression(query, allow_joins, reuse, summarize)
 93        if c.filter is not None:
 94            c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
 95        if not summarize:
 96            # Call Aggregate.get_source_expressions() to avoid
 97            # returning self.filter and including that in this loop.
 98            expressions = super(Aggregate, c).get_source_expressions()
 99            for index, expr in enumerate(expressions):
100                if expr.contains_aggregate:
101                    before_resolved = self.get_source_expressions()[index]
102                    name = (
103                        before_resolved.name
104                        if hasattr(before_resolved, "name")
105                        else repr(before_resolved)
106                    )
107                    raise FieldError(
108                        f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
109                    )
110        if (default := c.default) is None:
111            return c
112        if isinstance(default, ResolvableExpression):
113            default = default.resolve_expression(query, allow_joins, reuse, summarize)
114            if default._output_field_or_none is None:
115                default.output_field = c._output_field_or_none
116        else:
117            default = Value(default, c._output_field_or_none)
118        c.default = None  # Reset the default argument before wrapping.
119        coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
120        coalesce.is_summary = c.is_summary
121        return coalesce
122
123    @property
124    def default_alias(self) -> str:
125        expressions = self.get_source_expressions()
126        if len(expressions) == 1 and hasattr(expressions[0], "name"):
127            if self.name is None:
128                raise TypeError("Aggregate subclasses must define a name")
129            return f"{expressions[0].name}__{self.name.lower()}"
130        raise TypeError("Complex expressions require an alias")
131
132    def get_group_by_cols(self) -> list[Any]:
133        return []
134
135    def as_sql(
136        self,
137        compiler: SQLCompiler,
138        connection: BaseDatabaseWrapper,
139        function: str | None = None,
140        template: str | None = None,
141        arg_joiner: str | None = None,
142        **extra_context: Any,
143    ) -> tuple[str, list[Any]]:
144        extra_context["distinct"] = "DISTINCT " if self.distinct else ""
145        if self.filter is not None:
146            if connection.features.supports_aggregate_filter_clause:
147                try:
148                    filter_sql, filter_params = self.filter.as_sql(compiler, connection)  # type: ignore[possibly-missing-attribute]
149                except FullResultSet:
150                    pass
151                else:
152                    filter_template = self.filter_template % extra_context.get(
153                        "template", template or self.template
154                    )
155                    sql, params = super().as_sql(
156                        compiler,
157                        connection,
158                        function=function,
159                        template=filter_template,
160                        arg_joiner=arg_joiner,
161                        filter=filter_sql,
162                        **extra_context,
163                    )
164                    return sql, [*params, *filter_params]
165            else:
166                copy = self.copy()
167                copy.filter = None
168                source_expressions = copy.get_source_expressions()
169                condition = When(self.filter, then=source_expressions[0])
170                copy.set_source_expressions([Case(condition)] + source_expressions[1:])
171                return super(Aggregate, copy).as_sql(
172                    compiler,
173                    connection,
174                    function=function,
175                    template=template,
176                    arg_joiner=arg_joiner,
177                    **extra_context,
178                )
179        return super().as_sql(
180            compiler,
181            connection,
182            function=function,
183            template=template,
184            arg_joiner=arg_joiner,
185            **extra_context,
186        )
187
188    def _get_repr_options(self) -> dict[str, Any]:
189        options = super()._get_repr_options()
190        if self.distinct:
191            options["distinct"] = self.distinct
192        if self.filter:
193            options["filter"] = self.filter
194        return options
195
196
197class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
198    function = "AVG"
199    name = "Avg"
200    allow_distinct = True
201
202
203class Count(Aggregate):
204    function = "COUNT"
205    name = "Count"
206    output_field = IntegerField()
207    allow_distinct = True
208    empty_result_set_value = 0
209
210    def __init__(
211        self, expression: Any, filter: Q | Expression | None = None, **extra: Any
212    ) -> None:
213        if expression == "*":
214            expression = Star()
215        if isinstance(expression, Star) and filter is not None:
216            raise ValueError("Star cannot be used with filter. Please specify a field.")
217        super().__init__(expression, filter=filter, **extra)
218
219
220class Max(Aggregate):
221    function = "MAX"
222    name = "Max"
223
224
225class Min(Aggregate):
226    function = "MIN"
227    name = "Min"
228
229
230class StdDev(NumericOutputFieldMixin, Aggregate):
231    name = "StdDev"
232
233    def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
234        self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
235        super().__init__(expression, **extra)
236
237    def _get_repr_options(self) -> dict[str, Any]:
238        return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
239
240
241class Sum(FixDurationInputMixin, Aggregate):
242    function = "SUM"
243    name = "Sum"
244    allow_distinct = True
245
246
247class Variance(NumericOutputFieldMixin, Aggregate):
248    name = "Variance"
249
250    def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
251        self.function = "VAR_SAMP" if sample else "VAR_POP"
252        super().__init__(expression, **extra)
253
254    def _get_repr_options(self) -> dict[str, Any]:
255        return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}