1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.exceptions import FieldError, FullResultSet
6from plain.models.expressions import Case, Func, Star, Value, When
7from plain.models.fields import IntegerField
8from plain.models.functions.comparison import Coalesce
9from plain.models.functions.mixins import (
10 FixDurationInputMixin,
11 NumericOutputFieldMixin,
12)
13
14if TYPE_CHECKING:
15 from plain.models.backends.base.base import BaseDatabaseWrapper
16 from plain.models.expressions import Expression
17 from plain.models.query_utils import Q
18 from plain.models.sql.compiler import SQLCompiler
19
20__all__ = [
21 "Aggregate",
22 "Avg",
23 "Count",
24 "Max",
25 "Min",
26 "StdDev",
27 "Sum",
28 "Variance",
29]
30
31
32class Aggregate(Func):
33 template = "%(function)s(%(distinct)s%(expressions)s)"
34 contains_aggregate = True
35 name = None
36 filter_template = "%s FILTER (WHERE %%(filter)s)"
37 window_compatible = True
38 allow_distinct = False
39 empty_result_set_value = None
40
41 def __init__(
42 self,
43 *expressions: Any,
44 distinct: bool = False,
45 filter: Q | Expression | None = None,
46 default: Any = None,
47 **extra: Any,
48 ) -> None:
49 if distinct and not self.allow_distinct:
50 raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
51 if default is not None and self.empty_result_set_value is not None:
52 raise TypeError(f"{self.__class__.__name__} does not allow default.")
53 self.distinct = distinct
54 self.filter = filter
55 self.default = default
56 super().__init__(*expressions, **extra)
57
58 def get_source_fields(self) -> list[Any]:
59 # Don't return the filter expression since it's not a source field.
60 return [e._output_field_or_none for e in super().get_source_expressions()]
61
62 def get_source_expressions(self) -> list[Expression]:
63 source_expressions = super().get_source_expressions()
64 if self.filter:
65 return source_expressions + [self.filter]
66 return source_expressions
67
68 def set_source_expressions(self, exprs: list[Expression]) -> list[Expression]:
69 self.filter = self.filter and exprs.pop()
70 return super().set_source_expressions(exprs)
71
72 def resolve_expression(
73 self,
74 query: Any = None,
75 allow_joins: bool = True,
76 reuse: Any = None,
77 summarize: bool = False,
78 for_save: bool = False,
79 ) -> Expression:
80 # Aggregates are not allowed in UPDATE queries, so ignore for_save
81 c = super().resolve_expression(query, allow_joins, reuse, summarize)
82 c.filter = c.filter and c.filter.resolve_expression(
83 query, allow_joins, reuse, summarize
84 )
85 if not summarize:
86 # Call Aggregate.get_source_expressions() to avoid
87 # returning self.filter and including that in this loop.
88 expressions = super(Aggregate, c).get_source_expressions()
89 for index, expr in enumerate(expressions):
90 if expr.contains_aggregate:
91 before_resolved = self.get_source_expressions()[index]
92 name = (
93 before_resolved.name
94 if hasattr(before_resolved, "name")
95 else repr(before_resolved)
96 )
97 raise FieldError(
98 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
99 )
100 if (default := c.default) is None:
101 return c
102 if hasattr(default, "resolve_expression"):
103 default = default.resolve_expression(query, allow_joins, reuse, summarize)
104 if default._output_field_or_none is None:
105 default.output_field = c._output_field_or_none
106 else:
107 default = Value(default, c._output_field_or_none)
108 c.default = None # Reset the default argument before wrapping.
109 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
110 coalesce.is_summary = c.is_summary
111 return coalesce
112
113 @property
114 def default_alias(self) -> str:
115 expressions = self.get_source_expressions()
116 if len(expressions) == 1 and hasattr(expressions[0], "name"):
117 return f"{expressions[0].name}__{self.name.lower()}"
118 raise TypeError("Complex expressions require an alias")
119
120 def get_group_by_cols(self) -> list[Any]:
121 return []
122
123 def as_sql(
124 self,
125 compiler: SQLCompiler,
126 connection: BaseDatabaseWrapper,
127 **extra_context: Any,
128 ) -> tuple[str, tuple[Any, ...]]:
129 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
130 if self.filter:
131 if connection.features.supports_aggregate_filter_clause:
132 try:
133 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
134 except FullResultSet:
135 pass
136 else:
137 template = self.filter_template % extra_context.get(
138 "template", self.template
139 )
140 sql, params = super().as_sql(
141 compiler,
142 connection,
143 template=template,
144 filter=filter_sql,
145 **extra_context,
146 )
147 return sql, (*params, *filter_params)
148 else:
149 copy = self.copy()
150 copy.filter = None
151 source_expressions = copy.get_source_expressions()
152 condition = When(self.filter, then=source_expressions[0])
153 copy.set_source_expressions([Case(condition)] + source_expressions[1:])
154 return super(Aggregate, copy).as_sql(
155 compiler, connection, **extra_context
156 )
157 return super().as_sql(compiler, connection, **extra_context)
158
159 def _get_repr_options(self) -> dict[str, Any]:
160 options = super()._get_repr_options()
161 if self.distinct:
162 options["distinct"] = self.distinct
163 if self.filter:
164 options["filter"] = self.filter
165 return options
166
167
168class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
169 function = "AVG"
170 name = "Avg"
171 allow_distinct = True
172
173
174class Count(Aggregate):
175 function = "COUNT"
176 name = "Count"
177 output_field = IntegerField()
178 allow_distinct = True
179 empty_result_set_value = 0
180
181 def __init__(
182 self, expression: Any, filter: Q | Expression | None = None, **extra: Any
183 ) -> None:
184 if expression == "*":
185 expression = Star()
186 if isinstance(expression, Star) and filter is not None:
187 raise ValueError("Star cannot be used with filter. Please specify a field.")
188 super().__init__(expression, filter=filter, **extra)
189
190
191class Max(Aggregate):
192 function = "MAX"
193 name = "Max"
194
195
196class Min(Aggregate):
197 function = "MIN"
198 name = "Min"
199
200
201class StdDev(NumericOutputFieldMixin, Aggregate):
202 name = "StdDev"
203
204 def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
205 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
206 super().__init__(expression, **extra)
207
208 def _get_repr_options(self) -> dict[str, Any]:
209 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
210
211
212class Sum(FixDurationInputMixin, Aggregate):
213 function = "SUM"
214 name = "Sum"
215 allow_distinct = True
216
217
218class Variance(NumericOutputFieldMixin, Aggregate):
219 name = "Variance"
220
221 def __init__(self, expression: Any, sample: bool = False, **extra: Any) -> None:
222 self.function = "VAR_SAMP" if sample else "VAR_POP"
223 super().__init__(expression, **extra)
224
225 def _get_repr_options(self) -> dict[str, Any]:
226 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}