Plain is headed towards 1.0! Subscribe for development updates →

  1from collections import defaultdict
  2from typing import Any
  3
  4from plain.admin.dates import DatetimeRangeAliases
  5from plain.models import Count
  6from plain.models.functions import (
  7    TruncDate,
  8    TruncMonth,
  9)
 10
 11from .base import Card
 12
 13
 14class ChartCard(Card):
 15    template_name = "admin/cards/chart.html"
 16
 17    def get_template_context(self) -> dict[str, Any]:
 18        context = super().get_template_context()
 19        context["chart_data"] = self.get_chart_data()
 20        return context
 21
 22    def get_chart_data(self) -> dict:
 23        raise NotImplementedError
 24
 25
 26class TrendCard(ChartCard):
 27    """
 28    A card that renders a trend chart.
 29    Primarily intended for use with models, but it can also be customized.
 30    """
 31
 32    model = None
 33    datetime_field = None
 34    default_display = DatetimeRangeAliases.SINCE_30_DAYS_AGO
 35
 36    displays = DatetimeRangeAliases
 37
 38    def get_description(self) -> str:
 39        datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
 40        return f"{datetime_range.start} to {datetime_range.end}"
 41
 42    def get_current_display(self) -> DatetimeRangeAliases:
 43        if s := super().get_current_display():
 44            return DatetimeRangeAliases.from_value(s)
 45        return self.default_display
 46
 47    def get_trend_data(self) -> dict[str, int]:
 48        if not self.model or not self.datetime_field:
 49            raise NotImplementedError(
 50                "model and datetime_field must be set, or get_values must be overridden"
 51            )
 52
 53        datetime_range = DatetimeRangeAliases.to_range(self.get_current_display())
 54
 55        filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
 56
 57        if datetime_range.total_days() < 300:
 58            truncator = TruncDate
 59            iterator = datetime_range.iter_days
 60        else:
 61            truncator = TruncMonth
 62            iterator = datetime_range.iter_months
 63
 64        counts_by_date = (
 65            self.model.query.filter(**filter_kwargs)
 66            .annotate(chart_date=truncator(self.datetime_field))
 67            .values("chart_date")
 68            .annotate(chart_date_count=Count("id"))
 69        )
 70
 71        # Will do the zero filling for us on key access
 72        date_values = defaultdict(int)
 73
 74        for row in counts_by_date:
 75            date_values[row["chart_date"]] = row["chart_date_count"]
 76
 77        return {date.strftime("%Y-%m-%d"): date_values[date] for date in iterator()}
 78
 79    def get_chart_data(self) -> dict:
 80        data = self.get_trend_data()
 81        trend_labels = list(data.keys())
 82        trend_data = list(data.values())
 83
 84        def calculate_trend_line(data: list[int | float]) -> list[int | float]:
 85            """
 86            Calculate a trend line using basic linear regression.
 87            :param data: A list of numeric values representing the y-axis.
 88            :return: A list of trend line values (same length as data).
 89            """
 90            if not data or len(data) < 2:
 91                return (
 92                    data  # Return the data as-is if not enough points for a trend line
 93                )
 94
 95            n = len(data)
 96            x = list(range(n))
 97            y = data
 98
 99            # Calculate the means of x and y
100            x_mean = sum(x) / n
101            y_mean = sum(y) / n
102
103            # Calculate the slope (m) and y-intercept (b) of the line: y = mx + b
104            numerator = sum((x[i] - x_mean) * (y[i] - y_mean) for i in range(n))
105            denominator = sum((x[i] - x_mean) ** 2 for i in range(n))
106            slope = numerator / denominator if denominator != 0 else 0
107            intercept = y_mean - slope * x_mean
108
109            # Calculate the trend line values
110            trend = [slope * xi + intercept for xi in x]
111
112            # if it's all zeros, return nothing
113            if all(v == 0 for v in trend):
114                return []
115
116            return trend
117
118        return {
119            "type": "bar",
120            "data": {
121                "labels": trend_labels,
122                "datasets": [
123                    {
124                        "data": trend_data,
125                    },
126                    {
127                        "data": calculate_trend_line(trend_data),
128                        "type": "line",
129                        "borderColor": "rgba(255, 255, 255, 0.3)",
130                        "borderWidth": 2,
131                        "fill": False,
132                        "pointRadius": 0,  # Optional: Hide points
133                    },
134                ],
135            },
136            # Hide the scales
137            "options": {
138                "plugins": {"legend": {"display": False}},
139                "scales": {
140                    "x": {
141                        "display": False,
142                    },
143                    "y": {
144                        "suggestedMin": 0,
145                    },
146                },
147                "maintainAspectRatio": False,
148                "elements": {
149                    "bar": {"borderRadius": "3", "backgroundColor": "#d6d6d6"}
150                },
151            },
152        }