1from __future__ import annotations
2
3import base64
4import json
5from functools import cache
6from typing import TYPE_CHECKING, Any
7
8try:
9 from cryptography.fernet import Fernet, InvalidToken, MultiFernet
10 from cryptography.hazmat.primitives import hashes
11 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
12except ImportError:
13 Fernet = None # type: ignore[assignment,misc]
14 InvalidToken = None # type: ignore[assignment,misc]
15 MultiFernet = None # type: ignore[assignment,misc]
16 hashes = None # type: ignore[assignment]
17 PBKDF2HMAC = None # type: ignore[assignment]
18
19from plain import exceptions, preflight
20from plain.runtime import settings
21from plain.utils.encoding import force_bytes
22
23from . import Field
24
25if TYPE_CHECKING:
26 from collections.abc import Callable
27
28 from plain.models.lookups import Lookup, Transform
29 from plain.models.postgres.connection import DatabaseConnection
30 from plain.preflight.results import PreflightResult
31
32__all__ = [
33 "EncryptedTextField",
34 "EncryptedJSONField",
35]
36
37# Fixed salt for key derivation — changing this would invalidate all encrypted data.
38# This is not secret; it ensures the derived encryption key is distinct from
39# keys derived for other purposes (e.g., signing) even from the same SECRET_KEY.
40_KDF_SALT = b"plain.models.fields.encrypted"
41
42# Prefix for encrypted values in the database.
43# Makes encrypted data self-describing and distinguishable from plaintext.
44_ENCRYPTED_PREFIX = "$fernet$"
45
46
47def _derive_fernet_key(secret: str) -> bytes:
48 """Derive a Fernet-compatible key from an arbitrary secret string."""
49 if PBKDF2HMAC is None:
50 raise ImportError(
51 "The 'cryptography' package is required to use encrypted fields. "
52 "Install it with: pip install cryptography"
53 )
54 kdf = PBKDF2HMAC(
55 algorithm=hashes.SHA256(),
56 length=32,
57 salt=_KDF_SALT,
58 iterations=480_000,
59 )
60 return base64.urlsafe_b64encode(kdf.derive(force_bytes(secret)))
61
62
63@cache
64def _get_fernet(secret_key: str, fallbacks: tuple[str, ...]) -> MultiFernet:
65 """Build a MultiFernet from the given secret key and fallbacks.
66
67 The first key is used for encryption.
68 All keys are used for decryption, enabling key rotation.
69 Results are cached by (secret_key, fallbacks) so changing SECRET_KEY
70 (e.g. in tests) produces a new MultiFernet automatically.
71 """
72 keys = [_derive_fernet_key(secret_key)]
73 for fallback in fallbacks:
74 keys.append(_derive_fernet_key(fallback))
75 return MultiFernet([Fernet(k) for k in keys])
76
77
78def _encrypt(value: str) -> str:
79 """Encrypt a string and return a self-describing database value."""
80 if value == "":
81 return value
82 f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
83 token = f.encrypt(force_bytes(value))
84 return _ENCRYPTED_PREFIX + token.decode("ascii")
85
86
87def _decrypt(value: str) -> str:
88 """Decrypt a self-describing database value back to a string.
89
90 Gracefully handles unencrypted values — if the value doesn't have
91 the encryption prefix, it's returned as-is. This supports gradual
92 migration from plaintext to encrypted fields.
93 """
94 if not value.startswith(_ENCRYPTED_PREFIX):
95 return value
96 token = value[len(_ENCRYPTED_PREFIX) :]
97 f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
98 try:
99 return f.decrypt(token.encode("ascii")).decode("utf-8")
100 except InvalidToken:
101 raise ValueError(
102 "Could not decrypt field value. The SECRET_KEY (and SECRET_KEY_FALLBACKS) "
103 "may have changed since this data was encrypted."
104 )
105
106
107# isnull is obviously needed. exact is required so that `filter(field=None)`
108# works — the ORM resolves "exact" first and then rewrites None to isnull.
109# Exact lookups on non-None values will silently return no results (since
110# ciphertext is non-deterministic), but blocking exact entirely would break
111# the None/isnull path.
112_ALLOWED_LOOKUPS = {"isnull", "exact"}
113
114
115class EncryptedFieldMixin:
116 """Shared behavior for all encrypted fields.
117
118 Blocks lookups (except isnull and exact) since encrypted values are non-deterministic.
119 Errors at preflight if the field is used in indexes or unique constraints.
120
121 Must be used with Field as a co-base class.
122 """
123
124 # Type hints for attributes provided by Field (the required co-base class)
125 name: str
126 model: Any
127
128 def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
129 if lookup_name not in _ALLOWED_LOOKUPS:
130 return None
131 get_lookup = getattr(super(), "get_lookup")
132 return get_lookup(lookup_name)
133
134 def get_transform(
135 self, lookup_name: str
136 ) -> type[Transform] | Callable[..., Any] | None:
137 return None
138
139 def _check_encrypted_constraints(self) -> list[PreflightResult]:
140 errors: list[PreflightResult] = []
141 if not hasattr(self, "model"):
142 return errors
143
144 field_name = self.name
145
146 for constraint in self.model.model_options.constraints:
147 constraint_fields = getattr(constraint, "fields", ())
148 if field_name in constraint_fields:
149 errors.append(
150 preflight.PreflightResult(
151 fix=(
152 f"'{self.model.__name__}.{field_name}' is an encrypted field "
153 f"and cannot be used in constraint '{constraint.name}'. "
154 "Encrypted values are non-deterministic."
155 ),
156 obj=self,
157 id="fields.encrypted_in_constraint",
158 )
159 )
160
161 for index in self.model.model_options.indexes:
162 index_fields = getattr(index, "fields", ())
163 # Strip ordering prefix (e.g., "-field_name" for descending)
164 stripped_fields = [f.lstrip("-") for f in index_fields]
165 if field_name in stripped_fields:
166 errors.append(
167 preflight.PreflightResult(
168 fix=(
169 f"'{self.model.__name__}.{field_name}' is an encrypted field "
170 f"and cannot be used in index '{index.name}'. "
171 "Encrypted values are non-deterministic."
172 ),
173 obj=self,
174 id="fields.encrypted_in_index",
175 )
176 )
177
178 return errors
179
180
181class EncryptedTextField(EncryptedFieldMixin, Field[str]):
182 """A text field that encrypts its value before storing in the database.
183
184 Values are encrypted using Fernet (AES-128-CBC + HMAC-SHA256) with a key
185 derived from SECRET_KEY. The database column is always ``text`` regardless
186 of max_length, since ciphertext length is unpredictable.
187
188 max_length is enforced on the plaintext value (validation), not on the
189 ciphertext stored in the database.
190 """
191
192 description = "Encrypted text"
193
194 def get_internal_type(self) -> str:
195 # Always store as text — ciphertext is longer than plaintext
196 return "TextField"
197
198 def to_python(self, value: Any) -> str | None:
199 if isinstance(value, str) or value is None:
200 return value
201 return str(value)
202
203 def validate(self, value: Any, model_instance: Any) -> None:
204 super().validate(value, model_instance)
205 if (
206 self.max_length is not None
207 and value is not None
208 and len(value) > self.max_length
209 ):
210 raise exceptions.ValidationError(
211 f"Ensure this value has at most {self.max_length} characters (it has {len(value)}).",
212 code="max_length",
213 )
214
215 def get_prep_value(self, value: Any) -> Any:
216 value = super().get_prep_value(value)
217 if value is None:
218 return value
219 return self.to_python(value)
220
221 def get_db_prep_value(
222 self, value: Any, connection: DatabaseConnection, prepared: bool = False
223 ) -> Any:
224 value = super().get_db_prep_value(value, connection, prepared)
225 if value is None:
226 return value
227 return _encrypt(value)
228
229 def from_db_value(
230 self, value: Any, expression: Any, connection: DatabaseConnection
231 ) -> str | None:
232 if value is None:
233 return value
234 return _decrypt(value)
235
236 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
237 name, path, args, kwargs = super().deconstruct()
238 # Override the path rewrite from Field.deconstruct() which would
239 # shorten "plain.models.fields.encrypted" to "plain.models.encrypted"
240 # (a module that doesn't exist).
241 path = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
242 return name, path, args, kwargs
243
244 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
245 errors = super().preflight(**kwargs)
246 errors.extend(self._check_encrypted_constraints())
247 return errors
248
249
250class EncryptedJSONField(EncryptedFieldMixin, Field):
251 """A JSONField that encrypts its serialized value before storing in the database.
252
253 The JSON value is serialized to a string, encrypted, and stored as text.
254 On read, it's decrypted and deserialized back to a Python object.
255 """
256
257 empty_strings_allowed = False
258 description = "Encrypted JSON"
259 default_error_messages = {
260 "invalid": "Value must be valid JSON.",
261 }
262 _default_fix = ("dict", "{}")
263
264 def __init__(
265 self,
266 *,
267 encoder: type[json.JSONEncoder] | None = None,
268 decoder: type[json.JSONDecoder] | None = None,
269 **kwargs: Any,
270 ):
271 if encoder and not callable(encoder):
272 raise ValueError("The encoder parameter must be a callable object.")
273 if decoder and not callable(decoder):
274 raise ValueError("The decoder parameter must be a callable object.")
275 self.encoder = encoder
276 self.decoder = decoder
277 super().__init__(**kwargs)
278
279 def get_internal_type(self) -> str:
280 # Store as text, not jsonb — we're storing encrypted ciphertext
281 return "TextField"
282
283 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
284 name, path, args, kwargs = super().deconstruct()
285 # Override the path rewrite from Field.deconstruct() which would
286 # shorten to a nonexistent module (same pattern as EncryptedTextField).
287 path = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
288 if self.encoder is not None:
289 kwargs["encoder"] = self.encoder
290 if self.decoder is not None:
291 kwargs["decoder"] = self.decoder
292 return name, path, args, kwargs
293
294 def validate(self, value: Any, model_instance: Any) -> None:
295 super().validate(value, model_instance)
296 try:
297 json.dumps(value, cls=self.encoder)
298 except TypeError:
299 raise exceptions.ValidationError(
300 self.error_messages["invalid"],
301 code="invalid",
302 params={"value": value},
303 )
304
305 def get_db_prep_value(
306 self, value: Any, connection: DatabaseConnection, prepared: bool = False
307 ) -> Any:
308 value = super().get_db_prep_value(value, connection, prepared)
309 if value is None:
310 return value
311 json_str = json.dumps(value, cls=self.encoder)
312 return _encrypt(json_str)
313
314 def from_db_value(
315 self, value: Any, expression: Any, connection: DatabaseConnection
316 ) -> Any:
317 if value is None:
318 return value
319 decrypted = _decrypt(value)
320 try:
321 return json.loads(decrypted, cls=self.decoder)
322 except json.JSONDecodeError:
323 raise ValueError(
324 "Encrypted field contains data that is not valid JSON. "
325 "The stored value may be corrupt."
326 )
327
328 def _check_default(self) -> list[PreflightResult]:
329 if (
330 self.has_default()
331 and self.default is not None
332 and not callable(self.default)
333 ):
334 return [
335 preflight.PreflightResult(
336 fix=(
337 f"{self.__class__.__name__} default should be a callable instead of an instance "
338 "so that it's not shared between all field instances. "
339 "Use a callable instead, e.g., use `{}` instead of "
340 "`{}`.".format(*self._default_fix)
341 ),
342 obj=self,
343 id="fields.encrypted_mutable_default",
344 warning=True,
345 )
346 ]
347 else:
348 return []
349
350 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
351 errors = super().preflight(**kwargs)
352 errors.extend(self._check_default())
353 errors.extend(self._check_encrypted_constraints())
354 return errors