Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import json
  4from collections.abc import Callable, Sequence
  5from typing import TYPE_CHECKING, Any
  6
  7from plain import exceptions, preflight
  8from plain.models import expressions, lookups
  9from plain.models.backends.guards import is_mysql_connection, is_sqlite_connection
 10from plain.models.constants import LOOKUP_SEP
 11from plain.models.db import NotSupportedError, db_connection
 12from plain.models.fields import TextField
 13from plain.models.lookups import (
 14    FieldGetDbPrepValueMixin,
 15    Lookup,
 16    PostgresOperatorLookup,
 17    Transform,
 18)
 19
 20from . import Field
 21
 22if TYPE_CHECKING:
 23    from plain.models.backends.base.base import BaseDatabaseWrapper
 24    from plain.models.sql.compiler import SQLCompiler
 25    from plain.preflight.results import PreflightResult
 26
 27__all__ = ["JSONField"]
 28
 29
 30class JSONField(Field):
 31    empty_strings_allowed = False
 32    description = "A JSON object"
 33    default_error_messages = {
 34        "invalid": "Value must be valid JSON.",
 35    }
 36    _default_fix = ("dict", "{}")
 37
 38    def __init__(
 39        self,
 40        *,
 41        encoder: type[json.JSONEncoder] | None = None,
 42        decoder: type[json.JSONDecoder] | None = None,
 43        **kwargs: Any,
 44    ):
 45        if encoder and not callable(encoder):
 46            raise ValueError("The encoder parameter must be a callable object.")
 47        if decoder and not callable(decoder):
 48            raise ValueError("The decoder parameter must be a callable object.")
 49        self.encoder = encoder
 50        self.decoder = decoder
 51        super().__init__(**kwargs)
 52
 53    def _check_default(self) -> list[PreflightResult]:
 54        if (
 55            self.has_default()
 56            and self.default is not None
 57            and not callable(self.default)
 58        ):
 59            return [
 60                preflight.PreflightResult(
 61                    fix=(
 62                        f"{self.__class__.__name__} default should be a callable instead of an instance "
 63                        "so that it's not shared between all field instances. "
 64                        "Use a callable instead, e.g., use `{}` instead of "
 65                        "`{}`.".format(*self._default_fix)
 66                    ),
 67                    obj=self,
 68                    id="fields.invalid_choice_mixin_default",
 69                    warning=True,
 70                )
 71            ]
 72        else:
 73            return []
 74
 75    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
 76        errors = super().preflight(**kwargs)
 77        errors.extend(self._check_default())
 78        errors.extend(self._check_supported())
 79        return errors
 80
 81    def _check_supported(self) -> list[PreflightResult]:
 82        errors = []
 83
 84        if (
 85            self.model.model_options.required_db_vendor
 86            and self.model.model_options.required_db_vendor != db_connection.vendor
 87        ):
 88            return errors
 89
 90        if not (
 91            "supports_json_field" in self.model.model_options.required_db_features
 92            or db_connection.features.supports_json_field
 93        ):
 94            errors.append(
 95                preflight.PreflightResult(
 96                    fix=f"{db_connection.display_name} does not support JSONFields. Consider using a TextField with JSON serialization or upgrade to a database that supports JSON fields.",
 97                    obj=self.model,
 98                    id="fields.json_field_unsupported",
 99                )
100            )
101        return errors
102
103    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
104        name, path, args, kwargs = super().deconstruct()
105        if self.encoder is not None:
106            kwargs["encoder"] = self.encoder
107        if self.decoder is not None:
108            kwargs["decoder"] = self.decoder
109        return name, path, args, kwargs
110
111    def from_db_value(
112        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
113    ) -> Any:
114        if value is None:
115            return value
116        # Some backends (SQLite at least) extract non-string values in their
117        # SQL datatypes.
118        if isinstance(expression, KeyTransform) and not isinstance(value, str):
119            return value
120        try:
121            return json.loads(value, cls=self.decoder)
122        except json.JSONDecodeError:
123            return value
124
125    def get_internal_type(self) -> str:
126        return "JSONField"
127
128    def get_db_prep_value(
129        self, value: Any, connection: BaseDatabaseWrapper, prepared: bool = False
130    ) -> Any:
131        if isinstance(value, expressions.Value) and isinstance(
132            value.output_field, JSONField
133        ):
134            value = value.value
135        elif hasattr(value, "as_sql"):
136            return value
137        return connection.ops.adapt_json_value(value, self.encoder)
138
139    def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any:
140        if value is None:
141            return value
142        return self.get_db_prep_value(value, connection)
143
144    def get_transform(
145        self, lookup_name: str
146    ) -> type[Transform] | Callable[..., Any] | None:
147        # Always returns a transform (never None in practice)
148        transform = super().get_transform(lookup_name)
149        if transform:
150            return transform
151        return KeyTransformFactory(lookup_name)
152
153    def validate(self, value: Any, model_instance: Any) -> None:
154        super().validate(value, model_instance)
155        try:
156            json.dumps(value, cls=self.encoder)
157        except TypeError:
158            raise exceptions.ValidationError(
159                self.error_messages["invalid"],
160                code="invalid",
161                params={"value": value},
162            )
163
164    def value_to_string(self, obj: Any) -> Any:
165        return self.value_from_object(obj)
166
167
168def compile_json_path(key_transforms: list[Any], include_root: bool = True) -> str:
169    path = ["$"] if include_root else []
170    for key_transform in key_transforms:
171        try:
172            num = int(key_transform)
173        except ValueError:  # non-integer
174            path.append(".")
175            path.append(json.dumps(key_transform))
176        else:
177            path.append(f"[{num}]")
178    return "".join(path)
179
180
181class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
182    lookup_name = "contains"
183    postgres_operator = "@>"
184
185    def as_sql(
186        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
187    ) -> tuple[str, tuple[Any, ...]]:
188        if not connection.features.supports_json_field_contains:
189            raise NotSupportedError(
190                "contains lookup is not supported on this database backend."
191            )
192        lhs, lhs_params = self.process_lhs(compiler, connection)
193        rhs, rhs_params = self.process_rhs(compiler, connection)
194        params = tuple(lhs_params) + tuple(rhs_params)
195        return f"JSON_CONTAINS({lhs}, {rhs})", params
196
197
198class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
199    lookup_name = "contained_by"
200    postgres_operator = "<@"
201
202    def as_sql(
203        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
204    ) -> tuple[str, tuple[Any, ...]]:
205        if not connection.features.supports_json_field_contains:
206            raise NotSupportedError(
207                "contained_by lookup is not supported on this database backend."
208            )
209        lhs, lhs_params = self.process_lhs(compiler, connection)
210        rhs, rhs_params = self.process_rhs(compiler, connection)
211        params = tuple(rhs_params) + tuple(lhs_params)
212        return f"JSON_CONTAINS({rhs}, {lhs})", params
213
214
215class HasKeyLookup(PostgresOperatorLookup):
216    logical_operator: str | None = None
217
218    def compile_json_path_final_key(self, key_transform: Any) -> str:
219        # Compile the final key without interpreting ints as array elements.
220        return f".{json.dumps(key_transform)}"
221
222    def as_sql(
223        self,
224        compiler: SQLCompiler,
225        connection: BaseDatabaseWrapper,
226        template: str | None = None,
227    ) -> tuple[str, tuple[Any, ...]]:
228        # Process JSON path from the left-hand side.
229        if isinstance(self.lhs, KeyTransform):
230            lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
231                compiler, connection
232            )
233            lhs_json_path = compile_json_path(lhs_key_transforms)
234        else:
235            lhs, lhs_params = self.process_lhs(compiler, connection)
236            lhs_json_path = "$"
237        effective_template = template or "%s"
238        sql = effective_template % lhs
239        # Process JSON path from the right-hand side.
240        rhs = self.rhs
241        rhs_params = []
242        if not isinstance(rhs, list | tuple):
243            rhs = [rhs]
244        for key in rhs:
245            if isinstance(key, KeyTransform):
246                *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
247            else:
248                rhs_key_transforms = [key]
249            *rhs_key_transforms, final_key = rhs_key_transforms
250            rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
251            rhs_json_path += self.compile_json_path_final_key(final_key)
252            rhs_params.append(lhs_json_path + rhs_json_path)
253        # Add condition for each key.
254        if self.logical_operator:
255            sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
256        return sql, tuple(lhs_params) + tuple(rhs_params)
257
258    def as_mysql(
259        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
260    ) -> tuple[str, tuple[Any, ...]]:
261        return self.as_sql(
262            compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
263        )
264
265    def as_postgresql(
266        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
267    ) -> tuple[str, tuple[Any, ...]]:
268        if isinstance(self.rhs, KeyTransform):
269            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
270            for key in rhs_key_transforms[:-1]:
271                self.lhs = KeyTransform(key, self.lhs)
272            self.rhs = rhs_key_transforms[-1]
273        return super().as_postgresql(compiler, connection)
274
275    def as_sqlite(
276        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
277    ) -> tuple[str, tuple[Any, ...]]:
278        return self.as_sql(
279            compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
280        )
281
282
283class HasKey(HasKeyLookup):
284    lookup_name = "has_key"
285    postgres_operator = "?"
286    prepare_rhs = False
287
288
289class HasKeys(HasKeyLookup):
290    lookup_name = "has_keys"
291    postgres_operator = "?&"
292    logical_operator = " AND "
293
294    def get_prep_lookup(self) -> list[str]:
295        return [str(item) for item in self.rhs]
296
297
298class HasAnyKeys(HasKeys):
299    lookup_name = "has_any_keys"
300    postgres_operator = "?|"
301    logical_operator = " OR "
302
303
304class HasKeyOrArrayIndex(HasKey):
305    def compile_json_path_final_key(self, key_transform: Any) -> str:
306        return compile_json_path([key_transform], include_root=False)
307
308
309class CaseInsensitiveMixin(Lookup):
310    """
311    Mixin to allow case-insensitive comparison of JSON values on MySQL.
312    MySQL handles strings used in JSON context using the utf8mb4_bin collation.
313    Because utf8mb4_bin is a binary collation, comparison of JSON values is
314    case-sensitive.
315    """
316
317    def process_lhs(
318        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
319    ) -> tuple[str, list[Any]]:
320        lhs_sql, lhs_params = super().process_lhs(compiler, connection, lhs)
321        if connection.vendor == "mysql":
322            return f"LOWER({lhs_sql})", lhs_params
323        return lhs_sql, lhs_params
324
325    def process_rhs(
326        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
327    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
328        rhs, rhs_params = super().process_rhs(compiler, connection)
329        if isinstance(rhs, str):
330            if connection.vendor == "mysql":
331                return f"LOWER({rhs})", rhs_params
332            return rhs, rhs_params
333        else:
334            return rhs, rhs_params
335
336
337class JSONExact(lookups.Exact):
338    can_use_none_as_rhs = True
339
340    def process_rhs(
341        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
342    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
343        rhs, rhs_params = super().process_rhs(compiler, connection)
344        if isinstance(rhs, str):
345            # Treat None lookup values as null.
346            if rhs == "%s" and rhs_params == [None]:
347                rhs_params = ["null"]
348            if connection.vendor == "mysql":
349                func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
350                rhs = rhs % tuple(func)
351            return rhs, rhs_params
352        else:
353            return rhs, rhs_params
354
355
356class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
357    pass
358
359
360JSONField.register_lookup(DataContains)
361JSONField.register_lookup(ContainedBy)
362JSONField.register_lookup(HasKey)
363JSONField.register_lookup(HasKeys)
364JSONField.register_lookup(HasAnyKeys)
365JSONField.register_lookup(JSONExact)
366JSONField.register_lookup(JSONIContains)
367
368
369class KeyTransform(Transform):
370    postgres_operator = "->"
371    postgres_nested_operator = "#>"
372
373    def __init__(self, key_name: str, *args: Any, **kwargs: Any):
374        super().__init__(*args, **kwargs)
375        self.key_name = str(key_name)
376
377    def preprocess_lhs(
378        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
379    ) -> tuple[str, tuple[Any, ...], list[str]]:
380        key_transforms = [self.key_name]
381        previous = self.lhs
382        while isinstance(previous, KeyTransform):
383            key_transforms.insert(0, previous.key_name)
384            previous = previous.lhs
385        lhs, params = compiler.compile(previous)
386        return lhs, params, key_transforms
387
388    def as_mysql(
389        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
390    ) -> tuple[str, tuple[Any, ...]]:
391        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
392        json_path = compile_json_path(key_transforms)
393        return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
394
395    def as_postgresql(
396        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
397    ) -> tuple[str, tuple[Any, ...]]:
398        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
399        if len(key_transforms) > 1:
400            sql = f"({lhs} {self.postgres_nested_operator} %s)"
401            return sql, tuple(params) + (key_transforms,)
402        try:
403            lookup = int(self.key_name)
404        except ValueError:
405            lookup = self.key_name
406        return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
407
408    def as_sqlite(
409        self,
410        compiler: SQLCompiler,
411        connection: BaseDatabaseWrapper,
412        **extra_context: Any,
413    ) -> tuple[str, Sequence[Any]]:
414        # as_sqlite is only called for SQLite connections
415        assert is_sqlite_connection(connection)
416        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
417        json_path = compile_json_path(key_transforms)
418        datatype_values = ",".join(
419            [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
420        )
421        return (
422            f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
423            f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
424        ), (tuple(params) + (json_path,)) * 3
425
426
427class KeyTextTransform(KeyTransform):
428    postgres_operator = "->>"
429    postgres_nested_operator = "#>>"
430    output_field = TextField()
431
432    def as_mysql(
433        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
434    ) -> tuple[str, tuple[Any, ...]]:
435        # as_mysql is only called for MySQL connections
436        assert is_mysql_connection(connection)
437        if connection.mysql_is_mariadb:
438            # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
439            sql, params = super().as_mysql(compiler, connection)
440            return f"JSON_UNQUOTE({sql})", params
441        else:
442            lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
443            json_path = compile_json_path(key_transforms)
444            return f"({lhs} ->> %s)", tuple(params) + (json_path,)
445
446    @classmethod
447    def from_lookup(cls, lookup: str) -> Any:
448        transform, *keys = lookup.split(LOOKUP_SEP)
449        if not keys:
450            raise ValueError("Lookup must contain key or index transforms.")
451        for key in keys:
452            transform = cls(key, transform)
453        return transform
454
455
456KT = KeyTextTransform.from_lookup
457
458
459class KeyTransformTextLookupMixin(Lookup):
460    """
461    Mixin for combining with a lookup expecting a text lhs from a JSONField
462    key lookup. On PostgreSQL, make use of the ->> operator instead of casting
463    key values to text and performing the lookup on the resulting
464    representation.
465    """
466
467    def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
468        if not isinstance(key_transform, KeyTransform):
469            raise TypeError(
470                "Transform should be an instance of KeyTransform in order to "
471                "use this lookup."
472            )
473        key_text_transform = KeyTextTransform(
474            key_transform.key_name,
475            *key_transform.source_expressions,
476            **key_transform.extra,
477        )
478        super().__init__(key_text_transform, *args, **kwargs)
479
480
481class KeyTransformIsNull(lookups.IsNull):
482    # key__isnull=False is the same as has_key='key'
483    def as_sqlite(
484        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
485    ) -> tuple[str, tuple[Any, ...]]:
486        template = "JSON_TYPE(%s, %%s) IS NULL"
487        if not self.rhs:
488            template = "JSON_TYPE(%s, %%s) IS NOT NULL"
489        return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
490            compiler,
491            connection,
492            template=template,
493        )
494
495
496class KeyTransformIn(lookups.In):
497    def resolve_expression_parameter(
498        self,
499        compiler: SQLCompiler,
500        connection: BaseDatabaseWrapper,
501        sql: str,
502        param: Any,
503    ) -> tuple[str, list[Any]]:
504        sql, params = super().resolve_expression_parameter(
505            compiler,
506            connection,
507            sql,
508            param,
509        )
510        if (
511            not hasattr(param, "as_sql")
512            and not connection.features.has_native_json_field
513        ):
514            if connection.vendor == "mysql":
515                sql = "JSON_EXTRACT(%s, '$')"
516            elif is_sqlite_connection(connection):
517                # Type guard narrows connection to SQLiteDatabaseWrapper
518                if params[0] not in connection.ops.jsonfield_datatype_values:
519                    sql = "JSON_EXTRACT(%s, '$')"
520        if is_mysql_connection(connection):
521            # Type guard narrows connection to MySQLDatabaseWrapper
522            if connection.mysql_is_mariadb:
523                sql = f"JSON_UNQUOTE({sql})"
524        return sql, list(params)
525
526
527class KeyTransformExact(JSONExact):
528    def process_rhs(
529        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
530    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
531        if isinstance(self.rhs, KeyTransform):
532            return super(lookups.Exact, self).process_rhs(compiler, connection)
533        rhs, rhs_params = super().process_rhs(compiler, connection)
534        if isinstance(rhs, str):
535            if is_sqlite_connection(connection):
536                # Type guard narrows connection to SQLiteDatabaseWrapper
537                func = []
538                for value in rhs_params:
539                    if value in connection.ops.jsonfield_datatype_values:
540                        func.append("%s")
541                    else:
542                        func.append("JSON_EXTRACT(%s, '$')")
543                rhs = rhs % tuple(func)
544            return rhs, rhs_params
545        else:
546            return rhs, rhs_params
547
548
549class KeyTransformIExact(
550    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
551):
552    pass
553
554
555class KeyTransformIContains(
556    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
557):
558    pass
559
560
561class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
562    pass
563
564
565class KeyTransformIStartsWith(
566    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
567):
568    pass
569
570
571class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
572    pass
573
574
575class KeyTransformIEndsWith(
576    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
577):
578    pass
579
580
581class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
582    pass
583
584
585class KeyTransformIRegex(
586    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
587):
588    pass
589
590
591class KeyTransformNumericLookupMixin(Lookup):
592    def process_rhs(
593        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
594    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
595        rhs, rhs_params = super().process_rhs(compiler, connection)
596        if isinstance(rhs, str):
597            if not connection.features.has_native_json_field:
598                rhs_params = [json.loads(value) for value in rhs_params]
599            return rhs, rhs_params
600        else:
601            return rhs, rhs_params
602
603
604class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
605    pass
606
607
608class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
609    pass
610
611
612class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
613    pass
614
615
616class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
617    pass
618
619
620KeyTransform.register_lookup(KeyTransformIn)
621KeyTransform.register_lookup(KeyTransformExact)
622KeyTransform.register_lookup(KeyTransformIExact)
623KeyTransform.register_lookup(KeyTransformIsNull)
624KeyTransform.register_lookup(KeyTransformIContains)
625KeyTransform.register_lookup(KeyTransformStartsWith)
626KeyTransform.register_lookup(KeyTransformIStartsWith)
627KeyTransform.register_lookup(KeyTransformEndsWith)
628KeyTransform.register_lookup(KeyTransformIEndsWith)
629KeyTransform.register_lookup(KeyTransformRegex)
630KeyTransform.register_lookup(KeyTransformIRegex)
631
632KeyTransform.register_lookup(KeyTransformLt)
633KeyTransform.register_lookup(KeyTransformLte)
634KeyTransform.register_lookup(KeyTransformGt)
635KeyTransform.register_lookup(KeyTransformGte)
636
637
638class KeyTransformFactory:
639    def __init__(self, key_name: str):
640        self.key_name = key_name
641
642    def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
643        return KeyTransform(self.key_name, *args, **kwargs)