1from abc import ABC, abstractmethod
2from collections import defaultdict
3from typing import Any, Literal
4
5from plain.admin.dates import DatetimeRangeAliases
6from plain.postgres.aggregates import Count
7from plain.postgres.functions import (
8 TruncDate,
9 TruncMonth,
10)
11
12from .base import Card
13
14
15class ChartCard(Card, ABC):
16 template_name = "admin/cards/chart.html"
17
18 def get_template_context(self) -> dict[str, Any]:
19 context = super().get_template_context()
20 context["chart_data"] = self.get_chart_data()
21 return context
22
23 @abstractmethod
24 def get_chart_data(self) -> dict: ...
25
26
27class TrendCard(ChartCard):
28 """
29 A card that renders a trend chart.
30 Primarily intended for use with models, but it can also be customized.
31 """
32
33 model = None
34 datetime_field = None
35 group_field: str | None = None
36 group_labels: dict[str, str] | None = None
37 # CSS color values resolved by charts.js. `var(--chart-N)` reads the
38 # admin's chart palette so charts retheme automatically (incl. dark mode).
39 default_group_colors: list[str] = [
40 "var(--chart-1)",
41 "var(--chart-2)",
42 "var(--chart-3)",
43 "var(--chart-4)",
44 "var(--chart-5)",
45 ]
46 group_colors: dict[str, str] | None = None
47 aggregates: tuple[Literal["sum", "avg", "max"], ...] = ("sum",)
48 default_filter = DatetimeRangeAliases.SINCE_30_DAYS_AGO
49
50 filters = DatetimeRangeAliases
51
52 def get_current_filter(self) -> str:
53 if s := super().get_current_filter():
54 return s
55 return self.default_filter.value
56
57 def get_trend_data(self) -> dict[str, int] | dict[str, dict[str, int]]:
58 """Return trend data, optionally grouped by group_field.
59
60 Without group_field: {date_str: count}
61 With group_field: {group_label: {date_str: count}}
62 """
63 if not self.model or not self.datetime_field:
64 raise NotImplementedError(
65 "model and datetime_field must be set, or get_trend_data must be overridden"
66 )
67
68 datetime_range = DatetimeRangeAliases.to_range(self.get_current_filter())
69 filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
70
71 if datetime_range.total_days() < 300:
72 truncator = TruncDate
73 iterator = datetime_range.iter_days
74 else:
75 truncator = TruncMonth
76 iterator = datetime_range.iter_months
77
78 value_fields = ["chart_date"]
79 if self.group_field:
80 value_fields.append(self.group_field)
81
82 rows = (
83 self.model.query.filter(**filter_kwargs)
84 .annotate(chart_date=truncator(self.datetime_field))
85 .values(*value_fields)
86 .annotate(chart_date_count=Count("id"))
87 )
88
89 dates = list(iterator())
90
91 if not self.group_field:
92 date_values: defaultdict[Any, int] = defaultdict(int)
93 for row in rows:
94 date_values[row["chart_date"]] = row["chart_date_count"]
95 return {date.strftime("%Y-%m-%d"): date_values[date] for date in dates}
96
97 groups: dict[str, defaultdict[Any, int]] = defaultdict(lambda: defaultdict(int))
98 for row in rows:
99 raw = row[self.group_field]
100 raw_value = "Unknown" if raw is None else str(raw)
101 groups[raw_value][row["chart_date"]] = row["chart_date_count"]
102
103 return {
104 group: {date.strftime("%Y-%m-%d"): counts[date] for date in dates}
105 for group, counts in sorted(groups.items())
106 }
107
108 def get_chart_data(self) -> dict:
109 data = self.get_trend_data()
110
111 if self.group_field:
112 return self._build_grouped_chart(data)
113
114 return self._build_single_chart(data)
115
116 def _build_single_chart(self, data: dict) -> dict:
117 return {
118 "type": "bar",
119 "data": {
120 "labels": list(data.keys()),
121 "datasets": [
122 {
123 "label": self.title,
124 "data": list(data.values()),
125 "backgroundColor": "var(--chart-1)",
126 "borderRadius": {"topLeft": 2, "topRight": 2},
127 "borderSkipped": False,
128 "categoryPercentage": 0.9,
129 "barPercentage": 1.0,
130 },
131 ],
132 },
133 **self._chart_options(stacked=False),
134 "plain": self._plain_meta(),
135 }
136
137 def _build_grouped_chart(self, data: dict) -> dict:
138 if not data:
139 return self._build_single_chart({})
140
141 labels = list(next(iter(data.values())).keys())
142
143 group_labels = self.group_labels or {}
144
145 datasets = []
146 for i, (raw_name, date_counts) in enumerate(data.items()):
147 display_name = group_labels.get(raw_name, raw_name)
148 if self.group_colors and raw_name in self.group_colors:
149 color = self.group_colors[raw_name]
150 else:
151 color = self.default_group_colors[i % len(self.default_group_colors)]
152 datasets.append(
153 {
154 "label": str(display_name),
155 "data": list(date_counts.values()),
156 "backgroundColor": color,
157 "categoryPercentage": 0.9,
158 "barPercentage": 1.0,
159 }
160 )
161
162 return {
163 "type": "bar",
164 "data": {
165 "labels": labels,
166 "datasets": datasets,
167 },
168 **self._chart_options(stacked=True),
169 "plain": self._plain_meta(),
170 }
171
172 def _plain_meta(self) -> dict:
173 return {
174 "aggregates": list(self.aggregates),
175 }
176
177 def _chart_options(self, *, stacked: bool) -> dict:
178 return {
179 "options": {
180 "responsive": True,
181 "maintainAspectRatio": False,
182 "animation": {
183 "duration": 600,
184 "easing": "easeOutQuart",
185 },
186 "interaction": {
187 "mode": "index",
188 "intersect": False,
189 "axis": "x",
190 },
191 "plugins": {
192 "legend": {"display": False},
193 "tooltip": {"enabled": False},
194 },
195 "scales": {
196 "x": {
197 "display": False,
198 "grid": {"display": False},
199 "stacked": stacked,
200 },
201 "y": {
202 "beginAtZero": True,
203 "display": False,
204 "stacked": stacked,
205 },
206 },
207 "layout": {
208 "padding": {"top": 4, "bottom": 0, "left": 0, "right": 0},
209 },
210 },
211 }