1from abc import ABC, abstractmethod
  2from collections import defaultdict
  3from typing import Any
  4
  5from plain.admin.dates import DatetimeRangeAliases
  6from plain.postgres.aggregates import Count
  7from plain.postgres.functions import (
  8    TruncDate,
  9    TruncMonth,
 10)
 11
 12from .base import Card
 13
 14
 15class ChartCard(Card, ABC):
 16    template_name = "admin/cards/chart.html"
 17
 18    def get_template_context(self) -> dict[str, Any]:
 19        context = super().get_template_context()
 20        context["chart_data"] = self.get_chart_data()
 21        return context
 22
 23    @abstractmethod
 24    def get_chart_data(self) -> dict: ...
 25
 26
 27class TrendCard(ChartCard):
 28    """
 29    A card that renders a trend chart.
 30    Primarily intended for use with models, but it can also be customized.
 31    """
 32
 33    model = None
 34    datetime_field = None
 35    group_field: str | None = None
 36    group_labels: dict[str, str] | None = None
 37    default_group_colors: list[dict[str, str]] = [
 38        {"bg": "rgba(85, 107, 68, 0.8)", "hover": "rgba(85, 107, 68, 1)"},  # sage
 39        {
 40            "bg": "rgba(74, 111, 165, 0.8)",
 41            "hover": "rgba(74, 111, 165, 1)",
 42        },  # slate blue
 43        {
 44            "bg": "rgba(176, 110, 70, 0.8)",
 45            "hover": "rgba(176, 110, 70, 1)",
 46        },  # terracotta
 47        {
 48            "bg": "rgba(82, 126, 126, 0.8)",
 49            "hover": "rgba(82, 126, 126, 1)",
 50        },  # dusty teal
 51        {
 52            "bg": "rgba(140, 100, 75, 0.8)",
 53            "hover": "rgba(140, 100, 75, 1)",
 54        },  # warm brown
 55        {
 56            "bg": "rgba(130, 100, 140, 0.8)",
 57            "hover": "rgba(130, 100, 140, 1)",
 58        },  # muted plum
 59    ]
 60    group_colors: dict[str, dict[str, str]] | None = None
 61    default_filter = DatetimeRangeAliases.SINCE_30_DAYS_AGO
 62
 63    filters = DatetimeRangeAliases
 64
 65    def get_description(self) -> str:
 66        datetime_range = DatetimeRangeAliases.to_range(self.get_current_filter())
 67        start = datetime_range.start.strftime("%b %d, %Y")
 68        end = datetime_range.end.strftime("%b %d, %Y")
 69        return f"{start} to {end}"
 70
 71    def get_current_filter(self) -> str:
 72        if s := super().get_current_filter():
 73            return s
 74        return self.default_filter.value
 75
 76    def get_trend_data(self) -> dict[str, int] | dict[str, dict[str, int]]:
 77        """Return trend data, optionally grouped by group_field.
 78
 79        Without group_field: {date_str: count}
 80        With group_field: {group_label: {date_str: count}}
 81        """
 82        if not self.model or not self.datetime_field:
 83            raise NotImplementedError(
 84                "model and datetime_field must be set, or get_trend_data must be overridden"
 85            )
 86
 87        datetime_range = DatetimeRangeAliases.to_range(self.get_current_filter())
 88        filter_kwargs = {f"{self.datetime_field}__range": datetime_range.as_tuple()}
 89
 90        if datetime_range.total_days() < 300:
 91            truncator = TruncDate
 92            iterator = datetime_range.iter_days
 93        else:
 94            truncator = TruncMonth
 95            iterator = datetime_range.iter_months
 96
 97        value_fields = ["chart_date"]
 98        if self.group_field:
 99            value_fields.append(self.group_field)
100
101        rows = (
102            self.model.query.filter(**filter_kwargs)
103            .annotate(chart_date=truncator(self.datetime_field))
104            .values(*value_fields)
105            .annotate(chart_date_count=Count("id"))
106        )
107
108        dates = list(iterator())
109
110        if not self.group_field:
111            date_values: defaultdict[Any, int] = defaultdict(int)
112            for row in rows:
113                date_values[row["chart_date"]] = row["chart_date_count"]
114            return {date.strftime("%Y-%m-%d"): date_values[date] for date in dates}
115
116        groups: dict[str, defaultdict[Any, int]] = defaultdict(lambda: defaultdict(int))
117        for row in rows:
118            raw_value = row[self.group_field] or "Unknown"
119            groups[raw_value][row["chart_date"]] = row["chart_date_count"]
120
121        return {
122            group: {date.strftime("%Y-%m-%d"): counts[date] for date in dates}
123            for group, counts in sorted(groups.items())
124        }
125
126    def get_chart_data(self) -> dict:
127        data = self.get_trend_data()
128
129        if self.group_field:
130            return self._build_grouped_chart(data)
131
132        return self._build_single_chart(data)
133
134    def _build_single_chart(self, data: dict) -> dict:
135        return {
136            "type": "bar",
137            "data": {
138                "labels": list(data.keys()),
139                "datasets": [
140                    {
141                        "data": list(data.values()),
142                        "backgroundColor": "rgba(168, 162, 158, 0.7)",  # stone-400
143                        "hoverBackgroundColor": "rgba(120, 113, 108, 0.9)",  # stone-500
144                        "borderRadius": {"topLeft": 2, "topRight": 2},
145                        "borderSkipped": False,
146                    },
147                ],
148            },
149            **self._chart_options(show_legend=False, stacked=False),
150        }
151
152    def _build_grouped_chart(self, data: dict) -> dict:
153        if not data:
154            return self._build_single_chart({})
155
156        labels = list(next(iter(data.values())).keys())
157
158        group_labels = self.group_labels or {}
159
160        datasets = []
161        for i, (raw_name, date_counts) in enumerate(data.items()):
162            display_name = group_labels.get(raw_name, raw_name)
163            if self.group_colors and raw_name in self.group_colors:
164                colors = self.group_colors[raw_name]
165            else:
166                colors = self.default_group_colors[i % len(self.default_group_colors)]
167            datasets.append(
168                {
169                    "label": str(display_name),
170                    "data": list(date_counts.values()),
171                    "backgroundColor": colors["bg"],
172                    "hoverBackgroundColor": colors["hover"],
173                }
174            )
175
176        return {
177            "type": "bar",
178            "data": {
179                "labels": labels,
180                "datasets": datasets,
181            },
182            **self._chart_options(show_legend=True, stacked=True),
183        }
184
185    def _chart_options(self, *, show_legend: bool, stacked: bool) -> dict:
186        legend = (
187            {
188                "display": True,
189                "position": "top",
190                "align": "end",
191                "labels": {
192                    "boxWidth": 12,
193                    "boxHeight": 12,
194                    "borderRadius": 2,
195                    "useBorderRadius": True,
196                    "padding": 16,
197                    "font": {"size": 11},
198                },
199            }
200            if show_legend
201            else {"display": False}
202        )
203
204        return {
205            "options": {
206                "responsive": True,
207                "maintainAspectRatio": False,
208                "animation": {
209                    "duration": 600,
210                    "easing": "easeOutQuart",
211                },
212                "plugins": {
213                    "legend": legend,
214                    "tooltip": {
215                        "enabled": True,
216                        "backgroundColor": "rgba(41, 37, 36, 0.95)",  # stone-800
217                        "titleColor": "rgba(255, 255, 255, 0.7)",
218                        "bodyColor": "#ffffff",
219                        "bodyFont": {"size": 13, "weight": "600"},
220                        "titleFont": {"size": 11},
221                        "padding": {"x": 12, "y": 8},
222                        "cornerRadius": 6,
223                        "displayColors": show_legend,
224                    },
225                },
226                "scales": {
227                    "x": {
228                        "display": False,
229                        "grid": {"display": False},
230                        "stacked": stacked,
231                    },
232                    "y": {
233                        "beginAtZero": True,
234                        "display": True,
235                        "position": "right",
236                        "grid": {
237                            "display": True,
238                            "color": "rgba(0, 0, 0, 0.04)",
239                            "drawTicks": False,
240                        },
241                        "border": {"display": False},
242                        "ticks": {
243                            "display": False,
244                            "maxTicksLimit": 4,
245                        },
246                        "stacked": stacked,
247                    },
248                },
249                "layout": {
250                    "padding": {"top": 4, "bottom": 0, "left": 0, "right": 0},
251                },
252            },
253        }