Plain is headed towards 1.0! Subscribe for development updates →

  1import jinja2
  2from jinja2 import meta, nodes
  3from jinja2.ext import Extension
  4
  5from plain.runtime import settings
  6from plain.templates import register_template_extension
  7from plain.templates.jinja.extensions import InclusionTagExtension
  8
  9
 10@register_template_extension
 11class HTMXJSExtension(InclusionTagExtension):
 12    tags = {"htmx_js"}
 13    template_name = "htmx/js.html"
 14
 15    def get_context(self, context, *args, **kwargs):
 16        return {
 17            "DEBUG": settings.DEBUG,
 18            "extensions": kwargs.get("extensions", []),
 19        }
 20
 21
 22@register_template_extension
 23class HTMXFragmentExtension(Extension):
 24    tags = {"htmxfragment"}
 25
 26    def __init__(self, environment):
 27        super().__init__(environment)
 28        environment.extend(htmx_fragment_nodes={})
 29
 30    def parse(self, parser):
 31        lineno = next(parser.stream).lineno
 32
 33        fragment_name = parser.parse_expression()
 34
 35        kwargs = []
 36
 37        while parser.stream.current.type != "block_end":
 38            if parser.stream.current.type == "name":
 39                key = parser.stream.current.value
 40                parser.stream.skip()
 41                parser.stream.expect("assign")
 42                value = parser.parse_expression()
 43                kwargs.append(nodes.Keyword(key, value))
 44
 45        body = parser.parse_statements(["name:endhtmxfragment"], drop_needle=True)
 46
 47        call = self.call_method(
 48            "_render_htmx_fragment",
 49            args=[fragment_name, jinja2.nodes.ContextReference()],
 50            kwargs=kwargs,
 51        )
 52
 53        node = jinja2.nodes.CallBlock(call, [], [], body).set_lineno(lineno)
 54
 55        # Store a reference to the node for later
 56        self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[
 57            fragment_name.value
 58        ] = node
 59
 60        return node
 61
 62    def _render_htmx_fragment(self, fragment_name, context, caller, **kwargs):
 63        def attrs_to_str(attrs):
 64            parts = []
 65            for k, v in attrs.items():
 66                if v == "":
 67                    parts.append(k)
 68                else:
 69                    parts.append(f'{k}="{v}"')
 70            return " ".join(parts)
 71
 72        render_lazy = kwargs.get("lazy", False)
 73        as_element = kwargs.get("as", "div")
 74        attrs = {}
 75        for k, v in kwargs.items():
 76            if k.startswith("hx_"):
 77                attrs[k.replace("_", "-")] = v
 78            else:
 79                attrs[k] = v
 80
 81        if render_lazy:
 82            attrs.setdefault("hx-swap", "outerHTML")
 83            attrs.setdefault("hx-target", "this")
 84            attrs.setdefault("hx-indicator", "this")
 85            attrs_str = attrs_to_str(attrs)
 86            return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
 87        else:
 88            # Swap innerHTML so we can re-run hx calls inside the fragment automatically
 89            # (render_template_fragment won't render this part of the node again, just the inner nodes)
 90            attrs.setdefault("hx-swap", "innerHTML")
 91            attrs.setdefault("hx-target", "this")
 92            attrs.setdefault("hx-indicator", "this")
 93            # Add an id that you can use to target the fragment from outside the fragment
 94            attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
 95            attrs_str = attrs_to_str(attrs)
 96            return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
 97
 98
 99def render_template_fragment(*, template, fragment_name, context):
100    template = find_template_fragment(template, fragment_name)
101    return template.render(context)
102
103
104def find_template_fragment(template: jinja2.Template, fragment_name: str):
105    # Look in this template for the fragment
106    callblock_node = template.environment.htmx_fragment_nodes.get(
107        template.name, {}
108    ).get(fragment_name)
109
110    if not callblock_node:
111        # Look in other templates for this fragment
112        matching_callblock_nodes = []
113        for fragments in template.environment.htmx_fragment_nodes.values():
114            if fragment_name in fragments:
115                matching_callblock_nodes.append(fragments[fragment_name])
116
117        if len(matching_callblock_nodes) == 0:
118            # If we still haven't found anything, it's possible that we're
119            # in a different/new worker/process and haven't parsed the related templates yet
120            ast = template.environment.parse(
121                template.environment.loader.get_source(
122                    template.environment, template.name
123                )[0]
124            )
125            for ref in meta.find_referenced_templates(ast):
126                if ref not in template.environment.htmx_fragment_nodes:
127                    # Trigger them to parse
128                    template.environment.get_template(ref)
129
130            # Now look again
131            for fragments in template.environment.htmx_fragment_nodes.values():
132                if fragment_name in fragments:
133                    matching_callblock_nodes.append(fragments[fragment_name])
134
135        if len(matching_callblock_nodes) == 1:
136            callblock_node = matching_callblock_nodes[0]
137        elif len(matching_callblock_nodes) > 1:
138            raise jinja2.TemplateNotFound(
139                f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
140            )
141        else:
142            raise jinja2.TemplateNotFound(
143                f"Fragment {fragment_name} not found in any templates"
144            )
145
146    if not callblock_node:
147        raise jinja2.TemplateNotFound(
148            f"Fragment {fragment_name} not found in template {template.name}"
149        )
150
151    # Create a new template from the node
152    template_node = jinja2.nodes.Template(callblock_node.body)
153    return template.environment.from_string(template_node)