Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import time
  4import zlib
  5from typing import Any
  6
  7from plain.signing import (
  8    JSONSerializer,
  9    SignatureExpired,
 10    Signer,
 11    b62_decode,
 12    b62_encode,
 13    b64_decode,
 14    b64_encode,
 15)
 16
 17
 18class ExpiringSigner:
 19    """A signer with an embedded expiration (vs max age unsign).
 20
 21    Uses composition rather than inheritance since the interface
 22    intentionally differs from Signer (requires expires_in parameter).
 23    """
 24
 25    def __init__(
 26        self,
 27        *,
 28        key: str | None = None,
 29        sep: str = ":",
 30        salt: str | None = None,
 31        algorithm: str = "sha256",
 32        fallback_keys: list[str] | None = None,
 33    ) -> None:
 34        # Compute default salt here to preserve backwards compatibility.
 35        # When ExpiringSigner inherited from Signer, the default salt was
 36        # "plain.loginlink.signing.ExpiringSigner". Now that we use composition,
 37        # we must set it explicitly rather than letting Signer compute its own.
 38        if salt is None:
 39            salt = f"{self.__class__.__module__}.{self.__class__.__name__}"
 40        self._signer = Signer(
 41            key=key,
 42            sep=sep,
 43            salt=salt,
 44            algorithm=algorithm,
 45            fallback_keys=fallback_keys,
 46        )
 47
 48    @property
 49    def sep(self) -> str:
 50        return self._signer.sep
 51
 52    def sign(self, value: str, expires_in: int) -> str:
 53        timestamp = b62_encode(int(time.time() + expires_in))
 54        value = f"{value}{self.sep}{timestamp}"
 55        return self._signer.sign(value)
 56
 57    def unsign(self, signed_value: str) -> str:
 58        """
 59        Retrieve original value and check the expiration hasn't passed.
 60        """
 61        result = self._signer.unsign(signed_value)
 62        value, timestamp = result.rsplit(self.sep, 1)
 63        ts = b62_decode(timestamp)
 64        if ts < time.time():
 65            raise SignatureExpired("Signature expired")
 66        return value
 67
 68    def sign_object(
 69        self,
 70        obj: Any,
 71        *,
 72        expires_in: int,
 73        serializer: type = JSONSerializer,
 74        compress: bool = False,
 75    ) -> str:
 76        """
 77        Return URL-safe, hmac signed base64 compressed JSON string.
 78
 79        If compress is True (not the default), check if compressing using zlib
 80        can save some space. Prepend a '.' to signify compression. This is
 81        included in the signature, to protect against zip bombs.
 82
 83        The serializer is expected to return a bytestring.
 84        """
 85        data = serializer().dumps(obj)
 86        # Flag for if it's been compressed or not.
 87        is_compressed = False
 88
 89        if compress:
 90            # Avoid zlib dependency unless compress is being used.
 91            compressed = zlib.compress(data)
 92            if len(compressed) < (len(data) - 1):
 93                data = compressed
 94                is_compressed = True
 95        base64d = b64_encode(data).decode()
 96        if is_compressed:
 97            base64d = "." + base64d
 98        return self.sign(base64d, expires_in)
 99
100    def unsign_object(self, signed_obj: str, serializer: type = JSONSerializer) -> Any:
101        # Signer.unsign() returns str but base64 and zlib compression operate
102        # on bytes.
103        base64d = self.unsign(signed_obj).encode()
104        decompress = base64d[:1] == b"."
105        if decompress:
106            # It's compressed; uncompress it first.
107            base64d = base64d[1:]
108        data = b64_decode(base64d)
109        if decompress:
110            data = zlib.decompress(data)
111        return serializer().loads(data)
112
113
114def dumps(
115    obj: Any,
116    *,
117    expires_in: int,
118    key: str | None = None,
119    salt: str = "plain.loginlink",
120    serializer: type = JSONSerializer,
121    compress: bool = False,
122) -> str:
123    """
124    Return URL-safe, hmac signed base64 compressed JSON string. If key is
125    None, use settings.SECRET_KEY instead. The hmac algorithm is the default
126    Signer algorithm.
127
128    If compress is True (not the default), check if compressing using zlib can
129    save some space. Prepend a '.' to signify compression. This is included
130    in the signature, to protect against zip bombs.
131
132    Salt can be used to namespace the hash, so that a signed string is
133    only valid for a given namespace. Leaving this at the default
134    value or re-using a salt value across different parts of your
135    application without good cause is a security risk.
136
137    The serializer is expected to return a bytestring.
138    """
139    return ExpiringSigner(key=key, salt=salt).sign_object(
140        obj, expires_in=expires_in, serializer=serializer, compress=compress
141    )
142
143
144def loads(
145    s: str,
146    *,
147    key: str | None = None,
148    salt: str = "plain.loginlink",
149    serializer: type = JSONSerializer,
150    fallback_keys: list[str] | None = None,
151) -> Any:
152    """
153    Reverse of dumps(), raise BadSignature if signature fails.
154
155    The serializer is expected to accept a bytestring.
156    """
157    return ExpiringSigner(
158        key=key, salt=salt, fallback_keys=fallback_keys
159    ).unsign_object(
160        s,
161        serializer=serializer,
162    )