1from __future__ import annotations
2
3from typing import Any
4
5import jinja2
6from jinja2 import meta, nodes
7from jinja2.ext import Extension
8from jinja2.nodes import CallBlock, Node
9from jinja2.parser import Parser
10
11from plain.runtime import settings
12from plain.templates import register_template_extension
13from plain.templates.jinja.extensions import InclusionTagExtension
14
15
16@register_template_extension
17class HTMXJSExtension(InclusionTagExtension):
18 tags = {"htmx_js"}
19 template_name = "htmx/js.html"
20
21 def get_context(
22 self, context: dict[str, Any], *args: Any, **kwargs: Any
23 ) -> dict[str, Any]:
24 request = context.get("request")
25 return {
26 "DEBUG": settings.DEBUG,
27 "extensions": kwargs.get("extensions", []),
28 "csp_nonce": request.csp_nonce if request else None, # type: ignore[attr-defined]
29 }
30
31
32@register_template_extension
33class HTMXFragmentExtension(Extension):
34 tags = {"htmxfragment"}
35
36 def __init__(self, environment: jinja2.Environment):
37 super().__init__(environment)
38 environment.extend(htmx_fragment_nodes={})
39
40 def parse(self, parser: Parser) -> Node:
41 lineno = next(parser.stream).lineno
42
43 fragment_name = parser.parse_expression()
44
45 kwargs = []
46
47 while parser.stream.current.type != "block_end":
48 if parser.stream.current.type == "name":
49 key = parser.stream.current.value
50 parser.stream.skip()
51 parser.stream.expect("assign")
52 value = parser.parse_expression()
53 kwargs.append(nodes.Keyword(key, value))
54
55 body = parser.parse_statements(("name:endhtmxfragment",), drop_needle=True)
56
57 call = self.call_method(
58 "_render_htmx_fragment",
59 args=[fragment_name, nodes.ContextReference()],
60 kwargs=kwargs,
61 )
62
63 node = CallBlock(call, [], [], body).set_lineno(lineno)
64
65 # Store a reference to the node for later
66 self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[ # type: ignore[attr-defined]
67 fragment_name.value # type: ignore[attr-defined]
68 ] = node
69
70 return node
71
72 def _render_htmx_fragment(
73 self, fragment_name: str, context: dict[str, Any], caller: Any, **kwargs: Any
74 ) -> str:
75 def attrs_to_str(attrs: dict[str, Any]) -> str:
76 parts = []
77 for k, v in attrs.items():
78 if v == "":
79 parts.append(k)
80 else:
81 parts.append(f'{k}="{v}"')
82 return " ".join(parts)
83
84 render_lazy = kwargs.get("lazy", False)
85 as_element = kwargs.get("as", "div")
86 attrs = {}
87 for k, v in kwargs.items():
88 if k.startswith("hx_"):
89 attrs[k.replace("_", "-")] = v
90 else:
91 attrs[k] = v
92
93 if render_lazy:
94 attrs.setdefault("hx-swap", "outerHTML")
95 attrs.setdefault("hx-target", "this")
96 attrs.setdefault("hx-indicator", "this")
97 attrs_str = attrs_to_str(attrs)
98 return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
99 else:
100 # Swap innerHTML so we can re-run hx calls inside the fragment automatically
101 # (render_template_fragment won't render this part of the node again, just the inner nodes)
102 attrs.setdefault("hx-swap", "innerHTML")
103 attrs.setdefault("hx-target", "this")
104 attrs.setdefault("hx-indicator", "this")
105 # Add an id that you can use to target the fragment from outside the fragment
106 attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
107 attrs_str = attrs_to_str(attrs)
108 return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
109
110
111def render_template_fragment(
112 *, template: jinja2.Template, fragment_name: str, context: dict[str, Any]
113) -> str:
114 template = find_template_fragment(template, fragment_name)
115 return template.render(context)
116
117
118def find_template_fragment(
119 template: jinja2.Template, fragment_name: str
120) -> jinja2.Template:
121 # Look in this template for the fragment
122 callblock_node = template.environment.htmx_fragment_nodes.get( # type: ignore[attr-defined]
123 template.name, {}
124 ).get(fragment_name)
125
126 if not callblock_node:
127 # Look in other templates for this fragment
128 matching_callblock_nodes = []
129 for fragments in template.environment.htmx_fragment_nodes.values(): # type: ignore[attr-defined]
130 if fragment_name in fragments:
131 matching_callblock_nodes.append(fragments[fragment_name])
132
133 if len(matching_callblock_nodes) == 0:
134 # If we still haven't found anything, it's possible that we're
135 # in a different/new worker/process and haven't parsed the related templates yet
136 if template.environment.loader and template.name:
137 ast = template.environment.parse(
138 template.environment.loader.get_source(
139 template.environment, template.name
140 )[0]
141 )
142 for ref in meta.find_referenced_templates(ast):
143 if ref not in template.environment.htmx_fragment_nodes: # type: ignore[attr-defined]
144 # Trigger them to parse
145 template.environment.get_template(ref)
146
147 # Now look again
148 for fragments in template.environment.htmx_fragment_nodes.values(): # type: ignore[attr-defined]
149 if fragment_name in fragments:
150 matching_callblock_nodes.append(fragments[fragment_name])
151
152 if len(matching_callblock_nodes) == 1:
153 callblock_node = matching_callblock_nodes[0]
154 elif len(matching_callblock_nodes) > 1:
155 raise jinja2.TemplateNotFound(
156 f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
157 )
158 else:
159 raise jinja2.TemplateNotFound(
160 f"Fragment {fragment_name} not found in any templates"
161 )
162
163 if not callblock_node:
164 raise jinja2.TemplateNotFound(
165 f"Fragment {fragment_name} not found in template {template.name}"
166 )
167
168 # Create a new template from the node
169 template_node = nodes.Template(callblock_node.body) # type: ignore[attr-defined]
170 return template.environment.from_string(template_node) # type: ignore[arg-type]