1from collections import defaultdict
2
3from plain.admin.dates import DatetimeRangeAliases
4from plain.models import Count
5from plain.models.functions import (
6 TruncDate,
7 TruncMonth,
8)
9
10from .base import Card
11
12
13class ChartCard(Card):
14 template_name = "admin/cards/chart.html"
15
16 def get_template_context(self):
17 context = super().get_template_context()
18 context["chart_data"] = self.get_chart_data()
19 return context
20
21 def get_chart_data(self) -> dict:
22 raise NotImplementedError
23
24
25class TrendCard(ChartCard):
26 """
27 A card that renders a trend chart.
28 Primarily intended for use with models, but it can also be customized.
29 """
30
31 model = None
32 datetime_field = None
33 default_display = DatetimeRangeAliases.SINCE_30_DAYS_AGO
34
35 displays = DatetimeRangeAliases
36
37 def get_description(self):
38 datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
39 return f"{datetime_range.start} to {datetime_range.end}"
40
41 def get_current_display(self):
42 if s := super().get_current_display():
43 return DatetimeRangeAliases.from_value(s)
44 return self.default_display
45
46 def get_trend_data(self) -> list[int | float]:
47 if not self.model or not self.datetime_field:
48 raise NotImplementedError(
49 "model and datetime_field must be set, or get_values must be overridden"
50 )
51
52 datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
53
54 filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
55
56 if datetime_range.total_days() < 300:
57 truncator = TruncDate
58 iterator = datetime_range.iter_days
59 else:
60 truncator = TruncMonth
61 iterator = datetime_range.iter_months
62
63 counts_by_date = (
64 self.model.objects.filter(**filter_kwargs)
65 .annotate(chart_date=truncator(self.datetime_field))
66 .values("chart_date")
67 .annotate(chart_date_count=Count("id"))
68 )
69
70 # Will do the zero filling for us on key access
71 date_values = defaultdict(int)
72
73 for row in counts_by_date:
74 date_values[row["chart_date"]] = row["chart_date_count"]
75
76 return {date.strftime("%Y-%m-%d"): date_values[date] for date in iterator()}
77
78 def get_chart_data(self) -> dict:
79 data = self.get_trend_data()
80 trend_labels = list(data.keys())
81 trend_data = list(data.values())
82
83 def calculate_trend_line(data):
84 """
85 Calculate a trend line using basic linear regression.
86 :param data: A list of numeric values representing the y-axis.
87 :return: A list of trend line values (same length as data).
88 """
89 if not data or len(data) < 2:
90 return (
91 data # Return the data as-is if not enough points for a trend line
92 )
93
94 n = len(data)
95 x = list(range(n))
96 y = data
97
98 # Calculate the means of x and y
99 x_mean = sum(x) / n
100 y_mean = sum(y) / n
101
102 # Calculate the slope (m) and y-intercept (b) of the line: y = mx + b
103 numerator = sum((x[i] - x_mean) * (y[i] - y_mean) for i in range(n))
104 denominator = sum((x[i] - x_mean) ** 2 for i in range(n))
105 slope = numerator / denominator if denominator != 0 else 0
106 intercept = y_mean - slope * x_mean
107
108 # Calculate the trend line values
109 trend = [slope * xi + intercept for xi in x]
110
111 # if it's all zeros, return nothing
112 if all(v == 0 for v in trend):
113 return []
114
115 return trend
116
117 return {
118 "type": "bar",
119 "data": {
120 "labels": trend_labels,
121 "datasets": [
122 {
123 "data": trend_data,
124 },
125 {
126 "data": calculate_trend_line(trend_data),
127 "type": "line",
128 "borderColor": "rgba(0, 0, 0, 0.3)",
129 "borderWidth": 2,
130 "fill": False,
131 "pointRadius": 0, # Optional: Hide points
132 },
133 ],
134 },
135 # Hide the label
136 # "options": {"legend": {"display": False}},
137 # Hide the scales
138 "options": {
139 "plugins": {"legend": {"display": False}},
140 "scales": {
141 "x": {
142 "display": False,
143 },
144 "y": {
145 "suggestedMin": 0,
146 },
147 },
148 "maintainAspectRatio": False,
149 "elements": {
150 "bar": {"borderRadius": "3", "backgroundColor": "rgb(28, 25, 23)"}
151 },
152 },
153 }