Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import Any
  4
  5import jinja2
  6from jinja2 import meta, nodes
  7from jinja2.ext import Extension
  8from jinja2.nodes import CallBlock, Node
  9from jinja2.parser import Parser
 10
 11from plain.runtime import settings
 12from plain.templates import register_template_extension
 13from plain.templates.jinja.extensions import InclusionTagExtension
 14
 15
 16@register_template_extension
 17class HTMXJSExtension(InclusionTagExtension):
 18    tags = {"htmx_js"}
 19    template_name = "htmx/js.html"
 20
 21    def get_context(
 22        self, context: dict[str, Any], *args: Any, **kwargs: Any
 23    ) -> dict[str, Any]:
 24        return {
 25            "DEBUG": settings.DEBUG,
 26            "extensions": kwargs.get("extensions", []),
 27            "csp_nonce": context.get("request").csp_nonce,
 28        }
 29
 30
 31@register_template_extension
 32class HTMXFragmentExtension(Extension):
 33    tags = {"htmxfragment"}
 34
 35    def __init__(self, environment: jinja2.Environment):
 36        super().__init__(environment)
 37        environment.extend(htmx_fragment_nodes={})
 38
 39    def parse(self, parser: Parser) -> Node:
 40        lineno = next(parser.stream).lineno
 41
 42        fragment_name = parser.parse_expression()
 43
 44        kwargs = []
 45
 46        while parser.stream.current.type != "block_end":
 47            if parser.stream.current.type == "name":
 48                key = parser.stream.current.value
 49                parser.stream.skip()
 50                parser.stream.expect("assign")
 51                value = parser.parse_expression()
 52                kwargs.append(nodes.Keyword(key, value))
 53
 54        body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
 55
 56        call = self.call_method(
 57            "_render_htmx_fragment",
 58            args=[fragment_name, nodes.ContextReference()],
 59            kwargs=kwargs,
 60        )
 61
 62        node = CallBlock(call, [], [], body).set_lineno(lineno)
 63
 64        # Store a reference to the node for later
 65        self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[  # type: ignore[attr-defined]
 66            fragment_name.value  # type: ignore[attr-defined]
 67        ] = node
 68
 69        return node
 70
 71    def _render_htmx_fragment(
 72        self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
 73    ) -> str:
 74        def attrs_to_str(attrs: dict[str, Any]) -> str:
 75            parts = []
 76            for k, v in attrs.items():
 77                if v == "":
 78                    parts.append(k)
 79                else:
 80                    parts.append(f'{k}="{v}"')
 81            return " ".join(parts)
 82
 83        render_lazy = kwargs.get("lazy", False)
 84        as_element = kwargs.get("as", "div")
 85        attrs = {}
 86        for k, v in kwargs.items():
 87            if k.startswith("hx_"):
 88                attrs[k.replace("_", "-")] = v
 89            else:
 90                attrs[k] = v
 91
 92        if render_lazy:
 93            attrs.setdefault("hx-swap", "outerHTML")
 94            attrs.setdefault("hx-target", "this")
 95            attrs.setdefault("hx-indicator", "this")
 96            attrs_str = attrs_to_str(attrs)
 97            return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
 98        else:
 99            # Swap innerHTML so we can re-run hx calls inside the fragment automatically
100            # (render_template_fragment won't render this part of the node again, just the inner nodes)
101            attrs.setdefault("hx-swap", "innerHTML")
102            attrs.setdefault("hx-target", "this")
103            attrs.setdefault("hx-indicator", "this")
104            # Add an id that you can use to target the fragment from outside the fragment
105            attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
106            attrs_str = attrs_to_str(attrs)
107            return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
108
109
110def render_template_fragment(
111    *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
112) -> str:
113    template = find_template_fragment(template, fragment_name)
114    return template.render(context)
115
116
117def find_template_fragment(
118    template: jinja2.Template, fragment_name: str
119) -> jinja2.Template:
120    # Look in this template for the fragment
121    callblock_node = template.environment.htmx_fragment_nodes.get(  # type: ignore[attr-defined]
122        template.name, {}
123    ).get(fragment_name)
124
125    if not callblock_node:
126        # Look in other templates for this fragment
127        matching_callblock_nodes = []
128        for fragments in template.environment.htmx_fragment_nodes.values():  # type: ignore[attr-defined]
129            if fragment_name in fragments:
130                matching_callblock_nodes.append(fragments[fragment_name])
131
132        if len(matching_callblock_nodes) == 0:
133            # If we still haven't found anything, it's possible that we're
134            # in a different/new worker/process and haven't parsed the related templates yet
135            if template.environment.loader and template.name:
136                ast = template.environment.parse(
137                    template.environment.loader.get_source(
138                        template.environment, template.name
139                    )[0]
140                )
141                for ref in meta.find_referenced_templates(ast):
142                    if ref not in template.environment.htmx_fragment_nodes:  # type: ignore[attr-defined]
143                        # Trigger them to parse
144                        template.environment.get_template(ref)
145
146                # Now look again
147                for fragments in template.environment.htmx_fragment_nodes.values():  # type: ignore[attr-defined]
148                    if fragment_name in fragments:
149                        matching_callblock_nodes.append(fragments[fragment_name])
150
151        if len(matching_callblock_nodes) == 1:
152            callblock_node = matching_callblock_nodes[0]
153        elif len(matching_callblock_nodes) > 1:
154            raise jinja2.TemplateNotFound(
155                f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
156            )
157        else:
158            raise jinja2.TemplateNotFound(
159                f"Fragment {fragment_name} not found in any templates"
160            )
161
162    if not callblock_node:
163        raise jinja2.TemplateNotFound(
164            f"Fragment {fragment_name} not found in template {template.name}"
165        )
166
167    # Create a new template from the node
168    template_node = nodes.Template(callblock_node.body)  # type: ignore[attr-defined]
169    return template.environment.from_string(template_node)  # type: ignore[arg-type]