Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import json
  4from typing import TYPE_CHECKING, Any, cast
  5
  6from plain import exceptions, preflight
  7from plain.models import expressions, lookups
  8from plain.models.constants import LOOKUP_SEP
  9from plain.models.db import NotSupportedError, db_connection
 10from plain.models.fields import TextField
 11from plain.models.lookups import (
 12    FieldGetDbPrepValueMixin,
 13    PostgresOperatorLookup,
 14    Transform,
 15)
 16
 17from . import Field
 18from .mixins import CheckFieldDefaultMixin
 19
 20if TYPE_CHECKING:
 21    from plain.models.backends.base.base import BaseDatabaseWrapper
 22    from plain.models.backends.mysql.base import MySQLDatabaseWrapper
 23    from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
 24    from plain.models.sql.compiler import SQLCompiler
 25    from plain.preflight.results import PreflightResult
 26
 27__all__ = ["JSONField"]
 28
 29
 30class JSONField(CheckFieldDefaultMixin, Field):
 31    empty_strings_allowed = False
 32    description = "A JSON object"
 33    default_error_messages = {
 34        "invalid": "Value must be valid JSON.",
 35    }
 36    _default_fix = ("dict", "{}")
 37
 38    def __init__(
 39        self,
 40        *,
 41        encoder: type[json.JSONEncoder] | None = None,
 42        decoder: type[json.JSONDecoder] | None = None,
 43        **kwargs: Any,
 44    ):
 45        if encoder and not callable(encoder):
 46            raise ValueError("The encoder parameter must be a callable object.")
 47        if decoder and not callable(decoder):
 48            raise ValueError("The decoder parameter must be a callable object.")
 49        self.encoder = encoder
 50        self.decoder = decoder
 51        super().__init__(**kwargs)
 52
 53    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
 54        errors = super().preflight(**kwargs)
 55        errors.extend(self._check_supported())
 56        return errors
 57
 58    def _check_supported(self) -> list[PreflightResult]:
 59        errors = []
 60
 61        if (
 62            self.model.model_options.required_db_vendor
 63            and self.model.model_options.required_db_vendor != db_connection.vendor
 64        ):
 65            return errors
 66
 67        if not (
 68            "supports_json_field" in self.model.model_options.required_db_features
 69            or db_connection.features.supports_json_field
 70        ):
 71            errors.append(
 72                preflight.PreflightResult(
 73                    fix=f"{db_connection.display_name} does not support JSONFields. Consider using a TextField with JSON serialization or upgrade to a database that supports JSON fields.",
 74                    obj=self.model,
 75                    id="fields.json_field_unsupported",
 76                )
 77            )
 78        return errors
 79
 80    def deconstruct(self) -> tuple[str, str, list[Any], dict[str, Any]]:
 81        name, path, args, kwargs = super().deconstruct()
 82        if self.encoder is not None:
 83            kwargs["encoder"] = self.encoder
 84        if self.decoder is not None:
 85            kwargs["decoder"] = self.decoder
 86        return name, path, args, kwargs
 87
 88    def from_db_value(
 89        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
 90    ) -> Any:
 91        if value is None:
 92            return value
 93        # Some backends (SQLite at least) extract non-string values in their
 94        # SQL datatypes.
 95        if isinstance(expression, KeyTransform) and not isinstance(value, str):
 96            return value
 97        try:
 98            return json.loads(value, cls=self.decoder)
 99        except json.JSONDecodeError:
100            return value
101
102    def get_internal_type(self) -> str:
103        return "JSONField"
104
105    def get_db_prep_value(
106        self, value: Any, connection: BaseDatabaseWrapper, prepared: bool = False
107    ) -> Any:
108        if isinstance(value, expressions.Value) and isinstance(
109            value.output_field, JSONField
110        ):
111            value = value.value
112        elif hasattr(value, "as_sql"):
113            return value
114        return connection.ops.adapt_json_value(value, self.encoder)
115
116    def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any:
117        if value is None:
118            return value
119        return self.get_db_prep_value(value, connection)
120
121    def get_transform(self, name: str) -> KeyTransformFactory | type[Transform]:
122        transform = super().get_transform(name)
123        if transform:
124            return transform
125        return KeyTransformFactory(name)
126
127    def validate(self, value: Any, model_instance: Any) -> None:
128        super().validate(value, model_instance)
129        try:
130            json.dumps(value, cls=self.encoder)
131        except TypeError:
132            raise exceptions.ValidationError(
133                self.error_messages["invalid"],
134                code="invalid",
135                params={"value": value},
136            )
137
138    def value_to_string(self, obj: Any) -> Any:
139        return self.value_from_object(obj)
140
141
142def compile_json_path(key_transforms: list[Any], include_root: bool = True) -> str:
143    path = ["$"] if include_root else []
144    for key_transform in key_transforms:
145        try:
146            num = int(key_transform)
147        except ValueError:  # non-integer
148            path.append(".")
149            path.append(json.dumps(key_transform))
150        else:
151            path.append(f"[{num}]")
152    return "".join(path)
153
154
155class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
156    lookup_name = "contains"
157    postgres_operator = "@>"
158
159    def as_sql(
160        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
161    ) -> tuple[str, tuple[Any, ...]]:
162        if not connection.features.supports_json_field_contains:
163            raise NotSupportedError(
164                "contains lookup is not supported on this database backend."
165            )
166        lhs, lhs_params = self.process_lhs(compiler, connection)
167        rhs, rhs_params = self.process_rhs(compiler, connection)
168        params = tuple(lhs_params) + tuple(rhs_params)
169        return f"JSON_CONTAINS({lhs}, {rhs})", params
170
171
172class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
173    lookup_name = "contained_by"
174    postgres_operator = "<@"
175
176    def as_sql(
177        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
178    ) -> tuple[str, tuple[Any, ...]]:
179        if not connection.features.supports_json_field_contains:
180            raise NotSupportedError(
181                "contained_by lookup is not supported on this database backend."
182            )
183        lhs, lhs_params = self.process_lhs(compiler, connection)
184        rhs, rhs_params = self.process_rhs(compiler, connection)
185        params = tuple(rhs_params) + tuple(lhs_params)
186        return f"JSON_CONTAINS({rhs}, {lhs})", params
187
188
189class HasKeyLookup(PostgresOperatorLookup):
190    logical_operator: str | None = None
191
192    def compile_json_path_final_key(self, key_transform: Any) -> str:
193        # Compile the final key without interpreting ints as array elements.
194        return f".{json.dumps(key_transform)}"
195
196    def as_sql(
197        self,
198        compiler: SQLCompiler,
199        connection: BaseDatabaseWrapper,
200        template: str | None = None,
201    ) -> tuple[str, tuple[Any, ...]]:
202        # Process JSON path from the left-hand side.
203        if isinstance(self.lhs, KeyTransform):
204            lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
205                compiler, connection
206            )
207            lhs_json_path = compile_json_path(lhs_key_transforms)
208        else:
209            lhs, lhs_params = self.process_lhs(compiler, connection)
210            lhs_json_path = "$"
211        sql = template % lhs
212        # Process JSON path from the right-hand side.
213        rhs = self.rhs
214        rhs_params = []
215        if not isinstance(rhs, list | tuple):
216            rhs = [rhs]
217        for key in rhs:
218            if isinstance(key, KeyTransform):
219                *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
220            else:
221                rhs_key_transforms = [key]
222            *rhs_key_transforms, final_key = rhs_key_transforms
223            rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
224            rhs_json_path += self.compile_json_path_final_key(final_key)
225            rhs_params.append(lhs_json_path + rhs_json_path)
226        # Add condition for each key.
227        if self.logical_operator:
228            sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
229        return sql, tuple(lhs_params) + tuple(rhs_params)
230
231    def as_mysql(
232        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
233    ) -> tuple[str, tuple[Any, ...]]:
234        return self.as_sql(
235            compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
236        )
237
238    def as_postgresql(
239        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
240    ) -> tuple[str, tuple[Any, ...]]:
241        if isinstance(self.rhs, KeyTransform):
242            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
243            for key in rhs_key_transforms[:-1]:
244                self.lhs = KeyTransform(key, self.lhs)
245            self.rhs = rhs_key_transforms[-1]
246        return super().as_postgresql(compiler, connection)
247
248    def as_sqlite(
249        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
250    ) -> tuple[str, tuple[Any, ...]]:
251        return self.as_sql(
252            compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
253        )
254
255
256class HasKey(HasKeyLookup):
257    lookup_name = "has_key"
258    postgres_operator = "?"
259    prepare_rhs = False
260
261
262class HasKeys(HasKeyLookup):
263    lookup_name = "has_keys"
264    postgres_operator = "?&"
265    logical_operator = " AND "
266
267    def get_prep_lookup(self) -> list[str]:
268        return [str(item) for item in self.rhs]
269
270
271class HasAnyKeys(HasKeys):
272    lookup_name = "has_any_keys"
273    postgres_operator = "?|"
274    logical_operator = " OR "
275
276
277class HasKeyOrArrayIndex(HasKey):
278    def compile_json_path_final_key(self, key_transform: Any) -> str:
279        return compile_json_path([key_transform], include_root=False)
280
281
282class CaseInsensitiveMixin:
283    """
284    Mixin to allow case-insensitive comparison of JSON values on MySQL.
285    MySQL handles strings used in JSON context using the utf8mb4_bin collation.
286    Because utf8mb4_bin is a binary collation, comparison of JSON values is
287    case-sensitive.
288    """
289
290    def process_lhs(
291        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
292    ) -> tuple[str, list[Any]]:
293        lhs, lhs_params = super().process_lhs(compiler, connection)  # type: ignore[misc]
294        if connection.vendor == "mysql":
295            return f"LOWER({lhs})", lhs_params
296        return lhs, lhs_params
297
298    def process_rhs(
299        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
300    ) -> tuple[str, list[Any]]:
301        rhs, rhs_params = super().process_rhs(compiler, connection)  # type: ignore[misc]
302        if connection.vendor == "mysql":
303            return f"LOWER({rhs})", rhs_params
304        return rhs, rhs_params
305
306
307class JSONExact(lookups.Exact):
308    can_use_none_as_rhs = True
309
310    def process_rhs(
311        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
312    ) -> tuple[str, list[Any]]:
313        rhs, rhs_params = super().process_rhs(compiler, connection)
314        # Treat None lookup values as null.
315        if rhs == "%s" and rhs_params == [None]:
316            rhs_params = ["null"]
317        if connection.vendor == "mysql":
318            func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
319            rhs %= tuple(func)
320        return rhs, rhs_params
321
322
323class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
324    pass
325
326
327JSONField.register_lookup(DataContains)
328JSONField.register_lookup(ContainedBy)
329JSONField.register_lookup(HasKey)
330JSONField.register_lookup(HasKeys)
331JSONField.register_lookup(HasAnyKeys)
332JSONField.register_lookup(JSONExact)
333JSONField.register_lookup(JSONIContains)
334
335
336class KeyTransform(Transform):
337    postgres_operator = "->"
338    postgres_nested_operator = "#>"
339
340    def __init__(self, key_name: str, *args: Any, **kwargs: Any):
341        super().__init__(*args, **kwargs)
342        self.key_name = str(key_name)
343
344    def preprocess_lhs(
345        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
346    ) -> tuple[str, tuple[Any, ...], list[str]]:
347        key_transforms = [self.key_name]
348        previous = self.lhs
349        while isinstance(previous, KeyTransform):
350            key_transforms.insert(0, previous.key_name)
351            previous = previous.lhs
352        lhs, params = compiler.compile(previous)
353        return lhs, params, key_transforms
354
355    def as_mysql(
356        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
357    ) -> tuple[str, tuple[Any, ...]]:
358        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
359        json_path = compile_json_path(key_transforms)
360        return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
361
362    def as_postgresql(
363        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
364    ) -> tuple[str, tuple[Any, ...]]:
365        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
366        if len(key_transforms) > 1:
367            sql = f"({lhs} {self.postgres_nested_operator} %s)"
368            return sql, tuple(params) + (key_transforms,)
369        try:
370            lookup = int(self.key_name)
371        except ValueError:
372            lookup = self.key_name
373        return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
374
375    def as_sqlite(
376        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
377    ) -> tuple[str, tuple[Any, ...]]:
378        sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
379        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
380        json_path = compile_json_path(key_transforms)
381        datatype_values = ",".join(
382            [
383                repr(datatype)
384                for datatype in sqlite_connection.ops.jsonfield_datatype_values  # type: ignore[attr-defined]
385            ]
386        )
387        return (
388            f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
389            f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
390        ), (tuple(params) + (json_path,)) * 3
391
392
393class KeyTextTransform(KeyTransform):
394    postgres_operator = "->>"
395    postgres_nested_operator = "#>>"
396    output_field = TextField()
397
398    def as_mysql(
399        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
400    ) -> tuple[str, tuple[Any, ...]]:
401        mysql_connection = cast(MySQLDatabaseWrapper, connection)
402        if mysql_connection.mysql_is_mariadb:
403            # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
404            sql, params = super().as_mysql(compiler, connection)
405            return f"JSON_UNQUOTE({sql})", params
406        else:
407            lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
408            json_path = compile_json_path(key_transforms)
409            return f"({lhs} ->> %s)", tuple(params) + (json_path,)
410
411    @classmethod
412    def from_lookup(cls, lookup: str) -> Any:
413        transform, *keys = lookup.split(LOOKUP_SEP)
414        if not keys:
415            raise ValueError("Lookup must contain key or index transforms.")
416        for key in keys:
417            transform = cls(key, transform)
418        return transform
419
420
421KT = KeyTextTransform.from_lookup
422
423
424class KeyTransformTextLookupMixin:
425    """
426    Mixin for combining with a lookup expecting a text lhs from a JSONField
427    key lookup. On PostgreSQL, make use of the ->> operator instead of casting
428    key values to text and performing the lookup on the resulting
429    representation.
430    """
431
432    def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
433        if not isinstance(key_transform, KeyTransform):
434            raise TypeError(
435                "Transform should be an instance of KeyTransform in order to "
436                "use this lookup."
437            )
438        key_text_transform = KeyTextTransform(
439            key_transform.key_name,
440            *key_transform.source_expressions,
441            **key_transform.extra,
442        )
443        super().__init__(key_text_transform, *args, **kwargs)  # type: ignore[misc]
444
445
446class KeyTransformIsNull(lookups.IsNull):
447    # key__isnull=False is the same as has_key='key'
448    def as_sqlite(
449        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
450    ) -> tuple[str, tuple[Any, ...]]:
451        template = "JSON_TYPE(%s, %%s) IS NULL"
452        if not self.rhs:
453            template = "JSON_TYPE(%s, %%s) IS NOT NULL"
454        return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
455            compiler,
456            connection,
457            template=template,
458        )
459
460
461class KeyTransformIn(lookups.In):
462    def resolve_expression_parameter(
463        self,
464        compiler: SQLCompiler,
465        connection: BaseDatabaseWrapper,
466        sql: str,
467        param: Any,
468    ) -> tuple[str, tuple[Any, ...]]:
469        sql, params = super().resolve_expression_parameter(
470            compiler,
471            connection,
472            sql,
473            param,
474        )
475        if (
476            not hasattr(param, "as_sql")
477            and not connection.features.has_native_json_field
478        ):
479            if connection.vendor == "mysql":
480                sql = "JSON_EXTRACT(%s, '$')"
481            elif connection.vendor == "sqlite":
482                sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
483                if params[0] not in sqlite_connection.ops.jsonfield_datatype_values:  # type: ignore[attr-defined]
484                    sql = "JSON_EXTRACT(%s, '$')"
485        if connection.vendor == "mysql":
486            mysql_connection = cast(MySQLDatabaseWrapper, connection)
487            if mysql_connection.mysql_is_mariadb:
488                sql = f"JSON_UNQUOTE({sql})"
489        return sql, params
490
491
492class KeyTransformExact(JSONExact):
493    def process_rhs(
494        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
495    ) -> tuple[str, list[Any]]:
496        if isinstance(self.rhs, KeyTransform):
497            return super(lookups.Exact, self).process_rhs(compiler, connection)
498        rhs, rhs_params = super().process_rhs(compiler, connection)
499        if connection.vendor == "sqlite":
500            sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
501            func = []
502            for value in rhs_params:
503                if value in sqlite_connection.ops.jsonfield_datatype_values:  # type: ignore[attr-defined]
504                    func.append("%s")
505                else:
506                    func.append("JSON_EXTRACT(%s, '$')")
507            rhs %= tuple(func)
508        return rhs, rhs_params
509
510
511class KeyTransformIExact(
512    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
513):
514    pass
515
516
517class KeyTransformIContains(
518    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
519):
520    pass
521
522
523class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
524    pass
525
526
527class KeyTransformIStartsWith(
528    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
529):
530    pass
531
532
533class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
534    pass
535
536
537class KeyTransformIEndsWith(
538    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
539):
540    pass
541
542
543class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
544    pass
545
546
547class KeyTransformIRegex(
548    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
549):
550    pass
551
552
553class KeyTransformNumericLookupMixin:
554    def process_rhs(
555        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
556    ) -> tuple[str, list[Any]]:
557        rhs, rhs_params = super().process_rhs(compiler, connection)  # type: ignore[misc]
558        if not connection.features.has_native_json_field:
559            rhs_params = [json.loads(value) for value in rhs_params]
560        return rhs, rhs_params
561
562
563class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
564    pass
565
566
567class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
568    pass
569
570
571class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
572    pass
573
574
575class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
576    pass
577
578
579KeyTransform.register_lookup(KeyTransformIn)
580KeyTransform.register_lookup(KeyTransformExact)
581KeyTransform.register_lookup(KeyTransformIExact)
582KeyTransform.register_lookup(KeyTransformIsNull)
583KeyTransform.register_lookup(KeyTransformIContains)
584KeyTransform.register_lookup(KeyTransformStartsWith)
585KeyTransform.register_lookup(KeyTransformIStartsWith)
586KeyTransform.register_lookup(KeyTransformEndsWith)
587KeyTransform.register_lookup(KeyTransformIEndsWith)
588KeyTransform.register_lookup(KeyTransformRegex)
589KeyTransform.register_lookup(KeyTransformIRegex)
590
591KeyTransform.register_lookup(KeyTransformLt)
592KeyTransform.register_lookup(KeyTransformLte)
593KeyTransform.register_lookup(KeyTransformGt)
594KeyTransform.register_lookup(KeyTransformGte)
595
596
597class KeyTransformFactory:
598    def __init__(self, key_name: str):
599        self.key_name = key_name
600
601    def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
602        return KeyTransform(self.key_name, *args, **kwargs)