Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Classes to represent the definitions of aggregate functions.
  3"""
  4
  5from plain.exceptions import FieldError, FullResultSet
  6from plain.models.expressions import Case, Func, Star, Value, When
  7from plain.models.fields import IntegerField
  8from plain.models.functions.comparison import Coalesce
  9from plain.models.functions.mixins import (
 10    FixDurationInputMixin,
 11    NumericOutputFieldMixin,
 12)
 13
 14__all__ = [
 15    "Aggregate",
 16    "Avg",
 17    "Count",
 18    "Max",
 19    "Min",
 20    "StdDev",
 21    "Sum",
 22    "Variance",
 23]
 24
 25
 26class Aggregate(Func):
 27    template = "%(function)s(%(distinct)s%(expressions)s)"
 28    contains_aggregate = True
 29    name = None
 30    filter_template = "%s FILTER (WHERE %%(filter)s)"
 31    window_compatible = True
 32    allow_distinct = False
 33    empty_result_set_value = None
 34
 35    def __init__(
 36        self, *expressions, distinct=False, filter=None, default=None, **extra
 37    ):
 38        if distinct and not self.allow_distinct:
 39            raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
 40        if default is not None and self.empty_result_set_value is not None:
 41            raise TypeError(f"{self.__class__.__name__} does not allow default.")
 42        self.distinct = distinct
 43        self.filter = filter
 44        self.default = default
 45        super().__init__(*expressions, **extra)
 46
 47    def get_source_fields(self):
 48        # Don't return the filter expression since it's not a source field.
 49        return [e._output_field_or_none for e in super().get_source_expressions()]
 50
 51    def get_source_expressions(self):
 52        source_expressions = super().get_source_expressions()
 53        if self.filter:
 54            return source_expressions + [self.filter]
 55        return source_expressions
 56
 57    def set_source_expressions(self, exprs):
 58        self.filter = self.filter and exprs.pop()
 59        return super().set_source_expressions(exprs)
 60
 61    def resolve_expression(
 62        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 63    ):
 64        # Aggregates are not allowed in UPDATE queries, so ignore for_save
 65        c = super().resolve_expression(query, allow_joins, reuse, summarize)
 66        c.filter = c.filter and c.filter.resolve_expression(
 67            query, allow_joins, reuse, summarize
 68        )
 69        if not summarize:
 70            # Call Aggregate.get_source_expressions() to avoid
 71            # returning self.filter and including that in this loop.
 72            expressions = super(Aggregate, c).get_source_expressions()
 73            for index, expr in enumerate(expressions):
 74                if expr.contains_aggregate:
 75                    before_resolved = self.get_source_expressions()[index]
 76                    name = (
 77                        before_resolved.name
 78                        if hasattr(before_resolved, "name")
 79                        else repr(before_resolved)
 80                    )
 81                    raise FieldError(
 82                        f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
 83                    )
 84        if (default := c.default) is None:
 85            return c
 86        if hasattr(default, "resolve_expression"):
 87            default = default.resolve_expression(query, allow_joins, reuse, summarize)
 88            if default._output_field_or_none is None:
 89                default.output_field = c._output_field_or_none
 90        else:
 91            default = Value(default, c._output_field_or_none)
 92        c.default = None  # Reset the default argument before wrapping.
 93        coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
 94        coalesce.is_summary = c.is_summary
 95        return coalesce
 96
 97    @property
 98    def default_alias(self):
 99        expressions = self.get_source_expressions()
100        if len(expressions) == 1 and hasattr(expressions[0], "name"):
101            return f"{expressions[0].name}__{self.name.lower()}"
102        raise TypeError("Complex expressions require an alias")
103
104    def get_group_by_cols(self):
105        return []
106
107    def as_sql(self, compiler, connection, **extra_context):
108        extra_context["distinct"] = "DISTINCT " if self.distinct else ""
109        if self.filter:
110            if connection.features.supports_aggregate_filter_clause:
111                try:
112                    filter_sql, filter_params = self.filter.as_sql(compiler, connection)
113                except FullResultSet:
114                    pass
115                else:
116                    template = self.filter_template % extra_context.get(
117                        "template", self.template
118                    )
119                    sql, params = super().as_sql(
120                        compiler,
121                        connection,
122                        template=template,
123                        filter=filter_sql,
124                        **extra_context,
125                    )
126                    return sql, (*params, *filter_params)
127            else:
128                copy = self.copy()
129                copy.filter = None
130                source_expressions = copy.get_source_expressions()
131                condition = When(self.filter, then=source_expressions[0])
132                copy.set_source_expressions([Case(condition)] + source_expressions[1:])
133                return super(Aggregate, copy).as_sql(
134                    compiler, connection, **extra_context
135                )
136        return super().as_sql(compiler, connection, **extra_context)
137
138    def _get_repr_options(self):
139        options = super()._get_repr_options()
140        if self.distinct:
141            options["distinct"] = self.distinct
142        if self.filter:
143            options["filter"] = self.filter
144        return options
145
146
147class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
148    function = "AVG"
149    name = "Avg"
150    allow_distinct = True
151
152
153class Count(Aggregate):
154    function = "COUNT"
155    name = "Count"
156    output_field = IntegerField()
157    allow_distinct = True
158    empty_result_set_value = 0
159
160    def __init__(self, expression, filter=None, **extra):
161        if expression == "*":
162            expression = Star()
163        if isinstance(expression, Star) and filter is not None:
164            raise ValueError("Star cannot be used with filter. Please specify a field.")
165        super().__init__(expression, filter=filter, **extra)
166
167
168class Max(Aggregate):
169    function = "MAX"
170    name = "Max"
171
172
173class Min(Aggregate):
174    function = "MIN"
175    name = "Min"
176
177
178class StdDev(NumericOutputFieldMixin, Aggregate):
179    name = "StdDev"
180
181    def __init__(self, expression, sample=False, **extra):
182        self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
183        super().__init__(expression, **extra)
184
185    def _get_repr_options(self):
186        return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
187
188
189class Sum(FixDurationInputMixin, Aggregate):
190    function = "SUM"
191    name = "Sum"
192    allow_distinct = True
193
194
195class Variance(NumericOutputFieldMixin, Aggregate):
196    name = "Variance"
197
198    def __init__(self, expression, sample=False, **extra):
199        self.function = "VAR_SAMP" if sample else "VAR_POP"
200        super().__init__(expression, **extra)
201
202    def _get_repr_options(self):
203        return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}