1from abc import ABC, abstractmethod
2from collections import defaultdict
3from typing import Any
4
5from plain.admin.dates import DatetimeRangeAliases
6from plain.models.aggregates import Count
7from plain.models.functions import (
8 TruncDate,
9 TruncMonth,
10)
11
12from .base import Card
13
14
15class ChartCard(Card, ABC):
16 template_name = "admin/cards/chart.html"
17
18 def get_template_context(self) -> dict[str, Any]:
19 context = super().get_template_context()
20 context["chart_data"] = self.get_chart_data()
21 return context
22
23 @abstractmethod
24 def get_chart_data(self) -> dict: ...
25
26
27class TrendCard(ChartCard):
28 """
29 A card that renders a trend chart.
30 Primarily intended for use with models, but it can also be customized.
31 """
32
33 model = None
34 datetime_field = None
35 default_preset = DatetimeRangeAliases.SINCE_30_DAYS_AGO
36
37 presets = DatetimeRangeAliases
38
39 def get_description(self) -> str:
40 datetime_range = DatetimeRangeAliases.to_range(self.get_current_preset())
41 start = datetime_range.start.strftime("%b %d, %Y")
42 end = datetime_range.end.strftime("%b %d, %Y")
43 return f"{start} to {end}"
44
45 def get_current_preset(self) -> str:
46 if s := super().get_current_preset():
47 return s
48 return self.default_preset.value
49
50 def get_trend_data(self) -> dict[str, int]:
51 if not self.model or not self.datetime_field:
52 raise NotImplementedError(
53 "model and datetime_field must be set, or get_values must be overridden"
54 )
55
56 datetime_range = DatetimeRangeAliases.to_range(self.get_current_preset())
57
58 filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
59
60 if datetime_range.total_days() < 300:
61 truncator = TruncDate
62 iterator = datetime_range.iter_days
63 else:
64 truncator = TruncMonth
65 iterator = datetime_range.iter_months
66
67 counts_by_date = (
68 self.model.query.filter(**filter_kwargs)
69 .annotate(chart_date=truncator(self.datetime_field))
70 .values("chart_date")
71 .annotate(chart_date_count=Count("id"))
72 )
73
74 # Will do the zero filling for us on key access
75 date_values = defaultdict(int)
76
77 for row in counts_by_date:
78 date_values[row["chart_date"]] = row["chart_date_count"]
79
80 return {date.strftime("%Y-%m-%d"): date_values[date] for date in iterator()}
81
82 def get_chart_data(self) -> dict:
83 data = self.get_trend_data()
84 trend_labels = list(data.keys())
85 trend_data = list(data.values())
86
87 return {
88 "type": "bar",
89 "data": {
90 "labels": trend_labels,
91 "datasets": [
92 {
93 "data": trend_data,
94 # Gradient will be applied via JS - this is the fallback
95 "backgroundColor": "rgba(168, 162, 158, 0.7)", # stone-400
96 "hoverBackgroundColor": "rgba(120, 113, 108, 0.9)", # stone-500
97 "borderRadius": {"topLeft": 4, "topRight": 4},
98 "borderSkipped": False,
99 },
100 ],
101 },
102 "options": {
103 "responsive": True,
104 "maintainAspectRatio": False,
105 "animation": {
106 "duration": 600,
107 "easing": "easeOutQuart",
108 },
109 "plugins": {
110 "legend": {"display": False},
111 "tooltip": {
112 "enabled": True,
113 "backgroundColor": "rgba(41, 37, 36, 0.95)", # stone-800
114 "titleColor": "rgba(255, 255, 255, 0.7)",
115 "bodyColor": "#ffffff",
116 "bodyFont": {"size": 13, "weight": "600"},
117 "titleFont": {"size": 11},
118 "padding": {"x": 12, "y": 8},
119 "cornerRadius": 6,
120 "displayColors": False,
121 },
122 },
123 "scales": {
124 "x": {
125 "display": False,
126 "grid": {"display": False},
127 },
128 "y": {
129 "beginAtZero": True,
130 "display": True,
131 "position": "right",
132 "grid": {
133 "display": True,
134 "color": "rgba(0, 0, 0, 0.04)",
135 "drawTicks": False,
136 },
137 "border": {"display": False},
138 "ticks": {
139 "display": False,
140 "maxTicksLimit": 4,
141 },
142 },
143 },
144 "layout": {
145 "padding": {"top": 4, "bottom": 0, "left": 0, "right": 0},
146 },
147 },
148 }