Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import Any
  4
  5import jinja2
  6from jinja2 import meta, nodes
  7from jinja2.ext import Extension
  8from jinja2.nodes import CallBlock, Node
  9from jinja2.parser import Parser
 10
 11from plain.runtime import settings
 12from plain.templates import register_template_extension
 13from plain.templates.jinja.extensions import InclusionTagExtension
 14
 15
 16@register_template_extension
 17class HTMXJSExtension(InclusionTagExtension):
 18    tags = {"htmx_js"}
 19    template_name = "htmx/js.html"
 20
 21    def get_context(
 22        self, context: dict[str, Any], *args: Any, **kwargs: Any
 23    ) -> dict[str, Any]:
 24        request = context.get("request")
 25        return {
 26            "DEBUG": settings.DEBUG,
 27            "extensions": kwargs.get("extensions", []),
 28            "csp_nonce": request.csp_nonce if request else None,  # type: ignore[attr-defined]
 29        }
 30
 31
 32@register_template_extension
 33class HTMXFragmentExtension(Extension):
 34    tags = {"htmxfragment"}
 35
 36    def __init__(self, environment: jinja2.Environment):
 37        super().__init__(environment)
 38        environment.extend(htmx_fragment_nodes={})
 39
 40    def parse(self, parser: Parser) -> Node:
 41        lineno = next(parser.stream).lineno
 42
 43        fragment_name = parser.parse_expression()
 44
 45        kwargs = []
 46
 47        while parser.stream.current.type != "block_end":
 48            if parser.stream.current.type == "name":
 49                key = parser.stream.current.value
 50                parser.stream.skip()
 51                parser.stream.expect("assign")
 52                value = parser.parse_expression()
 53                kwargs.append(nodes.Keyword(key, value))
 54
 55        body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
 56
 57        call = self.call_method(
 58            "_render_htmx_fragment",
 59            args=[fragment_name, nodes.ContextReference()],
 60            kwargs=kwargs,
 61        )
 62
 63        node = CallBlock(call, [], [], body).set_lineno(lineno)
 64
 65        # Store a reference to the node for later
 66        self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[  # type: ignore[attr-defined]
 67            fragment_name.value  # type: ignore[attr-defined]
 68        ] = node
 69
 70        return node
 71
 72    def _render_htmx_fragment(
 73        self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
 74    ) -> str:
 75        def attrs_to_str(attrs: dict[str, Any]) -> str:
 76            parts = []
 77            for k, v in attrs.items():
 78                if v == "":
 79                    parts.append(k)
 80                else:
 81                    parts.append(f'{k}="{v}"')
 82            return " ".join(parts)
 83
 84        render_lazy = kwargs.get("lazy", False)
 85        as_element = kwargs.get("as", "div")
 86        attrs = {}
 87        for k, v in kwargs.items():
 88            if k.startswith("hx_"):
 89                attrs[k.replace("_", "-")] = v
 90            else:
 91                attrs[k] = v
 92
 93        if render_lazy:
 94            attrs.setdefault("hx-swap", "outerHTML")
 95            attrs.setdefault("hx-target", "this")
 96            attrs.setdefault("hx-indicator", "this")
 97            attrs_str = attrs_to_str(attrs)
 98            return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
 99        else:
100            # Swap innerHTML so we can re-run hx calls inside the fragment automatically
101            # (render_template_fragment won't render this part of the node again, just the inner nodes)
102            attrs.setdefault("hx-swap", "innerHTML")
103            attrs.setdefault("hx-target", "this")
104            attrs.setdefault("hx-indicator", "this")
105            # Add an id that you can use to target the fragment from outside the fragment
106            attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
107            attrs_str = attrs_to_str(attrs)
108            return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
109
110
111def render_template_fragment(
112    *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
113) -> str:
114    template = find_template_fragment(template, fragment_name)
115    return template.render(context)
116
117
118def find_template_fragment(
119    template: jinja2.Template, fragment_name: str
120) -> jinja2.Template:
121    # Look in this template for the fragment
122    callblock_node = template.environment.htmx_fragment_nodes.get(  # type: ignore[attr-defined]
123        template.name, {}
124    ).get(fragment_name)
125
126    if not callblock_node:
127        # Look in other templates for this fragment
128        matching_callblock_nodes = []
129        for fragments in template.environment.htmx_fragment_nodes.values():  # type: ignore[attr-defined]
130            if fragment_name in fragments:
131                matching_callblock_nodes.append(fragments[fragment_name])
132
133        if len(matching_callblock_nodes) == 0:
134            # If we still haven't found anything, it's possible that we're
135            # in a different/new worker/process and haven't parsed the related templates yet
136            if template.environment.loader and template.name:
137                ast = template.environment.parse(
138                    template.environment.loader.get_source(
139                        template.environment, template.name
140                    )[0]
141                )
142                for ref in meta.find_referenced_templates(ast):
143                    if ref not in template.environment.htmx_fragment_nodes:  # type: ignore[attr-defined]
144                        # Trigger them to parse
145                        template.environment.get_template(ref)
146
147                # Now look again
148                for fragments in template.environment.htmx_fragment_nodes.values():  # type: ignore[attr-defined]
149                    if fragment_name in fragments:
150                        matching_callblock_nodes.append(fragments[fragment_name])
151
152        if len(matching_callblock_nodes) == 1:
153            callblock_node = matching_callblock_nodes[0]
154        elif len(matching_callblock_nodes) > 1:
155            raise jinja2.TemplateNotFound(
156                f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
157            )
158        else:
159            raise jinja2.TemplateNotFound(
160                f"Fragment {fragment_name} not found in any templates"
161            )
162
163    if not callblock_node:
164        raise jinja2.TemplateNotFound(
165            f"Fragment {fragment_name} not found in template {template.name}"
166        )
167
168    # Create a new template from the node
169    template_node = nodes.Template(callblock_node.body)  # type: ignore[attr-defined]
170    return template.environment.from_string(template_node)  # type: ignore[arg-type]