1from __future__ import annotations
2
3from typing import Any
4
5import jinja2
6from jinja2 import nodes
7from jinja2.ext import Extension
8from jinja2.nodes import CallBlock, Node
9from jinja2.parser import Parser
10from jinja2.runtime import Context
11
12from plain.runtime import settings
13from plain.templates import register_template_extension
14from plain.templates.jinja.extensions import InclusionTagExtension
15
16
17@register_template_extension
18class HTMXJSExtension(InclusionTagExtension):
19 tags = {"htmx_js"}
20 template_name = "htmx/js.html"
21
22 def get_context(
23 self, context: Context, *args: Any, **kwargs: Any
24 ) -> dict[str, Any]:
25 request = context.get("request")
26 return {
27 "DEBUG": settings.DEBUG,
28 "extensions": kwargs.get("extensions", []),
29 "csp_nonce": request.csp_nonce if request else None,
30 }
31
32
33class _FragmentFound(Exception):
34 """Raised to short-circuit template rendering once the target fragment is found."""
35
36 def __init__(self, content: str) -> None:
37 self.content = content
38
39
40@register_template_extension
41class HTMXFragmentExtension(Extension):
42 tags = {"htmxfragment"}
43
44 def parse(self, parser: Parser) -> Node:
45 lineno = next(parser.stream).lineno
46
47 fragment_name = parser.parse_expression()
48
49 kwargs = []
50
51 while parser.stream.current.type != "block_end":
52 if parser.stream.current.type == "name":
53 key = parser.stream.current.value
54 parser.stream.skip()
55 parser.stream.expect("assign")
56 value = parser.parse_expression()
57 kwargs.append(nodes.Keyword(key, value))
58
59 body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
60
61 call = self.call_method(
62 "_render_htmx_fragment",
63 args=[fragment_name, nodes.ContextReference()],
64 kwargs=kwargs,
65 )
66
67 callblock = CallBlock(call, [], [], body)
68 callblock.set_lineno(lineno)
69
70 return callblock
71
72 def _render_htmx_fragment(
73 self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
74 ) -> str:
75 # Two-phase fragment targeting (see render_template_fragment):
76 # Phase 1 skips non-target bodies, phase 2 renders them for nesting.
77 # Once the target is found, "found" is set so child fragments render
78 # normally with their wrapper divs.
79 target_state = context.get("_htmx_target_fragment")
80 if target_state is not None and not target_state["found"]:
81 if str(fragment_name) == target_state["name"]:
82 target_state["found"] = True
83 content = caller()
84 raise _FragmentFound(content)
85 elif target_state["render_bodies"]:
86 return caller()
87 else:
88 return ""
89
90 def attrs_to_str(attrs: dict[str, Any]) -> str:
91 parts = []
92 for k, v in attrs.items():
93 if v == "":
94 parts.append(k)
95 else:
96 parts.append(f'{k}="{v}"')
97 return " ".join(parts)
98
99 render_lazy = kwargs.get("lazy", False)
100 as_element = kwargs.get("as", "div")
101 attrs = {}
102 for k, v in kwargs.items():
103 if k in ("lazy", "as"):
104 continue
105 if k.startswith("hx_"):
106 attrs[k.replace("_", "-")] = v
107 else:
108 attrs[k] = v
109
110 if render_lazy:
111 attrs.setdefault("hx-trigger", "load from:body")
112 attrs.setdefault("hx-swap", "outerHTML")
113 attrs.setdefault("hx-target", "this")
114 attrs.setdefault("hx-indicator", "this")
115 attrs_str = attrs_to_str(attrs)
116 return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get {attrs_str}></{as_element}>'
117 else:
118 # Swap innerHTML so we can re-run hx calls inside the fragment automatically
119 attrs.setdefault("hx-swap", "innerHTML")
120 attrs.setdefault("hx-target", "this")
121 attrs.setdefault("hx-indicator", "this")
122 # Add an id that you can use to target the fragment from outside the fragment
123 attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
124 attrs_str = attrs_to_str(attrs)
125 return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
126
127
128def render_template_fragment(
129 *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
130) -> str:
131 """Render only the named fragment from a template.
132
133 Two-phase approach:
134 1. Skip non-target fragment bodies (fast — handles top-level and loop fragments)
135 2. If not found, render bodies too (handles fragments nested inside other fragments)
136
137 Raises _FragmentFound to short-circuit as soon as the target is found.
138 """
139 for render_bodies in (False, True):
140 target_state = {
141 "name": fragment_name,
142 "found": False,
143 "render_bodies": render_bodies,
144 }
145 try:
146 template.render({**context, "_htmx_target_fragment": target_state})
147 except _FragmentFound as e:
148 return e.content
149
150 raise jinja2.TemplateNotFound(
151 f"Fragment '{fragment_name}' not found in template {template.name}"
152 )