1from __future__ import annotations
  2
  3from typing import Any
  4
  5import jinja2
  6from jinja2 import nodes
  7from jinja2.ext import Extension
  8from jinja2.nodes import CallBlock, Node
  9from jinja2.parser import Parser
 10from jinja2.runtime import Context
 11
 12from plain.runtime import settings
 13from plain.templates import register_template_extension
 14from plain.templates.jinja.extensions import InclusionTagExtension
 15
 16
 17@register_template_extension
 18class HTMXJSExtension(InclusionTagExtension):
 19    tags = {"htmx_js"}
 20    template_name = "htmx/js.html"
 21
 22    def get_context(
 23        self, context: Context, *args: Any, **kwargs: Any
 24    ) -> dict[str, Any]:
 25        request = context.get("request")
 26        return {
 27            "DEBUG": settings.DEBUG,
 28            "extensions": kwargs.get("extensions", []),
 29            "csp_nonce": request.csp_nonce if request else None,
 30        }
 31
 32
 33class _FragmentFound(Exception):
 34    """Raised to short-circuit template rendering once the target fragment is found."""
 35
 36    def __init__(self, content: str) -> None:
 37        self.content = content
 38
 39
 40@register_template_extension
 41class HTMXFragmentExtension(Extension):
 42    tags = {"htmxfragment"}
 43
 44    def parse(self, parser: Parser) -> Node:
 45        lineno = next(parser.stream).lineno
 46
 47        fragment_name = parser.parse_expression()
 48
 49        kwargs = []
 50
 51        while parser.stream.current.type != "block_end":
 52            if parser.stream.current.type == "name":
 53                key = parser.stream.current.value
 54                parser.stream.skip()
 55                parser.stream.expect("assign")
 56                value = parser.parse_expression()
 57                kwargs.append(nodes.Keyword(key, value))
 58
 59        body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
 60
 61        call = self.call_method(
 62            "_render_htmx_fragment",
 63            args=[fragment_name, nodes.ContextReference()],
 64            kwargs=kwargs,
 65        )
 66
 67        callblock = CallBlock(call, [], [], body)
 68        callblock.set_lineno(lineno)
 69
 70        return callblock
 71
 72    def _render_htmx_fragment(
 73        self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
 74    ) -> str:
 75        # Two-phase fragment targeting (see render_template_fragment):
 76        # Phase 1 skips non-target bodies, phase 2 renders them for nesting.
 77        # Once the target is found, "found" is set so child fragments render
 78        # normally with their wrapper divs.
 79        target_state = context.get("_htmx_target_fragment")
 80        if target_state is not None and not target_state["found"]:
 81            if str(fragment_name) == target_state["name"]:
 82                target_state["found"] = True
 83                content = caller()
 84                raise _FragmentFound(content)
 85            elif target_state["render_bodies"]:
 86                return caller()
 87            else:
 88                return ""
 89
 90        def attrs_to_str(attrs: dict[str, Any]) -> str:
 91            parts = []
 92            for k, v in attrs.items():
 93                if v == "":
 94                    parts.append(k)
 95                else:
 96                    parts.append(f'{k}="{v}"')
 97            return " ".join(parts)
 98
 99        render_lazy = kwargs.get("lazy", False)
100        as_element = kwargs.get("as", "div")
101        attrs = {}
102        for k, v in kwargs.items():
103            if k in ("lazy", "as"):
104                continue
105            if k.startswith("hx_"):
106                attrs[k.replace("_", "-")] = v
107            else:
108                attrs[k] = v
109
110        if render_lazy:
111            attrs.setdefault("hx-trigger", "load from:body")
112            attrs.setdefault("hx-swap", "outerHTML")
113            attrs.setdefault("hx-target", "this")
114            attrs.setdefault("hx-indicator", "this")
115            attrs_str = attrs_to_str(attrs)
116            return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get {attrs_str}></{as_element}>'
117        else:
118            # Swap innerHTML so we can re-run hx calls inside the fragment automatically
119            attrs.setdefault("hx-swap", "innerHTML")
120            attrs.setdefault("hx-target", "this")
121            attrs.setdefault("hx-indicator", "this")
122            # Add an id that you can use to target the fragment from outside the fragment
123            attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
124            attrs_str = attrs_to_str(attrs)
125            return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
126
127
128def render_template_fragment(
129    *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
130) -> str:
131    """Render only the named fragment from a template.
132
133    Two-phase approach:
134    1. Skip non-target fragment bodies (fast — handles top-level and loop fragments)
135    2. If not found, render bodies too (handles fragments nested inside other fragments)
136
137    Raises _FragmentFound to short-circuit as soon as the target is found.
138    """
139    for render_bodies in (False, True):
140        target_state = {
141            "name": fragment_name,
142            "found": False,
143            "render_bodies": render_bodies,
144        }
145        try:
146            template.render({**context, "_htmx_target_fragment": target_state})
147        except _FragmentFound as e:
148            return e.content
149
150    raise jinja2.TemplateNotFound(
151        f"Fragment '{fragment_name}' not found in template {template.name}"
152    )