1"""HTML utilities suitable for global use."""
2
3from __future__ import annotations
4
5import html
6import json
7from html.parser import HTMLParser
8from typing import Any
9
10from plain.utils.functional import Promise, keep_lazy, keep_lazy_text
11from plain.utils.safestring import SafeString, mark_safe
12
13
14@keep_lazy(SafeString)
15def escape(text: Any) -> SafeString:
16 """
17 Return the given text with ampersands, quotes and angle brackets encoded
18 for use in HTML.
19
20 Always escape input, even if it's already escaped and marked as such.
21 This may result in double-escaping. If this is a concern, use
22 conditional_escape() instead.
23 """
24 return SafeString(html.escape(str(text)))
25
26
27_json_script_escapes = {
28 ord(">"): "\\u003E",
29 ord("<"): "\\u003C",
30 ord("&"): "\\u0026",
31}
32
33
34def json_script(
35 value: Any,
36 element_id: str | None = None,
37 nonce: str = "",
38 encoder: type[json.JSONEncoder] | None = None,
39) -> SafeString:
40 """
41 Escape all the HTML/XML special characters with their unicode escapes, so
42 value is safe to be output anywhere except for inside a tag attribute. Wrap
43 the escaped JSON in a script tag.
44
45 Args:
46 value: The data to encode as JSON
47 element_id: Optional ID attribute for the script tag
48 nonce: Optional CSP nonce for inline script tags
49 encoder: Optional custom JSON encoder class
50 """
51 from plain.json import PlainJSONEncoder
52
53 json_str = json.dumps(value, cls=encoder or PlainJSONEncoder).translate(
54 _json_script_escapes
55 )
56 id_attr = f' id="{element_id}"' if element_id else ""
57 nonce_attr = f' nonce="{nonce}"' if nonce else ""
58 return mark_safe(
59 f'<script{id_attr}{nonce_attr} type="application/json">{json_str}</script>'
60 )
61
62
63def conditional_escape(text: Any) -> SafeString | str:
64 """
65 Similar to escape(), except that it doesn't operate on pre-escaped strings.
66
67 This function relies on the __html__ convention used both by Plain's
68 SafeData class and by third-party libraries like markupsafe.
69 """
70 if isinstance(text, Promise):
71 text = str(text)
72 if hasattr(text, "__html__"):
73 return text.__html__() # ty: ignore[call-non-callable]
74 else:
75 return escape(text)
76
77
78def format_html(format_string: str, *args: Any, **kwargs: Any) -> SafeString:
79 """
80 Similar to str.format, but pass all arguments through conditional_escape(),
81 and call mark_safe() on the result. This function should be used instead
82 of str.format or % interpolation to build up small HTML fragments.
83 """
84 args_safe = map(conditional_escape, args)
85 kwargs_safe = {k: conditional_escape(v) for (k, v) in kwargs.items()}
86 return mark_safe(format_string.format(*args_safe, **kwargs_safe))
87
88
89class MLStripper(HTMLParser):
90 def __init__(self) -> None:
91 super().__init__(convert_charrefs=False)
92 self.reset()
93 self.fed: list[str] = []
94
95 def handle_data(self, data: str) -> None:
96 self.fed.append(data)
97
98 def handle_entityref(self, name: str) -> None:
99 self.fed.append(f"&{name};")
100
101 def handle_charref(self, name: str) -> None:
102 self.fed.append(f"&#{name};")
103
104 def get_data(self) -> str:
105 return "".join(self.fed)
106
107
108def _strip_once(value: str) -> str:
109 """
110 Internal tag stripping utility used by strip_tags.
111 """
112 s = MLStripper()
113 s.feed(value)
114 s.close()
115 return s.get_data()
116
117
118@keep_lazy_text
119def strip_tags(value: Any) -> str:
120 """Return the given HTML with all tags stripped."""
121 # Note: in typical case this loop executes _strip_once once. Loop condition
122 # is redundant, but helps to reduce number of executions of _strip_once.
123 value = str(value)
124 while "<" in value and ">" in value:
125 new_value = _strip_once(value)
126 if value.count("<") == new_value.count("<"):
127 # _strip_once wasn't able to detect more tags.
128 break
129 value = new_value
130 return value
131
132
133def avoid_wrapping(value: str) -> str:
134 """
135 Avoid text wrapping in the middle of a phrase by adding non-breaking
136 spaces where there previously were normal spaces.
137 """
138 return value.replace(" ", "\xa0")