1import jinja2
2from jinja2 import meta, nodes
3from jinja2.ext import Extension
4
5from plain.runtime import settings
6from plain.templates import register_template_extension
7from plain.templates.jinja.extensions import InclusionTagExtension
8
9
10@register_template_extension
11class HTMXJSExtension(InclusionTagExtension):
12 tags = {"htmx_js"}
13 template_name = "htmx/js.html"
14
15 def get_context(self, context, *args, **kwargs):
16 return {
17 "DEBUG": settings.DEBUG,
18 "extensions": kwargs.get("extensions", []),
19 }
20
21
22@register_template_extension
23class HTMXFragmentExtension(Extension):
24 tags = {"htmxfragment"}
25
26 def __init__(self, environment):
27 super().__init__(environment)
28 environment.extend(htmx_fragment_nodes={})
29
30 def parse(self, parser):
31 lineno = next(parser.stream).lineno
32
33 fragment_name = parser.parse_expression()
34
35 kwargs = []
36
37 while parser.stream.current.type != "block_end":
38 if parser.stream.current.type == "name":
39 key = parser.stream.current.value
40 parser.stream.skip()
41 parser.stream.expect("assign")
42 value = parser.parse_expression()
43 kwargs.append(nodes.Keyword(key, value))
44
45 body = parser.parse_statements(["name:endhtmxfragment"], drop_needle=True)
46
47 call = self.call_method(
48 "_render_htmx_fragment",
49 args=[fragment_name, jinja2.nodes.ContextReference()],
50 kwargs=kwargs,
51 )
52
53 node = jinja2.nodes.CallBlock(call, [], [], body).set_lineno(lineno)
54
55 # Store a reference to the node for later
56 self.environment.htmx_fragment_nodes.setdefault(parser.name, {})[
57 fragment_name.value
58 ] = node
59
60 return node
61
62 def _render_htmx_fragment(self, fragment_name, context, caller, **kwargs):
63 def attrs_to_str(attrs):
64 parts = []
65 for k, v in attrs.items():
66 if v == "":
67 parts.append(k)
68 else:
69 parts.append(f'{k}="{v}"')
70 return " ".join(parts)
71
72 render_lazy = kwargs.get("lazy", False)
73 as_element = kwargs.get("as", "div")
74 attrs = {}
75 for k, v in kwargs.items():
76 if k.startswith("hx_"):
77 attrs[k.replace("_", "-")] = v
78 else:
79 attrs[k] = v
80
81 if render_lazy:
82 attrs.setdefault("hx-swap", "outerHTML")
83 attrs.setdefault("hx-target", "this")
84 attrs.setdefault("hx-indicator", "this")
85 attrs_str = attrs_to_str(attrs)
86 return f'<{as_element} plain-hx-fragment="{fragment_name}" hx-get hx-trigger="load from:body" {attrs_str}></{as_element}>'
87 else:
88 # Swap innerHTML so we can re-run hx calls inside the fragment automatically
89 # (render_template_fragment won't render this part of the node again, just the inner nodes)
90 attrs.setdefault("hx-swap", "innerHTML")
91 attrs.setdefault("hx-target", "this")
92 attrs.setdefault("hx-indicator", "this")
93 # Add an id that you can use to target the fragment from outside the fragment
94 attrs.setdefault("id", f"plain-hx-fragment-{fragment_name}")
95 attrs_str = attrs_to_str(attrs)
96 return f'<{as_element} plain-hx-fragment="{fragment_name}" {attrs_str}>{caller()}</{as_element}>'
97
98
99def render_template_fragment(*, template, fragment_name, context):
100 template = find_template_fragment(template, fragment_name)
101 return template.render(context)
102
103
104def find_template_fragment(template: jinja2.Template, fragment_name: str):
105 # Look in this template for the fragment
106 callblock_node = template.environment.htmx_fragment_nodes.get(
107 template.name, {}
108 ).get(fragment_name)
109
110 if not callblock_node:
111 # Look in other templates for this fragment
112 matching_callblock_nodes = []
113 for fragments in template.environment.htmx_fragment_nodes.values():
114 if fragment_name in fragments:
115 matching_callblock_nodes.append(fragments[fragment_name])
116
117 if len(matching_callblock_nodes) == 0:
118 # If we still haven't found anything, it's possible that we're
119 # in a different/new worker/process and haven't parsed the related templates yet
120 ast = template.environment.parse(
121 template.environment.loader.get_source(
122 template.environment, template.name
123 )[0]
124 )
125 for ref in meta.find_referenced_templates(ast):
126 if ref not in template.environment.htmx_fragment_nodes:
127 # Trigger them to parse
128 template.environment.get_template(ref)
129
130 # Now look again
131 for fragments in template.environment.htmx_fragment_nodes.values():
132 if fragment_name in fragments:
133 matching_callblock_nodes.append(fragments[fragment_name])
134
135 if len(matching_callblock_nodes) == 1:
136 callblock_node = matching_callblock_nodes[0]
137 elif len(matching_callblock_nodes) > 1:
138 raise jinja2.TemplateNotFound(
139 f"Fragment {fragment_name} found in multiple templates. Use a more specific name."
140 )
141 else:
142 raise jinja2.TemplateNotFound(
143 f"Fragment {fragment_name} not found in any templates"
144 )
145
146 if not callblock_node:
147 raise jinja2.TemplateNotFound(
148 f"Fragment {fragment_name} not found in template {template.name}"
149 )
150
151 # Create a new template from the node
152 template_node = jinja2.nodes.Template(callblock_node.body)
153 return template.environment.from_string(template_node)