1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.exceptions import FieldError, FullResultSet
6from plain.models.expressions import (
7 Case,
8 Func,
9 ResolvableExpression,
10 Star,
11 Value,
12 When,
13)
14from plain.models.fields import IntegerField
15from plain.models.functions.comparison import Coalesce
16from plain.models.functions.mixins import (
17 FixDurationInputMixin,
18 NumericOutputFieldMixin,
19)
20
21if TYPE_CHECKING:
22 from collections.abc import Sequence
23
24 from plain.models.backends.base.base import BaseDatabaseWrapper
25 from plain.models.expressions import Expression
26 from plain.models.query_utils import Q
27 from plain.models.sql.compiler import SQLCompiler
28
29
30__all__ = [
31 "Aggregate",
32 "Avg",
33 "Count",
34 "Max",
35 "Min",
36 "StdDev",
37 "Sum",
38 "Variance",
39]
40
41
42class Aggregate(Func):
43 template = "%(function)s(%(distinct)s%(expressions)s)"
44 contains_aggregate = True
45 name = None
46 filter_template = "%s FILTER (WHERE %%(filter)s)"
47 window_compatible = True
48 allow_distinct = False
49 empty_result_set_value = None
50
51 def __init__(
52 self,
53 *expressions: Any,
54 distinct: bool = False,
55 filter: Q | Expression | None = None,
56 default: Any = None,
57 **extra: Any,
58 ) -> None:
59 if distinct and not self.allow_distinct:
60 raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
61 if default is not None and self.empty_result_set_value is not None:
62 raise TypeError(f"{self.__class__.__name__} does not allow default.")
63 self.distinct = distinct
64 self.filter = filter
65 self.default = default
66 super().__init__(*expressions, **extra)
67
68 def get_source_fields(self) -> list[Any]:
69 # Don't return the filter expression since it's not a source field.
70 return [e._output_field_or_none for e in super().get_source_expressions()]
71
72 def get_source_expressions(self) -> list[Expression]:
73 source_expressions = super().get_source_expressions()
74 if self.filter:
75 return source_expressions + [self.filter]
76 return source_expressions
77
78 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
79 exprs_list = list(exprs)
80 self.filter = self.filter and exprs_list.pop()
81 super().set_source_expressions(exprs_list)
82
83 def resolve_expression( # type: ignore[override]
84 self,
85 query: Any = None,
86 allow_joins: bool = True,
87 reuse: Any = None,
88 summarize: bool = False,
89 for_save: bool = False,
90 ) -> Expression:
91 # Aggregates are not allowed in UPDATE queries, so ignore for_save
92 c = super().resolve_expression(query, allow_joins, reuse, summarize)
93 if c.filter is not None:
94 c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
95 if not summarize:
96 # Call Aggregate.get_source_expressions() to avoid
97 # returning self.filter and including that in this loop.
98 expressions = super(Aggregate, c).get_source_expressions()
99 for index, expr in enumerate(expressions):
100 if expr.contains_aggregate:
101 before_resolved = self.get_source_expressions()[index]
102 name = (
103 before_resolved.name
104 if hasattr(before_resolved, "name")
105 else repr(before_resolved)
106 )
107 raise FieldError(
108 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
109 )
110 if (default := c.default) is None:
111 return c
112 if isinstance(default, ResolvableExpression):
113 default = default.resolve_expression(query, allow_joins, reuse, summarize)
114 if default._output_field_or_none is None:
115 default.output_field = c._output_field_or_none
116 else:
117 default = Value(default, c._output_field_or_none)
118 c.default = None # Reset the default argument before wrapping.
119 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
120 coalesce.is_summary = c.is_summary
121 return coalesce
122
123 @property
124 def default_alias(self) -> str:
125 expressions = self.get_source_expressions()
126 if len(expressions) == 1 and hasattr(expressions[0], "name"):
127 if self.name is None:
128 raise TypeError("Aggregate subclasses must define a name")
129 return f"{expressions[0].name}__{self.name.lower()}"
130 raise TypeError("Complex expressions require an alias")
131
132 def get_group_by_cols(self) -> list[Any]:
133 return []
134
135 def as_sql(
136 self,
137 compiler: SQLCompiler,
138 connection: BaseDatabaseWrapper,
139 function: str | None = None,
140 template: str | None = None,
141 arg_joiner: str | None = None,
142 **extra_context: Any,
143 ) -> tuple[str, list[Any]]:
144 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
145 if self.filter is not None:
146 if connection.features.supports_aggregate_filter_clause:
147 try:
148 filter_sql, filter_params = self.filter.as_sql(compiler, connection) # type: ignore[possibly-missing-attribute]
149 except FullResultSet:
150 pass
151 else:
152 filter_template = self.filter_template % extra_context.get(
153 "template", template or self.template
154 )
155 sql, params = super().as_sql(
156 compiler,
157 connection,
158 function=function,
159 template=filter_template,
160 arg_joiner=arg_joiner,
161 filter=filter_sql,
162 **extra_context,
163 )
164 return sql, [*params, *filter_params]
165 else:
166 copy = self.copy()
167 copy.filter = None
168 source_expressions = copy.get_source_expressions()
169 condition = When(self.filter, then=source_expressions[0])
170 copy.set_source_expressions([Case(condition)] + source_expressions[1:])
171 return super(Aggregate, copy).as_sql(
172 compiler,
173 connection,
174 function=function,
175 template=template,
176 arg_joiner=arg_joiner,
177 **extra_context,
178 )
179 return super().as_sql(
180 compiler,
181 connection,
182 function=function,
183 template=template,
184 arg_joiner=arg_joiner,
185 **extra_context,
186 )
187
188 def _get_repr_options(self) -> dict[str, Any]:
189 options = super()._get_repr_options()
190 if self.distinct:
191 options["distinct"] = self.distinct
192 if self.filter:
193 options["filter"] = self.filter
194 return options
195
196
197class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
198 function = "AVG"
199 name = "Avg"
200 allow_distinct = True
201
202
203class Count(Aggregate):
204 function = "COUNT"
205 name = "Count"
206 output_field = IntegerField()
207 allow_distinct = True
208 empty_result_set_value = 0
209
210 def __init__(
211 self, expression: Any, filter: Q | Expression | None = None, **extra: Any
212 ) -> None:
213 if expression == "*":
214 expression = Star()
215 if isinstance(expression, Star) and filter is not None:
216 raise ValueError("Star cannot be used with filter. Please specify a field.")
217 super().__init__(expression, filter=filter, **extra)
218
219
220class Max(Aggregate):
221 function = "MAX"
222 name = "Max"
223
224
225class Min(Aggregate):
226 function = "MIN"
227 name = "Min"
228
229
230class StdDev(NumericOutputFieldMixin, Aggregate):
231 name = "StdDev"
232
233 def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
234 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
235 super().__init__(expression, **extra)
236
237 def _get_repr_options(self) -> dict[str, Any]:
238 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
239
240
241class Sum(FixDurationInputMixin, Aggregate):
242 function = "SUM"
243 name = "Sum"
244 allow_distinct = True
245
246
247class Variance(NumericOutputFieldMixin, Aggregate):
248 name = "Variance"
249
250 def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
251 self.function = "VAR_SAMP" if sample else "VAR_POP"
252 super().__init__(expression, **extra)
253
254 def _get_repr_options(self) -> dict[str, Any]:
255 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}