v0.148.0
  1from abc import ABC, abstractmethod
  2from collections import defaultdict
  3from typing import Any, Literal
  4
  5from plain.admin.dates import DatetimeRangeAliases
  6from plain.postgres.aggregates import Count
  7from plain.postgres.functions import (
  8    TruncDate,
  9    TruncMonth,
 10)
 11
 12from .base import Card
 13
 14
 15class ChartCard(Card, ABC):
 16    template_name = "admin/cards/chart.html"
 17
 18    def get_template_context(self) -> dict[str, Any]:
 19        context = super().get_template_context()
 20        context["chart_data"] = self.get_chart_data()
 21        return context
 22
 23    @abstractmethod
 24    def get_chart_data(self) -> dict: ...
 25
 26
 27class TrendCard(ChartCard):
 28    """
 29    A card that renders a trend chart.
 30    Primarily intended for use with models, but it can also be customized.
 31    """
 32
 33    model = None
 34    datetime_field = None
 35    group_field: str | None = None
 36    group_labels: dict[str, str] | None = None
 37    # CSS color values resolved by charts.js. `var(--chart-N)` reads the
 38    # admin's chart palette so charts retheme automatically (incl. dark mode).
 39    default_group_colors: list[str] = [
 40        "var(--chart-1)",
 41        "var(--chart-2)",
 42        "var(--chart-3)",
 43        "var(--chart-4)",
 44        "var(--chart-5)",
 45    ]
 46    group_colors: dict[str, str] | None = None
 47    aggregates: tuple[Literal["sum", "avg", "max"], ...] = ("sum",)
 48    default_filter = DatetimeRangeAliases.SINCE_30_DAYS_AGO
 49
 50    filters = DatetimeRangeAliases
 51
 52    def get_current_filter(self) -> str:
 53        if s := super().get_current_filter():
 54            return s
 55        return self.default_filter.value
 56
 57    def get_trend_data(self) -> dict[str, int] | dict[str, dict[str, int]]:
 58        """Return trend data, optionally grouped by group_field.
 59
 60        Without group_field: {date_str: count}
 61        With group_field: {group_label: {date_str: count}}
 62        """
 63        if not self.model or not self.datetime_field:
 64            raise NotImplementedError(
 65                "model and datetime_field must be set, or get_trend_data must be overridden"
 66            )
 67
 68        datetime_range = DatetimeRangeAliases.to_range(self.get_current_filter())
 69        filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
 70
 71        if datetime_range.total_days() < 300:
 72            truncator = TruncDate
 73            iterator = datetime_range.iter_days
 74        else:
 75            truncator = TruncMonth
 76            iterator = datetime_range.iter_months
 77
 78        value_fields = ["chart_date"]
 79        if self.group_field:
 80            value_fields.append(self.group_field)
 81
 82        rows = (
 83            self.model.query.filter(**filter_kwargs)
 84            .annotate(chart_date=truncator(self.datetime_field))
 85            .values(*value_fields)
 86            .annotate(chart_date_count=Count("id"))
 87        )
 88
 89        dates = list(iterator())
 90
 91        if not self.group_field:
 92            date_values: defaultdict[Any, int] = defaultdict(int)
 93            for row in rows:
 94                date_values[row["chart_date"]] = row["chart_date_count"]
 95            return {date.strftime("%Y-%m-%d"): date_values[date] for date in dates}
 96
 97        groups: dict[str, defaultdict[Any, int]] = defaultdict(lambda: defaultdict(int))
 98        for row in rows:
 99            raw = row[self.group_field]
100            raw_value = "Unknown" if raw is None else str(raw)
101            groups[raw_value][row["chart_date"]] = row["chart_date_count"]
102
103        return {
104            group: {date.strftime("%Y-%m-%d"): counts[date] for date in dates}
105            for group, counts in sorted(groups.items())
106        }
107
108    def get_chart_data(self) -> dict:
109        data = self.get_trend_data()
110
111        if self.group_field:
112            return self._build_grouped_chart(data)
113
114        return self._build_single_chart(data)
115
116    def _build_single_chart(self, data: dict) -> dict:
117        return {
118            "type": "bar",
119            "data": {
120                "labels": list(data.keys()),
121                "datasets": [
122                    {
123                        "label": self.title,
124                        "data": list(data.values()),
125                        "backgroundColor": "var(--chart-1)",
126                        "borderRadius": {"topLeft": 2, "topRight": 2},
127                        "borderSkipped": False,
128                        "categoryPercentage": 0.9,
129                        "barPercentage": 1.0,
130                    },
131                ],
132            },
133            **self._chart_options(stacked=False),
134            "plain": self._plain_meta(),
135        }
136
137    def _build_grouped_chart(self, data: dict) -> dict:
138        if not data:
139            return self._build_single_chart({})
140
141        labels = list(next(iter(data.values())).keys())
142
143        group_labels = self.group_labels or {}
144
145        datasets = []
146        for i, (raw_name, date_counts) in enumerate(data.items()):
147            display_name = group_labels.get(raw_name, raw_name)
148            if self.group_colors and raw_name in self.group_colors:
149                color = self.group_colors[raw_name]
150            else:
151                color = self.default_group_colors[i % len(self.default_group_colors)]
152            datasets.append(
153                {
154                    "label": str(display_name),
155                    "data": list(date_counts.values()),
156                    "backgroundColor": color,
157                    "categoryPercentage": 0.9,
158                    "barPercentage": 1.0,
159                }
160            )
161
162        return {
163            "type": "bar",
164            "data": {
165                "labels": labels,
166                "datasets": datasets,
167            },
168            **self._chart_options(stacked=True),
169            "plain": self._plain_meta(),
170        }
171
172    def _plain_meta(self) -> dict:
173        return {
174            "aggregates": list(self.aggregates),
175        }
176
177    def _chart_options(self, *, stacked: bool) -> dict:
178        return {
179            "options": {
180                "responsive": True,
181                "maintainAspectRatio": False,
182                "animation": {
183                    "duration": 600,
184                    "easing": "easeOutQuart",
185                },
186                "interaction": {
187                    "mode": "index",
188                    "intersect": False,
189                    "axis": "x",
190                },
191                "plugins": {
192                    "legend": {"display": False},
193                    "tooltip": {"enabled": False},
194                },
195                "scales": {
196                    "x": {
197                        "display": False,
198                        "grid": {"display": False},
199                        "stacked": stacked,
200                    },
201                    "y": {
202                        "beginAtZero": True,
203                        "display": False,
204                        "stacked": stacked,
205                    },
206                },
207                "layout": {
208                    "padding": {"top": 4, "bottom": 0, "left": 0, "right": 0},
209                },
210            },
211        }