1from abc import ABC, abstractmethod
2from collections import defaultdict
3from typing import Any
4
5from plain.admin.dates import DatetimeRangeAliases
6from plain.postgres.aggregates import Count
7from plain.postgres.functions import (
8 TruncDate,
9 TruncMonth,
10)
11
12from .base import Card
13
14
15class ChartCard(Card, ABC):
16 template_name = "admin/cards/chart.html"
17
18 def get_template_context(self) -> dict[str, Any]:
19 context = super().get_template_context()
20 context["chart_data"] = self.get_chart_data()
21 return context
22
23 @abstractmethod
24 def get_chart_data(self) -> dict: ...
25
26
27class TrendCard(ChartCard):
28 """
29 A card that renders a trend chart.
30 Primarily intended for use with models, but it can also be customized.
31 """
32
33 model = None
34 datetime_field = None
35 group_field: str | None = None
36 group_labels: dict[str, str] | None = None
37 default_group_colors: list[dict[str, str]] = [
38 {"bg": "rgba(85, 107, 68, 0.8)", "hover": "rgba(85, 107, 68, 1)"}, # sage
39 {
40 "bg": "rgba(74, 111, 165, 0.8)",
41 "hover": "rgba(74, 111, 165, 1)",
42 }, # slate blue
43 {
44 "bg": "rgba(176, 110, 70, 0.8)",
45 "hover": "rgba(176, 110, 70, 1)",
46 }, # terracotta
47 {
48 "bg": "rgba(82, 126, 126, 0.8)",
49 "hover": "rgba(82, 126, 126, 1)",
50 }, # dusty teal
51 {
52 "bg": "rgba(140, 100, 75, 0.8)",
53 "hover": "rgba(140, 100, 75, 1)",
54 }, # warm brown
55 {
56 "bg": "rgba(130, 100, 140, 0.8)",
57 "hover": "rgba(130, 100, 140, 1)",
58 }, # muted plum
59 ]
60 group_colors: dict[str, dict[str, str]] | None = None
61 default_filter = DatetimeRangeAliases.SINCE_30_DAYS_AGO
62
63 filters = DatetimeRangeAliases
64
65 def get_description(self) -> str:
66 datetime_range = DatetimeRangeAliases.to_range(self.get_current_filter())
67 start = datetime_range.start.strftime("%b %d, %Y")
68 end = datetime_range.end.strftime("%b %d, %Y")
69 return f"{start} to {end}"
70
71 def get_current_filter(self) -> str:
72 if s := super().get_current_filter():
73 return s
74 return self.default_filter.value
75
76 def get_trend_data(self) -> dict[str, int] | dict[str, dict[str, int]]:
77 """Return trend data, optionally grouped by group_field.
78
79 Without group_field: {date_str: count}
80 With group_field: {group_label: {date_str: count}}
81 """
82 if not self.model or not self.datetime_field:
83 raise NotImplementedError(
84 "model and datetime_field must be set, or get_trend_data must be overridden"
85 )
86
87 datetime_range = DatetimeRangeAliases.to_range(self.get_current_filter())
88 filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
89
90 if datetime_range.total_days() < 300:
91 truncator = TruncDate
92 iterator = datetime_range.iter_days
93 else:
94 truncator = TruncMonth
95 iterator = datetime_range.iter_months
96
97 value_fields = ["chart_date"]
98 if self.group_field:
99 value_fields.append(self.group_field)
100
101 rows = (
102 self.model.query.filter(**filter_kwargs)
103 .annotate(chart_date=truncator(self.datetime_field))
104 .values(*value_fields)
105 .annotate(chart_date_count=Count("id"))
106 )
107
108 dates = list(iterator())
109
110 if not self.group_field:
111 date_values: defaultdict[Any, int] = defaultdict(int)
112 for row in rows:
113 date_values[row["chart_date"]] = row["chart_date_count"]
114 return {date.strftime("%Y-%m-%d"): date_values[date] for date in dates}
115
116 groups: dict[str, defaultdict[Any, int]] = defaultdict(lambda: defaultdict(int))
117 for row in rows:
118 raw_value = row[self.group_field] or "Unknown"
119 groups[raw_value][row["chart_date"]] = row["chart_date_count"]
120
121 return {
122 group: {date.strftime("%Y-%m-%d"): counts[date] for date in dates}
123 for group, counts in sorted(groups.items())
124 }
125
126 def get_chart_data(self) -> dict:
127 data = self.get_trend_data()
128
129 if self.group_field:
130 return self._build_grouped_chart(data)
131
132 return self._build_single_chart(data)
133
134 def _build_single_chart(self, data: dict) -> dict:
135 return {
136 "type": "bar",
137 "data": {
138 "labels": list(data.keys()),
139 "datasets": [
140 {
141 "data": list(data.values()),
142 "backgroundColor": "rgba(168, 162, 158, 0.7)", # stone-400
143 "hoverBackgroundColor": "rgba(120, 113, 108, 0.9)", # stone-500
144 "borderRadius": {"topLeft": 2, "topRight": 2},
145 "borderSkipped": False,
146 },
147 ],
148 },
149 **self._chart_options(show_legend=False, stacked=False),
150 }
151
152 def _build_grouped_chart(self, data: dict) -> dict:
153 if not data:
154 return self._build_single_chart({})
155
156 labels = list(next(iter(data.values())).keys())
157
158 group_labels = self.group_labels or {}
159
160 datasets = []
161 for i, (raw_name, date_counts) in enumerate(data.items()):
162 display_name = group_labels.get(raw_name, raw_name)
163 if self.group_colors and raw_name in self.group_colors:
164 colors = self.group_colors[raw_name]
165 else:
166 colors = self.default_group_colors[i % len(self.default_group_colors)]
167 datasets.append(
168 {
169 "label": str(display_name),
170 "data": list(date_counts.values()),
171 "backgroundColor": colors["bg"],
172 "hoverBackgroundColor": colors["hover"],
173 }
174 )
175
176 return {
177 "type": "bar",
178 "data": {
179 "labels": labels,
180 "datasets": datasets,
181 },
182 **self._chart_options(show_legend=True, stacked=True),
183 }
184
185 def _chart_options(self, *, show_legend: bool, stacked: bool) -> dict:
186 legend = (
187 {
188 "display": True,
189 "position": "top",
190 "align": "end",
191 "labels": {
192 "boxWidth": 12,
193 "boxHeight": 12,
194 "borderRadius": 2,
195 "useBorderRadius": True,
196 "padding": 16,
197 "font": {"size": 11},
198 },
199 }
200 if show_legend
201 else {"display": False}
202 )
203
204 return {
205 "options": {
206 "responsive": True,
207 "maintainAspectRatio": False,
208 "animation": {
209 "duration": 600,
210 "easing": "easeOutQuart",
211 },
212 "plugins": {
213 "legend": legend,
214 "tooltip": {
215 "enabled": True,
216 "backgroundColor": "rgba(41, 37, 36, 0.95)", # stone-800
217 "titleColor": "rgba(255, 255, 255, 0.7)",
218 "bodyColor": "#ffffff",
219 "bodyFont": {"size": 13, "weight": "600"},
220 "titleFont": {"size": 11},
221 "padding": {"x": 12, "y": 8},
222 "cornerRadius": 6,
223 "displayColors": show_legend,
224 },
225 },
226 "scales": {
227 "x": {
228 "display": False,
229 "grid": {"display": False},
230 "stacked": stacked,
231 },
232 "y": {
233 "beginAtZero": True,
234 "display": True,
235 "position": "right",
236 "grid": {
237 "display": True,
238 "color": "rgba(0, 0, 0, 0.04)",
239 "drawTicks": False,
240 },
241 "border": {"display": False},
242 "ticks": {
243 "display": False,
244 "maxTicksLimit": 4,
245 },
246 "stacked": stacked,
247 },
248 },
249 "layout": {
250 "padding": {"top": 4, "bottom": 0, "left": 0, "right": 0},
251 },
252 },
253 }