Plain is headed towards 1.0! Subscribe for development updates →

  1from collections import defaultdict
  2
  3from plain.admin.dates import DatetimeRangeAliases
  4from plain.models import Count
  5from plain.models.functions import (
  6    TruncDate,
  7    TruncMonth,
  8)
  9
 10from .base import Card
 11
 12
 13class ChartCard(Card):
 14    template_name = "admin/cards/chart.html"
 15
 16    def get_template_context(self):
 17        context = super().get_template_context()
 18        context["chart_data"] = self.get_chart_data()
 19        return context
 20
 21    def get_chart_data(self) -> dict:
 22        raise NotImplementedError
 23
 24
 25class TrendCard(ChartCard):
 26    """
 27    A card that renders a trend chart.
 28    Primarily intended for use with models, but it can also be customized.
 29    """
 30
 31    model = None
 32    datetime_field = None
 33    default_display = DatetimeRangeAliases.SINCE_30_DAYS_AGO
 34
 35    displays = DatetimeRangeAliases
 36
 37    def get_description(self):
 38        datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
 39        return f"{datetime_range.start} to {datetime_range.end}"
 40
 41    def get_current_display(self):
 42        if s := super().get_current_display():
 43            return DatetimeRangeAliases.from_value(s)
 44        return self.default_display
 45
 46    def get_trend_data(self) -> list[int | float]:
 47        if not self.model or not self.datetime_field:
 48            raise NotImplementedError(
 49                "model and datetime_field must be set, or get_values must be overridden"
 50            )
 51
 52        datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
 53
 54        filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
 55
 56        if datetime_range.total_days() < 300:
 57            truncator = TruncDate
 58            iterator = datetime_range.iter_days
 59        else:
 60            truncator = TruncMonth
 61            iterator = datetime_range.iter_months
 62
 63        counts_by_date = (
 64            self.model.objects.filter(**filter_kwargs)
 65            .annotate(chart_date=truncator(self.datetime_field))
 66            .values("chart_date")
 67            .annotate(chart_date_count=Count("id"))
 68        )
 69
 70        # Will do the zero filling for us on key access
 71        date_values = defaultdict(int)
 72
 73        for row in counts_by_date:
 74            date_values[row["chart_date"]] = row["chart_date_count"]
 75
 76        return {date.strftime("%Y-%m-%d"): date_values[date] for date in iterator()}
 77
 78    def get_chart_data(self) -> dict:
 79        data = self.get_trend_data()
 80        trend_labels = list(data.keys())
 81        trend_data = list(data.values())
 82
 83        def calculate_trend_line(data):
 84            """
 85            Calculate a trend line using basic linear regression.
 86            :param data: A list of numeric values representing the y-axis.
 87            :return: A list of trend line values (same length as data).
 88            """
 89            if not data or len(data) < 2:
 90                return (
 91                    data  # Return the data as-is if not enough points for a trend line
 92                )
 93
 94            n = len(data)
 95            x = list(range(n))
 96            y = data
 97
 98            # Calculate the means of x and y
 99            x_mean = sum(x) / n
100            y_mean = sum(y) / n
101
102            # Calculate the slope (m) and y-intercept (b) of the line: y = mx + b
103            numerator = sum((x[i] - x_mean) * (y[i] - y_mean) for i in range(n))
104            denominator = sum((x[i] - x_mean) ** 2 for i in range(n))
105            slope = numerator / denominator if denominator != 0 else 0
106            intercept = y_mean - slope * x_mean
107
108            # Calculate the trend line values
109            trend = [slope * xi + intercept for xi in x]
110
111            # if it's all zeros, return nothing
112            if all(v == 0 for v in trend):
113                return []
114
115            return trend
116
117        return {
118            "type": "bar",
119            "data": {
120                "labels": trend_labels,
121                "datasets": [
122                    {
123                        "data": trend_data,
124                    },
125                    {
126                        "data": calculate_trend_line(trend_data),
127                        "type": "line",
128                        "borderColor": "rgba(0, 0, 0, 0.3)",
129                        "borderWidth": 2,
130                        "fill": False,
131                        "pointRadius": 0,  # Optional: Hide points
132                    },
133                ],
134            },
135            # Hide the label
136            # "options": {"legend": {"display": False}},
137            # Hide the scales
138            "options": {
139                "plugins": {"legend": {"display": False}},
140                "scales": {
141                    "x": {
142                        "display": False,
143                    },
144                    "y": {
145                        "suggestedMin": 0,
146                    },
147                },
148                "maintainAspectRatio": False,
149                "elements": {
150                    "bar": {"borderRadius": "3", "backgroundColor": "rgb(28, 25, 23)"}
151                },
152            },
153        }