1"""
2URL-safe signed JSON objects using HMAC/SHA-256.
3
4Use TimestampSigner for signing with expiration:
5
6 TimestampSigner(salt="my-salt").sign_object({"key": "value"})
7 TimestampSigner(salt="my-salt").unsign_object(token, max_age=3600)
8
9Use Signer for signing without expiration:
10
11 Signer(salt="my-salt").sign_object({"key": "value"})
12 Signer(salt="my-salt").unsign_object(token)
13"""
14
15from __future__ import annotations
16
17import base64
18import datetime
19import hmac
20import json
21import time
22import zlib
23from typing import Any
24
25from plain.runtime import settings
26from plain.utils.crypto import salted_hmac
27from plain.utils.encoding import force_bytes
28from plain.utils.regex_helper import _lazy_re_compile
29
30_SEP_UNSAFE = _lazy_re_compile(r"^[A-z0-9-_=]*$")
31BASE62_ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
32
33
34class BadSignature(Exception):
35 """Signature does not match."""
36
37 pass
38
39
40class SignatureExpired(BadSignature):
41 """Signature timestamp is older than required max_age."""
42
43 pass
44
45
46def b62_encode(s: int) -> str:
47 if s == 0:
48 return "0"
49 sign = "-" if s < 0 else ""
50 s = abs(s)
51 encoded = ""
52 while s > 0:
53 s, remainder = divmod(s, 62)
54 encoded = BASE62_ALPHABET[remainder] + encoded
55 return sign + encoded
56
57
58def b62_decode(s: str) -> int:
59 if s == "0":
60 return 0
61 sign = 1
62 if s[0] == "-":
63 s = s[1:]
64 sign = -1
65 decoded = 0
66 for digit in s:
67 decoded = decoded * 62 + BASE62_ALPHABET.index(digit)
68 return sign * decoded
69
70
71def b64_encode(s: bytes) -> bytes:
72 return base64.urlsafe_b64encode(s).strip(b"=")
73
74
75def b64_decode(s: bytes) -> bytes:
76 pad = b"=" * (-len(s) % 4)
77 return base64.urlsafe_b64decode(s + pad)
78
79
80def base64_hmac(salt: str, value: str, key: str, algorithm: str = "sha1") -> str:
81 return b64_encode(
82 salted_hmac(salt, value, key, algorithm=algorithm).digest()
83 ).decode()
84
85
86class JSONSerializer:
87 """
88 Simple wrapper around json used by Signer.sign_object and
89 Signer.unsign_object.
90 """
91
92 def dumps(self, obj: Any) -> bytes:
93 return json.dumps(obj, separators=(",", ":")).encode("latin-1")
94
95 def loads(self, data: bytes) -> Any:
96 return json.loads(data.decode("latin-1"))
97
98
99class Signer:
100 def __init__(
101 self,
102 *,
103 key: str | None = None,
104 sep: str = ":",
105 salt: str | None = None,
106 algorithm: str = "sha256",
107 fallback_keys: list[str] | None = None,
108 ) -> None:
109 self.key = key or settings.SECRET_KEY
110 self.fallback_keys = (
111 fallback_keys
112 if fallback_keys is not None
113 else settings.SECRET_KEY_FALLBACKS
114 )
115 self.sep = sep
116 self.salt = salt or f"{self.__class__.__module__}.{self.__class__.__name__}"
117 self.algorithm = algorithm
118
119 if _SEP_UNSAFE.match(self.sep):
120 raise ValueError(
121 f"Unsafe Signer separator: {sep!r} (cannot be empty or consist of "
122 "only A-z0-9-_=)",
123 )
124
125 def signature(self, value: str, key: str | None = None) -> str:
126 key = key or self.key
127 return base64_hmac(self.salt + "signer", value, key, algorithm=self.algorithm)
128
129 def sign(self, value: str) -> str:
130 return f"{value}{self.sep}{self.signature(value)}"
131
132 def unsign(self, signed_value: str) -> str:
133 if self.sep not in signed_value:
134 raise BadSignature(f'No "{self.sep}" found in value')
135 value, sig = signed_value.rsplit(self.sep, 1)
136 for key in [self.key, *self.fallback_keys]:
137 if hmac.compare_digest(
138 force_bytes(sig), force_bytes(self.signature(value, key))
139 ):
140 return value
141 raise BadSignature(f'Signature "{sig}" does not match')
142
143 def sign_object(
144 self,
145 obj: Any,
146 serializer: type[JSONSerializer] = JSONSerializer,
147 compress: bool = False,
148 ) -> str:
149 """
150 Return URL-safe, hmac signed base64 compressed JSON string.
151
152 If compress is True (not the default), check if compressing using zlib
153 can save some space. Prepend a '.' to signify compression. This is
154 included in the signature, to protect against zip bombs.
155
156 The serializer is expected to return a bytestring.
157 """
158 data = serializer().dumps(obj)
159 # Flag for if it's been compressed or not.
160 is_compressed = False
161
162 if compress:
163 # Avoid zlib dependency unless compress is being used.
164 compressed = zlib.compress(data)
165 if len(compressed) < (len(data) - 1):
166 data = compressed
167 is_compressed = True
168 base64d = b64_encode(data).decode()
169 if is_compressed:
170 base64d = "." + base64d
171 return self.sign(base64d)
172
173 def unsign_object(
174 self,
175 signed_obj: str,
176 serializer: type[JSONSerializer] = JSONSerializer,
177 **kwargs: Any,
178 ) -> Any:
179 # Signer.unsign() returns str but base64 and zlib compression operate
180 # on bytes.
181 base64d = self.unsign(signed_obj, **kwargs).encode()
182 decompress = base64d[:1] == b"."
183 if decompress:
184 # It's compressed; uncompress it first.
185 base64d = base64d[1:]
186 data = b64_decode(base64d)
187 if decompress:
188 data = zlib.decompress(data)
189 return serializer().loads(data)
190
191
192class TimestampSigner:
193 """A signer that includes a timestamp for max_age validation.
194
195 Uses composition rather than inheritance since the interface
196 intentionally differs from Signer (unsign accepts max_age parameter).
197 """
198
199 def __init__(
200 self,
201 *,
202 key: str | None = None,
203 sep: str = ":",
204 salt: str | None = None,
205 algorithm: str = "sha256",
206 fallback_keys: list[str] | None = None,
207 ) -> None:
208 # Compute default salt here to preserve backwards compatibility.
209 # When TimestampSigner inherited from Signer, the default salt was
210 # "plain.signing.TimestampSigner". Now that we use composition,
211 # we must set it explicitly rather than letting Signer compute its own.
212 if salt is None:
213 salt = f"{self.__class__.__module__}.{self.__class__.__name__}"
214 self._signer = Signer(
215 key=key,
216 sep=sep,
217 salt=salt,
218 algorithm=algorithm,
219 fallback_keys=fallback_keys,
220 )
221
222 @property
223 def sep(self) -> str:
224 return self._signer.sep
225
226 def timestamp(self) -> str:
227 return b62_encode(int(time.time()))
228
229 def sign(self, value: str) -> str:
230 value = f"{value}{self.sep}{self.timestamp()}"
231 return self._signer.sign(value)
232
233 def unsign(
234 self, value: str, max_age: int | float | datetime.timedelta | None = None
235 ) -> str:
236 """
237 Retrieve original value and check it wasn't signed more
238 than max_age seconds ago.
239 """
240 result = self._signer.unsign(value)
241 value, timestamp = result.rsplit(self.sep, 1)
242 ts = b62_decode(timestamp)
243 if max_age is not None:
244 if isinstance(max_age, datetime.timedelta):
245 max_age = max_age.total_seconds()
246 # Check timestamp is not older than max_age
247 age = time.time() - ts
248 if age > max_age:
249 raise SignatureExpired(f"Signature age {age} > {max_age} seconds")
250 return value
251
252 def sign_object(
253 self,
254 obj: Any,
255 serializer: type[JSONSerializer] = JSONSerializer,
256 compress: bool = False,
257 ) -> str:
258 """
259 Return URL-safe, hmac signed base64 compressed JSON string.
260
261 If compress is True (not the default), check if compressing using zlib
262 can save some space. Prepend a '.' to signify compression. This is
263 included in the signature, to protect against zip bombs.
264
265 The serializer is expected to return a bytestring.
266 """
267 data = serializer().dumps(obj)
268 is_compressed = False
269
270 if compress:
271 compressed = zlib.compress(data)
272 if len(compressed) < (len(data) - 1):
273 data = compressed
274 is_compressed = True
275 base64d = b64_encode(data).decode()
276 if is_compressed:
277 base64d = "." + base64d
278 return self.sign(base64d)
279
280 def unsign_object(
281 self,
282 signed_obj: str,
283 serializer: type[JSONSerializer] = JSONSerializer,
284 max_age: int | float | datetime.timedelta | None = None,
285 ) -> Any:
286 """Unsign and decode an object, optionally checking max_age."""
287 base64d = self.unsign(signed_obj, max_age=max_age).encode()
288 decompress = base64d[:1] == b"."
289 if decompress:
290 base64d = base64d[1:]
291 data = b64_decode(base64d)
292 if decompress:
293 data = zlib.decompress(data)
294 return serializer().loads(data)