1from __future__ import annotations
2
3import datetime
4from decimal import Decimal
5from types import NoneType
6from typing import Any
7from urllib.parse import quote
8
9from plain.utils.functional import Promise
10
11
12class PlainUnicodeDecodeError(UnicodeDecodeError):
13 def __init__(self, obj: Any, *args: Any):
14 self.obj = obj
15 super().__init__(*args)
16
17 def __str__(self) -> str:
18 return f"{super().__str__()}. You passed in {self.obj!r} ({type(self.obj)})"
19
20
21_PROTECTED_TYPES = (
22 NoneType,
23 int,
24 float,
25 Decimal,
26 datetime.datetime,
27 datetime.date,
28 datetime.time,
29)
30
31
32def is_protected_type(obj: Any) -> bool:
33 """Determine if the object instance is of a protected type.
34
35 Objects of protected types are preserved as-is when passed to
36 force_str(strings_only=True).
37 """
38 return isinstance(obj, _PROTECTED_TYPES)
39
40
41def force_str(
42 s: Any, encoding: str = "utf-8", strings_only: bool = False, errors: str = "strict"
43) -> str | Any:
44 """
45 Similar to smart_str(), except that lazy instances are resolved to
46 strings, rather than kept as lazy objects.
47
48 If strings_only is True, don't convert (some) non-string-like objects.
49 """
50 # Handle the common case first for performance reasons.
51 if issubclass(type(s), str):
52 return s
53 if strings_only and is_protected_type(s):
54 return s
55 try:
56 if isinstance(s, bytes):
57 s = str(s, encoding, errors)
58 else:
59 s = str(s)
60 except UnicodeDecodeError as e:
61 raise PlainUnicodeDecodeError(s, *e.args)
62 return s
63
64
65def force_bytes(
66 s: Any, encoding: str = "utf-8", strings_only: bool = False, errors: str = "strict"
67) -> bytes | Any:
68 """
69 Similar to smart_bytes, except that lazy instances are resolved to
70 strings, rather than kept as lazy objects.
71
72 If strings_only is True, don't convert (some) non-string-like objects.
73 """
74 # Handle the common case first for performance reasons.
75 if isinstance(s, bytes):
76 if encoding == "utf-8":
77 return s
78 else:
79 return s.decode("utf-8", errors).encode(encoding, errors)
80 if strings_only and is_protected_type(s):
81 return s
82 if isinstance(s, memoryview):
83 return bytes(s)
84 return str(s).encode(encoding, errors)
85
86
87def iri_to_uri(iri: str | Promise | None) -> str | None:
88 """
89 Convert an Internationalized Resource Identifier (IRI) portion to a URI
90 portion that is suitable for inclusion in a URL.
91
92 This is the algorithm from RFC 3987 Section 3.1, slightly simplified since
93 the input is assumed to be a string rather than an arbitrary byte stream.
94
95 Take an IRI (string or UTF-8 bytes, e.g. '/I ♥ Plain/' or
96 b'/I \xe2\x99\xa5 Plain/') and return a string containing the encoded
97 result with ASCII chars only (e.g. '/I%20%E2%99%A5%20Plain/').
98 """
99 # The list of safe characters here is constructed from the "reserved" and
100 # "unreserved" characters specified in RFC 3986 Sections 2.2 and 2.3:
101 # reserved = gen-delims / sub-delims
102 # gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@"
103 # sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
104 # / "*" / "+" / "," / ";" / "="
105 # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
106 # Of the unreserved characters, urllib.parse.quote() already considers all
107 # but the ~ safe.
108 # The % character is also added to the list of safe characters here, as the
109 # end of RFC 3987 Section 3.1 specifically mentions that % must not be
110 # converted.
111 if iri is None:
112 return iri
113 elif isinstance(iri, Promise):
114 iri = str(iri)
115 return quote(iri, safe="/#%[]=:;$&()+,!?*@'~")
116
117
118# List of byte values that uri_to_iri() decodes from percent encoding.
119# First, the unreserved characters from RFC 3986:
120_ascii_ranges = [[45, 46, 95, 126], range(65, 91), range(97, 123)]
121_hextobyte = {
122 (fmt % char).encode(): bytes((char,))
123 for ascii_range in _ascii_ranges
124 for char in ascii_range
125 for fmt in ["%02x", "%02X"]
126}
127# And then everything above 128, because bytes ≥ 128 are part of multibyte
128# Unicode characters.
129_hexdig = "0123456789ABCDEFabcdef"
130_hextobyte.update(
131 {(a + b).encode(): bytes.fromhex(a + b) for a in _hexdig[8:] for b in _hexdig}
132)
133
134
135def punycode(domain: str) -> str:
136 """Return the Punycode of the given domain if it's non-ASCII."""
137 return domain.encode("idna").decode("ascii")