1from collections import defaultdict
2from typing import Any
3
4from plain.admin.dates import DatetimeRangeAliases
5from plain.models import Count
6from plain.models.functions import (
7 TruncDate,
8 TruncMonth,
9)
10
11from .base import Card
12
13
14class ChartCard(Card):
15 template_name = "admin/cards/chart.html"
16
17 def get_template_context(self) -> dict[str, Any]:
18 context = super().get_template_context()
19 context["chart_data"] = self.get_chart_data()
20 return context
21
22 def get_chart_data(self) -> dict:
23 raise NotImplementedError
24
25
26class TrendCard(ChartCard):
27 """
28 A card that renders a trend chart.
29 Primarily intended for use with models, but it can also be customized.
30 """
31
32 model = None
33 datetime_field = None
34 default_display = DatetimeRangeAliases.SINCE_30_DAYS_AGO
35
36 displays = DatetimeRangeAliases
37
38 def get_description(self) -> str:
39 datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
40 return f"{datetime_range.start} to {datetime_range.end}"
41
42 def get_current_display(self) -> DatetimeRangeAliases:
43 if s := super().get_current_display():
44 return DatetimeRangeAliases.from_value(s)
45 return self.default_display
46
47 def get_trend_data(self) -> dict[str, int]:
48 if not self.model or not self.datetime_field:
49 raise NotImplementedError(
50 "model and datetime_field must be set, or get_values must be overridden"
51 )
52
53 datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
54
55 filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
56
57 if datetime_range.total_days() < 300:
58 truncator = TruncDate
59 iterator = datetime_range.iter_days
60 else:
61 truncator = TruncMonth
62 iterator = datetime_range.iter_months
63
64 counts_by_date = (
65 self.model.query.filter(**filter_kwargs)
66 .annotate(chart_date=truncator(self.datetime_field))
67 .values("chart_date")
68 .annotate(chart_date_count=Count("id"))
69 )
70
71 # Will do the zero filling for us on key access
72 date_values = defaultdict(int)
73
74 for row in counts_by_date:
75 date_values[row["chart_date"]] = row["chart_date_count"]
76
77 return {date.strftime("%Y-%m-%d"): date_values[date] for date in iterator()}
78
79 def get_chart_data(self) -> dict:
80 data = self.get_trend_data()
81 trend_labels = list(data.keys())
82 trend_data = list(data.values())
83
84 def calculate_trend_line(data: list[int | float]) -> list[int | float]:
85 """
86 Calculate a trend line using basic linear regression.
87 :param data: A list of numeric values representing the y-axis.
88 :return: A list of trend line values (same length as data).
89 """
90 if not data or len(data) < 2:
91 return (
92 data # Return the data as-is if not enough points for a trend line
93 )
94
95 n = len(data)
96 x = list(range(n))
97 y = data
98
99 # Calculate the means of x and y
100 x_mean = sum(x) / n
101 y_mean = sum(y) / n
102
103 # Calculate the slope (m) and y-intercept (b) of the line: y = mx + b
104 numerator = sum((x[i] - x_mean) * (y[i] - y_mean) for i in range(n))
105 denominator = sum((x[i] - x_mean) ** 2 for i in range(n))
106 slope = numerator / denominator if denominator != 0 else 0
107 intercept = y_mean - slope * x_mean
108
109 # Calculate the trend line values
110 trend = [slope * xi + intercept for xi in x]
111
112 # if it's all zeros, return nothing
113 if all(v == 0 for v in trend):
114 return []
115
116 return trend
117
118 return {
119 "type": "bar",
120 "data": {
121 "labels": trend_labels,
122 "datasets": [
123 {
124 "data": trend_data,
125 },
126 {
127 "data": calculate_trend_line(trend_data),
128 "type": "line",
129 "borderColor": "rgba(255, 255, 255, 0.3)",
130 "borderWidth": 2,
131 "fill": False,
132 "pointRadius": 0, # Optional: Hide points
133 },
134 ],
135 },
136 # Hide the scales
137 "options": {
138 "plugins": {"legend": {"display": False}},
139 "scales": {
140 "x": {
141 "display": False,
142 },
143 "y": {
144 "suggestedMin": 0,
145 },
146 },
147 "maintainAspectRatio": False,
148 "elements": {
149 "bar": {"borderRadius": "3", "backgroundColor": "#d6d6d6"}
150 },
151 },
152 }