v0.145.0
  1from __future__ import annotations
  2
  3import time
  4import zlib
  5from typing import Any
  6
  7from plain.signing import (
  8    JSONSerializer,
  9    SignatureExpired,
 10    Signer,
 11    b62_decode,
 12    b62_encode,
 13    b64_decode,
 14    b64_encode,
 15)
 16
 17
 18class ExpiringSigner:
 19    """A signer with an embedded expiration (vs max age unsign).
 20
 21    Uses composition rather than inheritance since the interface
 22    intentionally differs from Signer (requires expires_in parameter).
 23    """
 24
 25    def __init__(
 26        self,
 27        *,
 28        key: str | None = None,
 29        sep: str = ":",
 30        salt: str | None = None,
 31        algorithm: str = "sha256",
 32        fallback_keys: list[str] | None = None,
 33    ) -> None:
 34        # Compute default salt here to preserve backwards compatibility.
 35        # When ExpiringSigner inherited from Signer, the default salt was
 36        # "plain.loginlink.signing.ExpiringSigner". Now that we use composition,
 37        # we must set it explicitly rather than letting Signer compute its own.
 38        if salt is None:
 39            salt = f"{self.__class__.__module__}.{self.__class__.__name__}"
 40        self._signer = Signer(
 41            key=key,
 42            sep=sep,
 43            salt=salt,
 44            algorithm=algorithm,
 45            fallback_keys=fallback_keys,
 46        )
 47
 48    @property
 49    def sep(self) -> str:
 50        return self._signer.sep
 51
 52    def sign(self, value: str, expires_in: int) -> str:
 53        timestamp = b62_encode(int(time.time() + expires_in))
 54        value = f"{value}{self.sep}{timestamp}"
 55        return self._signer.sign(value)
 56
 57    def unsign(self, signed_value: str) -> str:
 58        """
 59        Retrieve original value and check the expiration hasn't passed.
 60        """
 61        result = self._signer.unsign(signed_value)
 62        value, timestamp = result.rsplit(self.sep, 1)
 63        ts = b62_decode(timestamp)
 64        if ts < time.time():
 65            raise SignatureExpired("Signature expired")
 66        return value
 67
 68    def sign_object(
 69        self,
 70        obj: Any,
 71        *,
 72        expires_in: int,
 73        serializer: type = JSONSerializer,
 74        compress: bool = False,
 75    ) -> str:
 76        """
 77        Return URL-safe, hmac signed base64 compressed JSON string.
 78
 79        If compress is True (not the default), check if compressing using zlib
 80        can save some space. Prepend a '.' to signify compression. This is
 81        included in the signature, to protect against zip bombs.
 82
 83        The serializer is expected to return a bytestring.
 84        """
 85        data = serializer().dumps(obj)
 86        # Flag for if it's been compressed or not.
 87        is_compressed = False
 88
 89        if compress:
 90            # Avoid zlib dependency unless compress is being used.
 91            compressed = zlib.compress(data)
 92            if len(compressed) < (len(data) - 1):
 93                data = compressed
 94                is_compressed = True
 95        base64d = b64_encode(data).decode()
 96        if is_compressed:
 97            base64d = "." + base64d
 98        return self.sign(base64d, expires_in)
 99
100    def unsign_object(self, signed_obj: str, serializer: type = JSONSerializer) -> Any:
101        # Signer.unsign() returns str but base64 and zlib compression operate
102        # on bytes.
103        base64d = self.unsign(signed_obj).encode()
104        decompress = base64d[:1] == b"."
105        if decompress:
106            # It's compressed; uncompress it first.
107            base64d = base64d[1:]
108        data = b64_decode(base64d)
109        if decompress:
110            data = zlib.decompress(data)
111        return serializer().loads(data)