Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import Any
  4
  5import jinja2
  6from jinja2 import meta, nodes
  7from jinja2.ext import Extension
  8from jinja2.nodes import CallBlock, Node
  9from jinja2.parser import Parser
 10from jinja2.runtime import Context
 11
 12from plain.runtime import settings
 13from plain.templates import register_template_extension
 14from plain.templates.jinja.extensions import InclusionTagExtension
 15
 16
 17@register_template_extension
 18class HTMXJSExtension(InclusionTagExtension):
 19    tags = {"htmx_js"}
 20    template_name = "htmx/js.html"
 21
 22    def get_context(
 23        self, context: Context, *args: Any, **kwargs: Any
 24    ) -> dict[str, Any]:
 25        request = context.get("request")
 26        return {
 27            "DEBUG": settings.DEBUG,
 28            "extensions": kwargs.get("extensions", []),
 29            "csp_nonce": request.csp_nonce if request else None,  # type: ignore[attr-defined]
 30        }
 31
 32
 33@register_template_extension
 34class HTMXFragmentExtension(Extension):
 35    tags = {"htmxfragment"}
 36
 37    def __init__(self, environment: jinja2.Environment):
 38        super().__init__(environment)
 39        environment.extend(htmx_fragment_nodes={})
 40
 41    def parse(self, parser: Parser) -> Node:
 42        lineno = next(parser.stream).lineno
 43
 44        fragment_name = parser.parse_expression()
 45
 46        kwargs = []
 47
 48        while parser.stream.current.type != "block_end":
 49            if parser.stream.current.type == "name":
 50                key = parser.stream.current.value
 51                parser.stream.skip()
 52                parser.stream.expect("assign")
 53                value = parser.parse_expression()
 54                kwargs.append(nodes.Keyword(key, value))
 55
 56        body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
 57
 58        call = self.call_method(
 59            "_render_htmx_fragment",
 60            args=[fragment_name, nodes.ContextReference()],
 61            kwargs=kwargs,
 62        )
 63
 64        node = CallBlock(call, [], [], body).set_lineno(lineno)
 65
 66        # Store a reference to the node for later
 67        self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[  # type: ignore[attr-defined]
 68            fragment_name.value  # type: ignore[attr-defined]
 69        ] = node
 70
 71        return node
 72
 73    def _render_htmx_fragment(
 74        self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
 75    ) -> str:
 76        def attrs_to_str(attrs: dict[str, Any]) -> str:
 77            parts = []
 78            for k, v in attrs.items():
 79                if v == "":
 80                    parts.append(k)
 81                else:
 82                    parts.append(f'{k}="{v}"')
 83            return " ".join(parts)
 84
 85        render_lazy = kwargs.get("lazy", False)
 86        as_element = kwargs.get("as", "div")
 87        attrs = {}
 88        for k, v in kwargs.items():
 89            if k.startswith("hx_"):
 90                attrs[k.replace("_", "-")] = v
 91            else:
 92                attrs[k] = v
 93
 94        if render_lazy:
 95            attrs.setdefault("hx-swap", "outerHTML")
 96            attrs.setdefault("hx-target", "this")
 97            attrs.setdefault("hx-indicator", "this")
 98            attrs_str = attrs_to_str(attrs)
 99            return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
100        else:
101            # Swap innerHTML so we can re-run hx calls inside the fragment automatically
102            # (render_template_fragment won't render this part of the node again, just the inner nodes)
103            attrs.setdefault("hx-swap", "innerHTML")
104            attrs.setdefault("hx-target", "this")
105            attrs.setdefault("hx-indicator", "this")
106            # Add an id that you can use to target the fragment from outside the fragment
107            attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
108            attrs_str = attrs_to_str(attrs)
109            return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
110
111
112def render_template_fragment(
113    *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
114) -> str:
115    template = find_template_fragment(template, fragment_name)
116    return template.render(context)
117
118
119def find_template_fragment(
120    template: jinja2.Template, fragment_name: str
121) -> jinja2.Template:
122    # Look in this template for the fragment
123    callblock_node = template.environment.htmx_fragment_nodes.get(  # type: ignore[attr-defined]
124        template.name, {}
125    ).get(fragment_name)
126
127    if not callblock_node:
128        # Look in other templates for this fragment
129        matching_callblock_nodes = []
130        for fragments in template.environment.htmx_fragment_nodes.values():  # type: ignore[attr-defined]
131            if fragment_name in fragments:
132                matching_callblock_nodes.append(fragments[fragment_name])
133
134        if len(matching_callblock_nodes) == 0:
135            # If we still haven't found anything, it's possible that we're
136            # in a different/new worker/process and haven't parsed the related templates yet
137            if template.environment.loader and template.name:
138                ast = template.environment.parse(
139                    template.environment.loader.get_source(
140                        template.environment, template.name
141                    )[0]
142                )
143                for ref in meta.find_referenced_templates(ast):
144                    if ref not in template.environment.htmx_fragment_nodes:  # type: ignore[attr-defined]
145                        # Trigger them to parse
146                        template.environment.get_template(ref)
147
148                # Now look again
149                for fragments in template.environment.htmx_fragment_nodes.values():  # type: ignore[attr-defined]
150                    if fragment_name in fragments:
151                        matching_callblock_nodes.append(fragments[fragment_name])
152
153        if len(matching_callblock_nodes) == 1:
154            callblock_node = matching_callblock_nodes[0]
155        elif len(matching_callblock_nodes) > 1:
156            raise jinja2.TemplateNotFound(
157                f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
158            )
159        else:
160            raise jinja2.TemplateNotFound(
161                f"Fragment {fragment_name} not found in any templates"
162            )
163
164    if not callblock_node:
165        raise jinja2.TemplateNotFound(
166            f"Fragment {fragment_name} not found in template {template.name}"
167        )
168
169    # Create a new template from the node
170    template_node = nodes.Template(callblock_node.body)  # type: ignore[attr-defined]
171    return template.environment.from_string(template_node)  # type: ignore[arg-type]