1from __future__ import annotations
2
3import time
4import zlib
5from typing import Any
6
7from plain.signing import (
8 JSONSerializer,
9 SignatureExpired,
10 Signer,
11 b62_decode,
12 b62_encode,
13 b64_decode,
14 b64_encode,
15)
16
17
18class ExpiringSigner:
19 """A signer with an embedded expiration (vs max age unsign).
20
21 Uses composition rather than inheritance since the interface
22 intentionally differs from Signer (requires expires_in parameter).
23 """
24
25 def __init__(
26 self,
27 *,
28 key: str | None = None,
29 sep: str = ":",
30 salt: str | None = None,
31 algorithm: str = "sha256",
32 fallback_keys: list[str] | None = None,
33 ) -> None:
34 # Compute default salt here to preserve backwards compatibility.
35 # When ExpiringSigner inherited from Signer, the default salt was
36 # "plain.loginlink.signing.ExpiringSigner". Now that we use composition,
37 # we must set it explicitly rather than letting Signer compute its own.
38 if salt is None:
39 salt = f"{self.__class__.__module__}.{self.__class__.__name__}"
40 self._signer = Signer(
41 key=key,
42 sep=sep,
43 salt=salt,
44 algorithm=algorithm,
45 fallback_keys=fallback_keys,
46 )
47
48 @property
49 def sep(self) -> str:
50 return self._signer.sep
51
52 def sign(self, value: str, expires_in: int) -> str:
53 timestamp = b62_encode(int(time.time() + expires_in))
54 value = f"{value}{self.sep}{timestamp}"
55 return self._signer.sign(value)
56
57 def unsign(self, signed_value: str) -> str:
58 """
59 Retrieve original value and check the expiration hasn't passed.
60 """
61 result = self._signer.unsign(signed_value)
62 value, timestamp = result.rsplit(self.sep, 1)
63 ts = b62_decode(timestamp)
64 if ts < time.time():
65 raise SignatureExpired("Signature expired")
66 return value
67
68 def sign_object(
69 self,
70 obj: Any,
71 *,
72 expires_in: int,
73 serializer: type = JSONSerializer,
74 compress: bool = False,
75 ) -> str:
76 """
77 Return URL-safe, hmac signed base64 compressed JSON string.
78
79 If compress is True (not the default), check if compressing using zlib
80 can save some space. Prepend a '.' to signify compression. This is
81 included in the signature, to protect against zip bombs.
82
83 The serializer is expected to return a bytestring.
84 """
85 data = serializer().dumps(obj)
86 # Flag for if it's been compressed or not.
87 is_compressed = False
88
89 if compress:
90 # Avoid zlib dependency unless compress is being used.
91 compressed = zlib.compress(data)
92 if len(compressed) < (len(data) - 1):
93 data = compressed
94 is_compressed = True
95 base64d = b64_encode(data).decode()
96 if is_compressed:
97 base64d = "." + base64d
98 return self.sign(base64d, expires_in)
99
100 def unsign_object(self, signed_obj: str, serializer: type = JSONSerializer) -> Any:
101 # Signer.unsign() returns str but base64 and zlib compression operate
102 # on bytes.
103 base64d = self.unsign(signed_obj).encode()
104 decompress = base64d[:1] == b"."
105 if decompress:
106 # It's compressed; uncompress it first.
107 base64d = base64d[1:]
108 data = b64_decode(base64d)
109 if decompress:
110 data = zlib.decompress(data)
111 return serializer().loads(data)
112
113
114def dumps(
115 obj: Any,
116 *,
117 expires_in: int,
118 key: str | None = None,
119 salt: str = "plain.loginlink",
120 serializer: type = JSONSerializer,
121 compress: bool = False,
122) -> str:
123 """
124 Return URL-safe, hmac signed base64 compressed JSON string. If key is
125 None, use settings.SECRET_KEY instead. The hmac algorithm is the default
126 Signer algorithm.
127
128 If compress is True (not the default), check if compressing using zlib can
129 save some space. Prepend a '.' to signify compression. This is included
130 in the signature, to protect against zip bombs.
131
132 Salt can be used to namespace the hash, so that a signed string is
133 only valid for a given namespace. Leaving this at the default
134 value or re-using a salt value across different parts of your
135 application without good cause is a security risk.
136
137 The serializer is expected to return a bytestring.
138 """
139 return ExpiringSigner(key=key, salt=salt).sign_object(
140 obj, expires_in=expires_in, serializer=serializer, compress=compress
141 )
142
143
144def loads(
145 s: str,
146 *,
147 key: str | None = None,
148 salt: str = "plain.loginlink",
149 serializer: type = JSONSerializer,
150 fallback_keys: list[str] | None = None,
151) -> Any:
152 """
153 Reverse of dumps(), raise BadSignature if signature fails.
154
155 The serializer is expected to accept a bytestring.
156 """
157 return ExpiringSigner(
158 key=key, salt=salt, fallback_keys=fallback_keys
159 ).unsign_object(
160 s,
161 serializer=serializer,
162 )