1from __future__ import annotations
  2
  3import base64
  4import json
  5from functools import cache
  6from typing import TYPE_CHECKING, Any
  7
  8try:
  9    from cryptography.fernet import Fernet, InvalidToken, MultiFernet
 10    from cryptography.hazmat.primitives import hashes
 11    from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 12except ImportError:
 13    Fernet = None  # type: ignore[assignment,misc]
 14    InvalidToken = None  # type: ignore[assignment,misc]
 15    MultiFernet = None  # type: ignore[assignment,misc]
 16    hashes = None  # type: ignore[assignment]
 17    PBKDF2HMAC = None  # type: ignore[assignment]
 18
 19from plain import exceptions, preflight
 20from plain.runtime import settings
 21from plain.utils.encoding import force_bytes
 22
 23from . import Field
 24
 25if TYPE_CHECKING:
 26    from collections.abc import Callable
 27
 28    from plain.models.lookups import Lookup, Transform
 29    from plain.models.postgres.connection import DatabaseConnection
 30    from plain.preflight.results import PreflightResult
 31
 32__all__ = [
 33    "EncryptedTextField",
 34    "EncryptedJSONField",
 35]
 36
 37# Fixed salt for key derivation — changing this would invalidate all encrypted data.
 38# This is not secret; it ensures the derived encryption key is distinct from
 39# keys derived for other purposes (e.g., signing) even from the same SECRET_KEY.
 40_KDF_SALT = b"plain.models.fields.encrypted"
 41
 42# Prefix for encrypted values in the database.
 43# Makes encrypted data self-describing and distinguishable from plaintext.
 44_ENCRYPTED_PREFIX = "$fernet$"
 45
 46
 47def _derive_fernet_key(secret: str) -> bytes:
 48    """Derive a Fernet-compatible key from an arbitrary secret string."""
 49    if PBKDF2HMAC is None:
 50        raise ImportError(
 51            "The 'cryptography' package is required to use encrypted fields. "
 52            "Install it with: pip install cryptography"
 53        )
 54    kdf = PBKDF2HMAC(
 55        algorithm=hashes.SHA256(),
 56        length=32,
 57        salt=_KDF_SALT,
 58        iterations=480_000,
 59    )
 60    return base64.urlsafe_b64encode(kdf.derive(force_bytes(secret)))
 61
 62
 63@cache
 64def _get_fernet(secret_key: str, fallbacks: tuple[str, ...]) -> MultiFernet:
 65    """Build a MultiFernet from the given secret key and fallbacks.
 66
 67    The first key is used for encryption.
 68    All keys are used for decryption, enabling key rotation.
 69    Results are cached by (secret_key, fallbacks) so changing SECRET_KEY
 70    (e.g. in tests) produces a new MultiFernet automatically.
 71    """
 72    keys = [_derive_fernet_key(secret_key)]
 73    for fallback in fallbacks:
 74        keys.append(_derive_fernet_key(fallback))
 75    return MultiFernet([Fernet(k) for k in keys])
 76
 77
 78def _encrypt(value: str) -> str:
 79    """Encrypt a string and return a self-describing database value."""
 80    if value == "":
 81        return value
 82    f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
 83    token = f.encrypt(force_bytes(value))
 84    return _ENCRYPTED_PREFIX + token.decode("ascii")
 85
 86
 87def _decrypt(value: str) -> str:
 88    """Decrypt a self-describing database value back to a string.
 89
 90    Gracefully handles unencrypted values — if the value doesn't have
 91    the encryption prefix, it's returned as-is. This supports gradual
 92    migration from plaintext to encrypted fields.
 93    """
 94    if not value.startswith(_ENCRYPTED_PREFIX):
 95        return value
 96    token = value[len(_ENCRYPTED_PREFIX) :]
 97    f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
 98    try:
 99        return f.decrypt(token.encode("ascii")).decode("utf-8")
100    except InvalidToken:
101        raise ValueError(
102            "Could not decrypt field value. The SECRET_KEY (and SECRET_KEY_FALLBACKS) "
103            "may have changed since this data was encrypted."
104        )
105
106
107# isnull is obviously needed. exact is required so that `filter(field=None)`
108# works — the ORM resolves "exact" first and then rewrites None to isnull.
109# Exact lookups on non-None values will silently return no results (since
110# ciphertext is non-deterministic), but blocking exact entirely would break
111# the None/isnull path.
112_ALLOWED_LOOKUPS = {"isnull", "exact"}
113
114
115class EncryptedFieldMixin:
116    """Shared behavior for all encrypted fields.
117
118    Blocks lookups (except isnull and exact) since encrypted values are non-deterministic.
119    Errors at preflight if the field is used in indexes or unique constraints.
120
121    Must be used with Field as a co-base class.
122    """
123
124    # Type hints for attributes provided by Field (the required co-base class)
125    name: str
126    model: Any
127
128    def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
129        if lookup_name not in _ALLOWED_LOOKUPS:
130            return None
131        get_lookup = getattr(super(), "get_lookup")
132        return get_lookup(lookup_name)
133
134    def get_transform(
135        self, lookup_name: str
136    ) -> type[Transform] | Callable[..., Any] | None:
137        return None
138
139    def _check_encrypted_constraints(self) -> list[PreflightResult]:
140        errors: list[PreflightResult] = []
141        if not hasattr(self, "model"):
142            return errors
143
144        field_name = self.name
145
146        for constraint in self.model.model_options.constraints:
147            constraint_fields = getattr(constraint, "fields", ())
148            if field_name in constraint_fields:
149                errors.append(
150                    preflight.PreflightResult(
151                        fix=(
152                            f"'{self.model.__name__}.{field_name}' is an encrypted field "
153                            f"and cannot be used in constraint '{constraint.name}'. "
154                            "Encrypted values are non-deterministic."
155                        ),
156                        obj=self,
157                        id="fields.encrypted_in_constraint",
158                    )
159                )
160
161        for index in self.model.model_options.indexes:
162            index_fields = getattr(index, "fields", ())
163            # Strip ordering prefix (e.g., "-field_name" for descending)
164            stripped_fields = [f.lstrip("-") for f in index_fields]
165            if field_name in stripped_fields:
166                errors.append(
167                    preflight.PreflightResult(
168                        fix=(
169                            f"'{self.model.__name__}.{field_name}' is an encrypted field "
170                            f"and cannot be used in index '{index.name}'. "
171                            "Encrypted values are non-deterministic."
172                        ),
173                        obj=self,
174                        id="fields.encrypted_in_index",
175                    )
176                )
177
178        return errors
179
180
181class EncryptedTextField(EncryptedFieldMixin, Field[str]):
182    """A text field that encrypts its value before storing in the database.
183
184    Values are encrypted using Fernet (AES-128-CBC + HMAC-SHA256) with a key
185    derived from SECRET_KEY. The database column is always ``text`` regardless
186    of max_length, since ciphertext length is unpredictable.
187
188    max_length is enforced on the plaintext value (validation), not on the
189    ciphertext stored in the database.
190    """
191
192    description = "Encrypted text"
193
194    def get_internal_type(self) -> str:
195        # Always store as text — ciphertext is longer than plaintext
196        return "TextField"
197
198    def to_python(self, value: Any) -> str | None:
199        if isinstance(value, str) or value is None:
200            return value
201        return str(value)
202
203    def validate(self, value: Any, model_instance: Any) -> None:
204        super().validate(value, model_instance)
205        if (
206            self.max_length is not None
207            and value is not None
208            and len(value) > self.max_length
209        ):
210            raise exceptions.ValidationError(
211                f"Ensure this value has at most {self.max_length} characters (it has {len(value)}).",
212                code="max_length",
213            )
214
215    def get_prep_value(self, value: Any) -> Any:
216        value = super().get_prep_value(value)
217        if value is None:
218            return value
219        return self.to_python(value)
220
221    def get_db_prep_value(
222        self, value: Any, connection: DatabaseConnection, prepared: bool = False
223    ) -> Any:
224        value = super().get_db_prep_value(value, connection, prepared)
225        if value is None:
226            return value
227        return _encrypt(value)
228
229    def from_db_value(
230        self, value: Any, expression: Any, connection: DatabaseConnection
231    ) -> str | None:
232        if value is None:
233            return value
234        return _decrypt(value)
235
236    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
237        name, path, args, kwargs = super().deconstruct()
238        # Override the path rewrite from Field.deconstruct() which would
239        # shorten "plain.models.fields.encrypted" to "plain.models.encrypted"
240        # (a module that doesn't exist).
241        path = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
242        return name, path, args, kwargs
243
244    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
245        errors = super().preflight(**kwargs)
246        errors.extend(self._check_encrypted_constraints())
247        return errors
248
249
250class EncryptedJSONField(EncryptedFieldMixin, Field):
251    """A JSONField that encrypts its serialized value before storing in the database.
252
253    The JSON value is serialized to a string, encrypted, and stored as text.
254    On read, it's decrypted and deserialized back to a Python object.
255    """
256
257    empty_strings_allowed = False
258    description = "Encrypted JSON"
259    default_error_messages = {
260        "invalid": "Value must be valid JSON.",
261    }
262    _default_fix = ("dict", "{}")
263
264    def __init__(
265        self,
266        *,
267        encoder: type[json.JSONEncoder] | None = None,
268        decoder: type[json.JSONDecoder] | None = None,
269        **kwargs: Any,
270    ):
271        if encoder and not callable(encoder):
272            raise ValueError("The encoder parameter must be a callable object.")
273        if decoder and not callable(decoder):
274            raise ValueError("The decoder parameter must be a callable object.")
275        self.encoder = encoder
276        self.decoder = decoder
277        super().__init__(**kwargs)
278
279    def get_internal_type(self) -> str:
280        # Store as text, not jsonb — we're storing encrypted ciphertext
281        return "TextField"
282
283    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
284        name, path, args, kwargs = super().deconstruct()
285        # Override the path rewrite from Field.deconstruct() which would
286        # shorten to a nonexistent module (same pattern as EncryptedTextField).
287        path = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
288        if self.encoder is not None:
289            kwargs["encoder"] = self.encoder
290        if self.decoder is not None:
291            kwargs["decoder"] = self.decoder
292        return name, path, args, kwargs
293
294    def validate(self, value: Any, model_instance: Any) -> None:
295        super().validate(value, model_instance)
296        try:
297            json.dumps(value, cls=self.encoder)
298        except TypeError:
299            raise exceptions.ValidationError(
300                self.error_messages["invalid"],
301                code="invalid",
302                params={"value": value},
303            )
304
305    def get_db_prep_value(
306        self, value: Any, connection: DatabaseConnection, prepared: bool = False
307    ) -> Any:
308        value = super().get_db_prep_value(value, connection, prepared)
309        if value is None:
310            return value
311        json_str = json.dumps(value, cls=self.encoder)
312        return _encrypt(json_str)
313
314    def from_db_value(
315        self, value: Any, expression: Any, connection: DatabaseConnection
316    ) -> Any:
317        if value is None:
318            return value
319        decrypted = _decrypt(value)
320        try:
321            return json.loads(decrypted, cls=self.decoder)
322        except json.JSONDecodeError:
323            raise ValueError(
324                "Encrypted field contains data that is not valid JSON. "
325                "The stored value may be corrupt."
326            )
327
328    def _check_default(self) -> list[PreflightResult]:
329        if (
330            self.has_default()
331            and self.default is not None
332            and not callable(self.default)
333        ):
334            return [
335                preflight.PreflightResult(
336                    fix=(
337                        f"{self.__class__.__name__} default should be a callable instead of an instance "
338                        "so that it's not shared between all field instances. "
339                        "Use a callable instead, e.g., use `{}` instead of "
340                        "`{}`.".format(*self._default_fix)
341                    ),
342                    obj=self,
343                    id="fields.encrypted_mutable_default",
344                    warning=True,
345                )
346            ]
347        else:
348            return []
349
350    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
351        errors = super().preflight(**kwargs)
352        errors.extend(self._check_default())
353        errors.extend(self._check_encrypted_constraints())
354        return errors