1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.exceptions import FieldError, FullResultSet
6from plain.models.expressions import (
7 Func,
8 ResolvableExpression,
9 Star,
10 Value,
11)
12from plain.models.fields import IntegerField
13from plain.models.functions.comparison import Coalesce
14from plain.models.functions.mixins import NumericOutputFieldMixin
15
16if TYPE_CHECKING:
17 from collections.abc import Sequence
18
19 from plain.models.expressions import Expression
20 from plain.models.postgres.wrapper import DatabaseWrapper
21 from plain.models.query_utils import Q
22 from plain.models.sql.compiler import SQLCompiler
23
24
25__all__ = [
26 "Aggregate",
27 "Avg",
28 "Count",
29 "Max",
30 "Min",
31 "StdDev",
32 "Sum",
33 "Variance",
34]
35
36
37class Aggregate(Func):
38 template = "%(function)s(%(distinct)s%(expressions)s)"
39 contains_aggregate = True
40 name = None
41 filter_template = "%s FILTER (WHERE %%(filter)s)"
42 window_compatible = True
43 allow_distinct = False
44 empty_result_set_value = None
45
46 def __init__(
47 self,
48 *expressions: Any,
49 distinct: bool = False,
50 filter: Q | Expression | None = None,
51 default: Any = None,
52 **extra: Any,
53 ) -> None:
54 if distinct and not self.allow_distinct:
55 raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
56 if default is not None and self.empty_result_set_value is not None:
57 raise TypeError(f"{self.__class__.__name__} does not allow default.")
58 self.distinct = distinct
59 self.filter = filter
60 self.default = default
61 super().__init__(*expressions, **extra)
62
63 def get_source_fields(self) -> list[Any]:
64 # Don't return the filter expression since it's not a source field.
65 return [e._output_field_or_none for e in super().get_source_expressions()]
66
67 def get_source_expressions(self) -> list[Expression]:
68 source_expressions = super().get_source_expressions()
69 if self.filter:
70 return source_expressions + [self.filter]
71 return source_expressions
72
73 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
74 exprs_list = list(exprs)
75 self.filter = self.filter and exprs_list.pop()
76 super().set_source_expressions(exprs_list)
77
78 def resolve_expression( # type: ignore[override]
79 self,
80 query: Any = None,
81 allow_joins: bool = True,
82 reuse: Any = None,
83 summarize: bool = False,
84 for_save: bool = False,
85 ) -> Expression:
86 # Aggregates are not allowed in UPDATE queries, so ignore for_save
87 c = super().resolve_expression(query, allow_joins, reuse, summarize)
88 if c.filter is not None:
89 c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
90 if not summarize:
91 # Call Aggregate.get_source_expressions() to avoid
92 # returning self.filter and including that in this loop.
93 expressions = super(Aggregate, c).get_source_expressions()
94 for index, expr in enumerate(expressions):
95 if expr.contains_aggregate:
96 before_resolved = self.get_source_expressions()[index]
97 name = (
98 before_resolved.name
99 if hasattr(before_resolved, "name")
100 else repr(before_resolved)
101 )
102 raise FieldError(
103 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
104 )
105 if (default := c.default) is None:
106 return c
107 if isinstance(default, ResolvableExpression):
108 default = default.resolve_expression(query, allow_joins, reuse, summarize)
109 if default._output_field_or_none is None:
110 default.output_field = c._output_field_or_none
111 else:
112 default = Value(default, c._output_field_or_none)
113 c.default = None # Reset the default argument before wrapping.
114 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
115 coalesce.is_summary = c.is_summary
116 return coalesce
117
118 @property
119 def default_alias(self) -> str:
120 expressions = self.get_source_expressions()
121 if len(expressions) == 1 and hasattr(expressions[0], "name"):
122 if self.name is None:
123 raise TypeError("Aggregate subclasses must define a name")
124 return f"{expressions[0].name}__{self.name.lower()}"
125 raise TypeError("Complex expressions require an alias")
126
127 def get_group_by_cols(self) -> list[Any]:
128 return []
129
130 def as_sql(
131 self,
132 compiler: SQLCompiler,
133 connection: DatabaseWrapper,
134 function: str | None = None,
135 template: str | None = None,
136 arg_joiner: str | None = None,
137 **extra_context: Any,
138 ) -> tuple[str, list[Any]]:
139 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
140 if self.filter is not None:
141 # Use FILTER clause for aggregates when filter is specified
142 try:
143 filter_sql, filter_params = self.filter.as_sql(compiler, connection) # type: ignore[union-attr]
144 except FullResultSet:
145 pass
146 else:
147 filter_template = self.filter_template % extra_context.get(
148 "template", template or self.template
149 )
150 sql, params = super().as_sql(
151 compiler,
152 connection,
153 function=function,
154 template=filter_template,
155 arg_joiner=arg_joiner,
156 filter=filter_sql,
157 **extra_context,
158 )
159 return sql, [*params, *filter_params]
160 return super().as_sql(
161 compiler,
162 connection,
163 function=function,
164 template=template,
165 arg_joiner=arg_joiner,
166 **extra_context,
167 )
168
169 def _get_repr_options(self) -> dict[str, Any]:
170 options = super()._get_repr_options()
171 if self.distinct:
172 options["distinct"] = self.distinct
173 if self.filter:
174 options["filter"] = self.filter
175 return options
176
177
178class Avg(NumericOutputFieldMixin, Aggregate):
179 function = "AVG"
180 name = "Avg"
181 allow_distinct = True
182
183
184class Count(Aggregate):
185 function = "COUNT"
186 name = "Count"
187 output_field = IntegerField()
188 allow_distinct = True
189 empty_result_set_value = 0
190
191 def __init__(
192 self, expression: Any, filter: Q | Expression | None = None, **extra: Any
193 ) -> None:
194 if expression == "*":
195 expression = Star()
196 if isinstance(expression, Star) and filter is not None:
197 raise ValueError("Star cannot be used with filter. Please specify a field.")
198 super().__init__(expression, filter=filter, **extra)
199
200
201class Max(Aggregate):
202 function = "MAX"
203 name = "Max"
204
205
206class Min(Aggregate):
207 function = "MIN"
208 name = "Min"
209
210
211class StdDev(NumericOutputFieldMixin, Aggregate):
212 name = "StdDev"
213
214 def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
215 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
216 super().__init__(expression, **extra)
217
218 def _get_repr_options(self) -> dict[str, Any]:
219 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
220
221
222class Sum(Aggregate):
223 function = "SUM"
224 name = "Sum"
225 allow_distinct = True
226
227
228class Variance(NumericOutputFieldMixin, Aggregate):
229 name = "Variance"
230
231 def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
232 self.function = "VAR_SAMP" if sample else "VAR_POP"
233 super().__init__(expression, **extra)
234
235 def _get_repr_options(self) -> dict[str, Any]:
236 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}