1from __future__ import annotations
2
3import json
4from collections.abc import Callable, Sequence
5from typing import TYPE_CHECKING, Any
6
7from plain import exceptions, preflight
8from plain.models import expressions, lookups
9from plain.models.backends.guards import is_mysql_connection, is_sqlite_connection
10from plain.models.constants import LOOKUP_SEP
11from plain.models.db import NotSupportedError, db_connection
12from plain.models.fields import TextField
13from plain.models.lookups import (
14 FieldGetDbPrepValueMixin,
15 Lookup,
16 PostgresOperatorLookup,
17 Transform,
18)
19
20from . import Field
21
22if TYPE_CHECKING:
23 from plain.models.backends.base.base import BaseDatabaseWrapper
24 from plain.models.sql.compiler import SQLCompiler
25 from plain.preflight.results import PreflightResult
26
27__all__ = ["JSONField"]
28
29
30class JSONField(Field):
31 empty_strings_allowed = False
32 description = "A JSON object"
33 default_error_messages = {
34 "invalid": "Value must be valid JSON.",
35 }
36 _default_fix = ("dict", "{}")
37
38 def __init__(
39 self,
40 *,
41 encoder: type[json.JSONEncoder] | None = None,
42 decoder: type[json.JSONDecoder] | None = None,
43 **kwargs: Any,
44 ):
45 if encoder and not callable(encoder):
46 raise ValueError("The encoder parameter must be a callable object.")
47 if decoder and not callable(decoder):
48 raise ValueError("The decoder parameter must be a callable object.")
49 self.encoder = encoder
50 self.decoder = decoder
51 super().__init__(**kwargs)
52
53 def _check_default(self) -> list[PreflightResult]:
54 if (
55 self.has_default()
56 and self.default is not None
57 and not callable(self.default)
58 ):
59 return [
60 preflight.PreflightResult(
61 fix=(
62 f"{self.__class__.__name__} default should be a callable instead of an instance "
63 "so that it's not shared between all field instances. "
64 "Use a callable instead, e.g., use `{}` instead of "
65 "`{}`.".format(*self._default_fix)
66 ),
67 obj=self,
68 id="fields.invalid_choice_mixin_default",
69 warning=True,
70 )
71 ]
72 else:
73 return []
74
75 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
76 errors = super().preflight(**kwargs)
77 errors.extend(self._check_default())
78 errors.extend(self._check_supported())
79 return errors
80
81 def _check_supported(self) -> list[PreflightResult]:
82 errors = []
83
84 if (
85 self.model.model_options.required_db_vendor
86 and self.model.model_options.required_db_vendor != db_connection.vendor
87 ):
88 return errors
89
90 if not (
91 "supports_json_field" in self.model.model_options.required_db_features
92 or db_connection.features.supports_json_field
93 ):
94 errors.append(
95 preflight.PreflightResult(
96 fix=f"{db_connection.display_name} does not support JSONFields. Consider using a TextField with JSON serialization or upgrade to a database that supports JSON fields.",
97 obj=self.model,
98 id="fields.json_field_unsupported",
99 )
100 )
101 return errors
102
103 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
104 name, path, args, kwargs = super().deconstruct()
105 if self.encoder is not None:
106 kwargs["encoder"] = self.encoder
107 if self.decoder is not None:
108 kwargs["decoder"] = self.decoder
109 return name, path, args, kwargs
110
111 def from_db_value(
112 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
113 ) -> Any:
114 if value is None:
115 return value
116 # Some backends (SQLite at least) extract non-string values in their
117 # SQL datatypes.
118 if isinstance(expression, KeyTransform) and not isinstance(value, str):
119 return value
120 try:
121 return json.loads(value, cls=self.decoder)
122 except json.JSONDecodeError:
123 return value
124
125 def get_internal_type(self) -> str:
126 return "JSONField"
127
128 def get_db_prep_value(
129 self, value: Any, connection: BaseDatabaseWrapper, prepared: bool = False
130 ) -> Any:
131 if isinstance(value, expressions.Value) and isinstance(
132 value.output_field, JSONField
133 ):
134 value = value.value
135 elif hasattr(value, "as_sql"):
136 return value
137 return connection.ops.adapt_json_value(value, self.encoder)
138
139 def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any:
140 if value is None:
141 return value
142 return self.get_db_prep_value(value, connection)
143
144 def get_transform(
145 self, lookup_name: str
146 ) -> type[Transform] | Callable[..., Any] | None:
147 # Always returns a transform (never None in practice)
148 transform = super().get_transform(lookup_name)
149 if transform:
150 return transform
151 return KeyTransformFactory(lookup_name)
152
153 def validate(self, value: Any, model_instance: Any) -> None:
154 super().validate(value, model_instance)
155 try:
156 json.dumps(value, cls=self.encoder)
157 except TypeError:
158 raise exceptions.ValidationError(
159 self.error_messages["invalid"],
160 code="invalid",
161 params={"value": value},
162 )
163
164 def value_to_string(self, obj: Any) -> Any:
165 return self.value_from_object(obj)
166
167
168def compile_json_path(key_transforms: list[Any], include_root: bool = True) -> str:
169 path = ["$"] if include_root else []
170 for key_transform in key_transforms:
171 try:
172 num = int(key_transform)
173 except ValueError: # non-integer
174 path.append(".")
175 path.append(json.dumps(key_transform))
176 else:
177 path.append(f"[{num}]")
178 return "".join(path)
179
180
181class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
182 lookup_name = "contains"
183 postgres_operator = "@>"
184
185 def as_sql(
186 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
187 ) -> tuple[str, tuple[Any, ...]]:
188 if not connection.features.supports_json_field_contains:
189 raise NotSupportedError(
190 "contains lookup is not supported on this database backend."
191 )
192 lhs, lhs_params = self.process_lhs(compiler, connection)
193 rhs, rhs_params = self.process_rhs(compiler, connection)
194 params = tuple(lhs_params) + tuple(rhs_params)
195 return f"JSON_CONTAINS({lhs}, {rhs})", params
196
197
198class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
199 lookup_name = "contained_by"
200 postgres_operator = "<@"
201
202 def as_sql(
203 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
204 ) -> tuple[str, tuple[Any, ...]]:
205 if not connection.features.supports_json_field_contains:
206 raise NotSupportedError(
207 "contained_by lookup is not supported on this database backend."
208 )
209 lhs, lhs_params = self.process_lhs(compiler, connection)
210 rhs, rhs_params = self.process_rhs(compiler, connection)
211 params = tuple(rhs_params) + tuple(lhs_params)
212 return f"JSON_CONTAINS({rhs}, {lhs})", params
213
214
215class HasKeyLookup(PostgresOperatorLookup):
216 logical_operator: str | None = None
217
218 def compile_json_path_final_key(self, key_transform: Any) -> str:
219 # Compile the final key without interpreting ints as array elements.
220 return f".{json.dumps(key_transform)}"
221
222 def as_sql(
223 self,
224 compiler: SQLCompiler,
225 connection: BaseDatabaseWrapper,
226 template: str | None = None,
227 ) -> tuple[str, tuple[Any, ...]]:
228 # Process JSON path from the left-hand side.
229 if isinstance(self.lhs, KeyTransform):
230 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
231 compiler, connection
232 )
233 lhs_json_path = compile_json_path(lhs_key_transforms)
234 else:
235 lhs, lhs_params = self.process_lhs(compiler, connection)
236 lhs_json_path = "$"
237 effective_template = template or "%s"
238 sql = effective_template % lhs
239 # Process JSON path from the right-hand side.
240 rhs = self.rhs
241 rhs_params = []
242 if not isinstance(rhs, list | tuple):
243 rhs = [rhs]
244 for key in rhs:
245 if isinstance(key, KeyTransform):
246 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
247 else:
248 rhs_key_transforms = [key]
249 *rhs_key_transforms, final_key = rhs_key_transforms
250 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
251 rhs_json_path += self.compile_json_path_final_key(final_key)
252 rhs_params.append(lhs_json_path + rhs_json_path)
253 # Add condition for each key.
254 if self.logical_operator:
255 sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
256 return sql, tuple(lhs_params) + tuple(rhs_params)
257
258 def as_mysql(
259 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
260 ) -> tuple[str, tuple[Any, ...]]:
261 return self.as_sql(
262 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
263 )
264
265 def as_postgresql(
266 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
267 ) -> tuple[str, tuple[Any, ...]]:
268 if isinstance(self.rhs, KeyTransform):
269 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
270 for key in rhs_key_transforms[:-1]:
271 self.lhs = KeyTransform(key, self.lhs)
272 self.rhs = rhs_key_transforms[-1]
273 return super().as_postgresql(compiler, connection)
274
275 def as_sqlite(
276 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
277 ) -> tuple[str, tuple[Any, ...]]:
278 return self.as_sql(
279 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
280 )
281
282
283class HasKey(HasKeyLookup):
284 lookup_name = "has_key"
285 postgres_operator = "?"
286 prepare_rhs = False
287
288
289class HasKeys(HasKeyLookup):
290 lookup_name = "has_keys"
291 postgres_operator = "?&"
292 logical_operator = " AND "
293
294 def get_prep_lookup(self) -> list[str]:
295 return [str(item) for item in self.rhs]
296
297
298class HasAnyKeys(HasKeys):
299 lookup_name = "has_any_keys"
300 postgres_operator = "?|"
301 logical_operator = " OR "
302
303
304class HasKeyOrArrayIndex(HasKey):
305 def compile_json_path_final_key(self, key_transform: Any) -> str:
306 return compile_json_path([key_transform], include_root=False)
307
308
309class CaseInsensitiveMixin(Lookup):
310 """
311 Mixin to allow case-insensitive comparison of JSON values on MySQL.
312 MySQL handles strings used in JSON context using the utf8mb4_bin collation.
313 Because utf8mb4_bin is a binary collation, comparison of JSON values is
314 case-sensitive.
315 """
316
317 def process_lhs(
318 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
319 ) -> tuple[str, list[Any]]:
320 lhs_sql, lhs_params = super().process_lhs(compiler, connection, lhs)
321 if connection.vendor == "mysql":
322 return f"LOWER({lhs_sql})", lhs_params
323 return lhs_sql, lhs_params
324
325 def process_rhs(
326 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
327 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
328 rhs, rhs_params = super().process_rhs(compiler, connection)
329 if isinstance(rhs, str):
330 if connection.vendor == "mysql":
331 return f"LOWER({rhs})", rhs_params
332 return rhs, rhs_params
333 else:
334 return rhs, rhs_params
335
336
337class JSONExact(lookups.Exact):
338 can_use_none_as_rhs = True
339
340 def process_rhs(
341 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
342 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
343 rhs, rhs_params = super().process_rhs(compiler, connection)
344 if isinstance(rhs, str):
345 # Treat None lookup values as null.
346 if rhs == "%s" and rhs_params == [None]:
347 rhs_params = ["null"]
348 if connection.vendor == "mysql":
349 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
350 rhs = rhs % tuple(func)
351 return rhs, rhs_params
352 else:
353 return rhs, rhs_params
354
355
356class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
357 pass
358
359
360JSONField.register_lookup(DataContains)
361JSONField.register_lookup(ContainedBy)
362JSONField.register_lookup(HasKey)
363JSONField.register_lookup(HasKeys)
364JSONField.register_lookup(HasAnyKeys)
365JSONField.register_lookup(JSONExact)
366JSONField.register_lookup(JSONIContains)
367
368
369class KeyTransform(Transform):
370 postgres_operator = "->"
371 postgres_nested_operator = "#>"
372
373 def __init__(self, key_name: str, *args: Any, **kwargs: Any):
374 super().__init__(*args, **kwargs)
375 self.key_name = str(key_name)
376
377 def preprocess_lhs(
378 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
379 ) -> tuple[str, tuple[Any, ...], list[str]]:
380 key_transforms = [self.key_name]
381 previous = self.lhs
382 while isinstance(previous, KeyTransform):
383 key_transforms.insert(0, previous.key_name)
384 previous = previous.lhs
385 lhs, params = compiler.compile(previous)
386 return lhs, params, key_transforms
387
388 def as_mysql(
389 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
390 ) -> tuple[str, tuple[Any, ...]]:
391 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
392 json_path = compile_json_path(key_transforms)
393 return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
394
395 def as_postgresql(
396 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
397 ) -> tuple[str, tuple[Any, ...]]:
398 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
399 if len(key_transforms) > 1:
400 sql = f"({lhs} {self.postgres_nested_operator} %s)"
401 return sql, tuple(params) + (key_transforms,)
402 try:
403 lookup = int(self.key_name)
404 except ValueError:
405 lookup = self.key_name
406 return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
407
408 def as_sqlite(
409 self,
410 compiler: SQLCompiler,
411 connection: BaseDatabaseWrapper,
412 **extra_context: Any,
413 ) -> tuple[str, Sequence[Any]]:
414 # as_sqlite is only called for SQLite connections
415 assert is_sqlite_connection(connection)
416 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
417 json_path = compile_json_path(key_transforms)
418 datatype_values = ",".join(
419 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
420 )
421 return (
422 f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
423 f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
424 ), (tuple(params) + (json_path,)) * 3
425
426
427class KeyTextTransform(KeyTransform):
428 postgres_operator = "->>"
429 postgres_nested_operator = "#>>"
430 output_field = TextField()
431
432 def as_mysql(
433 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
434 ) -> tuple[str, tuple[Any, ...]]:
435 # as_mysql is only called for MySQL connections
436 assert is_mysql_connection(connection)
437 if connection.mysql_is_mariadb:
438 # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
439 sql, params = super().as_mysql(compiler, connection)
440 return f"JSON_UNQUOTE({sql})", params
441 else:
442 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
443 json_path = compile_json_path(key_transforms)
444 return f"({lhs} ->> %s)", tuple(params) + (json_path,)
445
446 @classmethod
447 def from_lookup(cls, lookup: str) -> Any:
448 transform, *keys = lookup.split(LOOKUP_SEP)
449 if not keys:
450 raise ValueError("Lookup must contain key or index transforms.")
451 for key in keys:
452 transform = cls(key, transform)
453 return transform
454
455
456KT = KeyTextTransform.from_lookup
457
458
459class KeyTransformTextLookupMixin(Lookup):
460 """
461 Mixin for combining with a lookup expecting a text lhs from a JSONField
462 key lookup. On PostgreSQL, make use of the ->> operator instead of casting
463 key values to text and performing the lookup on the resulting
464 representation.
465 """
466
467 def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
468 if not isinstance(key_transform, KeyTransform):
469 raise TypeError(
470 "Transform should be an instance of KeyTransform in order to "
471 "use this lookup."
472 )
473 key_text_transform = KeyTextTransform(
474 key_transform.key_name,
475 *key_transform.source_expressions,
476 **key_transform.extra,
477 )
478 super().__init__(key_text_transform, *args, **kwargs)
479
480
481class KeyTransformIsNull(lookups.IsNull):
482 # key__isnull=False is the same as has_key='key'
483 def as_sqlite(
484 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
485 ) -> tuple[str, tuple[Any, ...]]:
486 template = "JSON_TYPE(%s, %%s) IS NULL"
487 if not self.rhs:
488 template = "JSON_TYPE(%s, %%s) IS NOT NULL"
489 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
490 compiler,
491 connection,
492 template=template,
493 )
494
495
496class KeyTransformIn(lookups.In):
497 def resolve_expression_parameter(
498 self,
499 compiler: SQLCompiler,
500 connection: BaseDatabaseWrapper,
501 sql: str,
502 param: Any,
503 ) -> tuple[str, list[Any]]:
504 sql, params = super().resolve_expression_parameter(
505 compiler,
506 connection,
507 sql,
508 param,
509 )
510 if (
511 not hasattr(param, "as_sql")
512 and not connection.features.has_native_json_field
513 ):
514 if connection.vendor == "mysql":
515 sql = "JSON_EXTRACT(%s, '$')"
516 elif is_sqlite_connection(connection):
517 # Type guard narrows connection to SQLiteDatabaseWrapper
518 if params[0] not in connection.ops.jsonfield_datatype_values:
519 sql = "JSON_EXTRACT(%s, '$')"
520 if is_mysql_connection(connection):
521 # Type guard narrows connection to MySQLDatabaseWrapper
522 if connection.mysql_is_mariadb:
523 sql = f"JSON_UNQUOTE({sql})"
524 return sql, list(params)
525
526
527class KeyTransformExact(JSONExact):
528 def process_rhs(
529 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
530 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
531 if isinstance(self.rhs, KeyTransform):
532 return super(lookups.Exact, self).process_rhs(compiler, connection)
533 rhs, rhs_params = super().process_rhs(compiler, connection)
534 if isinstance(rhs, str):
535 if is_sqlite_connection(connection):
536 # Type guard narrows connection to SQLiteDatabaseWrapper
537 func = []
538 for value in rhs_params:
539 if value in connection.ops.jsonfield_datatype_values:
540 func.append("%s")
541 else:
542 func.append("JSON_EXTRACT(%s, '$')")
543 rhs = rhs % tuple(func)
544 return rhs, rhs_params
545 else:
546 return rhs, rhs_params
547
548
549class KeyTransformIExact(
550 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
551):
552 pass
553
554
555class KeyTransformIContains(
556 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
557):
558 pass
559
560
561class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
562 pass
563
564
565class KeyTransformIStartsWith(
566 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
567):
568 pass
569
570
571class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
572 pass
573
574
575class KeyTransformIEndsWith(
576 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
577):
578 pass
579
580
581class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
582 pass
583
584
585class KeyTransformIRegex(
586 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
587):
588 pass
589
590
591class KeyTransformNumericLookupMixin(Lookup):
592 def process_rhs(
593 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
594 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
595 rhs, rhs_params = super().process_rhs(compiler, connection)
596 if isinstance(rhs, str):
597 if not connection.features.has_native_json_field:
598 rhs_params = [json.loads(value) for value in rhs_params]
599 return rhs, rhs_params
600 else:
601 return rhs, rhs_params
602
603
604class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
605 pass
606
607
608class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
609 pass
610
611
612class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
613 pass
614
615
616class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
617 pass
618
619
620KeyTransform.register_lookup(KeyTransformIn)
621KeyTransform.register_lookup(KeyTransformExact)
622KeyTransform.register_lookup(KeyTransformIExact)
623KeyTransform.register_lookup(KeyTransformIsNull)
624KeyTransform.register_lookup(KeyTransformIContains)
625KeyTransform.register_lookup(KeyTransformStartsWith)
626KeyTransform.register_lookup(KeyTransformIStartsWith)
627KeyTransform.register_lookup(KeyTransformEndsWith)
628KeyTransform.register_lookup(KeyTransformIEndsWith)
629KeyTransform.register_lookup(KeyTransformRegex)
630KeyTransform.register_lookup(KeyTransformIRegex)
631
632KeyTransform.register_lookup(KeyTransformLt)
633KeyTransform.register_lookup(KeyTransformLte)
634KeyTransform.register_lookup(KeyTransformGt)
635KeyTransform.register_lookup(KeyTransformGte)
636
637
638class KeyTransformFactory:
639 def __init__(self, key_name: str):
640 self.key_name = key_name
641
642 def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
643 return KeyTransform(self.key_name, *args, **kwargs)