Plain is headed towards 1.0! Subscribe for development updates →

  1import jinja2
  2from jinja2 import meta, nodes
  3from jinja2.ext import Extension
  4
  5from plain.runtime import settings
  6from plain.templates import register_template_extension
  7from plain.templates.jinja.extensions import InclusionTagExtension
  8
  9
 10@register_template_extension
 11class HTMXJSExtension(InclusionTagExtension):
 12    tags = {"htmx_js"}
 13    template_name = "htmx/js.html"
 14
 15    def get_context(self, context, *args, **kwargs):
 16        return {
 17            "csrf_header": settings.CSRF_HEADER_NAME,
 18            "csrf_token": context["csrf_token"],
 19            "DEBUG": settings.DEBUG,
 20            "extensions": kwargs.get("extensions", []),
 21        }
 22
 23
 24@register_template_extension
 25class HTMXFragmentExtension(Extension):
 26    tags = {"htmxfragment"}
 27
 28    def __init__(self, environment):
 29        super().__init__(environment)
 30        environment.extend(htmx_fragment_nodes={})
 31
 32    def parse(self, parser):
 33        lineno = next(parser.stream).lineno
 34
 35        fragment_name = parser.parse_expression()
 36
 37        kwargs = []
 38
 39        while parser.stream.current.type != "block_end":
 40            if parser.stream.current.type == "name":
 41                key = parser.stream.current.value
 42                parser.stream.skip()
 43                parser.stream.expect("assign")
 44                value = parser.parse_expression()
 45                kwargs.append(nodes.Keyword(key, value))
 46
 47        body = parser.parse_statements(["name:endhtmxfragment"], drop_needle=True)
 48
 49        call = self.call_method(
 50            "_render_htmx_fragment",
 51            args=[fragment_name, jinja2.nodes.ContextReference()],
 52            kwargs=kwargs,
 53        )
 54
 55        node = jinja2.nodes.CallBlock(call, [], [], body).set_lineno(lineno)
 56
 57        # Store a reference to the node for later
 58        self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[
 59            fragment_name.value
 60        ] = node
 61
 62        return node
 63
 64    def _render_htmx_fragment(self, fragment_name, context, caller, **kwargs):
 65        def attrs_to_str(attrs):
 66            parts = []
 67            for k, v in attrs.items():
 68                if v == "":
 69                    parts.append(k)
 70                else:
 71                    parts.append(f'{k}="{v}"')
 72            return " ".join(parts)
 73
 74        render_lazy = kwargs.get("lazy", False)
 75        as_element = kwargs.get("as", "div")
 76        attrs = {}
 77        for k, v in kwargs.items():
 78            if k.startswith("hx_"):
 79                attrs[k.replace("_", "-")] = v
 80            else:
 81                attrs[k] = v
 82
 83        if render_lazy:
 84            attrs.setdefault("hx-swap", "outerHTML")
 85            attrs.setdefault("hx-target", "this")
 86            attrs.setdefault("hx-indicator", "this")
 87            attrs_str = attrs_to_str(attrs)
 88            return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="plainhtmx:load from:body" {attrs_str}></{as_element}>'
 89        else:
 90            # Swap innerHTML so we can re-run hx calls inside the fragment automatically
 91            # (render_template_fragment won't render this part of the node again, just the inner nodes)
 92            attrs.setdefault("hx-swap", "innerHTML")
 93            attrs.setdefault("hx-target", "this")
 94            attrs.setdefault("hx-indicator", "this")
 95            # Add an id that you can use to target the fragment from outside the fragment
 96            attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
 97            attrs_str = attrs_to_str(attrs)
 98            return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
 99
100
101def render_template_fragment(*, template, fragment_name, context):
102    template = find_template_fragment(template, fragment_name)
103    return template.render(context)
104
105
106def find_template_fragment(template: jinja2.Template, fragment_name: str):
107    # Look in this template for the fragment
108    callblock_node = template.environment.htmx_fragment_nodes.get(
109        template.name, {}
110    ).get(fragment_name)
111
112    if not callblock_node:
113        # Look in other templates for this fragment
114        matching_callblock_nodes = []
115        for fragments in template.environment.htmx_fragment_nodes.values():
116            if fragment_name in fragments:
117                matching_callblock_nodes.append(fragments[fragment_name])
118
119        if len(matching_callblock_nodes) == 0:
120            # If we still haven't found anything, it's possible that we're
121            # in a different/new worker/process and haven't parsed the related templates yet
122            ast = template.environment.parse(
123                template.environment.loader.get_source(
124                    template.environment, template.name
125                )[0]
126            )
127            for ref in meta.find_referenced_templates(ast):
128                if ref not in template.environment.htmx_fragment_nodes:
129                    # Trigger them to parse
130                    template.environment.get_template(ref)
131
132            # Now look again
133            for fragments in template.environment.htmx_fragment_nodes.values():
134                if fragment_name in fragments:
135                    matching_callblock_nodes.append(fragments[fragment_name])
136
137        if len(matching_callblock_nodes) == 1:
138            callblock_node = matching_callblock_nodes[0]
139        elif len(matching_callblock_nodes) > 1:
140            raise jinja2.TemplateNotFound(
141                f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
142            )
143        else:
144            raise jinja2.TemplateNotFound(
145                f"Fragment {fragment_name} not found in any templates"
146            )
147
148    if not callblock_node:
149        raise jinja2.TemplateNotFound(
150            f"Fragment {fragment_name} not found in template {template.name}"
151        )
152
153    # Create a new template from the node
154    template_node = jinja2.nodes.Template(callblock_node.body)
155    return template.environment.from_string(template_node)