1"""
  2URL-safe signed JSON objects using HMAC/SHA-256.
  3
  4Use TimestampSigner for signing with expiration:
  5
  6    TimestampSigner(salt="my-salt").sign_object({"key": "value"})
  7    TimestampSigner(salt="my-salt").unsign_object(token, max_age=3600)
  8
  9Use Signer for signing without expiration:
 10
 11    Signer(salt="my-salt").sign_object({"key": "value"})
 12    Signer(salt="my-salt").unsign_object(token)
 13"""
 14
 15from __future__ import annotations
 16
 17import base64
 18import datetime
 19import hmac
 20import json
 21import time
 22import zlib
 23from typing import Any
 24
 25from plain.runtime import settings
 26from plain.utils.crypto import salted_hmac
 27from plain.utils.encoding import force_bytes
 28from plain.utils.regex_helper import _lazy_re_compile
 29
 30_SEP_UNSAFE = _lazy_re_compile(r"^[A-z0-9-_=]*$")
 31BASE62_ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
 32
 33
 34class BadSignature(Exception):
 35    """Signature does not match."""
 36
 37    pass
 38
 39
 40class SignatureExpired(BadSignature):
 41    """Signature timestamp is older than required max_age."""
 42
 43    pass
 44
 45
 46def b62_encode(s: int) -> str:
 47    if s == 0:
 48        return "0"
 49    sign = "-" if s < 0 else ""
 50    s = abs(s)
 51    encoded = ""
 52    while s > 0:
 53        s, remainder = divmod(s, 62)
 54        encoded = BASE62_ALPHABET[remainder] + encoded
 55    return sign + encoded
 56
 57
 58def b62_decode(s: str) -> int:
 59    if s == "0":
 60        return 0
 61    sign = 1
 62    if s[0] == "-":
 63        s = s[1:]
 64        sign = -1
 65    decoded = 0
 66    for digit in s:
 67        decoded = decoded * 62 + BASE62_ALPHABET.index(digit)
 68    return sign * decoded
 69
 70
 71def b64_encode(s: bytes) -> bytes:
 72    return base64.urlsafe_b64encode(s).strip(b"=")
 73
 74
 75def b64_decode(s: bytes) -> bytes:
 76    pad = b"=" * (-len(s) % 4)
 77    return base64.urlsafe_b64decode(s + pad)
 78
 79
 80def base64_hmac(salt: str, value: str, key: str, algorithm: str = "sha1") -> str:
 81    return b64_encode(
 82        salted_hmac(salt, value, key, algorithm=algorithm).digest()
 83    ).decode()
 84
 85
 86class JSONSerializer:
 87    """
 88    Simple wrapper around json used by Signer.sign_object and
 89    Signer.unsign_object.
 90    """
 91
 92    def dumps(self, obj: Any) -> bytes:
 93        return json.dumps(obj, separators=(",", ":")).encode("latin-1")
 94
 95    def loads(self, data: bytes) -> Any:
 96        return json.loads(data.decode("latin-1"))
 97
 98
 99class Signer:
100    def __init__(
101        self,
102        *,
103        key: str | None = None,
104        sep: str = ":",
105        salt: str | None = None,
106        algorithm: str = "sha256",
107        fallback_keys: list[str] | None = None,
108    ) -> None:
109        self.key = key or settings.SECRET_KEY
110        self.fallback_keys = (
111            fallback_keys
112            if fallback_keys is not None
113            else settings.SECRET_KEY_FALLBACKS
114        )
115        self.sep = sep
116        self.salt = salt or f"{self.__class__.__module__}.{self.__class__.__name__}"
117        self.algorithm = algorithm
118
119        if _SEP_UNSAFE.match(self.sep):
120            raise ValueError(
121                f"Unsafe Signer separator: {sep!r} (cannot be empty or consist of "
122                "only A-z0-9-_=)",
123            )
124
125    def signature(self, value: str, key: str | None = None) -> str:
126        key = key or self.key
127        return base64_hmac(self.salt + "signer", value, key, algorithm=self.algorithm)
128
129    def sign(self, value: str) -> str:
130        return f"{value}{self.sep}{self.signature(value)}"
131
132    def unsign(self, signed_value: str) -> str:
133        if self.sep not in signed_value:
134            raise BadSignature(f'No "{self.sep}" found in value')
135        value, sig = signed_value.rsplit(self.sep, 1)
136        for key in [self.key, *self.fallback_keys]:
137            if hmac.compare_digest(
138                force_bytes(sig), force_bytes(self.signature(value, key))
139            ):
140                return value
141        raise BadSignature(f'Signature "{sig}" does not match')
142
143    def sign_object(
144        self,
145        obj: Any,
146        serializer: type[JSONSerializer] = JSONSerializer,
147        compress: bool = False,
148    ) -> str:
149        """
150        Return URL-safe, hmac signed base64 compressed JSON string.
151
152        If compress is True (not the default), check if compressing using zlib
153        can save some space. Prepend a '.' to signify compression. This is
154        included in the signature, to protect against zip bombs.
155
156        The serializer is expected to return a bytestring.
157        """
158        data = serializer().dumps(obj)
159        # Flag for if it's been compressed or not.
160        is_compressed = False
161
162        if compress:
163            # Avoid zlib dependency unless compress is being used.
164            compressed = zlib.compress(data)
165            if len(compressed) < (len(data) - 1):
166                data = compressed
167                is_compressed = True
168        base64d = b64_encode(data).decode()
169        if is_compressed:
170            base64d = "." + base64d
171        return self.sign(base64d)
172
173    def unsign_object(
174        self,
175        signed_obj: str,
176        serializer: type[JSONSerializer] = JSONSerializer,
177        **kwargs: Any,
178    ) -> Any:
179        # Signer.unsign() returns str but base64 and zlib compression operate
180        # on bytes.
181        base64d = self.unsign(signed_obj, **kwargs).encode()
182        decompress = base64d[:1] == b"."
183        if decompress:
184            # It's compressed; uncompress it first.
185            base64d = base64d[1:]
186        data = b64_decode(base64d)
187        if decompress:
188            data = zlib.decompress(data)
189        return serializer().loads(data)
190
191
192class TimestampSigner:
193    """A signer that includes a timestamp for max_age validation.
194
195    Uses composition rather than inheritance since the interface
196    intentionally differs from Signer (unsign accepts max_age parameter).
197    """
198
199    def __init__(
200        self,
201        *,
202        key: str | None = None,
203        sep: str = ":",
204        salt: str | None = None,
205        algorithm: str = "sha256",
206        fallback_keys: list[str] | None = None,
207    ) -> None:
208        # Compute default salt here to preserve backwards compatibility.
209        # When TimestampSigner inherited from Signer, the default salt was
210        # "plain.signing.TimestampSigner". Now that we use composition,
211        # we must set it explicitly rather than letting Signer compute its own.
212        if salt is None:
213            salt = f"{self.__class__.__module__}.{self.__class__.__name__}"
214        self._signer = Signer(
215            key=key,
216            sep=sep,
217            salt=salt,
218            algorithm=algorithm,
219            fallback_keys=fallback_keys,
220        )
221
222    @property
223    def sep(self) -> str:
224        return self._signer.sep
225
226    def timestamp(self) -> str:
227        return b62_encode(int(time.time()))
228
229    def sign(self, value: str) -> str:
230        value = f"{value}{self.sep}{self.timestamp()}"
231        return self._signer.sign(value)
232
233    def unsign(
234        self, value: str, max_age: int | float | datetime.timedelta | None = None
235    ) -> str:
236        """
237        Retrieve original value and check it wasn't signed more
238        than max_age seconds ago.
239        """
240        result = self._signer.unsign(value)
241        value, timestamp = result.rsplit(self.sep, 1)
242        ts = b62_decode(timestamp)
243        if max_age is not None:
244            if isinstance(max_age, datetime.timedelta):
245                max_age = max_age.total_seconds()
246            # Check timestamp is not older than max_age
247            age = time.time() - ts
248            if age > max_age:
249                raise SignatureExpired(f"Signature age {age} > {max_age} seconds")
250        return value
251
252    def sign_object(
253        self,
254        obj: Any,
255        serializer: type[JSONSerializer] = JSONSerializer,
256        compress: bool = False,
257    ) -> str:
258        """
259        Return URL-safe, hmac signed base64 compressed JSON string.
260
261        If compress is True (not the default), check if compressing using zlib
262        can save some space. Prepend a '.' to signify compression. This is
263        included in the signature, to protect against zip bombs.
264
265        The serializer is expected to return a bytestring.
266        """
267        data = serializer().dumps(obj)
268        is_compressed = False
269
270        if compress:
271            compressed = zlib.compress(data)
272            if len(compressed) < (len(data) - 1):
273                data = compressed
274                is_compressed = True
275        base64d = b64_encode(data).decode()
276        if is_compressed:
277            base64d = "." + base64d
278        return self.sign(base64d)
279
280    def unsign_object(
281        self,
282        signed_obj: str,
283        serializer: type[JSONSerializer] = JSONSerializer,
284        max_age: int | float | datetime.timedelta | None = None,
285    ) -> Any:
286        """Unsign and decode an object, optionally checking max_age."""
287        base64d = self.unsign(signed_obj, max_age=max_age).encode()
288        decompress = base64d[:1] == b"."
289        if decompress:
290            base64d = base64d[1:]
291        data = b64_decode(base64d)
292        if decompress:
293            data = zlib.decompress(data)
294        return serializer().loads(data)