1import json
2
3from plain import exceptions, preflight
4from plain.models import expressions, lookups
5from plain.models.constants import LOOKUP_SEP
6from plain.models.db import NotSupportedError, db_connection
7from plain.models.fields import TextField
8from plain.models.lookups import (
9 FieldGetDbPrepValueMixin,
10 PostgresOperatorLookup,
11 Transform,
12)
13
14from . import Field
15from .mixins import CheckFieldDefaultMixin
16
17__all__ = ["JSONField"]
18
19
20class JSONField(CheckFieldDefaultMixin, Field):
21 empty_strings_allowed = False
22 description = "A JSON object"
23 default_error_messages = {
24 "invalid": "Value must be valid JSON.",
25 }
26 _default_hint = ("dict", "{}")
27
28 def __init__(
29 self,
30 *,
31 encoder=None,
32 decoder=None,
33 **kwargs,
34 ):
35 if encoder and not callable(encoder):
36 raise ValueError("The encoder parameter must be a callable object.")
37 if decoder and not callable(decoder):
38 raise ValueError("The decoder parameter must be a callable object.")
39 self.encoder = encoder
40 self.decoder = decoder
41 super().__init__(**kwargs)
42
43 def check(self, **kwargs):
44 errors = super().check(**kwargs)
45 database = kwargs.get("database", False)
46 errors.extend(self._check_supported(database))
47 return errors
48
49 def _check_supported(self, database):
50 errors = []
51 if database:
52 if (
53 self.model._meta.required_db_vendor
54 and self.model._meta.required_db_vendor != db_connection.vendor
55 ):
56 return errors
57 if not (
58 "supports_json_field" in self.model._meta.required_db_features
59 or db_connection.features.supports_json_field
60 ):
61 errors.append(
62 preflight.Error(
63 f"{db_connection.display_name} does not support JSONFields.",
64 obj=self.model,
65 id="fields.E180",
66 )
67 )
68 return errors
69
70 def deconstruct(self):
71 name, path, args, kwargs = super().deconstruct()
72 if self.encoder is not None:
73 kwargs["encoder"] = self.encoder
74 if self.decoder is not None:
75 kwargs["decoder"] = self.decoder
76 return name, path, args, kwargs
77
78 def from_db_value(self, value, expression, connection):
79 if value is None:
80 return value
81 # Some backends (SQLite at least) extract non-string values in their
82 # SQL datatypes.
83 if isinstance(expression, KeyTransform) and not isinstance(value, str):
84 return value
85 try:
86 return json.loads(value, cls=self.decoder)
87 except json.JSONDecodeError:
88 return value
89
90 def get_internal_type(self):
91 return "JSONField"
92
93 def get_db_prep_value(self, value, connection, prepared=False):
94 if isinstance(value, expressions.Value) and isinstance(
95 value.output_field, JSONField
96 ):
97 value = value.value
98 elif hasattr(value, "as_sql"):
99 return value
100 return connection.ops.adapt_json_value(value, self.encoder)
101
102 def get_db_prep_save(self, value, connection):
103 if value is None:
104 return value
105 return self.get_db_prep_value(value, connection)
106
107 def get_transform(self, name):
108 transform = super().get_transform(name)
109 if transform:
110 return transform
111 return KeyTransformFactory(name)
112
113 def validate(self, value, model_instance):
114 super().validate(value, model_instance)
115 try:
116 json.dumps(value, cls=self.encoder)
117 except TypeError:
118 raise exceptions.ValidationError(
119 self.error_messages["invalid"],
120 code="invalid",
121 params={"value": value},
122 )
123
124 def value_to_string(self, obj):
125 return self.value_from_object(obj)
126
127
128def compile_json_path(key_transforms, include_root=True):
129 path = ["$"] if include_root else []
130 for key_transform in key_transforms:
131 try:
132 num = int(key_transform)
133 except ValueError: # non-integer
134 path.append(".")
135 path.append(json.dumps(key_transform))
136 else:
137 path.append(f"[{num}]")
138 return "".join(path)
139
140
141class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
142 lookup_name = "contains"
143 postgres_operator = "@>"
144
145 def as_sql(self, compiler, connection):
146 if not connection.features.supports_json_field_contains:
147 raise NotSupportedError(
148 "contains lookup is not supported on this database backend."
149 )
150 lhs, lhs_params = self.process_lhs(compiler, connection)
151 rhs, rhs_params = self.process_rhs(compiler, connection)
152 params = tuple(lhs_params) + tuple(rhs_params)
153 return f"JSON_CONTAINS({lhs}, {rhs})", params
154
155
156class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
157 lookup_name = "contained_by"
158 postgres_operator = "<@"
159
160 def as_sql(self, compiler, connection):
161 if not connection.features.supports_json_field_contains:
162 raise NotSupportedError(
163 "contained_by lookup is not supported on this database backend."
164 )
165 lhs, lhs_params = self.process_lhs(compiler, connection)
166 rhs, rhs_params = self.process_rhs(compiler, connection)
167 params = tuple(rhs_params) + tuple(lhs_params)
168 return f"JSON_CONTAINS({rhs}, {lhs})", params
169
170
171class HasKeyLookup(PostgresOperatorLookup):
172 logical_operator = None
173
174 def compile_json_path_final_key(self, key_transform):
175 # Compile the final key without interpreting ints as array elements.
176 return f".{json.dumps(key_transform)}"
177
178 def as_sql(self, compiler, connection, template=None):
179 # Process JSON path from the left-hand side.
180 if isinstance(self.lhs, KeyTransform):
181 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
182 compiler, connection
183 )
184 lhs_json_path = compile_json_path(lhs_key_transforms)
185 else:
186 lhs, lhs_params = self.process_lhs(compiler, connection)
187 lhs_json_path = "$"
188 sql = template % lhs
189 # Process JSON path from the right-hand side.
190 rhs = self.rhs
191 rhs_params = []
192 if not isinstance(rhs, list | tuple):
193 rhs = [rhs]
194 for key in rhs:
195 if isinstance(key, KeyTransform):
196 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
197 else:
198 rhs_key_transforms = [key]
199 *rhs_key_transforms, final_key = rhs_key_transforms
200 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
201 rhs_json_path += self.compile_json_path_final_key(final_key)
202 rhs_params.append(lhs_json_path + rhs_json_path)
203 # Add condition for each key.
204 if self.logical_operator:
205 sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
206 return sql, tuple(lhs_params) + tuple(rhs_params)
207
208 def as_mysql(self, compiler, connection):
209 return self.as_sql(
210 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
211 )
212
213 def as_postgresql(self, compiler, connection):
214 if isinstance(self.rhs, KeyTransform):
215 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
216 for key in rhs_key_transforms[:-1]:
217 self.lhs = KeyTransform(key, self.lhs)
218 self.rhs = rhs_key_transforms[-1]
219 return super().as_postgresql(compiler, connection)
220
221 def as_sqlite(self, compiler, connection):
222 return self.as_sql(
223 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
224 )
225
226
227class HasKey(HasKeyLookup):
228 lookup_name = "has_key"
229 postgres_operator = "?"
230 prepare_rhs = False
231
232
233class HasKeys(HasKeyLookup):
234 lookup_name = "has_keys"
235 postgres_operator = "?&"
236 logical_operator = " AND "
237
238 def get_prep_lookup(self):
239 return [str(item) for item in self.rhs]
240
241
242class HasAnyKeys(HasKeys):
243 lookup_name = "has_any_keys"
244 postgres_operator = "?|"
245 logical_operator = " OR "
246
247
248class HasKeyOrArrayIndex(HasKey):
249 def compile_json_path_final_key(self, key_transform):
250 return compile_json_path([key_transform], include_root=False)
251
252
253class CaseInsensitiveMixin:
254 """
255 Mixin to allow case-insensitive comparison of JSON values on MySQL.
256 MySQL handles strings used in JSON context using the utf8mb4_bin collation.
257 Because utf8mb4_bin is a binary collation, comparison of JSON values is
258 case-sensitive.
259 """
260
261 def process_lhs(self, compiler, connection):
262 lhs, lhs_params = super().process_lhs(compiler, connection)
263 if connection.vendor == "mysql":
264 return f"LOWER({lhs})", lhs_params
265 return lhs, lhs_params
266
267 def process_rhs(self, compiler, connection):
268 rhs, rhs_params = super().process_rhs(compiler, connection)
269 if connection.vendor == "mysql":
270 return f"LOWER({rhs})", rhs_params
271 return rhs, rhs_params
272
273
274class JSONExact(lookups.Exact):
275 can_use_none_as_rhs = True
276
277 def process_rhs(self, compiler, connection):
278 rhs, rhs_params = super().process_rhs(compiler, connection)
279 # Treat None lookup values as null.
280 if rhs == "%s" and rhs_params == [None]:
281 rhs_params = ["null"]
282 if connection.vendor == "mysql":
283 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
284 rhs %= tuple(func)
285 return rhs, rhs_params
286
287
288class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
289 pass
290
291
292JSONField.register_lookup(DataContains)
293JSONField.register_lookup(ContainedBy)
294JSONField.register_lookup(HasKey)
295JSONField.register_lookup(HasKeys)
296JSONField.register_lookup(HasAnyKeys)
297JSONField.register_lookup(JSONExact)
298JSONField.register_lookup(JSONIContains)
299
300
301class KeyTransform(Transform):
302 postgres_operator = "->"
303 postgres_nested_operator = "#>"
304
305 def __init__(self, key_name, *args, **kwargs):
306 super().__init__(*args, **kwargs)
307 self.key_name = str(key_name)
308
309 def preprocess_lhs(self, compiler, connection):
310 key_transforms = [self.key_name]
311 previous = self.lhs
312 while isinstance(previous, KeyTransform):
313 key_transforms.insert(0, previous.key_name)
314 previous = previous.lhs
315 lhs, params = compiler.compile(previous)
316 return lhs, params, key_transforms
317
318 def as_mysql(self, compiler, connection):
319 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
320 json_path = compile_json_path(key_transforms)
321 return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
322
323 def as_postgresql(self, compiler, connection):
324 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
325 if len(key_transforms) > 1:
326 sql = f"({lhs} {self.postgres_nested_operator} %s)"
327 return sql, tuple(params) + (key_transforms,)
328 try:
329 lookup = int(self.key_name)
330 except ValueError:
331 lookup = self.key_name
332 return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
333
334 def as_sqlite(self, compiler, connection):
335 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
336 json_path = compile_json_path(key_transforms)
337 datatype_values = ",".join(
338 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
339 )
340 return (
341 f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
342 f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
343 ), (tuple(params) + (json_path,)) * 3
344
345
346class KeyTextTransform(KeyTransform):
347 postgres_operator = "->>"
348 postgres_nested_operator = "#>>"
349 output_field = TextField()
350
351 def as_mysql(self, compiler, connection):
352 if connection.mysql_is_mariadb:
353 # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
354 sql, params = super().as_mysql(compiler, connection)
355 return f"JSON_UNQUOTE({sql})", params
356 else:
357 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
358 json_path = compile_json_path(key_transforms)
359 return f"({lhs} ->> %s)", tuple(params) + (json_path,)
360
361 @classmethod
362 def from_lookup(cls, lookup):
363 transform, *keys = lookup.split(LOOKUP_SEP)
364 if not keys:
365 raise ValueError("Lookup must contain key or index transforms.")
366 for key in keys:
367 transform = cls(key, transform)
368 return transform
369
370
371KT = KeyTextTransform.from_lookup
372
373
374class KeyTransformTextLookupMixin:
375 """
376 Mixin for combining with a lookup expecting a text lhs from a JSONField
377 key lookup. On PostgreSQL, make use of the ->> operator instead of casting
378 key values to text and performing the lookup on the resulting
379 representation.
380 """
381
382 def __init__(self, key_transform, *args, **kwargs):
383 if not isinstance(key_transform, KeyTransform):
384 raise TypeError(
385 "Transform should be an instance of KeyTransform in order to "
386 "use this lookup."
387 )
388 key_text_transform = KeyTextTransform(
389 key_transform.key_name,
390 *key_transform.source_expressions,
391 **key_transform.extra,
392 )
393 super().__init__(key_text_transform, *args, **kwargs)
394
395
396class KeyTransformIsNull(lookups.IsNull):
397 # key__isnull=False is the same as has_key='key'
398 def as_sqlite(self, compiler, connection):
399 template = "JSON_TYPE(%s, %%s) IS NULL"
400 if not self.rhs:
401 template = "JSON_TYPE(%s, %%s) IS NOT NULL"
402 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
403 compiler,
404 connection,
405 template=template,
406 )
407
408
409class KeyTransformIn(lookups.In):
410 def resolve_expression_parameter(self, compiler, connection, sql, param):
411 sql, params = super().resolve_expression_parameter(
412 compiler,
413 connection,
414 sql,
415 param,
416 )
417 if (
418 not hasattr(param, "as_sql")
419 and not connection.features.has_native_json_field
420 ):
421 if connection.vendor == "mysql" or (
422 connection.vendor == "sqlite"
423 and params[0] not in connection.ops.jsonfield_datatype_values
424 ):
425 sql = "JSON_EXTRACT(%s, '$')"
426 if connection.vendor == "mysql" and connection.mysql_is_mariadb:
427 sql = f"JSON_UNQUOTE({sql})"
428 return sql, params
429
430
431class KeyTransformExact(JSONExact):
432 def process_rhs(self, compiler, connection):
433 if isinstance(self.rhs, KeyTransform):
434 return super(lookups.Exact, self).process_rhs(compiler, connection)
435 rhs, rhs_params = super().process_rhs(compiler, connection)
436 if connection.vendor == "sqlite":
437 func = []
438 for value in rhs_params:
439 if value in connection.ops.jsonfield_datatype_values:
440 func.append("%s")
441 else:
442 func.append("JSON_EXTRACT(%s, '$')")
443 rhs %= tuple(func)
444 return rhs, rhs_params
445
446
447class KeyTransformIExact(
448 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
449):
450 pass
451
452
453class KeyTransformIContains(
454 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
455):
456 pass
457
458
459class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
460 pass
461
462
463class KeyTransformIStartsWith(
464 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
465):
466 pass
467
468
469class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
470 pass
471
472
473class KeyTransformIEndsWith(
474 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
475):
476 pass
477
478
479class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
480 pass
481
482
483class KeyTransformIRegex(
484 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
485):
486 pass
487
488
489class KeyTransformNumericLookupMixin:
490 def process_rhs(self, compiler, connection):
491 rhs, rhs_params = super().process_rhs(compiler, connection)
492 if not connection.features.has_native_json_field:
493 rhs_params = [json.loads(value) for value in rhs_params]
494 return rhs, rhs_params
495
496
497class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
498 pass
499
500
501class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
502 pass
503
504
505class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
506 pass
507
508
509class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
510 pass
511
512
513KeyTransform.register_lookup(KeyTransformIn)
514KeyTransform.register_lookup(KeyTransformExact)
515KeyTransform.register_lookup(KeyTransformIExact)
516KeyTransform.register_lookup(KeyTransformIsNull)
517KeyTransform.register_lookup(KeyTransformIContains)
518KeyTransform.register_lookup(KeyTransformStartsWith)
519KeyTransform.register_lookup(KeyTransformIStartsWith)
520KeyTransform.register_lookup(KeyTransformEndsWith)
521KeyTransform.register_lookup(KeyTransformIEndsWith)
522KeyTransform.register_lookup(KeyTransformRegex)
523KeyTransform.register_lookup(KeyTransformIRegex)
524
525KeyTransform.register_lookup(KeyTransformLt)
526KeyTransform.register_lookup(KeyTransformLte)
527KeyTransform.register_lookup(KeyTransformGt)
528KeyTransform.register_lookup(KeyTransformGte)
529
530
531class KeyTransformFactory:
532 def __init__(self, key_name):
533 self.key_name = key_name
534
535 def __call__(self, *args, **kwargs):
536 return KeyTransform(self.key_name, *args, **kwargs)