1from __future__ import annotations
2
3from typing import Any
4
5import jinja2
6from jinja2 import meta, nodes
7from jinja2.ext import Extension
8from jinja2.nodes import CallBlock, Node
9from jinja2.parser import Parser
10from jinja2.runtime import Context
11
12from plain.runtime import settings
13from plain.templates import register_template_extension
14from plain.templates.jinja.extensions import InclusionTagExtension
15
16
17@register_template_extension
18class HTMXJSExtension(InclusionTagExtension):
19 tags = {"htmx_js"}
20 template_name = "htmx/js.html"
21
22 def get_context(
23 self, context: Context, *args: Any, **kwargs: Any
24 ) -> dict[str, Any]:
25 request = context.get("request")
26 return {
27 "DEBUG": settings.DEBUG,
28 "extensions": kwargs.get("extensions", []),
29 "csp_nonce": request.csp_nonce if request else None, # type: ignore[attr-defined]
30 }
31
32
33@register_template_extension
34class HTMXFragmentExtension(Extension):
35 tags = {"htmxfragment"}
36
37 def __init__(self, environment: jinja2.Environment):
38 super().__init__(environment)
39 environment.extend(htmx_fragment_nodes={})
40
41 def parse(self, parser: Parser) -> Node:
42 lineno = next(parser.stream).lineno
43
44 fragment_name = parser.parse_expression()
45
46 kwargs = []
47
48 while parser.stream.current.type != "block_end":
49 if parser.stream.current.type == "name":
50 key = parser.stream.current.value
51 parser.stream.skip()
52 parser.stream.expect("assign")
53 value = parser.parse_expression()
54 kwargs.append(nodes.Keyword(key, value))
55
56 body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
57
58 call = self.call_method(
59 "_render_htmx_fragment",
60 args=[fragment_name, nodes.ContextReference()],
61 kwargs=kwargs,
62 )
63
64 node = CallBlock(call, [], [], body).set_lineno(lineno)
65
66 # Store a reference to the node for later
67 self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[ # type: ignore[attr-defined]
68 fragment_name.value # type: ignore[attr-defined]
69 ] = node
70
71 return node
72
73 def _render_htmx_fragment(
74 self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
75 ) -> str:
76 def attrs_to_str(attrs: dict[str, Any]) -> str:
77 parts = []
78 for k, v in attrs.items():
79 if v == "":
80 parts.append(k)
81 else:
82 parts.append(f'{k}="{v}"')
83 return " ".join(parts)
84
85 render_lazy = kwargs.get("lazy", False)
86 as_element = kwargs.get("as", "div")
87 attrs = {}
88 for k, v in kwargs.items():
89 if k.startswith("hx_"):
90 attrs[k.replace("_", "-")] = v
91 else:
92 attrs[k] = v
93
94 if render_lazy:
95 attrs.setdefault("hx-swap", "outerHTML")
96 attrs.setdefault("hx-target", "this")
97 attrs.setdefault("hx-indicator", "this")
98 attrs_str = attrs_to_str(attrs)
99 return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
100 else:
101 # Swap innerHTML so we can re-run hx calls inside the fragment automatically
102 # (render_template_fragment won't render this part of the node again, just the inner nodes)
103 attrs.setdefault("hx-swap", "innerHTML")
104 attrs.setdefault("hx-target", "this")
105 attrs.setdefault("hx-indicator", "this")
106 # Add an id that you can use to target the fragment from outside the fragment
107 attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
108 attrs_str = attrs_to_str(attrs)
109 return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
110
111
112def render_template_fragment(
113 *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
114) -> str:
115 template = find_template_fragment(template, fragment_name)
116 return template.render(context)
117
118
119def find_template_fragment(
120 template: jinja2.Template, fragment_name: str
121) -> jinja2.Template:
122 # Look in this template for the fragment
123 callblock_node = template.environment.htmx_fragment_nodes.get( # type: ignore[attr-defined]
124 template.name, {}
125 ).get(fragment_name)
126
127 if not callblock_node:
128 # Look in other templates for this fragment
129 matching_callblock_nodes = []
130 for fragments in template.environment.htmx_fragment_nodes.values(): # type: ignore[attr-defined]
131 if fragment_name in fragments:
132 matching_callblock_nodes.append(fragments[fragment_name])
133
134 if len(matching_callblock_nodes) == 0:
135 # If we still haven't found anything, it's possible that we're
136 # in a different/new worker/process and haven't parsed the related templates yet
137 if template.environment.loader and template.name:
138 ast = template.environment.parse(
139 template.environment.loader.get_source(
140 template.environment, template.name
141 )[0]
142 )
143 for ref in meta.find_referenced_templates(ast):
144 if ref not in template.environment.htmx_fragment_nodes: # type: ignore[attr-defined]
145 # Trigger them to parse
146 template.environment.get_template(ref)
147
148 # Now look again
149 for fragments in template.environment.htmx_fragment_nodes.values(): # type: ignore[attr-defined]
150 if fragment_name in fragments:
151 matching_callblock_nodes.append(fragments[fragment_name])
152
153 if len(matching_callblock_nodes) == 1:
154 callblock_node = matching_callblock_nodes[0]
155 elif len(matching_callblock_nodes) > 1:
156 raise jinja2.TemplateNotFound(
157 f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
158 )
159 else:
160 raise jinja2.TemplateNotFound(
161 f"Fragment {fragment_name} not found in any templates"
162 )
163
164 if not callblock_node:
165 raise jinja2.TemplateNotFound(
166 f"Fragment {fragment_name} not found in template {template.name}"
167 )
168
169 # Create a new template from the node
170 template_node = nodes.Template(callblock_node.body) # type: ignore[attr-defined]
171 return template.environment.from_string(template_node) # type: ignore[arg-type]