1from __future__ import annotations
 2
 3from collections.abc import Callable
 4from typing import Any
 5
 6from plain.http import ResponseBase
 7from plain.utils.cache import patch_vary_headers
 8from plain.views import TemplateView
 9
10from .templates import render_template_fragment
11
12__all__ = ["HTMXView"]
13
14
15class HTMXView(TemplateView):
16    """View with HTMX-specific functionality."""
17
18    def render_template(self) -> str:
19        template = self.get_template()
20        context = self.get_template_context()
21
22        if self.is_htmx_request() and self.get_htmx_fragment_name():
23            return render_template_fragment(
24                template=template._jinja_template,
25                fragment_name=self.get_htmx_fragment_name(),
26                context=context,
27            )
28
29        return template.render(context)
30
31    def get_response(self) -> ResponseBase:
32        response = super().get_response()
33        # Tell browser caching to also consider the fragment header,
34        # not just the url/cookie.
35        patch_vary_headers(
36            response, ["HX-Request", "Plain-HX-Fragment", "Plain-HX-Action"]
37        )
38        return response
39
40    def get_request_handler(self) -> Callable[[], Any] | None:
41        if self.is_htmx_request() and self.request.method:
42            # You can use an htmx_{method} method on views
43            # (or htmx_{method}_{action} for specific actions)
44            method = f"htmx_{self.request.method.lower()}"
45
46            if action := self.get_htmx_action_name():
47                # If an action is specified, we throw an error if
48                # the associated method isn't found
49                return getattr(self, f"{method}_{action}")
50
51            if handler := getattr(self, method, None):
52                # If it's just an htmx post, for example,
53                # we can use a custom method or we can let it fall back
54                # to a regular post method if it's not found
55                return handler
56
57        return super().get_request_handler()
58
59    def is_htmx_request(self) -> bool:
60        return self.request.headers.get("HX-Request") == "true"
61
62    def get_htmx_fragment_name(self) -> str:
63        # A custom header that we pass with the {% htmxfragment %} tag
64        return self.request.headers.get("Plain-HX-Fragment", "")
65
66    def get_htmx_action_name(self) -> str:
67        return self.request.headers.get("Plain-HX-Action", "")