1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.exceptions import FieldError, FullResultSet
  6from plain.models.expressions import (
  7    Func,
  8    ResolvableExpression,
  9    Star,
 10    Value,
 11)
 12from plain.models.fields import IntegerField
 13from plain.models.functions.comparison import Coalesce
 14from plain.models.functions.mixins import NumericOutputFieldMixin
 15
 16if TYPE_CHECKING:
 17    from collections.abc import Sequence
 18
 19    from plain.models.expressions import Expression
 20    from plain.models.postgres.wrapper import DatabaseWrapper
 21    from plain.models.query_utils import Q
 22    from plain.models.sql.compiler import SQLCompiler
 23
 24
 25__all__ = [
 26    "Aggregate",
 27    "Avg",
 28    "Count",
 29    "Max",
 30    "Min",
 31    "StdDev",
 32    "Sum",
 33    "Variance",
 34]
 35
 36
 37class Aggregate(Func):
 38    template = "%(function)s(%(distinct)s%(expressions)s)"
 39    contains_aggregate = True
 40    name = None
 41    filter_template = "%s FILTER (WHERE %%(filter)s)"
 42    window_compatible = True
 43    allow_distinct = False
 44    empty_result_set_value = None
 45
 46    def __init__(
 47        self,
 48        *expressions: Any,
 49        distinct: bool = False,
 50        filter: Q | Expression | None = None,
 51        default: Any = None,
 52        **extra: Any,
 53    ) -> None:
 54        if distinct and not self.allow_distinct:
 55            raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
 56        if default is not None and self.empty_result_set_value is not None:
 57            raise TypeError(f"{self.__class__.__name__} does not allow default.")
 58        self.distinct = distinct
 59        self.filter = filter
 60        self.default = default
 61        super().__init__(*expressions, **extra)
 62
 63    def get_source_fields(self) -> list[Any]:
 64        # Don't return the filter expression since it's not a source field.
 65        return [e._output_field_or_none for e in super().get_source_expressions()]
 66
 67    def get_source_expressions(self) -> list[Expression]:
 68        source_expressions = super().get_source_expressions()
 69        if self.filter:
 70            return source_expressions + [self.filter]
 71        return source_expressions
 72
 73    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 74        exprs_list = list(exprs)
 75        self.filter = self.filter and exprs_list.pop()
 76        super().set_source_expressions(exprs_list)
 77
 78    def resolve_expression(  # type: ignore[override]
 79        self,
 80        query: Any = None,
 81        allow_joins: bool = True,
 82        reuse: Any = None,
 83        summarize: bool = False,
 84        for_save: bool = False,
 85    ) -> Expression:
 86        # Aggregates are not allowed in UPDATE queries, so ignore for_save
 87        c = super().resolve_expression(query, allow_joins, reuse, summarize)
 88        if c.filter is not None:
 89            c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
 90        if not summarize:
 91            # Call Aggregate.get_source_expressions() to avoid
 92            # returning self.filter and including that in this loop.
 93            expressions = super(Aggregate, c).get_source_expressions()
 94            for index, expr in enumerate(expressions):
 95                if expr.contains_aggregate:
 96                    before_resolved = self.get_source_expressions()[index]
 97                    name = (
 98                        before_resolved.name
 99                        if hasattr(before_resolved, "name")
100                        else repr(before_resolved)
101                    )
102                    raise FieldError(
103                        f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
104                    )
105        if (default := c.default) is None:
106            return c
107        if isinstance(default, ResolvableExpression):
108            default = default.resolve_expression(query, allow_joins, reuse, summarize)
109            if default._output_field_or_none is None:
110                default.output_field = c._output_field_or_none
111        else:
112            default = Value(default, c._output_field_or_none)
113        c.default = None  # Reset the default argument before wrapping.
114        coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
115        coalesce.is_summary = c.is_summary
116        return coalesce
117
118    @property
119    def default_alias(self) -> str:
120        expressions = self.get_source_expressions()
121        if len(expressions) == 1 and hasattr(expressions[0], "name"):
122            if self.name is None:
123                raise TypeError("Aggregate subclasses must define a name")
124            return f"{expressions[0].name}__{self.name.lower()}"
125        raise TypeError("Complex expressions require an alias")
126
127    def get_group_by_cols(self) -> list[Any]:
128        return []
129
130    def as_sql(
131        self,
132        compiler: SQLCompiler,
133        connection: DatabaseWrapper,
134        function: str | None = None,
135        template: str | None = None,
136        arg_joiner: str | None = None,
137        **extra_context: Any,
138    ) -> tuple[str, list[Any]]:
139        extra_context["distinct"] = "DISTINCT " if self.distinct else ""
140        if self.filter is not None:
141            # Use FILTER clause for aggregates when filter is specified
142            try:
143                filter_sql, filter_params = self.filter.as_sql(compiler, connection)  # type: ignore[union-attr]
144            except FullResultSet:
145                pass
146            else:
147                filter_template = self.filter_template % extra_context.get(
148                    "template", template or self.template
149                )
150                sql, params = super().as_sql(
151                    compiler,
152                    connection,
153                    function=function,
154                    template=filter_template,
155                    arg_joiner=arg_joiner,
156                    filter=filter_sql,
157                    **extra_context,
158                )
159                return sql, [*params, *filter_params]
160        return super().as_sql(
161            compiler,
162            connection,
163            function=function,
164            template=template,
165            arg_joiner=arg_joiner,
166            **extra_context,
167        )
168
169    def _get_repr_options(self) -> dict[str, Any]:
170        options = super()._get_repr_options()
171        if self.distinct:
172            options["distinct"] = self.distinct
173        if self.filter:
174            options["filter"] = self.filter
175        return options
176
177
178class Avg(NumericOutputFieldMixin, Aggregate):
179    function = "AVG"
180    name = "Avg"
181    allow_distinct = True
182
183
184class Count(Aggregate):
185    function = "COUNT"
186    name = "Count"
187    output_field = IntegerField()
188    allow_distinct = True
189    empty_result_set_value = 0
190
191    def __init__(
192        self, expression: Any, filter: Q | Expression | None = None, **extra: Any
193    ) -> None:
194        if expression == "*":
195            expression = Star()
196        if isinstance(expression, Star) and filter is not None:
197            raise ValueError("Star cannot be used with filter. Please specify a field.")
198        super().__init__(expression, filter=filter, **extra)
199
200
201class Max(Aggregate):
202    function = "MAX"
203    name = "Max"
204
205
206class Min(Aggregate):
207    function = "MIN"
208    name = "Min"
209
210
211class StdDev(NumericOutputFieldMixin, Aggregate):
212    name = "StdDev"
213
214    def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
215        self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
216        super().__init__(expression, **extra)
217
218    def _get_repr_options(self) -> dict[str, Any]:
219        return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
220
221
222class Sum(Aggregate):
223    function = "SUM"
224    name = "Sum"
225    allow_distinct = True
226
227
228class Variance(NumericOutputFieldMixin, Aggregate):
229    name = "Variance"
230
231    def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
232        self.function = "VAR_SAMP" if sample else "VAR_POP"
233        super().__init__(expression, **extra)
234
235    def _get_repr_options(self) -> dict[str, Any]:
236        return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}