Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2
  3from plain import exceptions, preflight
  4from plain.models import expressions, lookups
  5from plain.models.constants import LOOKUP_SEP
  6from plain.models.db import NotSupportedError, db_connection
  7from plain.models.fields import TextField
  8from plain.models.lookups import (
  9    FieldGetDbPrepValueMixin,
 10    PostgresOperatorLookup,
 11    Transform,
 12)
 13
 14from . import Field
 15from .mixins import CheckFieldDefaultMixin
 16
 17__all__ = ["JSONField"]
 18
 19
 20class JSONField(CheckFieldDefaultMixin, Field):
 21    empty_strings_allowed = False
 22    description = "A JSON object"
 23    default_error_messages = {
 24        "invalid": "Value must be valid JSON.",
 25    }
 26    _default_hint = ("dict", "{}")
 27
 28    def __init__(
 29        self,
 30        *,
 31        encoder=None,
 32        decoder=None,
 33        **kwargs,
 34    ):
 35        if encoder and not callable(encoder):
 36            raise ValueError("The encoder parameter must be a callable object.")
 37        if decoder and not callable(decoder):
 38            raise ValueError("The decoder parameter must be a callable object.")
 39        self.encoder = encoder
 40        self.decoder = decoder
 41        super().__init__(**kwargs)
 42
 43    def check(self, **kwargs):
 44        errors = super().check(**kwargs)
 45        database = kwargs.get("database", False)
 46        errors.extend(self._check_supported(database))
 47        return errors
 48
 49    def _check_supported(self, database):
 50        errors = []
 51        if database:
 52            if (
 53                self.model._meta.required_db_vendor
 54                and self.model._meta.required_db_vendor != db_connection.vendor
 55            ):
 56                return errors
 57            if not (
 58                "supports_json_field" in self.model._meta.required_db_features
 59                or db_connection.features.supports_json_field
 60            ):
 61                errors.append(
 62                    preflight.Error(
 63                        f"{db_connection.display_name} does not support JSONFields.",
 64                        obj=self.model,
 65                        id="fields.E180",
 66                    )
 67                )
 68        return errors
 69
 70    def deconstruct(self):
 71        name, path, args, kwargs = super().deconstruct()
 72        if self.encoder is not None:
 73            kwargs["encoder"] = self.encoder
 74        if self.decoder is not None:
 75            kwargs["decoder"] = self.decoder
 76        return name, path, args, kwargs
 77
 78    def from_db_value(self, value, expression, connection):
 79        if value is None:
 80            return value
 81        # Some backends (SQLite at least) extract non-string values in their
 82        # SQL datatypes.
 83        if isinstance(expression, KeyTransform) and not isinstance(value, str):
 84            return value
 85        try:
 86            return json.loads(value, cls=self.decoder)
 87        except json.JSONDecodeError:
 88            return value
 89
 90    def get_internal_type(self):
 91        return "JSONField"
 92
 93    def get_db_prep_value(self, value, connection, prepared=False):
 94        if isinstance(value, expressions.Value) and isinstance(
 95            value.output_field, JSONField
 96        ):
 97            value = value.value
 98        elif hasattr(value, "as_sql"):
 99            return value
100        return connection.ops.adapt_json_value(value, self.encoder)
101
102    def get_db_prep_save(self, value, connection):
103        if value is None:
104            return value
105        return self.get_db_prep_value(value, connection)
106
107    def get_transform(self, name):
108        transform = super().get_transform(name)
109        if transform:
110            return transform
111        return KeyTransformFactory(name)
112
113    def validate(self, value, model_instance):
114        super().validate(value, model_instance)
115        try:
116            json.dumps(value, cls=self.encoder)
117        except TypeError:
118            raise exceptions.ValidationError(
119                self.error_messages["invalid"],
120                code="invalid",
121                params={"value": value},
122            )
123
124    def value_to_string(self, obj):
125        return self.value_from_object(obj)
126
127
128def compile_json_path(key_transforms, include_root=True):
129    path = ["$"] if include_root else []
130    for key_transform in key_transforms:
131        try:
132            num = int(key_transform)
133        except ValueError:  # non-integer
134            path.append(".")
135            path.append(json.dumps(key_transform))
136        else:
137            path.append(f"[{num}]")
138    return "".join(path)
139
140
141class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
142    lookup_name = "contains"
143    postgres_operator = "@>"
144
145    def as_sql(self, compiler, connection):
146        if not connection.features.supports_json_field_contains:
147            raise NotSupportedError(
148                "contains lookup is not supported on this database backend."
149            )
150        lhs, lhs_params = self.process_lhs(compiler, connection)
151        rhs, rhs_params = self.process_rhs(compiler, connection)
152        params = tuple(lhs_params) + tuple(rhs_params)
153        return f"JSON_CONTAINS({lhs}, {rhs})", params
154
155
156class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
157    lookup_name = "contained_by"
158    postgres_operator = "<@"
159
160    def as_sql(self, compiler, connection):
161        if not connection.features.supports_json_field_contains:
162            raise NotSupportedError(
163                "contained_by lookup is not supported on this database backend."
164            )
165        lhs, lhs_params = self.process_lhs(compiler, connection)
166        rhs, rhs_params = self.process_rhs(compiler, connection)
167        params = tuple(rhs_params) + tuple(lhs_params)
168        return f"JSON_CONTAINS({rhs}, {lhs})", params
169
170
171class HasKeyLookup(PostgresOperatorLookup):
172    logical_operator = None
173
174    def compile_json_path_final_key(self, key_transform):
175        # Compile the final key without interpreting ints as array elements.
176        return f".{json.dumps(key_transform)}"
177
178    def as_sql(self, compiler, connection, template=None):
179        # Process JSON path from the left-hand side.
180        if isinstance(self.lhs, KeyTransform):
181            lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
182                compiler, connection
183            )
184            lhs_json_path = compile_json_path(lhs_key_transforms)
185        else:
186            lhs, lhs_params = self.process_lhs(compiler, connection)
187            lhs_json_path = "$"
188        sql = template % lhs
189        # Process JSON path from the right-hand side.
190        rhs = self.rhs
191        rhs_params = []
192        if not isinstance(rhs, list | tuple):
193            rhs = [rhs]
194        for key in rhs:
195            if isinstance(key, KeyTransform):
196                *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
197            else:
198                rhs_key_transforms = [key]
199            *rhs_key_transforms, final_key = rhs_key_transforms
200            rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
201            rhs_json_path += self.compile_json_path_final_key(final_key)
202            rhs_params.append(lhs_json_path + rhs_json_path)
203        # Add condition for each key.
204        if self.logical_operator:
205            sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
206        return sql, tuple(lhs_params) + tuple(rhs_params)
207
208    def as_mysql(self, compiler, connection):
209        return self.as_sql(
210            compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
211        )
212
213    def as_postgresql(self, compiler, connection):
214        if isinstance(self.rhs, KeyTransform):
215            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
216            for key in rhs_key_transforms[:-1]:
217                self.lhs = KeyTransform(key, self.lhs)
218            self.rhs = rhs_key_transforms[-1]
219        return super().as_postgresql(compiler, connection)
220
221    def as_sqlite(self, compiler, connection):
222        return self.as_sql(
223            compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
224        )
225
226
227class HasKey(HasKeyLookup):
228    lookup_name = "has_key"
229    postgres_operator = "?"
230    prepare_rhs = False
231
232
233class HasKeys(HasKeyLookup):
234    lookup_name = "has_keys"
235    postgres_operator = "?&"
236    logical_operator = " AND "
237
238    def get_prep_lookup(self):
239        return [str(item) for item in self.rhs]
240
241
242class HasAnyKeys(HasKeys):
243    lookup_name = "has_any_keys"
244    postgres_operator = "?|"
245    logical_operator = " OR "
246
247
248class HasKeyOrArrayIndex(HasKey):
249    def compile_json_path_final_key(self, key_transform):
250        return compile_json_path([key_transform], include_root=False)
251
252
253class CaseInsensitiveMixin:
254    """
255    Mixin to allow case-insensitive comparison of JSON values on MySQL.
256    MySQL handles strings used in JSON context using the utf8mb4_bin collation.
257    Because utf8mb4_bin is a binary collation, comparison of JSON values is
258    case-sensitive.
259    """
260
261    def process_lhs(self, compiler, connection):
262        lhs, lhs_params = super().process_lhs(compiler, connection)
263        if connection.vendor == "mysql":
264            return f"LOWER({lhs})", lhs_params
265        return lhs, lhs_params
266
267    def process_rhs(self, compiler, connection):
268        rhs, rhs_params = super().process_rhs(compiler, connection)
269        if connection.vendor == "mysql":
270            return f"LOWER({rhs})", rhs_params
271        return rhs, rhs_params
272
273
274class JSONExact(lookups.Exact):
275    can_use_none_as_rhs = True
276
277    def process_rhs(self, compiler, connection):
278        rhs, rhs_params = super().process_rhs(compiler, connection)
279        # Treat None lookup values as null.
280        if rhs == "%s" and rhs_params == [None]:
281            rhs_params = ["null"]
282        if connection.vendor == "mysql":
283            func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
284            rhs %= tuple(func)
285        return rhs, rhs_params
286
287
288class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
289    pass
290
291
292JSONField.register_lookup(DataContains)
293JSONField.register_lookup(ContainedBy)
294JSONField.register_lookup(HasKey)
295JSONField.register_lookup(HasKeys)
296JSONField.register_lookup(HasAnyKeys)
297JSONField.register_lookup(JSONExact)
298JSONField.register_lookup(JSONIContains)
299
300
301class KeyTransform(Transform):
302    postgres_operator = "->"
303    postgres_nested_operator = "#>"
304
305    def __init__(self, key_name, *args, **kwargs):
306        super().__init__(*args, **kwargs)
307        self.key_name = str(key_name)
308
309    def preprocess_lhs(self, compiler, connection):
310        key_transforms = [self.key_name]
311        previous = self.lhs
312        while isinstance(previous, KeyTransform):
313            key_transforms.insert(0, previous.key_name)
314            previous = previous.lhs
315        lhs, params = compiler.compile(previous)
316        return lhs, params, key_transforms
317
318    def as_mysql(self, compiler, connection):
319        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
320        json_path = compile_json_path(key_transforms)
321        return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
322
323    def as_postgresql(self, compiler, connection):
324        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
325        if len(key_transforms) > 1:
326            sql = f"({lhs} {self.postgres_nested_operator} %s)"
327            return sql, tuple(params) + (key_transforms,)
328        try:
329            lookup = int(self.key_name)
330        except ValueError:
331            lookup = self.key_name
332        return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
333
334    def as_sqlite(self, compiler, connection):
335        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
336        json_path = compile_json_path(key_transforms)
337        datatype_values = ",".join(
338            [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
339        )
340        return (
341            f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
342            f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
343        ), (tuple(params) + (json_path,)) * 3
344
345
346class KeyTextTransform(KeyTransform):
347    postgres_operator = "->>"
348    postgres_nested_operator = "#>>"
349    output_field = TextField()
350
351    def as_mysql(self, compiler, connection):
352        if connection.mysql_is_mariadb:
353            # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
354            sql, params = super().as_mysql(compiler, connection)
355            return f"JSON_UNQUOTE({sql})", params
356        else:
357            lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
358            json_path = compile_json_path(key_transforms)
359            return f"({lhs} ->> %s)", tuple(params) + (json_path,)
360
361    @classmethod
362    def from_lookup(cls, lookup):
363        transform, *keys = lookup.split(LOOKUP_SEP)
364        if not keys:
365            raise ValueError("Lookup must contain key or index transforms.")
366        for key in keys:
367            transform = cls(key, transform)
368        return transform
369
370
371KT = KeyTextTransform.from_lookup
372
373
374class KeyTransformTextLookupMixin:
375    """
376    Mixin for combining with a lookup expecting a text lhs from a JSONField
377    key lookup. On PostgreSQL, make use of the ->> operator instead of casting
378    key values to text and performing the lookup on the resulting
379    representation.
380    """
381
382    def __init__(self, key_transform, *args, **kwargs):
383        if not isinstance(key_transform, KeyTransform):
384            raise TypeError(
385                "Transform should be an instance of KeyTransform in order to "
386                "use this lookup."
387            )
388        key_text_transform = KeyTextTransform(
389            key_transform.key_name,
390            *key_transform.source_expressions,
391            **key_transform.extra,
392        )
393        super().__init__(key_text_transform, *args, **kwargs)
394
395
396class KeyTransformIsNull(lookups.IsNull):
397    # key__isnull=False is the same as has_key='key'
398    def as_sqlite(self, compiler, connection):
399        template = "JSON_TYPE(%s, %%s) IS NULL"
400        if not self.rhs:
401            template = "JSON_TYPE(%s, %%s) IS NOT NULL"
402        return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
403            compiler,
404            connection,
405            template=template,
406        )
407
408
409class KeyTransformIn(lookups.In):
410    def resolve_expression_parameter(self, compiler, connection, sql, param):
411        sql, params = super().resolve_expression_parameter(
412            compiler,
413            connection,
414            sql,
415            param,
416        )
417        if (
418            not hasattr(param, "as_sql")
419            and not connection.features.has_native_json_field
420        ):
421            if connection.vendor == "mysql" or (
422                connection.vendor == "sqlite"
423                and params[0] not in connection.ops.jsonfield_datatype_values
424            ):
425                sql = "JSON_EXTRACT(%s, '$')"
426        if connection.vendor == "mysql" and connection.mysql_is_mariadb:
427            sql = f"JSON_UNQUOTE({sql})"
428        return sql, params
429
430
431class KeyTransformExact(JSONExact):
432    def process_rhs(self, compiler, connection):
433        if isinstance(self.rhs, KeyTransform):
434            return super(lookups.Exact, self).process_rhs(compiler, connection)
435        rhs, rhs_params = super().process_rhs(compiler, connection)
436        if connection.vendor == "sqlite":
437            func = []
438            for value in rhs_params:
439                if value in connection.ops.jsonfield_datatype_values:
440                    func.append("%s")
441                else:
442                    func.append("JSON_EXTRACT(%s, '$')")
443            rhs %= tuple(func)
444        return rhs, rhs_params
445
446
447class KeyTransformIExact(
448    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
449):
450    pass
451
452
453class KeyTransformIContains(
454    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
455):
456    pass
457
458
459class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
460    pass
461
462
463class KeyTransformIStartsWith(
464    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
465):
466    pass
467
468
469class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
470    pass
471
472
473class KeyTransformIEndsWith(
474    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
475):
476    pass
477
478
479class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
480    pass
481
482
483class KeyTransformIRegex(
484    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
485):
486    pass
487
488
489class KeyTransformNumericLookupMixin:
490    def process_rhs(self, compiler, connection):
491        rhs, rhs_params = super().process_rhs(compiler, connection)
492        if not connection.features.has_native_json_field:
493            rhs_params = [json.loads(value) for value in rhs_params]
494        return rhs, rhs_params
495
496
497class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
498    pass
499
500
501class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
502    pass
503
504
505class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
506    pass
507
508
509class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
510    pass
511
512
513KeyTransform.register_lookup(KeyTransformIn)
514KeyTransform.register_lookup(KeyTransformExact)
515KeyTransform.register_lookup(KeyTransformIExact)
516KeyTransform.register_lookup(KeyTransformIsNull)
517KeyTransform.register_lookup(KeyTransformIContains)
518KeyTransform.register_lookup(KeyTransformStartsWith)
519KeyTransform.register_lookup(KeyTransformIStartsWith)
520KeyTransform.register_lookup(KeyTransformEndsWith)
521KeyTransform.register_lookup(KeyTransformIEndsWith)
522KeyTransform.register_lookup(KeyTransformRegex)
523KeyTransform.register_lookup(KeyTransformIRegex)
524
525KeyTransform.register_lookup(KeyTransformLt)
526KeyTransform.register_lookup(KeyTransformLte)
527KeyTransform.register_lookup(KeyTransformGt)
528KeyTransform.register_lookup(KeyTransformGte)
529
530
531class KeyTransformFactory:
532    def __init__(self, key_name):
533        self.key_name = key_name
534
535    def __call__(self, *args, **kwargs):
536        return KeyTransform(self.key_name, *args, **kwargs)